zoukankan      html  css  js  c++  java
  • spark1.1.0源码阅读-executor

    1. executor上执行launchTask

    1   def launchTask(
    2       context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) {
    3     val tr = new TaskRunner(context, taskId, taskName, serializedTask)
    4     runningTasks.put(taskId, tr)
    5     threadPool.execute(tr)
    6   }

    2. executor上执行TaskRunner的run

     1  class TaskRunner(
     2       execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer)
     3     extends Runnable {
     4 
     5     @volatile private var killed = false
     6     @volatile var task: Task[Any] = _
     7     @volatile var attemptedTask: Option[Task[Any]] = None
     8 
     9     def kill(interruptThread: Boolean) {
    10       logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
    11       killed = true
    12       if (task != null) {
    13         task.kill(interruptThread)
    14       }
    15     }
    16 
    17     override def run() {
    18       val startTime = System.currentTimeMillis()
    19       SparkEnv.set(env)
    20       Thread.currentThread.setContextClassLoader(replClassLoader)
    21       val ser = SparkEnv.get.closureSerializer.newInstance()
    22       logInfo(s"Running $taskName (TID $taskId)")
    23       execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
    24       var taskStart: Long = 0
    25       def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
    26       val startGCTime = gcTime
    27 
    28       try {
    29         SparkEnv.set(env)
    30         Accumulators.clear()
    31         val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)  //反序列化出 taskFiles,taskJars,taskBytes
    32         updateDependencies(taskFiles, taskJars)
    33         task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)  //反序列化出task对象
    34 
    35         // If this task has been killed before we deserialized it, let's quit now. Otherwise,
    36         // continue executing the task.
    37         if (killed) {
    38           // Throw an exception rather than returning, because returning within a try{} block
    39           // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
    40           // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
    41           // for the task.
    42           throw new TaskKilledException
    43         }
    44 
    45         attemptedTask = Some(task)
    46         logDebug("Task " + taskId + "'s epoch is " + task.epoch)
    47         env.mapOutputTracker.updateEpoch(task.epoch)
    48 
    49         // Run the actual task and measure its runtime.
    50         taskStart = System.currentTimeMillis()
    51         val value = task.run(taskId.toInt)
    52         val taskFinish = System.currentTimeMillis()
    53 
    54         // If the task has been killed, let's fail it.
    55         if (task.killed) {
    56           throw new TaskKilledException
    57         }

    3. task.run

     1 private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
     2 
     3   final def run(attemptId: Long): T = {
     4     context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
     5     context.taskMetrics.hostname = Utils.localHostName()
     6     taskThread = Thread.currentThread()
     7     if (_killed) {
     8       kill(interruptThread = false)
     9     }
    10     runTask(context)
    11   }

    4. task是抽象类,对于具体的类(resultTask和shuffleMapTask)会执行相应的runTask。

    a. resultTask

     1   override def runTask(context: TaskContext): U = {
     2     // Deserialize the RDD and the func using the broadcast variables.
     3     val ser = SparkEnv.get.closureSerializer.newInstance()
     4     val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
     5       ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
     6 
     7     metrics = Some(context.taskMetrics)
     8     try {
     9       func(context, rdd.iterator(partition, context))
    10     } finally {
    11       context.markTaskCompleted()
    12     }
    13   }

    b. shuffleMapTask

     1   override def runTask(context: TaskContext): MapStatus = {
     2     // Deserialize the RDD using the broadcast variable.
     3     val ser = SparkEnv.get.closureSerializer.newInstance()
     4     val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
     5       ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
     6 
     7     metrics = Some(context.taskMetrics)
     8     var writer: ShuffleWriter[Any, Any] = null
     9     try {
    10       val manager = SparkEnv.get.shuffleManager
    11       writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
    12       writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
    13       return writer.stop(success = true).get
    14     } catch {
    15       case e: Exception =>
    16         if (writer != null) {
    17           writer.stop(success = false)
    18         }
    19         throw e
    20     } finally {
    21       context.markTaskCompleted()
    22     }
    23   }
     1   /** Write a bunch of records to this task's output */
     2   override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
     3     val iter = if (dep.aggregator.isDefined) {
     4       if (dep.mapSideCombine) {
     5         dep.aggregator.get.combineValuesByKey(records, context)
     6       } else {
     7         records
     8       }
     9     } else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
    10       throw new IllegalStateException("Aggregator is empty for map-side combine")
    11     } else {
    12       records
    13     }
    14 
    15     for (elem <- iter) {
    16       val bucketId = dep.partitioner.getPartition(elem._1)
    17       shuffle.writers(bucketId).write(elem)
    18     }
    19   }
     1   /**
     2    * Get a ShuffleWriterGroup for the given map task, which will register it as complete
     3    * when the writers are closed successfully
     4    */
     5   def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer,
     6       writeMetrics: ShuffleWriteMetrics) = {
     7     new ShuffleWriterGroup {
     8       shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
     9       private val shuffleState = shuffleStates(shuffleId)
    10       private var fileGroup: ShuffleFileGroup = null
    11 
    12       val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
    13         fileGroup = getUnusedFileGroup()
    14         Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
    15           val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
    16           blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize,
    17             writeMetrics)
    18         }
    19       } else {
    20         Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
    21           val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
    22           val blockFile = blockManager.diskBlockManager.getFile(blockId)
    23           // Because of previous failures, the shuffle file may already exist on this machine.
    24           // If so, remove it.
    25           if (blockFile.exists) {
    26             if (blockFile.delete()) {
    27               logInfo(s"Removed existing shuffle file $blockFile")
    28             } else {
    29               logWarning(s"Failed to remove existing shuffle file $blockFile")
    30             }
    31           }
    32           blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics)
    33         }
    34       }
  • 相关阅读:
    startx
    BIOS above 4g setting
    How to download file from google drive
    凤凰财经
    SOTA state-of-the-art model
    update gpu drivers
    /opt/nvidia/deepstream/deepstream/sources/libs/nvdsinfer_customparser/nvdsinfer_custombboxparser.cpp
    retinanet keras
    fp16 fp32 int8
    tf.gather with tf.Session() as sess:
  • 原文地址:https://www.cnblogs.com/Torstan/p/4158671.html
Copyright © 2011-2022 走看看