zoukankan      html  css  js  c++  java
  • 【原创】大数据基础之Spark(8)Spark中Join实现原理

    spark中join有两种,一种是RDD的join,一种是sql中的join,分别来看:

    1 RDD join

    org.apache.spark.rdd.PairRDDFunctions

      /**
       * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
       * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
       * (k, v2) is in `other`. Performs a hash join across the cluster.
       */
      def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = self.withScope {
        join(other, defaultPartitioner(self, other))
      }
    
      /**
       * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
       * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
       * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD.
       */
      def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = self.withScope {
        this.cogroup(other, partitioner).flatMapValues( pair =>
          for (v <- pair._1.iterator; w <- pair._2.iterator) yield (v, w)
        )
      }
    
      /**
       * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
       * list of values for that key in `this` as well as `other`.
       */
      def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner)
          : RDD[(K, (Iterable[V], Iterable[W]))] = self.withScope {
        if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) {
          throw new SparkException("HashPartitioner cannot partition array keys.")
        }
        val cg = new CoGroupedRDD[K](Seq(self, other), partitioner)
        cg.mapValues { case Array(vs, w1s) =>
          (vs.asInstanceOf[Iterable[V]], w1s.asInstanceOf[Iterable[W]])
        }
      }

    join操作会返回CoGroupedRDD,CoGroupedRDD构造参数为rdd数组,即多个需要join的rdd,下面看CoGroupedRDD:

    org.apache.spark.rdd.CoGroupedRDD

    class CoGroupedRDD[K: ClassTag](
        @transient var rdds: Seq[RDD[_ <: Product2[K, _]]],
        part: Partitioner)
      extends RDD[(K, Array[Iterable[_]])](rdds.head.context, Nil) {
    
      override def getDependencies: Seq[Dependency[_]] = {
        rdds.map { rdd: RDD[_] =>
          if (rdd.partitioner == Some(part)) {
            logDebug("Adding one-to-one dependency with " + rdd)
            new OneToOneDependency(rdd)
          } else {
            logDebug("Adding shuffle dependency with " + rdd)
            new ShuffleDependency[K, Any, CoGroupCombiner](
              rdd.asInstanceOf[RDD[_ <: Product2[K, _]]], part, serializer)
          }
        }
      }
    
      override def compute(s: Partition, context: TaskContext): Iterator[(K, Array[Iterable[_]])] = {
        val split = s.asInstanceOf[CoGroupPartition]
        val numRdds = dependencies.length
    
        // A list of (rdd iterator, dependency number) pairs
        val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)]
        for ((dep, depNum) <- dependencies.zipWithIndex) dep match {
          case oneToOneDependency: OneToOneDependency[Product2[K, Any]] @unchecked =>
            val dependencyPartition = split.narrowDeps(depNum).get.split
            // Read them from the parent
            val it = oneToOneDependency.rdd.iterator(dependencyPartition, context)
            rddIterators += ((it, depNum))
    
          case shuffleDependency: ShuffleDependency[_, _, _] =>
            // Read map outputs of shuffle
            val it = SparkEnv.get.shuffleManager
              .getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context)
              .read()
            rddIterators += ((it, depNum))
        }
    
        val map = createExternalMap(numRdds)
        for ((it, depNum) <- rddIterators) {
          map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum))))
        }
        context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled)
        context.taskMetrics().incPeakExecutionMemory(map.peakMemoryUsedBytes)
        new InterruptibleIterator(context,
          map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]])
      }
    
      private def createExternalMap(numRdds: Int)
        : ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner] = {
    
        val createCombiner: (CoGroupValue => CoGroupCombiner) = value => {
          val newCombiner = Array.fill(numRdds)(new CoGroup)
          newCombiner(value._2) += value._1
          newCombiner
        }
        val mergeValue: (CoGroupCombiner, CoGroupValue) => CoGroupCombiner =
          (combiner, value) => {
          combiner(value._2) += value._1
          combiner
        }
        val mergeCombiners: (CoGroupCombiner, CoGroupCombiner) => CoGroupCombiner =
          (combiner1, combiner2) => {
            var depNum = 0
            while (depNum < numRdds) {
              combiner1(depNum) ++= combiner2(depNum)
              depNum += 1
            }
            combiner1
          }
        new ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner](
          createCombiner, mergeValue, mergeCombiners)
      }

    CoGroupedRDD首先将rdds逐个转化为dependency,然后将所有的dependency转化为rddIterators,最后通过ExternalAppendOnlyMap来实现合并;

    如果rdd需要shuffle,是通过ShuffleManager实现,ShuffleManager实现类为SortShuffleManager,shuffle过程详见:https://www.cnblogs.com/barneywill/p/10158457.html

    附:spark中dependency结构,即常说的宽依赖、窄依赖:

    org.apache.spark.Dependency

    Dependency

             NarrowDependency

                      OneToOneDependency

                      RangeDependency

             ShuffleDependency

    区别就是shuffle,不需要shuffle就是NarrowDependency,需要就是ShuffleDependency;

    2 sql join

    sql中的join有一个选择策略:

    org.apache.spark.sql.execution.SparkStrategies.JoinSelection

        def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    
          // --- BroadcastHashJoin --------------------------------------------------------------------
    
          case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
            if canBuildRight(joinType) && canBroadcast(right) =>
            Seq(joins.BroadcastHashJoinExec(
              leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right)))
    
          case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
            if canBuildLeft(joinType) && canBroadcast(left) =>
            Seq(joins.BroadcastHashJoinExec(
              leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right)))
    
          // --- ShuffledHashJoin ---------------------------------------------------------------------
    
          case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
             if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right)
               && muchSmaller(right, left) ||
               !RowOrdering.isOrderable(leftKeys) =>
            Seq(joins.ShuffledHashJoinExec(
              leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right)))
    
          case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
             if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left)
               && muchSmaller(left, right) ||
               !RowOrdering.isOrderable(leftKeys) =>
            Seq(joins.ShuffledHashJoinExec(
              leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right)))
    
          // --- SortMergeJoin ------------------------------------------------------------
    
          case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
            if RowOrdering.isOrderable(leftKeys) =>
            joins.SortMergeJoinExec(
              leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
    ...

    其中conf.preferSortMergeJoin

    org.apache.spark.sql.internal.SQLConf

      val PREFER_SORTMERGEJOIN = SQLConfigBuilder("spark.sql.join.preferSortMergeJoin")
        .internal()
        .doc("When true, prefer sort merge join over shuffle hash join.")
        .booleanConf
        .createWithDefault(true)

    配置spark.sql.join.preferSortMergeJoin,默认为true,即是否优先使用SortMergeJoin;

    可以看到join实现主要有3种,即BroadcastHashJoinExec、ShuffledHashJoinExec和SortMergeJoinExec,优先级为

    • 1 如果canBroadcast,则BroadcastHashJoinExec;
    • 2 如果spark.sql.join.preferSortMergeJoin=false,则ShuffledHashJoinExec;
    • 3 否则为SortMergeJoinExec;

    其中BroadcastHashJoinExec和ShuffledHashJoinExec都会用到HashJoin,先看HashJoin:

    2.1 HashJoin

    org.apache.spark.sql.execution.joins.HashJoin

      protected def join(
          streamedIter: Iterator[InternalRow],
          hashed: HashedRelation,
          numOutputRows: SQLMetric): Iterator[InternalRow] = {
    
        val joinedIter = joinType match {
          case _: InnerLike =>
            innerJoin(streamedIter, hashed)
          case LeftOuter | RightOuter =>
            outerJoin(streamedIter, hashed)
          case LeftSemi =>
            semiJoin(streamedIter, hashed)
          case LeftAnti =>
            antiJoin(streamedIter, hashed)
          case j: ExistenceJoin =>
            existenceJoin(streamedIter, hashed)
          case x =>
            throw new IllegalArgumentException(
              s"BroadcastHashJoin should not take $x as the JoinType")
        }
    
        val resultProj = createResultProjection
        joinedIter.map { r =>
          numOutputRows += 1
          resultProj(r)
        }
      }
    
      private def innerJoin(
          streamIter: Iterator[InternalRow],
          hashedRelation: HashedRelation): Iterator[InternalRow] = {
        val joinRow = new JoinedRow
        val joinKeys = streamSideKeyGenerator()
        streamIter.flatMap { srow =>
          joinRow.withLeft(srow)
          val matches = hashedRelation.get(joinKeys(srow))
          if (matches != null) {
            matches.map(joinRow.withRight(_)).filter(boundCondition)
          } else {
            Seq.empty
          }
        }
      }

    这里只贴出内关联,即innerJoin,代码比较简单,注意这里是内存操作,会在单个partition内部进行;

    2.2 BroadcastHashJoinExec

    org.apache.spark.sql.execution.joins.BroadcastHashJoinExec

      protected override def doExecute(): RDD[InternalRow] = {
        val numOutputRows = longMetric("numOutputRows")
    
        val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
        streamedPlan.execute().mapPartitions { streamedIter =>
          val hashed = broadcastRelation.value.asReadOnlyCopy()
          TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize)
          join(streamedIter, hashed, numOutputRows)
        }
      }

    这里会将buildPlan广播出去,然后在streamedPlan上通过mapPartitions在1个分区内部进行join,join方法见HashJoin;

    2.3 ShuffledHashJoinExec

    org.apache.spark.sql.execution.joins.ShuffledHashJoinExec

      protected override def doExecute(): RDD[InternalRow] = {
        val numOutputRows = longMetric("numOutputRows")
        streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
          val hashed = buildHashedRelation(buildIter)
          join(streamIter, hashed, numOutputRows)
        }
      }

    join过程为先将两个rdd(streamedPlan和buildPlan)进行zipPartitions,然后在1个partition内部join,join方法见HashJoin;

    2.4 SortMergeJoinExec

    org.apache.spark.sql.execution.joins.SortMergeJoinExec

      protected override def doExecute(): RDD[InternalRow] = {
        val numOutputRows = longMetric("numOutputRows")
    
        left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
          val boundCondition: (InternalRow) => Boolean = {
            condition.map { cond =>
              newPredicate(cond, left.output ++ right.output).eval _
            }.getOrElse {
              (r: InternalRow) => true
            }
          }
    
          // An ordering that can be used to compare keys from both sides.
          val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
          val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)
    
          joinType match {
            case _: InnerLike =>
              new RowIterator {
                private[this] var currentLeftRow: InternalRow = _
                private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
                private[this] var currentMatchIdx: Int = -1
                private[this] val smjScanner = new SortMergeJoinScanner(
                  createLeftKeyGenerator(),
                  createRightKeyGenerator(),
                  keyOrdering,
                  RowIterator.fromScala(leftIter),
                  RowIterator.fromScala(rightIter)
                )
                private[this] val joinRow = new JoinedRow
    
                if (smjScanner.findNextInnerJoinRows()) {
                  currentRightMatches = smjScanner.getBufferedMatches
                  currentLeftRow = smjScanner.getStreamedRow
                  currentMatchIdx = 0
                }
    
                override def advanceNext(): Boolean = {
                  while (currentMatchIdx >= 0) {
                    if (currentMatchIdx == currentRightMatches.length) {
                      if (smjScanner.findNextInnerJoinRows()) {
                        currentRightMatches = smjScanner.getBufferedMatches
                        currentLeftRow = smjScanner.getStreamedRow
                        currentMatchIdx = 0
                      } else {
                        currentRightMatches = null
                        currentLeftRow = null
                        currentMatchIdx = -1
                        return false
                      }
                    }
                    joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
                    currentMatchIdx += 1
                    if (boundCondition(joinRow)) {
                      numOutputRows += 1
                      return true
                    }
                  }
                  false
                }
    
                override def getRow: InternalRow = resultProj(joinRow)
              }.toScala
    ...

    和ShuffledHashJoinExec一样,同样先zipPartitions,然后在1个partition内部根据joinType返回不同的RowIterator实现类,上边代码包含内关联实现,大部分工作通过SortMergeJoinScanner实现

    org.apache.spark.sql.execution.joins.SortMergeJoinScanner

      final def findNextInnerJoinRows(): Boolean = {
        while (advancedStreamed() && streamedRowKey.anyNull) {
          // Advance the streamed side of the join until we find the next row whose join key contains
          // no nulls or we hit the end of the streamed iterator.
        }
        if (streamedRow == null) {
          // We have consumed the entire streamed iterator, so there can be no more matches.
          matchJoinKey = null
          bufferedMatches.clear()
          false
        } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) {
          // The new streamed row has the same join key as the previous row, so return the same matches.
          true
        } else if (bufferedRow == null) {
          // The streamed row's join key does not match the current batch of buffered rows and there are
          // no more rows to read from the buffered iterator, so there can be no more matches.
          matchJoinKey = null
          bufferedMatches.clear()
          false
        } else {
          // Advance both the streamed and buffered iterators to find the next pair of matching rows.
          var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
          do {
            if (streamedRowKey.anyNull) {
              advancedStreamed()
            } else {
              assert(!bufferedRowKey.anyNull)
              comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
              if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey()
              else if (comp < 0) advancedStreamed()
            }
          } while (streamedRow != null && bufferedRow != null && comp != 0)
          if (streamedRow == null || bufferedRow == null) {
            // We have either hit the end of one of the iterators, so there can be no more matches.
            matchJoinKey = null
            bufferedMatches.clear()
            false
          } else {
            // The streamed row's join key matches the current buffered row's join, so walk through the
            // buffered iterator to buffer the rest of the matching rows.
            assert(comp == 0)
            bufferMatchingRows()
            true
          }
        }
      }
    
      /**
       * Advance the streamed iterator and compute the new row's join key.
       * @return true if the streamed iterator returned a row and false otherwise.
       */
      private def advancedStreamed(): Boolean = {
        if (streamedIter.advanceNext()) {
          streamedRow = streamedIter.getRow
          streamedRowKey = streamedKeyGenerator(streamedRow)
          true
        } else {
          streamedRow = null
          streamedRowKey = null
          false
        }
      }
    
      /**
       * Advance the buffered iterator until we find a row with join key that does not contain nulls.
       * @return true if the buffered iterator returned a row and false otherwise.
       */
      private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = {
        var foundRow: Boolean = false
        while (!foundRow && bufferedIter.advanceNext()) {
          bufferedRow = bufferedIter.getRow
          bufferedRowKey = bufferedKeyGenerator(bufferedRow)
          foundRow = !bufferedRowKey.anyNull
        }
        if (!foundRow) {
          bufferedRow = null
          bufferedRowKey = null
          false
        } else {
          true
        }
      }
    
      /**
       * Called when the streamed and buffered join keys match in order to buffer the matching rows.
       */
      private def bufferMatchingRows(): Unit = {
        assert(streamedRowKey != null)
        assert(!streamedRowKey.anyNull)
        assert(bufferedRowKey != null)
        assert(!bufferedRowKey.anyNull)
        assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
        // This join key may have been produced by a mutable projection, so we need to make a copy:
        matchJoinKey = streamedRowKey.copy()
        bufferedMatches.clear()
        do {
          bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them
          advancedBufferedToRowWithNullFreeJoinKey()
        } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
      }

    可以看到过程和二路归并排序Binary Merge Sort差不多;

    附:RowIterator是一个抽象类,本质是一个接口,是一个常见的Iterator定义,如下:

    org.apache.spark.sql.execution.RowIterator

    abstract class RowIterator {
      /**
       * Advance this iterator by a single row. Returns `false` if this iterator has no more rows
       * and `true` otherwise. If this returns `true`, then the new row can be retrieved by calling
       * [[getRow]].
       */
      def advanceNext(): Boolean
    
      /**
       * Retrieve the row from this iterator. This method is idempotent. It is illegal to call this
       * method after [[advanceNext()]] has returned `false`.
       */
      def getRow: InternalRow
    
      /**
       * Convert this RowIterator into a [[scala.collection.Iterator]].
       */
      def toScala: Iterator[InternalRow] = new RowIteratorToScala(this)
    }
  • 相关阅读:
    java生成pdf文字水印和图片水印
    el-date-picker设置可选范围picker-options需要注意的事项,要不然可能会报undefined的错误
    Invalid prop: type check failed for prop "value". Expected String, Number, got Boolean with value false.
    el-table去掉最外层的边框线
    工业物联网之设备云控3 QuartzNet任务调度程序
    工业物联网之设备云控4 管理平台
    工业物联网之设备云控1 技术方案
    C# NModbus4实现PLC数据获取(参考HslCommunication)
    工业物联网之设备云控5 对接流程
    Mongdb数据备份和还原
  • 原文地址:https://www.cnblogs.com/barneywill/p/10187751.html
Copyright © 2011-2022 走看看