Spark 2.0引入了更高级别的新的流处理API,叫做Structured Streaming[结构化流]。本文通过Spark的源代码来理解具体的Structured Streaming的执行过程。本文使用的是Spark 2.4.0的源代码。Structured Streaming相关的源代码存在于sql文件夹中。

我们先给出一个典型的使用Structured Streaming API的方式:

val q = query_result
   .writeStream
   .outputMode("append")
   .format("console")
   .trigger(Trigger.ProcessingTime(100, TimeUnit.MILLISECONDS))
   .start()

首先query_result的类型是org.apache.spark.sql.DataFrame,而通过查看位于org.apache.spark.sqlpackage.scala文件,我们发现

//org.apache.spark.sql.DataFrame

type DataFrame = Dataset[Row]

在同样位于org.apache.spark.sqlDataset.scala中,我们找到了writeStream方法,该方法新建并返回一个DataStreamWriter

// org.apache.spark.sql.Dataset

@InterfaceStability.Evolving
def writeStream: DataStreamWriter[T] = {
   if (!isStreaming) {
   logicalPlan.failAnalysis(
      "'writeStream' can be called only on streaming Dataset/DataFrame")
   }
   new DataStreamWriter[T](this)
}

查看位于org.apache.spark.sql.streaming中的DataStreamWriter.scala可知,其outputMode(),format()以及trigger()都属于setter方法,也就是为类的属性传值。

//org.apache.spark.sql.streaming.DataStreamWriter

final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
   ...
   
   def outputMode(outputMode: String): DataStreamWriter[T] = {
      this.outputMode = InternalOutputModes(outputMode)
      this
   }

   def format(source: String): DataStreamWriter[T] = {
      this.source = source
      this
   }

   def trigger(trigger: Trigger): DataStreamWriter[T] = {
      this.trigger = trigger
      this
   }

   ...
}

而最后的start()函数调用,则最为关键,它会启动所有关于query_result的定义好的计算。DataStreamWriter.scala中的start()函数相对复杂,但是核心逻辑都是配置并启动df.sparkSession.sessionState.streamingQueryManager.startQuery()

//org.apache.spark.sql.streaming.DataStreamWriter

def start(): StreamingQuery = {
    if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) {
      throw new AnalysisException("Hive data source can only be used with tables, you can not " +
        "write files of Hive data source directly.")
    }

    if (source == "memory") {
      ...
      val query = df.sparkSession.sessionState.streamingQueryManager.startQuery(
        extraOptions.get("queryName"),
        chkpointLoc,
        df,
        extraOptions.toMap,
        sink,
        outputMode,
        useTempCheckpointLocation = true,
        recoverFromCheckpointLocation = recoverFromChkpoint,
        trigger = trigger)
      resultDf.createOrReplaceTempView(query.name)
      query
    } else if (source == "foreach") {
      assertNotPartitioned("foreach")
      val sink = ForeachWriterProvider[T](foreachWriter, ds.exprEnc)
      df.sparkSession.sessionState.streamingQueryManager.startQuery(
        extraOptions.get("queryName"),
        extraOptions.get("checkpointLocation"),
        df,
        extraOptions.toMap,
        sink,
        outputMode,
        useTempCheckpointLocation = true,
        trigger = trigger)
    } else if (source == "foreachBatch") {
      assertNotPartitioned("foreachBatch")
      if (trigger.isInstanceOf[ContinuousTrigger]) {
        throw new AnalysisException("'foreachBatch' is not supported with continuous trigger")
      }
      val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc)
      df.sparkSession.sessionState.streamingQueryManager.startQuery(
        extraOptions.get("queryName"),
        extraOptions.get("checkpointLocation"),
        df,
        extraOptions.toMap,
        sink,
        outputMode,
        useTempCheckpointLocation = true,
        trigger = trigger)
    } else {
      ...
      df.sparkSession.sessionState.streamingQueryManager.startQuery(
        options.get("queryName"),
        options.get("checkpointLocation"),
        df,
        options,
        sink,
        outputMode,
        useTempCheckpointLocation = source == "console",
        recoverFromCheckpointLocation = true,
        trigger = trigger)
    }

