zoukankan      html  css  js  c++  java
  • Spark-RDD-内部计算机制

    RDD的多个Partition由不同Task处理,Task分为shuffleMapTaskresultTask

    1.Task解析

    Task是计算的基本单位,一个Task处理RDD的一个Partition,Task运行在Executor上,Executor位于CoarseGrainedExecutorBackend

    在Spark中可以根据Task所处Stage的位置将其分为两类:

    shuffleMapTask指Task所处的Stage不是最后一个Stage

    resultTask指Task所处的Stage是最后一个Stage

    除了最后一个Stage的Task为resultTask,其他都为shuffleMapTask

    2.计算过程解析

    RDD进行计算之前,Driver给Executor发送消息让它启动Task,Executor启动成功后,返回成功信息给Driver,详细步骤如下:

    1.Driver中CoarseGrainedSchedulerBackendCoarseGrainedExecutorBackend发送LaunchTask消息

    receive方法如下:

    override def receive: PartialFunction[Any, Unit] = {
        case RegisteredExecutor =>
          logInfo("Successfully registered with driver")
          try {
            executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
          } catch {
            case NonFatal(e) =>
              exitExecutor(1, "Unable to create executor due to " + e.getMessage, e)
          }
    
        case RegisterExecutorFailed(message) =>
          exitExecutor(1, "Slave registration failed: " + message)
    
        case LaunchTask(data) =>
          if (executor == null) {
            exitExecutor(1, "Received LaunchTask command but executor was null")
          } else {
            //反序列化TaskDescription
            val taskDesc = TaskDescription.decode(data.value)
            logInfo("Got assigned task " + taskDesc.taskId)
            executor.launchTask(this, taskDesc)
          }
    
        case KillTask(taskId, _, interruptThread, reason) =>
          if (executor == null) {
            exitExecutor(1, "Received KillTask command but executor was null")
          } else {
            executor.killTask(taskId, interruptThread, reason)
          }
    
        case StopExecutor =>
          stopping.set(true)
          logInfo("Driver commanded a shutdown")
          // Cannot shutdown here because an ack may need to be sent back to the caller. So send
          // a message to self to actually do the shutdown.
          self.send(Shutdown)
    
        case Shutdown =>
          stopping.set(true)
          new Thread("CoarseGrainedExecutorBackend-stop-executor") {
            override def run(): Unit = {
              // executor.stop() will call `SparkEnv.stop()` which waits until RpcEnv stops totally.
              // However, if `executor.stop()` runs in some thread of RpcEnv, RpcEnv won't be able to
              // stop until `executor.stop()` returns, which becomes a dead-lock (See SPARK-14180).
              // Therefore, we put this line in a new thread.
              executor.stop()
            }
          }.start()
      }
    

    decode方法如下:

    def decode(byteBuffer: ByteBuffer): TaskDescription = {
        val dataIn = new DataInputStream(new ByteBufferInputStream(byteBuffer))
        val taskId = dataIn.readLong()
        val attemptNumber = dataIn.readInt()
        val executorId = dataIn.readUTF()
        val name = dataIn.readUTF()
        val index = dataIn.readInt()
    
        // Read files.读取文件
        val taskFiles = deserializeStringLongMap(dataIn)
    
        // Read jars.读取jars包
        val taskJars = deserializeStringLongMap(dataIn)
    
        // Read properties.读取配置
        val properties = new Properties()
        val numProperties = dataIn.readInt()
        for (i <- 0 until numProperties) {
          val key = dataIn.readUTF()
          val valueLength = dataIn.readInt()
          val valueBytes = new Array[Byte](valueLength)
          dataIn.readFully(valueBytes)
          properties.setProperty(key, new String(valueBytes, StandardCharsets.UTF_8))
        }
    
        // Create a sub-buffer for the serialized task into its own buffer (to be deserialized later).
        //创建一个子缓冲区用于序列化任务将其变成自己的缓冲区(稍后反序列化)
        val serializedTask = byteBuffer.slice()
    
        new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars,
          properties, serializedTask)
      }
    

    2.Executor调用launchTask运行Task

    3.launchTask方法创建一个TaskRunner实例在threadPool中运行具体的Task,源码如下:

    def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
        //创建TaskRunner对象
        val tr = new TaskRunner(context, taskDescription)
        //将创建的TaskRunner对象放入运行任务的堆栈中
        runningTasks.put(taskDescription.taskId, tr)
        //从线程池中分配线程给TaskRunner
        threadPool.execute(tr)
      }
    

    ThreadRunner实现了Runnable接口,run方法源码如下:

    override def run(): Unit = {
          threadId = Thread.currentThread.getId
          Thread.currentThread.setName(threadName)
          val threadMXBean = ManagementFactory.getThreadMXBean
          val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
          val deserializeStartTime = System.currentTimeMillis()
          val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
            threadMXBean.getCurrentThreadCpuTime
          } else 0L
          Thread.currentThread.setContextClassLoader(replClassLoader)
          val ser = env.closureSerializer.newInstance()
          logInfo(s"Running $taskName (TID $taskId)")
          //告诉Driver处于RUNNING状态
          execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
          var taskStart: Long = 0
          var taskStartCpu: Long = 0
          startGCTime = computeTotalGcTime()
    
          try {
            // Must be set before updateDependencies() is called, in case fetching dependencies
            // requires access to properties contained within (e.g. for access control).
            Executor.taskDeserializationProps.set(taskDescription.properties)
    
            updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
            //反序列化task
            task = ser.deserialize[Task[Any]](
              taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
            task.localProperties = taskDescription.properties
            task.setTaskMemoryManager(taskMemoryManager)
    
            // If this task has been killed before we deserialized it, let's quit now. Otherwise,
            // continue executing the task.
            val killReason = reasonIfKilled
            if (killReason.isDefined) {
              // Throw an exception rather than returning, because returning within a try{} block
              // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
              // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
              // for the task.
              throw new TaskKilledException(killReason.get)
            }
    
            logDebug("Task " + taskId + "'s epoch is " + task.epoch)
            env.mapOutputTracker.updateEpoch(task.epoch)
    
            // Run the actual task and measure its runtime.
            //运行实际任务并测量其运行时间
            //计算开始时间
            taskStart = System.currentTimeMillis()
            taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
              threadMXBean.getCurrentThreadCpuTime
            } else 0L
            var threwException = true
            //运行Task的run方法
            val value = try {
              val res = task.run(
                taskAttemptId = taskId,
                attemptNumber = taskDescription.attemptNumber,
                metricsSystem = env.metricsSystem)
              threwException = false
              res
            } finally {
              val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
              val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
    
              if (freedMemory > 0 && !threwException) {
                val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
                if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
                  throw new SparkException(errMsg)
                } else {
                  logWarning(errMsg)
                }
              }
    
              if (releasedLocks.nonEmpty && !threwException) {
                val errMsg =
                  s"${releasedLocks.size} block locks were not released by TID = $taskId:
    " +
                    releasedLocks.mkString("[", ", ", "]")
                if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false)) {
                  throw new SparkException(errMsg)
                } else {
                  logInfo(errMsg)
                }
              }
            }
            task.context.fetchFailed.foreach { fetchFailure =>
              // uh-oh.  it appears the user code has caught the fetch-failure without throwing any
              // other exceptions.  Its *possible* this is what the user meant to do (though highly
              // unlikely).  So we will log an error and keep going.
              logError(s"TID ${taskId} completed successfully though internally it encountered " +
                s"unrecoverable fetch failures!  Most likely this means user code is incorrectly " +
                s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure)
            }
            //计算完成时间
            val taskFinish = System.currentTimeMillis()
            val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
              threadMXBean.getCurrentThreadCpuTime
            } else 0L
    .....
    

    task.run方法调用了runTask方法,源码如下:

    /**
       * Called by [[org.apache.spark.executor.Executor]] to run this task.
       *
       * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
       * @param attemptNumber how many times this task has been attempted (0 for the first attempt)
       * @return the result of the task along with updates of Accumulators.
       */
      final def run(
          taskAttemptId: Long,
          attemptNumber: Int,
          metricsSystem: MetricsSystem): T = {
        SparkEnv.get.blockManager.registerTask(taskAttemptId)
        context = new TaskContextImpl(
          stageId,
          partitionId,
          taskAttemptId,
          attemptNumber,
          taskMemoryManager,
          localProperties,
          metricsSystem,
          metrics)
        TaskContext.setTaskContext(context)
        taskThread = Thread.currentThread()
    
        if (_reasonIfKilled != null) {
          kill(interruptThread = false, _reasonIfKilled)
        }
    
        new CallerContext(
          "TASK",
          SparkEnv.get.conf.get(APP_CALLER_CONTEXT),
          appId,
          appAttemptId,
          jobId,
          Option(stageId),
          Option(stageAttemptId),
          Option(taskAttemptId),
          Option(attemptNumber)).setCurrentContext()
    
        try {
          runTask(context)
        } catch {
          case e: Throwable =>
            // Catch all errors; run task failure callbacks, and rethrow the exception.
            try {
              context.markTaskFailed(e)
            } catch {
              case t: Throwable =>
                e.addSuppressed(t)
            }
            context.markTaskCompleted(Some(e))
            throw e
        } finally {
          try {
            // Call the task completion callbacks. If "markTaskCompleted" is called twice, the second
            // one is no-op.
            context.markTaskCompleted(None)
          } finally {
            try {
              Utils.tryLogNonFatalError {
                // Release memory used by this thread for unrolling blocks
                SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
                SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
                  MemoryMode.OFF_HEAP)
                // Notify any tasks waiting for execution memory to be freed to wake up and try to
                // acquire memory again. This makes impossible the scenario where a task sleeps forever
                // because there are no other tasks left to notify it. Since this is safe to do but may
                // not be strictly necessary, we should revisit whether we can remove this in the
                // future.
                val memoryManager = SparkEnv.get.memoryManager
                memoryManager.synchronized { memoryManager.notifyAll() }
              }
            } finally {
              // Though we unset the ThreadLocal here, the context member variable itself is still
              // queried directly in the TaskRunner to check for FetchFailedExceptions.
              TaskContext.unset()
            }
          }
        }
      }
    

    runTask方法有两种实现,分别为ShuffleMapTask和ResultTask:

    1.ShuffleMapTask

    该类的runTask方法如下:

    override def runTask(context: TaskContext): MapStatus = {
        // Deserialize the RDD using the broadcast variable.
        //使用广播变量反序列化RDD
        val threadMXBean = ManagementFactory.getThreadMXBean
        val deserializeStartTime = System.currentTimeMillis()
        val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime
        } else 0L
        //创建序列化器
        val ser = SparkEnv.get.closureSerializer.newInstance()
        //反序列化得到rdd和依赖关系
        val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
          ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
        _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
        _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
        } else 0L
    
        //创建ShuffleWriter对象,将计算结果写入shuffleManager
        var writer: ShuffleWriter[Any, Any] = null
        try {
          //实例化shuffleManager
          val manager = SparkEnv.get.shuffleManager
          writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
          //写入过程
          writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
          writer.stop(success = true).get
        } catch {
          case e: Exception =>
            try {
              if (writer != null) {
                writer.stop(success = false)
              }
            } catch {
              case e: Exception =>
                log.debug("Could not stop writer", e)
            }
            throw e
        }
      }
    

    首先会反序列化RDD和依赖关系,最后调用rdd.iterator方法计算

    /**
       * Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
       * This should ''not'' be called by users directly, but is available for implementors of custom
       * subclasses of RDD.
       */
      final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
        if (storageLevel != StorageLevel.NONE) {
          getOrCompute(split, context)
        } else {
          computeOrReadCheckpoint(split, context)
        }
      }
    

    computeOrReadCheckpoint方法如下:

    /**
       * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
       */
      private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
      {
        if (isCheckpointedAndMaterialized) {
          firstParent[T].iterator(split, context)
        } else {
          compute(split, context)
        }
      }
    

    该方法会在每个RDD里面进行重写,例如MapPartitionsRDD:

    /**
     * An RDD that applies the provided function to every partition of the parent RDD.
     */
    private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
        var prev: RDD[T],
        f: (TaskContext, Int, Iterator[T]) => Iterator[U],  // (TaskContext, partition index, iterator)
        preservesPartitioning: Boolean = false)
      extends RDD[U](prev) {
    
      override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
    
      override def getPartitions: Array[Partition] = firstParent[T].partitions
    
      override def compute(split: Partition, context: TaskContext): Iterator[U] =
        f(context, split.index, firstParent[T].iterator(split, context))
    
      override def clearDependencies() {
        super.clearDependencies()
        prev = null
      }
    }
    

    其中f函数就是我们创建MapPartitionsRDD时输入的操作函数,在计算具体的Partition之后,通过shuffleManager得到的shuffleWriter将当前计算结果写入具体文件中,操作完成之后将MapStatus发送给Driver端的DAGScheduler的MapOutputTracker

    2.ResultTask

    Driver端的DAGScheduler的MapOutputTracker将shuffleMapTask执行的结果交给ResultTask,然后根据前面Stage的执行结果进行shuffle后产生最后结果

    源码如下:

    override def runTask(context: TaskContext): U = {
        // Deserialize the RDD and the func using the broadcast variables.
        //使用广播变量反序列化RDD
        val threadMXBean = ManagementFactory.getThreadMXBean
        val deserializeStartTime = System.currentTimeMillis()
        val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime
        } else 0L
        //创建序列化器
        val ser = SparkEnv.get.closureSerializer.newInstance()
        //反序列化RDD和func处理函数,通过func函数计算最后结果
        val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
          ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
        _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
        _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
          threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
        } else 0L
    
        func(context, rdd.iterator(partition, context))
      }
    
  • 相关阅读:
    disruptor笔记之一:快速入门
    React-高阶函数_函数柯里化
    解决跨域、同源策略-React中代理的配置
    React中key的作用
    React三种路由参数传递方式
    React生命周期(好玩的讲解方式)
    React数据共享插件-PubSub
    React中路由基本&高级使用
    React中嵌套路由
    React中网络请求(axios和fetch)
  • 原文地址:https://www.cnblogs.com/jordan95225/p/13458827.html
Copyright © 2011-2022 走看看