Understanding of Spark Structured Streaming Execution via Source Code [通过源码理解Spark的结构化流执行]
Spark 2.0引入了更高级别的新的流处理API,叫做Structured Streaming[结构化流]。本文通过Spark的源代码来理解具体的Structured Streaming的执行过程。本文使用的是Spark 2.4.0的源代码。Structured Streaming相关的源代码存在于sql
文件夹中。
我们先给出一个典型的使用Structured Streaming API的方式:
val q = query_result
.writeStream
.outputMode("append")
.format("console")
.trigger(Trigger.ProcessingTime(100, TimeUnit.MILLISECONDS))
.start()
首先query_result
的类型是org.apache.spark.sql.DataFrame
,而通过查看位于org.apache.spark.sql
的package.scala
文件,我们发现
//org.apache.spark.sql.DataFrame
type DataFrame = Dataset[Row]
在同样位于org.apache.spark.sql
的Dataset.scala
中,我们找到了writeStream
方法,该方法新建并返回一个DataStreamWriter
// org.apache.spark.sql.Dataset
@InterfaceStability.Evolving
def writeStream: DataStreamWriter[T] = {
if (!isStreaming) {
logicalPlan.failAnalysis(
"'writeStream' can be called only on streaming Dataset/DataFrame")
}
new DataStreamWriter[T](this)
}
查看位于org.apache.spark.sql.streaming
中的DataStreamWriter.scala
可知,其outputMode()
,format()
以及trigger()
都属于setter方法,也就是为类的属性传值。
//org.apache.spark.sql.streaming.DataStreamWriter
final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
...
def outputMode(outputMode: String): DataStreamWriter[T] = {
this.outputMode = InternalOutputModes(outputMode)
this
}
def format(source: String): DataStreamWriter[T] = {
this.source = source
this
}
def trigger(trigger: Trigger): DataStreamWriter[T] = {
this.trigger = trigger
this
}
...
}
而最后的start()
函数调用,则最为关键,它会启动所有关于query_result
的定义好的计算。DataStreamWriter.scala
中的start()
函数相对复杂,但是核心逻辑都是配置并启动df.sparkSession.sessionState.streamingQueryManager.startQuery()
。
//org.apache.spark.sql.streaming.DataStreamWriter
def start(): StreamingQuery = {
if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) {
throw new AnalysisException("Hive data source can only be used with tables, you can not " +
"write files of Hive data source directly.")
}
if (source == "memory") {
...
val query = df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
chkpointLoc,
df,
extraOptions.toMap,
sink,
outputMode,
useTempCheckpointLocation = true,
recoverFromCheckpointLocation = recoverFromChkpoint,
trigger = trigger)
resultDf.createOrReplaceTempView(query.name)
query
} else if (source == "foreach") {
assertNotPartitioned("foreach")
val sink = ForeachWriterProvider[T](foreachWriter, ds.exprEnc)
df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
df,
extraOptions.toMap,
sink,
outputMode,
useTempCheckpointLocation = true,
trigger = trigger)
} else if (source == "foreachBatch") {
assertNotPartitioned("foreachBatch")
if (trigger.isInstanceOf[ContinuousTrigger]) {
throw new AnalysisException("'foreachBatch' is not supported with continuous trigger")
}
val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc)
df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
df,
extraOptions.toMap,
sink,
outputMode,
useTempCheckpointLocation = true,
trigger = trigger)
} else {
...
df.sparkSession.sessionState.streamingQueryManager.startQuery(
options.get("queryName"),
options.get("checkpointLocation"),
df,
options,
sink,
outputMode,
useTempCheckpointLocation = source == "console",
recoverFromCheckpointLocation = true,
trigger = trigger)
}
之后,我们需要查看sessionState
定义。在org.apache.spark.sql.SparkSession.scala
中,我们可以发现,其sessionState定义为
//org.apache.spark.sql.SparkSession
@InterfaceStability.Unstable
@transient
lazy val sessionState: SessionState = {
parentSessionState
.map(_.clone(this))
.getOrElse {
val state = SparkSession.instantiateSessionState(
SparkSession.sessionStateClassName(sparkContext.conf),
self)
initialSessionOptions.foreach {case (k, v) => state.conf.setConfString(k, v)}
state
}
}
而SessionState
则定义在org.apache.spark.sql.internal.SessionState
,其成员变量streamingQueryManager
定义在org.apache.spark.sql.streaming.StreamingQueryManager
,其中方法startQuery()
为其核心启动查询的API。它的核心代码片段为query.streamingQuery.start()
//org.apache.spark.sql.streaming.StreamingQueryManager#startQuery()
val query = createQuery(
userSpecifiedName,
userSpecifiedCheckpointLocation,
df,
extraOptions,
sink,
outputMode,
useTempCheckpointLocation,
recoverFromCheckpointLocation,
trigger,
triggerClock)
...
try {
// When starting a query, it will call `StreamingQueryListener.onQueryStarted` synchronously.
// As it's provided by the user and can run arbitrary codes, we must not hold any lock here.
// Otherwise, it's easy to cause dead-lock, or block too long if the user codes take a long
// time to finish.
query.streamingQuery.start()
} catch {
case e: Throwable =>
activeQueriesLock.synchronized {
activeQueries -= query.id
}
throw e
}
query
}
这里的变量query
由方法createQuery()
生成,其返回类型为org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
[定义在对应的scala文件内]。而StreamingQueryWrapper
中的streamingQuery
就是Spark Structured Streaming的执行抽象类org.apache.spark.sql.execution.streaming.StreamExecution
。StreamExecution
中的start()
方法也很简单,就是启动执行线程QueryExecutionThread()
。
//org.apache.spark.sql.execution.streaming.StreamExecution
/**
* Starts the execution. This returns only after the thread has started and [[QueryStartedEvent]]
* has been posted to all the listeners.
*/
def start(): Unit = {
queryExecutionThread.setDaemon(true)
queryExecutionThread.start()
// Wait until thread started and QueryStart event has been posted
startLatch.await()
}
queryExecutionThread
的定义同样存在于StreamExecution
中,其重载了run()
方法并在其中执行核心方法runStream()
。也就是说一旦queryExecutionThread.start()
被调用,最终就会执行runStream()
方法。
//org.apache.spark.sql.execution.streaming.StreamExecution
val queryExecutionThread: QueryExecutionThread =
new QueryExecutionThread(s"stream execution thread for $prettyIdString") {
override def run(): Unit = {
sparkSession.sparkContext.setCallSite(callSite)
runStream()
}
}
runStream()
同样定义在StreamExecution
中,其关键代码如下:
//org.apache.spark.sql.execution.streaming.StreamExecution#runStream()
try {
// 运行Stream query的准备工作
// send QueryStartedEvent event, countDown latch,streaming configure等操作
runActivatedStream(sparkSessionForStream) // 运行 stream
} catch {
// 异常处理
} finally {
// 运行完Stream query的收尾工作
// stop source,send stream stop event,删除checkpoint等等操作
}
在StreamExecution
中runActivatedStream()
是一个抽象方法,由实现StreamExecution
的子类实现,
//org.apache.spark.sql.execution.streaming.StreamExecution
/**
* Run the activated stream until stopped.
*/
protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit
好了,那么问题就是那些类集成了StreamExecution
又是在哪里调用的呢?
其实在org.apache.spark.sql.streaming.StreamingQueryManager#createQuery()
方法中,由如下代码片段:
//org.apache.spark.sql.streaming.StreamingQueryManager#createQuery()
(sink, trigger) match {
// 使用 ContinuousTrigger 则为 ContinuousExecution
case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) =>
if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) {
UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode)
}
new StreamingQueryWrapper(new ContinuousExecution(
sparkSession,
userSpecifiedName.orNull,
checkpointLocation,
analyzedPlan,
v2Sink,
trigger,
triggerClock,
outputMode,
extraOptions,
deleteCheckpointOnStop))
// 使用 ProcessingTrigger 或者 OneTimeTrigger 则为 MicroBatchExecution
case _ =>
new StreamingQueryWrapper(new MicroBatchExecution(
sparkSession,
userSpecifiedName.orNull,
checkpointLocation,
analyzedPlan,
sink,
trigger,
triggerClock,
outputMode,
extraOptions,
deleteCheckpointOnStop))
}
这里就能看出,在执行createQuery()
方法时,根据不同的trigger
,就已经决定了要使用那种StreamExection
,也就是说ContinuousExecution
和MicroBatchExecution
就是两个继承了StreamExection
的类,并且定义了各自的runActivatedStream()
。
以MicroBatchExecution
为例,其具体定义在org.apache.spark.sql.execution.streaming.MicroBatchExecution.scala
。其中的核心语句为:
//org.apache.spark.sql.execution.streaming.MicroBatchExecution
/**
* Repeatedly attempts to run batches as data arrives.
*/
protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = {
...
triggerExecutor.execute() => {
//提交执行每一个query的操作
}
...
}
其中triggerExecutor
在MicroBatchExecution.scala
的定义如下,
//org.apache.spark.sql.execution.streaming.MicroBatchExecution
private val triggerExecutor = trigger match {
case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock)
case OneTimeTrigger => OneTimeExecutor()
case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger")
}
所以,如果使用了Trigger.ProcessingTime()
,则会调用ProcessingTimeExecutor
,其具体定义在org.apache.spark.sql.execution.streaming.TriggerExecutor
,包括在runActivatedStream()
方法中最终被调用的execute()
方法。ProcessingTime Trigger会使用ProcessingTimeExecutor
来周期性生成batch query
//org.apache.spark.sql.execution.streaming.TriggerExecutor
/**
* A trigger executor that runs a batch every `intervalMs` milliseconds.
*/
case class ProcessingTimeExecutor(processingTime: ProcessingTime,
clock: Clock = new SystemClock())
extends TriggerExecutor with Logging {
private val intervalMs = processingTime.intervalMs
require(intervalMs >= 0)
override def execute(triggerHandler: () => Boolean): Unit = {
while (true) {
val triggerTimeMs = clock.getTimeMillis
val nextTriggerTimeMs = nextBatchTime(triggerTimeMs)
val terminated = !triggerHandler()
if (intervalMs > 0) {
val batchElapsedTimeMs = clock.getTimeMillis - triggerTimeMs
if (batchElapsedTimeMs > intervalMs) {
notifyBatchFallingBehind(batchElapsedTimeMs)
}
if (terminated) {
return
}
clock.waitTillTime(nextTriggerTimeMs)
} else {
if (terminated) {
return
}
}
}
}
/** Called when a batch falls behind */
def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = {
logWarning("Current batch is falling behind. The trigger interval is " +
s"${intervalMs} milliseconds, but spent ${realElapsedTimeMs} milliseconds")
}
def nextBatchTime(now: Long): Long = {
if (intervalMs == 0) now else now / intervalMs * intervalMs + intervalMs
}
}