之后,我们需要查看sessionState定义。在org.apache.spark.sql.SparkSession.scala中,我们可以发现,其sessionState定义为

//org.apache.spark.sql.SparkSession

@InterfaceStability.Unstable
@transient
lazy val sessionState: SessionState = {
   parentSessionState
      .map(_.clone(this))
      .getOrElse {
         val state = SparkSession.instantiateSessionState(
            SparkSession.sessionStateClassName(sparkContext.conf), 
            self)
         initialSessionOptions.foreach {case (k, v) => state.conf.setConfString(k, v)}
         state
      }
}

SessionState则定义在org.apache.spark.sql.internal.SessionState,其成员变量streamingQueryManager定义在org.apache.spark.sql.streaming.StreamingQueryManager,其中方法startQuery()为其核心启动查询的API。它的核心代码片段为query.streamingQuery.start()

//org.apache.spark.sql.streaming.StreamingQueryManager#startQuery()

val query = createQuery(
   userSpecifiedName,
   userSpecifiedCheckpointLocation,
   df,
   extraOptions,
   sink,
   outputMode,
   useTempCheckpointLocation,
   recoverFromCheckpointLocation,
   trigger,
   triggerClock)

...

try {
   // When starting a query, it will call `StreamingQueryListener.onQueryStarted` synchronously.
   // As it's provided by the user and can run arbitrary codes, we must not hold any lock here.
   // Otherwise, it's easy to cause dead-lock, or block too long if the user codes take a long
   // time to finish.
      query.streamingQuery.start()
   } catch {
      case e: Throwable =>
         activeQueriesLock.synchronized {
            activeQueries -= query.id
         }
         throw e
   }
   query
}

这里的变量query由方法createQuery()生成,其返回类型为org.apache.spark.sql.execution.streaming.StreamingQueryWrapper[定义在对应的scala文件内]。而StreamingQueryWrapper中的streamingQuery就是Spark Structured Streaming的执行抽象类org.apache.spark.sql.execution.streaming.StreamExecutionStreamExecution中的start()方法也很简单,就是启动执行线程QueryExecutionThread()

//org.apache.spark.sql.execution.streaming.StreamExecution

/**
* Starts the execution. This returns only after the thread has started and [[QueryStartedEvent]]
* has been posted to all the listeners.
*/
def start(): Unit = {
   queryExecutionThread.setDaemon(true)
   queryExecutionThread.start()
   // Wait until thread started and QueryStart event has been posted
   startLatch.await()
}

queryExecutionThread的定义同样存在于StreamExecution中,其重载了run()方法并在其中执行核心方法runStream()。也就是说一旦queryExecutionThread.start()被调用,最终就会执行runStream()方法。

//org.apache.spark.sql.execution.streaming.StreamExecution
val queryExecutionThread: QueryExecutionThread =
   new QueryExecutionThread(s"stream execution thread for $prettyIdString") {
      override def run(): Unit = {
         sparkSession.sparkContext.setCallSite(callSite)
         runStream()
      }
   }

runStream()同样定义在StreamExecution中,其关键代码如下:

//org.apache.spark.sql.execution.streaming.StreamExecution#runStream()

try {
   // 运行Stream query的准备工作
   // send QueryStartedEvent event, countDown latch,streaming configure等操作
   runActivatedStream(sparkSessionForStream) // 运行 stream
} catch {
   // 异常处理
} finally {
   // 运行完Stream query的收尾工作
   // stop source,send stream stop event,删除checkpoint等等操作
}

StreamExecutionrunActivatedStream()是一个抽象方法,由实现StreamExecution的子类实现,

//org.apache.spark.sql.execution.streaming.StreamExecution

/**
* Run the activated stream until stopped.
*/
protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit

