zoukankan      html  css  js  c++  java
  • 14:Spark Streaming源码解读之State管理之updateStateByKey和mapWithState解密

        首先简单解释一下什么是state(状态)管理?我们以wordcount为例。每个batchInterval会计算当前batch的单词计数,那如果需要计算从流开始到目前为止的单词出现的次数,该如计算呢?SparkStreaming提供了两种方法:updateStateByKey和mapWithState 。mapWithState 是1.6版本新增功能,目前属于实验阶段。mapWithState具官方说性能较updateStateByKey提升10倍。那么我们来看看他们到底是如何实现的。

    一、updateStateByKey 解析
    1.1 updateStateByKey 的使用实例
    首先看一个updateStateByKey函数使用的例子:
    1. object UpdateStateByKeyDemo {
    2. def main(args: Array[String]) {
    3. val conf = new SparkConf().setAppName("UpdateStateByKeyDemo")
    4. val ssc = new StreamingContext(conf,Seconds(20))
    5. //要使用updateStateByKey方法,必须设置Checkpoint。
    6. ssc.checkpoint("/checkpoint/")
    7. val socketLines = ssc.socketTextStream("localhost",9999)
    8. socketLines.flatMap(_.split(",")).map(word=>(word,1))
    9. .updateStateByKey( (currValues:Seq[Int],preValue:Option[Int]) =>{
    10.     val currValue = currValues.sum //将目前值相加
    11. Some(currValue + preValue.getOrElse(0)) //目前值的和加上历史值
    12. }).print()
    13. ssc.start()
    14. ssc.awaitTermination()
    15. ssc.stop()
    16. }
    17. }
    代码很简单,关键地方写了详细的注释。

    1.2 updateStateByKey 方法源码分析
        我们知道map返回的是MappedDStream,而MappedDStream并没有updateStateByKey方法,并且它的父类DStream中也没有该方法。但是DStream的伴生对象中有一个隐式转换函数
    1. implicit def toPairDStreamFunctions[K, V](stream: DStream[(K, V)])
    2. (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null):
    3. PairDStreamFunctions[K, V] = {
    4. new PairDStreamFunctions[K, V](stream)
    5. }
    PairDStreamFunction 中updateStateByKey的源码如下:
    1. def updateStateByKey[S: ClassTag](
    2. updateFunc: (Seq[V], Option[S]) => Option[S]
    3. ): DStream[(K, S)] = ssc.withScope {
    4. updateStateByKey(updateFunc, defaultPartitioner())
    5. }
    其中updateFunc就要传入的参数,他是一个函数Seq[V]表示当前key对应的所有值,Option[S] 是当前key的历史状态,返回的是新的状态。

    最终会调用下面的方法:
    1. def updateStateByKey[S: ClassTag](
    2. updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
    3. partitioner: Partitioner,
    4. rememberPartitioner: Boolean
    5. ): DStream[(K, S)] = ssc.withScope {
    6. new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None)
    7. }

    在这里面new出了一个StateDStream对象。在其compute方法中,会先获取上一个batch计算出的RDD(包含了至程序开始到上一个batch单词的累计计数),然后在获取本次batch中StateDStream的父类计算出的RDD(本次batch的单词计数)分别是prevStateRDD和parentRDD,然后在调用 computeUsingPreviousRDD 方法:
    1. private [this] def computeUsingPreviousRDD (
    2. parentRDD: RDD[(K, V)], prevStateRDD: RDD[(K, S)]) = {
    3. // Define the function for the mapPartition operation on cogrouped RDD;
    4. // first map the cogrouped tuple to tuples of required type,
    5. // and then apply the update function
    6. val updateFuncLocal = updateFunc
    7. val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
    8. val i = iterator.map { t =>
    9. val itr = t._2._2.iterator
    10. val headOption = if (itr.hasNext) Some(itr.next()) else None
    11. (t._1, t._2._1.toSeq, headOption)
    12. }
    13. updateFuncLocal(i)
    14. }
    15. val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
    16. val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
    17. Some(stateRDD)
    18. }
    两个RDD进行cogroup然后应用updateStateByKey传入的函数。cogroup的性能是比较低下的。


    二、mapWithState方法解析
    2.1 mapWithState方法使用实例:
    1. object StatefulNetworkWordCount {
    2. def main(args: Array[String]) {
    3. if (args.length < 2) {
    4. System.err.println("Usage: StatefulNetworkWordCount <hostname> <port>")
    5. System.exit(1)
    6. }
    7. StreamingExamples.setStreamingLogLevels()
    8. val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount")
    9. // Create the context with a 1 second batch size
    10. val ssc = new StreamingContext(sparkConf, Seconds(1))
    11. ssc.checkpoint(".")
    12. // Initial state RDD for mapWithState operation
    13. val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))
    14. // Create a ReceiverInputDStream on target ip:port and count the
    15. // words in input stream of delimited test (eg. generated by 'nc')
    16. val lines = ssc.socketTextStream(args(0), args(1).toInt)
    17. val words = lines.flatMap(_.split(" "))
    18. val wordDstream = words.map(x => (x, 1))
    19. // Update the cumulative count using mapWithState
    20. // This will give a DStream made of state (which is the cumulative count of the words)
    21. val mappingFunc = (word: String, one: Option[Int], state: State[Int]) => {
    22. val sum = one.getOrElse(0) + state.getOption.getOrElse(0)
    23. val output = (word, sum)
    24. state.update(sum)
    25. output
    26. }
    27. val stateDstream = wordDstream.mapWithState(
    28. StateSpec.function(mappingFunc).initialState(initialRDD))
    29. stateDstream.print()
    30. ssc.start()
    31. ssc.awaitTermination()
    32. }
    33. }

    mapWithState接收的参数是一个StateSpec对象。在StateSpec中封装了状态管理的函数

    mapWithState函数中创建了MapWithStateDStreamImpl对象

    1. def mapWithState[StateType: ClassTag, MappedType: ClassTag](
    2. spec: StateSpec[K, V, StateType, MappedType]
    3. ): MapWithStateDStream[K, V, StateType, MappedType] = {
    4. new MapWithStateDStreamImpl[K, V, StateType, MappedType](
    5. self,
    6. spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]]
    7. )
    8. }
    MapWithStateDStreamImpl 中创建了一个InternalMapWithStateDStream类型对象internalStream,在MapWithStateDStreamImpl的compute方法中调用了internalStream的getOrCompute方法。
    1. /** Internal implementation of the [[MapWithStateDStream]] */
    2. private[streaming] class MapWithStateDStreamImpl[
    3. KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, MappedType: ClassTag](
    4. dataStream: DStream[(KeyType, ValueType)],
    5. spec: StateSpecImpl[KeyType, ValueType, StateType, MappedType])
    6. extends MapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream.context) {
    7. private val internalStream =
    8. new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec)
    9. override def slideDuration: Duration = internalStream.slideDuration
    10. override def dependencies: List[DStream[_]] = List(internalStream)
    11. override def compute(validTime: Time): Option[RDD[MappedType]] = {
    12. internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } }
    13. }

    InternalMapWithStateDStream中没有getOrCompute方法,这里调用的是其父类 DStream 的getOrCpmpute方法,该方法中最终会调用InternalMapWithStateDStream的Compute方法:

    1. /** Method that generates a RDD for the given time */
    2. override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = {
    3. // Get the previous state or create a new empty state RDD
    4. val prevStateRDD = getOrCompute(validTime - slideDuration) match {
    5. case Some(rdd) =>
    6. if (rdd.partitioner != Some(partitioner)) {
    7. // If the RDD is not partitioned the right way, let us repartition it using the
    8. // partition index as the key. This is to ensure that state RDD is always partitioned
    9. // before creating another state RDD using it
    10. MapWithStateRDD.createFromRDD[K, V, S, E](
    11. rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
    12. } else {
    13. rdd
    14. }
    15. case None =>
    16. MapWithStateRDD.createFromPairRDD[K, V, S, E](
    17. spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
    18. partitioner,
    19. validTime
    20. )
    21. }
    22. // Compute the new state RDD with previous state RDD and partitioned data RDD
    23. // Even if there is no data RDD, use an empty one to create a new state RDD
    24. val dataRDD = parent.getOrCompute(validTime).getOrElse {
    25. context.sparkContext.emptyRDD[(K, V)]
    26. }
    27. val partitionedDataRDD = dataRDD.partitionBy(partitioner)
    28. val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
    29. (validTime - interval).milliseconds
    30. }
    31. Some(new MapWithStateRDD(
    32. prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime))
    33. }
    根据给定的时间生成一个MapWithStateRDD,首先获取了先前状态的RDD:preStateRDD和当前时间的RDD:dataRDD,然后对dataRDD基于先前状态RDD的分区器进行重新分区获取partitionedDataRDD。最后将preStateRDD,partitionedDataRDD和用户定义的函数mappingFunction传给新生成的MapWithStateRDD对象返回。
    下面看一下MapWithStateRDD的compute方法:

    1. override def compute(
    2. partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = {
    3. val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition]
    4. val prevStateRDDIterator = prevStateRDD.iterator(
    5. stateRDDPartition.previousSessionRDDPartition, context)
    6. val dataIterator = partitionedDataRDD.iterator(
    7. stateRDDPartition.partitionedDataRDDPartition, context)
    8. //prevRecord 代表一个分区的数据
    9. val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
    10. val newRecord = MapWithStateRDDRecord.updateRecordWithData(
    11. prevRecord,
    12. dataIterator,
    13. mappingFunction,
    14. batchTime,
    15. timeoutThresholdTime,
    16. removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
    17. )
    18. Iterator(newRecord)
    19. }
    MapWithStateRDDRecord 对应MapWithStateRDD 的一个分区:
    1. private[streaming] case class MapWithStateRDDRecord[K, S, E](
    2. var stateMap: StateMap[K, S], var mappedData: Seq[E])
    其中stateMap存储了key的状态,mappedData存储了mapping function函数的返回值
    看一下MapWithStateRDDRecordupdateRecordWithData方法
    1. def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
    2. prevRecord: Option[MapWithStateRDDRecord[K, S, E]],
    3. dataIterator: Iterator[(K, V)],
    4. mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
    5. batchTime: Time,
    6. timeoutThresholdTime: Option[Long],
    7. removeTimedoutData: Boolean
    8. ): MapWithStateRDDRecord[K, S, E] = {

    9. // 创建一个新的 state map 从过去的Recoord中复制 (如果存在) 否则创建一下空的StateMap对象
    10. val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }
    11. val mappedData = new ArrayBuffer[E]
    12.     //状态
    13. val wrappedState = new StateImpl[S]()
    14. // Call the mapping function on each record in the data iterator, and accordingly
    15. // update the states touched, and collect the data returned by the mapping function
    16. dataIterator.foreach { case (key, value) =>
    17.     //获取key对应的状态
    18. wrappedState.wrap(newStateMap.get(key))
    19.     //调用mappingFunction获取返回值
    20. val returned = mappingFunction(batchTime, key, Some(value), wrappedState)
    21.     //维护newStateMap的值
    22. if (wrappedState.isRemoved) {
    23. newStateMap.remove(key)
    24. } else if (wrappedState.isUpdated
    25. || (wrappedState.exists && timeoutThresholdTime.isDefined)) {
    26. newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
    27. }
    28. mappedData ++= returned
    29. }
    30. // Get the timed out state records, call the mapping function on each and collect the
    31. // data returned
    32. if (removeTimedoutData && timeoutThresholdTime.isDefined) {
    33. newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
    34. wrappedState.wrapTimingOutState(state)
    35. val returned = mappingFunction(batchTime, key, None, wrappedState)
    36. mappedData ++= returned
    37. newStateMap.remove(key)
    38. }
    39. }
    40. MapWithStateRDDRecord(newStateMap, mappedData)
    41. }

    最终返回MapWithStateRDDRecord对象交个MapWithStateRDD的compute函数,MapWithStateRDD的compute函数将其封装成Iterator返回。









  • 相关阅读:
    Oracle的序列、视图、索引和表空间
    MySQL存储过程
    MySQL触发器
    MySQL索引和视图
    完整性约束
    Mybatis的核心对象及运行流程
    Mybatis中配置连接池
    IDEA中创建Maven工程整合Mybatis
    Idea中创建JavaWeb工程
    实现整数集合的并、交、差运算
  • 原文地址:https://www.cnblogs.com/zhouyf/p/5556200.html
Copyright © 2011-2022 走看看