zoukankan      html  css  js  c++  java
  • Spark源码分析之八:Task运行(二)

    《Spark源码分析之七:Task运行(一)》一文中,我们详细叙述了Task运行的整体流程,最终Task被传输到Executor上,启动一个对应的TaskRunner线程,并且在线程池中被调度执行。继而,我们对TaskRunner的run()方法进行了详细的分析,总结出了其内Task执行的三个主要步骤:

            Step1:Task及其运行时需要的辅助对象构造,主要包括:

                           1、当前线程设置上下文类加载器;

                           2、获取序列化器ser;

                           3、更新任务状态TaskState;

                           4、计算垃圾回收时间;

                           5、反序列化得到Task运行的jar、文件、Task对象二进制数据;

                           6、反序列化Task对象二进制数据得到Task对象;

                           7、设置任务内存管理器;

            Step2:Task运行:调用Task的run()方法,真正执行Task,并获得运行结果value
            Step3:Task运行结果处理:

                           1、序列化Task运行结果value,得到valueBytes;

                           2、根据Task运行结果大小处理Task运行结果valueBytes:

                                2.1、如果Task运行结果大小大于所有Task运行结果的最大大小,序列化IndirectTaskResult,IndirectTaskResult为存储在Worker上BlockManager中DirectTaskResult的一个引用;

                                2.2、如果 Task运行结果大小超过Akka除去需要保留的字节外最大大小,则将结果写入BlockManager,Task运行结果比较小的话,直接返回,通过消息传递;

                                2.3、Task运行结果比较小的话,直接返回,通过消息传递。

            大体流程大概就是如此。我们先回顾到这里。那么,接下来的问题是,任务内存管理器是什么?如何计算开始垃圾回收时间?Task的run()方法的执行流程是什么?IndirectTaskResult,或者BlockManager又是如何传递任务运行结果至应用程序即客户端的?

            不要着急,我们一个一个来解决。

            关于任务内存管理器TaskMemoryManager,可以参照《Spark源码分析之九:内存管理模型》一文,只要知道它是任务运行期间各区域内存的管理者就行,这里不再赘述。

            接下来,我们重点分析下Task的run()方法,看看Task实际运行时的处理逻辑。其代码如下:

    [java] view plain copy
     
     在CODE上查看代码片派生到我的代码片
    1. /** 
    2.    * Called by [[Executor]] to run this task. 
    3.    * 被Executor调用以执行Task 
    4.    * 
    5.    * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext. 
    6.    * @param attemptNumber how many times this task has been attempted (0 for the first attempt) 
    7.    * @return the result of the task along with updates of Accumulators. 
    8.    */  
    9.   final def run(  
    10.     taskAttemptId: Long,  
    11.     attemptNumber: Int,  
    12.     metricsSystem: MetricsSystem)  
    13.   : (T, AccumulatorUpdates) = {  
    14.     
    15.     // 创建一个Task上下文实例:TaskContextImpl类型的context  
    16.     context = new TaskContextImpl(  
    17.       stageId,  
    18.       partitionId,  
    19.       taskAttemptId,  
    20.       attemptNumber,  
    21.       taskMemoryManager,  
    22.       metricsSystem,  
    23.       internalAccumulators,  
    24.       runningLocally = false)  
    25.         
    26.     // 将context放入TaskContext的taskContext变量中  
    27.     // taskContext变量为ThreadLocal[TaskContext]  
    28.     TaskContext.setTaskContext(context)  
    29.       
    30.     // 设置主机名localHostName、内部累加器internalAccumulators等Metrics信息  
    31.     context.taskMetrics.setHostname(Utils.localHostName())  
    32.     context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators)  
    33.       
    34.     // task线程为当前线程  
    35.     taskThread = Thread.currentThread()  
    36.       
    37.     if (_killed) {// 如果需要杀死task,调用kill()方法,且调用的方式为不中断线程  
    38.       kill(interruptThread = false)  
    39.     }  
    40.       
    41.     try {  
    42.       // 调用runTask()方法,传入Task上下文信息context,执行Task,并调用Task上下文的collectAccumulators()方法,收集累加器  
    43.       (runTask(context), context.collectAccumulators())  
    44.     } finally {  
    45.       // 上下文标记Task完成  
    46.       context.markTaskCompleted()  
    47.         
    48.       try {  
    49.         Utils.tryLogNonFatalError {  
    50.           // Release memory used by this thread for unrolling blocks  
    51.           // 为unrolling块释放当前线程使用的内存  
    52.           SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask()  
    53.           // Notify any tasks waiting for execution memory to be freed to wake up and try to  
    54.           // acquire memory again. This makes impossible the scenario where a task sleeps forever  
    55.           // because there are no other tasks left to notify it. Since this is safe to do but may  
    56.           // not be strictly necessary, we should revisit whether we can remove this in the future.  
    57.           val memoryManager = SparkEnv.get.memoryManager  
    58.           memoryManager.synchronized { memoryManager.notifyAll() }  
    59.         }  
    60.       } finally {  
    61.         // 释放TaskContext  
    62.         TaskContext.unset()  
    63.       }  
    64.     }  
    65.   }  

            代码逻辑非常简单,概述如下:

            1、需要创建一个Task上下文实例,即TaskContextImpl类型的context,这个TaskContextImpl主要包括以下内容:Task所属Stage的stageId、Task对应数据分区的partitionId、Task执行的taskAttemptId、Task执行的序号attemptNumber、Task内存管理器taskMemoryManager、指标度量系统metricsSystem、内部累加器internalAccumulators、是否本地运行的标志位runningLocally(为false);

            2、将context放入TaskContext的taskContext变量中,这个taskContext变量为ThreadLocal[TaskContext];

            3、在任务上下文context中设置主机名localHostName、内部累加器internalAccumulators等Metrics信息;

            4、设置task线程为当前线程;

            5、如果需要杀死task,调用kill()方法,且调用的方式为不中断线程;

            6、调用runTask()方法,传入Task上下文信息context,执行Task,并调用Task上下文的collectAccumulators()方法,收集累加器;

            7、最后,任务上下文context标记Task完成,为unrolling块释放当前线程使用的内存,清楚任务上下文等。

            接下来自然要看下runTask()方法。但是Task中,runTask()方法却没有实现。我们知道,Task共分为两种类型,一个是最后一个Stage产生的ResultTask,另外一个是其parent Stage产生的ShuffleMapTask。那么,我们分开来分析下,首先看下ShuffleMapTask中的runTask()方法,定义如下:

    [java] view plain copy
     
     在CODE上查看代码片派生到我的代码片
    1. override def runTask(context: TaskContext): MapStatus = {  
    2.     // Deserialize the RDD using the broadcast variable.  
    3.     // 使用广播变量反序列化RDD  
    4.       
    5.     // 反序列化的起始时间  
    6.     val deserializeStartTime = System.currentTimeMillis()  
    7.       
    8.     // 获得反序列化器closureSerializer  
    9.     val ser = SparkEnv.get.closureSerializer.newInstance()  
    10.       
    11.     // 调用反序列化器closureSerializer的deserialize()进行RDD和ShuffleDependency的反序列化,数据来源于taskBinary  
    12.     val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](  
    13.       ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)  
    14.       
    15.     // 计算Executor进行反序列化的时间  
    16.     _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime  
    17.   
    18.     metrics = Some(context.taskMetrics)  
    19.     var writer: ShuffleWriter[Any, Any] = null  
    20.     try {  
    21.       // 获得shuffleManager  
    22.       val manager = SparkEnv.get.shuffleManager  
    23.         
    24.       // 通过shuffleManager的getWriter()方法,获得shuffle的writer  
    25.       // 启动的partitionId表示的是当前RDD的某个partition,也就是说write操作作用于partition之上  
    26.       writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)  
    27.         
    28.       // 针对RDD中的分区<span style="font-family: Arial, Helvetica, sans-serif;">partition</span><span style="font-family: Arial, Helvetica, sans-serif;">,调用rdd的iterator()方法后,再调用writer的write()方法,写数据</span>  
    29.       writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])  
    30.         
    31.       // 停止writer,并返回标志位  
    32.       writer.stop(success = true).get  
    33.     } catch {  
    34.       case e: Exception =>  
    35.         try {  
    36.           if (writer != null) {  
    37.             writer.stop(success = false)  
    38.           }  
    39.         } catch {  
    40.           case e: Exception =>  
    41.             log.debug("Could not stop writer", e)  
    42.         }  
    43.         throw e  
    44.     }  
    45.   }  

            运行的主要逻辑其实只有两步,如下:

            1、通过使用广播变量反序列化得到RDD和ShuffleDependency:

                  1.1、获得反序列化的起始时间deserializeStartTime;

                  1.2、通过SparkEnv获得反序列化器ser;

                  1.3、调用反序列化器ser的deserialize()进行RDD和ShuffleDependency的反序列化,数据来源于taskBinary,得到rdd、dep;

                  1.4、计算Executor进行反序列化的时间_executorDeserializeTime;

             2、利用shuffleManager的writer进行数据的写入:

                   2.1、通过SparkEnv获得shuffleManager;

                   2.2、通过shuffleManager的getWriter()方法,获得shuffle的writer,其中的partitionId表示的是当前RDD的某个partition,也就是说write操作作用于partition之上;

                   2.3、针对RDD中的分区partition,调用rdd的iterator()方法后,再调用writer的write()方法,写数据;

                   2.4、停止writer,并返回标志位。

              至于shuffle的详细内容,我会在后续的博文中深入分析。

              下面,再看下ResultTask,其runTask()方法更简单,代码如下:

    [java] view plain copy
     
     在CODE上查看代码片派生到我的代码片
    1. override def runTask(context: TaskContext): U = {  
    2.     // Deserialize the RDD and the func using the broadcast variables.  
    3.       
    4.     // 获取反序列化的起始时间  
    5.     val deserializeStartTime = System.currentTimeMillis()  
    6.       
    7.     // 获取反序列化器  
    8.     val ser = SparkEnv.get.closureSerializer.newInstance()  
    9.       
    10.     // 调用反序列化器ser的deserialize()方法,得到RDD和FUNC,数据来自taskBinary  
    11.     val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](  
    12.       ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)  
    13.       
    14.     // 计算反序列化时间_executorDeserializeTime  
    15.     _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime  
    16.   
    17.   
    18.     metrics = Some(context.taskMetrics)  
    19.       
    20.     // 调针对RDD中的每个分区,迭代执行func方法,执行Task  
    21.     func(context, rdd.iterator(partition, context))  
    22.   }  

            首先,获取反序列化的起始时间deserializeStartTime;

            其次,通过SparkEnv获取反序列化器ser;

            然后,调用反序列化器ser的deserialize()方法,得到RDD和FUNC,数据来自taskBinary;

            紧接着,计算反序列化时间_executorDeserializeTime;

            最后,调针对RDD中的每个分区,迭代执行func方法,执行Task。

            到了这里,读者可能会有一个很大的疑问,Task的运行就这样完了?ReusltTask还好说,它会执行反序列化后得到的func函数,那么ShuffleMapTask呢?仅仅是shuffle的数据写入吗?它的分区数据需要执行什么函数来继续转换呢?现在,我就来为大家解答下这个问题。

            首先,在ShuffleMapTask的runTask()方法中,反序列化得到rdd后,在执行writer的write()方法之前,会调用rdd的iterator()函数,对rdd的分区partition进行处理。那么我们看下RDD中的iterator()函数是如何定义的?

    [java] view plain copy
     
     在CODE上查看代码片派生到我的代码片
    1. /** 
    2.    * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. 
    3.    * This should ''not'' be called by users directly, but is available for implementors of custom 
    4.    * subclasses of RDD. 
    5.    */  
    6.   final def iterator(split: Partition, context: TaskContext): Iterator[T] = {  
    7.     if (storageLevel != StorageLevel.NONE) {  
    8.       SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)  
    9.     } else {  
    10.       computeOrReadCheckpoint(split, context)  
    11.     }  
    12.   }  

            很简单,它会根据存储级别,来决定:

            1、如果存储级别storageLevel不为空,调用SparkEnv中的cacheManager的getOrCompute()方法;

            2、如果存储级别storageLevel为空,则调用computeOrReadCheckpoint()方法;
            我们先看下SparkEnv中cacheManager的定义:

    [java] view plain copy
     
     在CODE上查看代码片派生到我的代码片
    1. val cacheManager = new CacheManager(blockManager)  

            它是一个CacheManager类型的对象。而CacheManager中getOrCompute()方法的定义如下:

    [java] view plain copy
     
     在CODE上查看代码片派生到我的代码片
    1. /** Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is cached. */  
    2.   // 获取或计算一个RDD的分区  
    3.   def getOrCompute[T](  
    4.       rdd: RDD[T],  
    5.       partition: Partition,  
    6.       context: TaskContext,  
    7.       storageLevel: StorageLevel): Iterator[T] = {  
    8.       
    9.     // 通过rdd的id和分区的索引号,获取RDDBlockId类型的key  
    10.     val key = RDDBlockId(rdd.id, partition.index)  
    11.     logDebug(s"Looking for partition $key")  
    12.       
    13.     // 在blockManager中根据key查找  
    14.     blockManager.get(key) match {  
    15.         
    16.       // 如果为blockResult,意味着分区Partition已经被物化,直接获取结果即可  
    17.       case Some(blockResult) =>  
    18.         // Partition is already materialized, so just return its values  
    19.         val existingMetrics = context.taskMetrics  
    20.           .getInputMetricsForReadMethod(blockResult.readMethod)  
    21.         existingMetrics.incBytesRead(blockResult.bytes)  
    22.   
    23.         val iter = blockResult.data.asInstanceOf[Iterator[T]]  
    24.         new InterruptibleIterator[T](context, iter) {  
    25.           override def next(): T = {  
    26.             existingMetrics.incRecordsRead(1)  
    27.             delegate.next()  
    28.           }  
    29.         }  
    30.           
    31.       // 如果没有,则需要计算  
    32.       case None =>  
    33.           
    34.         // Acquire a lock for loading this partition  
    35.         // If another thread already holds the lock, wait for it to finish return its results  
    36.           
    37.         // 首先需要为load该分区申请锁,如果其它线程已经获取对应的锁,那么该线程则会一直等待其他线程处理完毕后的返回结果,然后直接返回这个结果即可  
    38.         val storedValues = acquireLockForPartition[T](key)  
    39.         if (storedValues.isDefined) {// 如果storedValues被定义的话,直接返回结果  
    40.           return new InterruptibleIterator[T](context, storedValues.get)  
    41.         }  
    42.   
    43.         // Otherwise, we have to load the partition ourselves  
    44.         // 当获得了锁后,我们不得不自己load分区  
    45.         try {  
    46.           logInfo(s"Partition $key not found, computing it")  
    47.           // 调用RDD的computeOrReadCheckpoint()方法进行计算  
    48.           val computedValues = rdd.computeOrReadCheckpoint(partition, context)  
    49.   
    50.           // If the task is running locally, do not persist the result  
    51.           // 如果task是本地运行,不需要持久化数据,直接返回  
    52.           if (context.isRunningLocally) {  
    53.             return computedValues  
    54.           }  
    55.   
    56.           // Otherwise, cache the values and keep track of any updates in block statuses  
    57.           // 否则,需要缓存结果,并对block状态的更新保持追踪  
    58.           val updatedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]  
    59.           val cachedValues = putInBlockManager(key, computedValues, storageLevel, updatedBlocks)  
    60.           val metrics = context.taskMetrics  
    61.           val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())  
    62.           metrics.updatedBlocks = Some(lastUpdatedBlocks ++ updatedBlocks.toSeq)  
    63.           new InterruptibleIterator(context, cachedValues)  
    64.   
    65.         } finally {  
    66.           loading.synchronized {  
    67.             loading.remove(key)  
    68.             loading.notifyAll()  
    69.           }  
    70.         }  
    71.     }  
    72.   }  

            getOrCompute()方法的大体逻辑如下:

            1、通过rdd的id和分区的索引号,获取RDDBlockId类型的key;

            2、在blockManager中根据key查找:

                  2.1、如果为blockResult,意味着分区Partition已经被物化,直接获取结果即可;

                  2.2、如果没有,则需要计算:

                           2.2.1、首先需要为load该分区申请锁,如果其它线程已经获取对应的锁,那么该线程则会一直等待其他线程处理完毕后的返回结果,然后直接返回这个结果即可;

                           2.2.2、当获得了锁后,我们不得不自己load分区:

                                        2.2.2.1、调用RDD的computeOrReadCheckpoint()方法进行计算,得到computedValues;

                                        2.2.2.2、如果task是本地运行,不需要持久化数据,直接返回;

                                        2.2.2.3、否则,需要缓存结果,并对block状态的更新保持追踪。

            然后,问题又统一性的扔给了RDD的computeOrReadCheckpoint()方法,我们来看下它的实现:

    [java] view plain copy
     
     在CODE上查看代码片派生到我的代码片
    1. /** 
    2.    * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing. 
    3.    * 计算一个RDD分区,或者如果该RDD正在做checkpoint,直接读取 
    4.    */  
    5.   private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =  
    6.   {  
    7.     if (isCheckpointedAndMaterialized) {  
    8.       firstParent[T].iterator(split, context)  
    9.     } else {  
    10.       compute(split, context)  
    11.     }  
    12.   }  

            哦,它原来是调用RDD的compute()方法(其实,通过读了那么多Spark介绍的文章,我早就知道了,这里故作深沉,想真正探寻下它是如何调用到compute()方法的)。

            接下来,我们再深入分析下两种Task的执行流程中涉及到的公共部分:反序列化器。它是通过SparkEnv的closureSerializer来获取的,而在SparkEnv中,是如何定义closureSerializer的呢?代码如下:

    [java] view plain copy
     
     在CODE上查看代码片派生到我的代码片
    1. val closureSerializer = instantiateClassFromConf[Serializer](  
    2.       "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")  

            也就是说,它实际上取得是参数spark.closure.serializer配置的类,默认是org.apache.spark.serializer.JavaSerializer类。而接下来的instantiateClassFromConf()方法很简单,就是从配置中实例化class得到对象,其定义如下:

    [java] view plain copy
     
     在CODE上查看代码片派生到我的代码片
    1. // Create an instance of the class named by the given SparkConf property, or defaultClassName  
    2.     // if the property is not set, possibly initializing it with our conf  
    3.     def instantiateClassFromConf[T](propertyName: String, defaultClassName: String): T = {  
    4.       instantiateClass[T](conf.get(propertyName, defaultClassName))  
    5.     }  

            继续看instantiateClass()方法,它是根据指定name来创建一个类的实例,代码如下:

    [java] view plain copy
     
     在CODE上查看代码片派生到我的代码片
    1. // Create an instance of the class with the given name, possibly initializing it with our conf  
    2.     def instantiateClass[T](className: String): T = {  
    3.       val cls = Utils.classForName(className)  
    4.       // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just  
    5.       // SparkConf, then one taking no arguments  
    6.       try {  
    7.         cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)  
    8.           .newInstance(conf, new java.lang.Boolean(isDriver))  
    9.           .asInstanceOf[T]  
    10.       } catch {  
    11.         case _: NoSuchMethodException =>  
    12.           try {  
    13.             cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]  
    14.           } catch {  
    15.             case _: NoSuchMethodException =>  
    16.               cls.getConstructor().newInstance().asInstanceOf[T]  
    17.           }  
    18.       }  
    19.     }  

           同过类名来获得类,并调用其构造方法进行对象的构造。我们看下序列化器的默认实现org.apache.spark.serializer.JavaSerializer的deserialize()方法,代码如下:

    [java] view plain copy
     
     在CODE上查看代码片派生到我的代码片
    1. override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = {  
    2.     val bis = new ByteBufferInputStream(bytes)  
    3.     val in = deserializeStream(bis, loader)  
    4.     in.readObject()  
    5.   }  

            首先,通过ByteBuffer类型的bytes构造ByteBufferInputStream类型的bis;

            其次,调用deserializeStream()方法,获得反序列化输入流in;

            最后,通过反序列化输入流in的readObject()方法获得对象。

            经历了上述过程,RDD、ShuffleDependency或者RDD、FUNC就不难获取到了。

            先发表出来,余下的一些细节,或者没有讲到的部分,未完待续吧!

    博客原地址:http://blog.csdn.net/lipeng_bigdata/article/details/50752101

  • 相关阅读:
    有关 JavaScript 的 10 件让人费解的事情
    Apache ab介绍1
    Oracle Raw,number,varchar2... 转换
    Flex开发者需要知道的10件事
    linux命令之nice
    JavaIO复习和目录文件的复制
    使用php获取网页内容
    linux 安装sysstat使用iostat、mpstat、sar、sa
    SQL Injection 实战某基金
    ubuntu root锁屏工具
  • 原文地址:https://www.cnblogs.com/jirimutu01/p/5274463.html
Copyright © 2011-2022 走看看