好了,那么问题就是那些类集成了StreamExecution又是在哪里调用的呢?

其实在org.apache.spark.sql.streaming.StreamingQueryManager#createQuery()方法中,由如下代码片段:

//org.apache.spark.sql.streaming.StreamingQueryManager#createQuery()

(sink, trigger) match {
   // 使用 ContinuousTrigger 则为 ContinuousExecution
   case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) =>
      if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) {
         UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode)
      }
      new StreamingQueryWrapper(new ContinuousExecution(
         sparkSession,
         userSpecifiedName.orNull,
         checkpointLocation,
         analyzedPlan,
         v2Sink,
         trigger,
         triggerClock,
         outputMode,
         extraOptions,
         deleteCheckpointOnStop))
   // 使用 ProcessingTrigger 或者 OneTimeTrigger 则为 MicroBatchExecution
   case _ =>
      new StreamingQueryWrapper(new MicroBatchExecution(
         sparkSession,
         userSpecifiedName.orNull,
         checkpointLocation,
         analyzedPlan,
         sink,
         trigger,
         triggerClock,
         outputMode,
         extraOptions,
         deleteCheckpointOnStop))
   }

这里就能看出,在执行createQuery()方法时,根据不同的trigger,就已经决定了要使用那种StreamExection,也就是说ContinuousExecutionMicroBatchExecution就是两个继承了StreamExection的类,并且定义了各自的runActivatedStream()

MicroBatchExecution为例,其具体定义在org.apache.spark.sql.execution.streaming.MicroBatchExecution.scala。其中的核心语句为:

//org.apache.spark.sql.execution.streaming.MicroBatchExecution

/**
* Repeatedly attempts to run batches as data arrives.
*/
protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = {
   ...
   triggerExecutor.execute() => {
      //提交执行每一个query的操作
   }
   ...
}

其中triggerExecutorMicroBatchExecution.scala的定义如下,

//org.apache.spark.sql.execution.streaming.MicroBatchExecution

private val triggerExecutor = trigger match {
   case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock)
   case OneTimeTrigger => OneTimeExecutor()
   case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger")
}

所以,如果使用了Trigger.ProcessingTime(),则会调用ProcessingTimeExecutor,其具体定义在org.apache.spark.sql.execution.streaming.TriggerExecutor,包括在runActivatedStream()方法中最终被调用的execute()方法。ProcessingTime Trigger会使用ProcessingTimeExecutor来周期性生成batch query

//org.apache.spark.sql.execution.streaming.TriggerExecutor

/**
 * A trigger executor that runs a batch every `intervalMs` milliseconds.
 */
case class ProcessingTimeExecutor(processingTime: ProcessingTime, 
                                  clock: Clock = new SystemClock())
   extends TriggerExecutor with Logging {

   private val intervalMs = processingTime.intervalMs
   require(intervalMs >= 0)

   override def execute(triggerHandler: () => Boolean): Unit = {
      while (true) {
         val triggerTimeMs = clock.getTimeMillis
         val nextTriggerTimeMs = nextBatchTime(triggerTimeMs)
         val terminated = !triggerHandler()
         if (intervalMs > 0) {
            val batchElapsedTimeMs = clock.getTimeMillis - triggerTimeMs
            if (batchElapsedTimeMs > intervalMs) {
               notifyBatchFallingBehind(batchElapsedTimeMs)
            }
            if (terminated) {
               return
            }
            clock.waitTillTime(nextTriggerTimeMs)
         } else {
            if (terminated) {
               return
            }
         }
      }
   }

   /** Called when a batch falls behind */
   def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = {
      logWarning("Current batch is falling behind. The trigger interval is " +
         s"${intervalMs} milliseconds, but spent ${realElapsedTimeMs} milliseconds")
   }

   def nextBatchTime(now: Long): Long = {
      if (intervalMs == 0) now else now / intervalMs * intervalMs + intervalMs
   }
}