zoukankan      html  css  js  c++  java
  • spark源码分析, 任务反序列化及执行

    1 ==> 接受消息,org.apache.spark.executor.CoarseGrainedExecutorBackend#receive

        case LaunchTask(data) =>
          if (executor == null) {
            exitExecutor(1, "Received LaunchTask command but executor was null")
          } else {
            val taskDesc = TaskDescription.decode(data.value)
            logInfo("Got assigned task " + taskDesc.taskId)
            executor.launchTask(this, taskDesc)
          }

    2. ==> org.apache.spark.executor.Executor#launchTask

      // Maintains the list of running tasks.
      private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
    
     def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
        val tr = new TaskRunner(context, taskDescription)
        runningTasks.put(taskDescription.taskId, tr)
        threadPool.execute(tr)
      }

    3. ==>org.apache.spark.executor.Executor.TaskRunner#run

    override def run(): Unit = {
          threadId = Thread.currentThread.getId
          Thread.currentThread.setName(threadName)
          val threadMXBean = ManagementFactory.getThreadMXBean
          val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
      
    //下载依赖
            updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
    //反序列化得到真正的 task task
    = ser.deserialize[Task[Any]](taskDescription.serializedTask, Thread.currentThread.getContextClassLoader) task.localProperties = taskDescription.properties task.setTaskMemoryManager(taskMemoryManager) val value = Utils.tryWithSafeFinally { val res = task.run( taskAttemptId = taskId, attemptNumber = taskDescription.attemptNumber, metricsSystem = env.metricsSystem) threwException = false res } { val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId) val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() } //处理执行结果 val resultSer = env.serializer.newInstance() val beforeSerialization = System.currentTimeMillis() val valueBytes = resultSer.serialize(value) val afterSerialization = System.currentTimeMillis() // Note: accumulator updates must be collected after TaskMetrics is updated val accumUpdates = task.collectAccumulatorUpdates() // TODO: do not serialize value twice val directResult = new DirectTaskResult(valueBytes, accumUpdates) val serializedDirectResult = ser.serialize(directResult) val resultSize = serializedDirectResult.limit() // directSend = sending directly back to the driver val serializedResult: ByteBuffer = { if (maxResultSize > 0 && resultSize > maxResultSize) { logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " + s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + s"dropping it.") ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize)) } else if (resultSize > maxDirectResultSize) { val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( blockId, new ChunkedByteBuffer(serializedDirectResult.duplicate()), StorageLevel.MEMORY_AND_DISK_SER) logInfo( s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") ser.serialize(new IndirectTaskResult[Any](blockId, resultSize)) } else { logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver") serializedDirectResult } } setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) }

    ==> org.apache.spark.executor.Executor#updateDependencies

      /**
       * Download any missing dependencies if we receive a new set of files and JARs from the
       * SparkContext. Also adds any new JARs we fetched to the class loader.
       */
      private def updateDependencies(newFiles: Map[String, Long], newJars: Map[String, Long]) {
        lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
        synchronized {
          // Fetch missing dependencies
          for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
            logInfo("Fetching " + name + " with timestamp " + timestamp)
            // Fetch file with useCache mode, close cache for local mode.
            Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
              env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
            currentFiles(name) = timestamp
          }
          for ((name, timestamp) <- newJars) {
            val localName = new URI(name).getPath.split("/").last
            val currentTimeStamp = currentJars.get(name)
              .orElse(currentJars.get(localName))
              .getOrElse(-1L)
            if (currentTimeStamp < timestamp) {
              logInfo("Fetching " + name + " with timestamp " + timestamp)
              // Fetch file with useCache mode, close cache for local mode.
              Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
                env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
              currentJars(name) = timestamp
              // Add it to our class loader
              val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL
              if (!urlClassLoader.getURLs().contains(url)) {
                logInfo("Adding " + url + " to class loader")
                urlClassLoader.addURL(url)
              }
            }
          }
        }
      }

    ==> org.apache.spark.scheduler.Task#run

     final def run(
          taskAttemptId: Long,
          attemptNumber: Int,
          metricsSystem: MetricsSystem): T = {
        SparkEnv.get.blockManager.registerTask(taskAttemptId)
    
    
        val taskContext = new TaskContextImpl(
          stageId,
          stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal
          partitionId,
          taskAttemptId,
          attemptNumber,
          taskMemoryManager,
          localProperties,
          metricsSystem,
          metrics)
    
        context = if (isBarrier) {
          new BarrierTaskContext(taskContext)
        } else {
          taskContext
        }
    
        TaskContext.setTaskContext(context)
        taskThread = Thread.currentThread()
    
        if (_reasonIfKilled != null) {
          kill(interruptThread = false, _reasonIfKilled)
        }
    
        new CallerContext(
          "TASK",
          SparkEnv.get.conf.get(APP_CALLER_CONTEXT),
          appId,
          appAttemptId,
          jobId,
          Option(stageId),
          Option(stageAttemptId),
          Option(taskAttemptId),
          Option(attemptNumber)).setCurrentContext()
    
        try {
        //这个类只是一个模板类或者抽象类, 具体实现类分为ResultTask, ShuffleMapTask 两种
          runTask(context)
        } 
      }

    ==>org.apache.spark.scheduler.ShuffleMapTask#runTask

    ShuffleMapTask将rdd的元素,切分为多个bucket, 基于ShuffleDependency指定的partitioner,默认就是HashPartitioner

    ShuffleMapTask 核心方法是 RDD.iterator[底层调用 compute 方法(fn(context,index,partition))],

    执行完成rdd之后,rdd或返回处理过后的partition数据,这些数据通过shuffleWriter在经过HashPartitioner写入对应的分区中

    // ShuffleMapTask将rdd的元素,切分为多个bucket
    // 基于ShuffleDependency指定的partitioner,默认就是HashPartitioner
    private[spark] class ShuffleMapTask(
       ...
       // ShuffleMapTask的 runTask 有 MapStatus返回值
      override def runTask(context: TaskContext): MapStatus = {
        // Deserialize the RDD using the broadcast variable.
        val threadMXBean = ManagementFactory.getThreadMXBean
        val deserializeStartTime = System.currentTimeMillis()
        val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime
        } else 0L
    
        // 对task要处理的数据,做反序列化操作
     
        val ser = SparkEnv.get.closureSerializer.newInstance()
        //获得 RDD 
        val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
          ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
        _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
        _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
        } else 0L
    
        var writer: ShuffleWriter[Any, Any] = null
        try {
          // 拿到shuffleManager
          val manager = SparkEnv.get.shuffleManager
          // 拿到shuffleWriter
          writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
    
          // 核心逻辑,调用rdd的iterator方法,并且传入了当前要处理的partition
          // 执行完成rdd之后,rdd或返回处理过后的partition数据,这些数据通过shuffleWriter
          // 在经过HashPartitioner写入对应的分区中
          
          writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
    
          // 返回结果 MapStatus ,里面封装了ShuffleMapTask存储在哪里,其实就是BlockManager相关信息
          writer.stop(success = true).get
        } 
      }
      ...
    }

    ==> org.apache.spark.scheduler.ResultTask#runTask

      override def runTask(context: TaskContext): U = {
        // Deserialize the RDD and the func using the broadcast variables.
        val threadMXBean = ManagementFactory.getThreadMXBean
        val deserializeStartTime = System.currentTimeMillis()
        val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime
        } else 0L
        val ser = SparkEnv.get.closureSerializer.newInstance()
        val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
          ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
        _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
        _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
        } else 0L
    
       //直接调用用户自定义函数
        func(context, rdd.iterator(partition, context))
      }

    ==> org.apache.spark.rdd.RDD#iterator

     final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
       //结果不需要存储
    if (storageLevel != StorageLevel.NONE) { getOrCompute(split, context) } else { computeOrReadCheckpoint(split, context) } }

    ==> org.apache.spark.rdd.RDD#computeOrReadCheckpoint

      /**
       * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
       */
      private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
      {
        if (isCheckpointedAndMaterialized) {
          firstParent[T].iterator(split, context)
        } else {
         //核心方法, 此方法为虚方法,具体实现由具体 RDD 子类实现,如 MapPartitionsRDD,JdbcRDD等
          compute(split, context)
        }
      }    

    demo: 

    class MapPartitionsRDD[U: ClassTag, T: ClassTag](
        var prev: RDD[T],
        f: (TaskContext, Int, Iterator[T]) => Iterator[U],  // (TaskContext, partition index, iterator)
        preservesPartitioning: Boolean = false,
        isFromBarrier: Boolean = false,
        isOrderSensitive: Boolean = false)
      extends RDD[U](prev) {
    
    
      override def compute(split: Partition, context: TaskContext): Iterator[U] =
        f(context, split.index, firstParent[T].iterator(split, context))
    
    }
    
    
    class JdbcRDD[T: ClassTag](
        sc: SparkContext,
        getConnection: () => Connection,
        sql: String,
        lowerBound: Long,
        upperBound: Long,
        numPartitions: Int,
        mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _)
      extends RDD[T](sc, Nil) with Logging {
      override def getPartitions: Array[Partition] = {
        // bounds are inclusive, hence the + 1 here and - 1 on end
        val length = BigInt(1) + upperBound - lowerBound
        (0 until numPartitions).map { i =>
          val start = lowerBound + ((i * length) / numPartitions)
          val end = lowerBound + (((i + 1) * length) / numPartitions) - 1
          new JdbcPartition(i, start.toLong, end.toLong)
        }.toArray
      }
    
      override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T]
      {
        context.addTaskCompletionListener[Unit]{ context => closeIfNeeded() }
        val part = thePart.asInstanceOf[JdbcPartition]
        val conn = getConnection()
        val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
    
        val url = conn.getMetaData.getURL
        
        val rs = stmt.executeQuery()
    
        override def getNext(): T = {
          if (rs.next()) {
            mapRow(rs)
          } else {
            finished = true
            null.asInstanceOf[T]
          }
        }
    
        override def close() {
         
        }
      }
    }
  • 相关阅读:
    日志模块
    模块介绍3
    模块介绍2
    模块介绍
    迭代器
    Python装饰器续/三元表达式/匿名函数
    Python装饰器详解
    LATEX LIAN XI
    BELLMAN 最短路算法
    B阿狸和桃子的游戏
  • 原文地址:https://www.cnblogs.com/snow-man/p/13555013.html
Copyright © 2011-2022 走看看