zoukankan      html  css  js  c++  java
  • Spark SQL(9)-Spark SQL JOIN操作源码总结

    Spark SQL(9)-Spark SQL JOIN操作源码总结

    本文主要总结下spark sql join操作的实现,本文会根据spark sql 的源码来总结其具体的实现;大体流程还是从sql语句到逻辑算子树再到analyzed-> optimized -> 物理计划及其处理逻辑进行大致的总结。

    Join逻辑算子树

    先来一个sql:

    SELECT NAME FROM NAME LEFT 
    JOIN NAME2 ON NAME = NAME
    JOIN NAME3 ON NAME = NAME

    这条sql形成的逻辑算子树为:

    上图的树结构的生成;主要关注join部分就可以;其源码在AstBuilder中:

     override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) {
        val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) =>
          val right = plan(relation.relationPrimary)
          val join = right.optionalMap(left)(Join(_, _, Inner, None))
          withJoinRelations(join, relation)
        }
        ctx.lateralView.asScala.foldLeft(from)(withGenerate)
      }
    

      

      private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = {
        val pp = ctx.joinRelation
        pp.asScala.foldLeft(base) { (left, join) =>
          withOrigin(join) {
            val baseJoinType = join.joinType match {
              case null => Inner
              case jt if jt.CROSS != null => Cross
              case jt if jt.FULL != null => FullOuter
              case jt if jt.SEMI != null => LeftSemi
              case jt if jt.ANTI != null => LeftAnti
              case jt if jt.LEFT != null => LeftOuter
              case jt if jt.RIGHT != null => RightOuter
              case _ => Inner
            }
    
            // Resolve the join type and join condition
            val (joinType, condition) = Option(join.joinCriteria) match {
              case Some(c) if c.USING != null =>
                (UsingJoin(baseJoinType, c.identifier.asScala.map(_.getText)), None)
              case Some(c) if c.booleanExpression != null =>
                (baseJoinType, Option(expression(c.booleanExpression)))
              case None if join.NATURAL != null =>
                if (baseJoinType == Cross) {
                  throw new ParseException("NATURAL CROSS JOIN is not supported", ctx)
                }
                (NaturalJoin(baseJoinType), None)
              case None =>
                (baseJoinType, None)
            }
            Join(left, plan(join.right), joinType, condition)
          }
        }
      }
    

     从上图可以看出来对于join的操作,形成的树结构里面,保存的join关系是一个list<JoinReleation>,每个joinRelation包含了JoinType、relationPrimary以及joinCriteria;其中joinCriteria相当于是booleanExpression操作。

       之后就是Join Analyzed 以及optimized 操作,在这里俩步主要操作就是添加子查询别名等操作,之后在优化阶段算子下推、消除子查询别名等优化;这里面涉及的规则比较多,感兴趣的同学可以查看源码多研究研究;

    物理计划阶段

        这一步主要涉及到 SparkPlanner 中配置的各种strategies,在这些策略中主要关注JoinSelection部分就行,他的apply方如下:

       def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    
          // --- BroadcastHashJoin --------------------------------------------------------------------
    
          // broadcast hints were specified
          case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
            if canBroadcastByHints(joinType, left, right) =>
            val buildSide = broadcastSideByHints(joinType, left, right)
            Seq(joins.BroadcastHashJoinExec(
              leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right)))
    
          // broadcast hints were not specified, so need to infer it from size and configuration.
          case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
            if canBroadcastBySizes(joinType, left, right) =>
            val buildSide = broadcastSideBySizes(joinType, left, right)
            Seq(joins.BroadcastHashJoinExec(
              leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right)))
    
          // --- ShuffledHashJoin ---------------------------------------------------------------------
    
          case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
             if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right)
               && muchSmaller(right, left) ||
               !RowOrdering.isOrderable(leftKeys) =>
            Seq(joins.ShuffledHashJoinExec(
              leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right)))
    
          case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
             if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left)
               && muchSmaller(left, right) ||
               !RowOrdering.isOrderable(leftKeys) =>
            Seq(joins.ShuffledHashJoinExec(
              leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right)))
    
          // --- SortMergeJoin ------------------------------------------------------------
    
          case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
            if RowOrdering.isOrderable(leftKeys) =>
            joins.SortMergeJoinExec(
              leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
    
          // --- Without joining keys ------------------------------------------------------------
    
          // Pick BroadcastNestedLoopJoin if one side could be broadcast
          case j @ logical.Join(left, right, joinType, condition)
              if canBroadcastByHints(joinType, left, right) =>
            val buildSide = broadcastSideByHints(joinType, left, right)
            joins.BroadcastNestedLoopJoinExec(
              planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
    
          case j @ logical.Join(left, right, joinType, condition)
              if canBroadcastBySizes(joinType, left, right) =>
            val buildSide = broadcastSideBySizes(joinType, left, right)
            joins.BroadcastNestedLoopJoinExec(
              planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
    
          // Pick CartesianProduct for InnerJoin
          case logical.Join(left, right, _: InnerLike, condition) =>
            joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil
    
          case logical.Join(left, right, joinType, condition) =>
            val buildSide = broadcastSide(
              left.stats.hints.broadcast, right.stats.hints.broadcast, left, right)
            // This join could be very slow or OOM
            joins.BroadcastNestedLoopJoinExec(
              planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
    
          // --- Cases where this strategy does not apply ---------------------------------------------
    
          case _ => Nil
        }
      }
    

      从上面的代码可以看出其根据不同的条件生成不同的join操作:BroadcastHashJoinExec、ShuffledHashJoinExec、SortMergeJoinExec、BroadcastNestedLoopJoinExec;

          在介绍在四个操作之前,先介绍下join操作实现的大体思想:

          假设有俩张表,在spark中进行操作的时候;

          一张表为流表;一张表为构建表;默认的大表为流表,小表为构建表;基于流表的迭代,然后和构建表进行匹配,生成join之后的行数据。其实可以想象一种极端情况;大表特别的大有几百万行数据,小表数据只有10行,这个时候只需要迭代遍历流表,然后去小表(构建表)去匹配数据,匹配到之后生成join完成之后的行;

          在spark中join的大体实现是分流表和构建表;基于这俩个角色来实现join操作。接下来简单介绍下上面的几种join操作:

          1、BroadcastHashJoinExec主要通过广播形式实现join操作;其生成的条件是:一种是标记了hint;并且可以创建构建右表或者构建左表;另外一种是小表小于配置的spark.sql.autoBroadcastJoinThreshold参数的大小,则会进行基于广播的join;这里面spark会先将构建表的数据拉倒driver端,之后再分发到各个worker节点,所以这一步如果构建表比较大的情况下对spark的driver节点来说可能会有压力。

          2、ShuffledHashJoinExec 通过shuffle之后在内存中保存join构建表来实现join操作;其生成的条件是:可以构建左表或者右表,其次表的大小小于分区数和配置的广播参数的乘积(保证可以加载到本地内存进行计算),并且打开了优先考虑基于hash join的开关、其次需要保证构建表足够小(构建表*3小于流表);其主要思想就是对流表进行迭代,之后和内存中的构建表数据匹配生成join之后的行数据。

          3、SortMergeJoinExec 通过shuffle操作之后进行排序,再然后进行基于排序的join操作;如果上述俩个都不满足的情况就会进行就排序的join(前提是可以排序);排序的join就是先对数据进行shuffle分区,保证相同的key分到相同的分区,之后进行排序操作,保证数据有序,之后进行merge join操作,同时读取流表和构建表;因为数据有序,所以只要顺序遍历流表和构建表;匹配生成join行数据就行

          4、BroadcastNestedLoopJoinExec 主要针对的是没有join条件的连接操作;暂时不做研究;

    接下来主要总结下hashJoin和SortMergeJoinExec的实现逻辑;

          ShuffledHashJoinExec

          

      private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
        val buildDataSize = longMetric("buildDataSize")
        val buildTime = longMetric("buildTime")
        val start = System.nanoTime()
        val context = TaskContext.get()
        val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
        buildTime += (System.nanoTime() - start) / 1000000
        buildDataSize += relation.estimatedSize
        // This relation is usually used until the end of task.
        context.addTaskCompletionListener(_ => relation.close())
        relation
      }
    
      protected override def doExecute(): RDD[InternalRow] = {
        val numOutputRows = longMetric("numOutputRows")
        val avgHashProbe = longMetric("avgHashProbe")
        streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
          val hashed = buildHashedRelation(buildIter)
          join(streamIter, hashed, numOutputRows, avgHashProbe)
        }
      }
    

      先看上面的doExecute方法,一般物理计划都是触发这个方法来执行的,这里主要的逻辑是调用了buildHashedRelation方法,在这个方法中主要关注HashedRelation就行:

    private[execution] object HashedRelation {
    
      /**
       * Create a HashedRelation from an Iterator of InternalRow.
       */
      def apply(
          input: Iterator[InternalRow],
          key: Seq[Expression],
          sizeEstimate: Int = 64,
          taskMemoryManager: TaskMemoryManager = null): HashedRelation = {
        val mm = Option(taskMemoryManager).getOrElse {
          new TaskMemoryManager(
            new StaticMemoryManager(
              new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
              Long.MaxValue,
              Long.MaxValue,
              1),
            0)
        }
    
        if (key.length == 1 && key.head.dataType == LongType) {
          LongHashedRelation(input, key, sizeEstimate, mm)
        } else {
          UnsafeHashedRelation(input, key, sizeEstimate, mm)
        }
      }
    }
    

      这里面根据类型dataType如果是long那么就生成LongHashedRelation(基于LongToUnsafeRowMap实现),如果不是就是UnsafeHashedRelation(基于BytesToBytesMap实现)这里主要关注UnsafeHashedRelation就行:

    private[joins] object UnsafeHashedRelation {
    
      def apply(
          input: Iterator[InternalRow],
          key: Seq[Expression],
          sizeEstimate: Int,
          taskMemoryManager: TaskMemoryManager): HashedRelation = {
    
        val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes)
          .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m"))
    
        val binaryMap = new BytesToBytesMap(
          taskMemoryManager,
          // Only 70% of the slots can be used before growing, more capacity help to reduce collision
          (sizeEstimate * 1.5 + 1).toInt,
          pageSizeBytes,
          true)
    
        // Create a mapping of buildKeys -> rows
        val keyGenerator = UnsafeProjection.create(key)
        var numFields = 0
        while (input.hasNext) {
          val row = input.next().asInstanceOf[UnsafeRow]
          numFields = row.numFields()
          val key = keyGenerator(row)
          if (!key.anyNull) {
            val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
            val success = loc.append(
              key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
              row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
            if (!success) {
              binaryMap.free()
              throw new SparkException("There is no enough memory to build hash map")
            }
          }
        }
    
        new UnsafeHashedRelation(numFields, binaryMap)
      }
    

      从上面的代码可以看出,这里主要是根据从ShuffledHashJoinExec传过来的buildKeys,构建一个基于buildKeys和rows的映射表,其实就是上面提到的构建表。这里准备好构建表之后,回到上面提到的ShuffledHashJoinExec.doExecute中可以看到:

    protected override def doExecute(): RDD[InternalRow] = {
        val numOutputRows = longMetric("numOutputRows")
        val avgHashProbe = longMetric("avgHashProbe")
        streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
          val hashed = buildHashedRelation(buildIter)
          join(streamIter, hashed, numOutputRows, avgHashProbe)
        }
      }
    

      可以看到基于streamIter(流表)、hashed(构建表)构成了一个join操作:

      protected def join(
          streamedIter: Iterator[InternalRow],
          hashed: HashedRelation,
          numOutputRows: SQLMetric,
          avgHashProbe: SQLMetric): Iterator[InternalRow] = {
    
        val joinedIter = joinType match {
          case _: InnerLike =>
            innerJoin(streamedIter, hashed)
          case LeftOuter | RightOuter =>
            outerJoin(streamedIter, hashed)
          case LeftSemi =>
            semiJoin(streamedIter, hashed)
          case LeftAnti =>
            antiJoin(streamedIter, hashed)
          case j: ExistenceJoin =>
            existenceJoin(streamedIter, hashed)
          case x =>
            throw new IllegalArgumentException(
              s"BroadcastHashJoin should not take $x as the JoinType")
        }
    
        // At the end of the task, we update the avg hash probe.
        TaskContext.get().addTaskCompletionListener(_ =>
          avgHashProbe.set(hashed.getAverageProbesPerLookup))
    
        val resultProj = createResultProjection
        joinedIter.map { r =>
          numOutputRows += 1
          resultProj(r)
        }
      }
    

      这里可以看看innerJoin的操作:

     private def innerJoin(
          streamIter: Iterator[InternalRow],
          hashedRelation: HashedRelation): Iterator[InternalRow] = {
        val joinRow = new JoinedRow
        val joinKeys = streamSideKeyGenerator()
        streamIter.flatMap { srow =>
          joinRow.withLeft(srow)
          val matches = hashedRelation.get(joinKeys(srow))
          if (matches != null) {
            matches.map(joinRow.withRight(_)).filter(boundCondition)
          } else {
            Seq.empty
          }
        }
      }
    

       可以看出,遍历流表,从构建表获取相同的key,如果不为空就构建joinRow,并应用join的条件进行筛选。到这里整个hash join的实现就算是完成了。对于其他类型的join可以自己跟代码阅读。

         SortMergeJoinExec

          doExecute方法如下:

     protected override def doExecute(): RDD[InternalRow] = {
        val numOutputRows = longMetric("numOutputRows")
        val spillThreshold = getSpillThreshold
        val inMemoryThreshold = getInMemoryThreshold
        left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
          val boundCondition: (InternalRow) => Boolean = {
            condition.map { cond =>
              newPredicate(cond, left.output ++ right.output).eval _
            }.getOrElse {
              (r: InternalRow) => true
            }
          }
    
          // An ordering that can be used to compare keys from both sides.
          val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
          val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)
    
          joinType match {
            case _: InnerLike =>
              new RowIterator {
                private[this] var currentLeftRow: InternalRow = _
                private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _
                private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null
                private[this] val smjScanner = new SortMergeJoinScanner(
                  createLeftKeyGenerator(),
                  createRightKeyGenerator(),
                  keyOrdering,
                  RowIterator.fromScala(leftIter),
                  RowIterator.fromScala(rightIter),
                  inMemoryThreshold,
                  spillThreshold
                )
                private[this] val joinRow = new JoinedRow
    
                if (smjScanner.findNextInnerJoinRows()) {
                  currentRightMatches = smjScanner.getBufferedMatches
                  currentLeftRow = smjScanner.getStreamedRow
                  rightMatchesIterator = currentRightMatches.generateIterator()
                }
    
                override def advanceNext(): Boolean = {
                  while (rightMatchesIterator != null) {
                    if (!rightMatchesIterator.hasNext) {
                      if (smjScanner.findNextInnerJoinRows()) {
                        currentRightMatches = smjScanner.getBufferedMatches
                        currentLeftRow = smjScanner.getStreamedRow
                        rightMatchesIterator = currentRightMatches.generateIterator()
                      } else {
                        currentRightMatches = null
                        currentLeftRow = null
                        rightMatchesIterator = null
                        return false
                      }
                    }
                    joinRow(currentLeftRow, rightMatchesIterator.next())
                    if (boundCondition(joinRow)) {
                      numOutputRows += 1
                      return true
                    }
                  }
                  false
                }
    
                override def getRow: InternalRow = resultProj(joinRow)
              }.toScala
    
            case LeftOuter =>
              val smjScanner = new SortMergeJoinScanner(
                streamedKeyGenerator = createLeftKeyGenerator(),
                bufferedKeyGenerator = createRightKeyGenerator(),
                keyOrdering,
                streamedIter = RowIterator.fromScala(leftIter),
                bufferedIter = RowIterator.fromScala(rightIter),
                inMemoryThreshold,
                spillThreshold
              )
              val rightNullRow = new GenericInternalRow(right.output.length)
              new LeftOuterIterator(
                smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala
    
            case RightOuter =>
              val smjScanner = new SortMergeJoinScanner(
                streamedKeyGenerator = createRightKeyGenerator(),
                bufferedKeyGenerator = createLeftKeyGenerator(),
                keyOrdering,
                streamedIter = RowIterator.fromScala(rightIter),
                bufferedIter = RowIterator.fromScala(leftIter),
                inMemoryThreshold,
                spillThreshold
              )
              val leftNullRow = new GenericInternalRow(left.output.length)
              new RightOuterIterator(
                smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala
    
            case FullOuter =>
              val leftNullRow = new GenericInternalRow(left.output.length)
              val rightNullRow = new GenericInternalRow(right.output.length)
              val smjScanner = new SortMergeFullOuterJoinScanner(
                leftKeyGenerator = createLeftKeyGenerator(),
                rightKeyGenerator = createRightKeyGenerator(),
                keyOrdering,
                leftIter = RowIterator.fromScala(leftIter),
                rightIter = RowIterator.fromScala(rightIter),
                boundCondition,
                leftNullRow,
                rightNullRow)
    
              new FullOuterIterator(
                smjScanner,
                resultProj,
                numOutputRows).toScala
    
            case LeftSemi =>
              new RowIterator {
                private[this] var currentLeftRow: InternalRow = _
                private[this] val smjScanner = new SortMergeJoinScanner(
                  createLeftKeyGenerator(),
                  createRightKeyGenerator(),
                  keyOrdering,
                  RowIterator.fromScala(leftIter),
                  RowIterator.fromScala(rightIter),
                  inMemoryThreshold,
                  spillThreshold
                )
                private[this] val joinRow = new JoinedRow
    
                override def advanceNext(): Boolean = {
                  while (smjScanner.findNextInnerJoinRows()) {
                    val currentRightMatches = smjScanner.getBufferedMatches
                    currentLeftRow = smjScanner.getStreamedRow
                    if (currentRightMatches != null && currentRightMatches.length > 0) {
                      val rightMatchesIterator = currentRightMatches.generateIterator()
                      while (rightMatchesIterator.hasNext) {
                        joinRow(currentLeftRow, rightMatchesIterator.next())
                        if (boundCondition(joinRow)) {
                          numOutputRows += 1
                          return true
                        }
                      }
                    }
                  }
                  false
                }
    
                override def getRow: InternalRow = currentLeftRow
              }.toScala
    
            case LeftAnti =>
              new RowIterator {
                private[this] var currentLeftRow: InternalRow = _
                private[this] val smjScanner = new SortMergeJoinScanner(
                  createLeftKeyGenerator(),
                  createRightKeyGenerator(),
                  keyOrdering,
                  RowIterator.fromScala(leftIter),
                  RowIterator.fromScala(rightIter),
                  inMemoryThreshold,
                  spillThreshold
                )
                private[this] val joinRow = new JoinedRow
    
                override def advanceNext(): Boolean = {
                  while (smjScanner.findNextOuterJoinRows()) {
                    currentLeftRow = smjScanner.getStreamedRow
                    val currentRightMatches = smjScanner.getBufferedMatches
                    if (currentRightMatches == null || currentRightMatches.length == 0) {
                      numOutputRows += 1
                      return true
                    }
                    var found = false
                    val rightMatchesIterator = currentRightMatches.generateIterator()
                    while (!found && rightMatchesIterator.hasNext) {
                      joinRow(currentLeftRow, rightMatchesIterator.next())
                      if (boundCondition(joinRow)) {
                        found = true
                      }
                    }
                    if (!found) {
                      numOutputRows += 1
                      return true
                    }
                  }
                  false
                }
    
                override def getRow: InternalRow = currentLeftRow
              }.toScala
    
            case j: ExistenceJoin =>
              new RowIterator {
                private[this] var currentLeftRow: InternalRow = _
                private[this] val result: InternalRow = new GenericInternalRow(Array[Any](null))
                private[this] val smjScanner = new SortMergeJoinScanner(
                  createLeftKeyGenerator(),
                  createRightKeyGenerator(),
                  keyOrdering,
                  RowIterator.fromScala(leftIter),
                  RowIterator.fromScala(rightIter),
                  inMemoryThreshold,
                  spillThreshold
                )
                private[this] val joinRow = new JoinedRow
    
                override def advanceNext(): Boolean = {
                  while (smjScanner.findNextOuterJoinRows()) {
                    currentLeftRow = smjScanner.getStreamedRow
                    val currentRightMatches = smjScanner.getBufferedMatches
                    var found = false
                    if (currentRightMatches != null && currentRightMatches.length > 0) {
                      val rightMatchesIterator = currentRightMatches.generateIterator()
                      while (!found && rightMatchesIterator.hasNext) {
                        joinRow(currentLeftRow, rightMatchesIterator.next())
                        if (boundCondition(joinRow)) {
                          found = true
                        }
                      }
                    }
                    result.setBoolean(0, found)
                    numOutputRows += 1
                    return true
                  }
                  false
                }
    
                override def getRow: InternalRow = resultProj(joinRow(currentLeftRow, result))
              }.toScala
    
            case x =>
              throw new IllegalArgumentException(
                s"SortMergeJoin should not take $x as the JoinType")
          }
    
        }
      }
    

          这里首先看下InnerLike分支下的实现:

               具体逻辑很简单:

               实例化了一个SortMergeJoinScanner,具体实现可以看实现的advanceNext方法,调用findNextInnerJoinRows找到下一行可以join的数据;这里面:

               1、currentLeftRow相当于是流表数据,触发是:smjScanner.getStreamedRow

               2、currentRightMatches相当于是构建表数据,触发是:smjScanner.getBufferedMatches

               3、advanceNext这里面主要就是findNextInnerJoinRows方法,如果返回true那么就是有新行,直接重置1、2的值,然后构建joinRow,之后再应用过滤条件

               4、findNextInnerJoinRows:

     final def findNextInnerJoinRows(): Boolean = {
        while (advancedStreamed() && streamedRowKey.anyNull) {
          // Advance the streamed side of the join until we find the next row whose join key contains
          // no nulls or we hit the end of the streamed iterator.
        }
        if (streamedRow == null) {
          // We have consumed the entire streamed iterator, so there can be no more matches.
          matchJoinKey = null
          bufferedMatches.clear()
          false
        } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) {
          // The new streamed row has the same join key as the previous row, so return the same matches.
          true
        } else if (bufferedRow == null) {
          // The streamed row's join key does not match the current batch of buffered rows and there are
          // no more rows to read from the buffered iterator, so there can be no more matches.
          matchJoinKey = null
          bufferedMatches.clear()
          false
        } else {
          // Advance both the streamed and buffered iterators to find the next pair of matching rows.
          var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
          do {
            if (streamedRowKey.anyNull) {
              advancedStreamed()
            } else {
              assert(!bufferedRowKey.anyNull)
              comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
              if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey()
              else if (comp < 0) advancedStreamed()
            }
          } while (streamedRow != null && bufferedRow != null && comp != 0)
          if (streamedRow == null || bufferedRow == null) {
            // We have either hit the end of one of the iterators, so there can be no more matches.
            matchJoinKey = null
            bufferedMatches.clear()
            false
          } else {
            // The streamed row's join key matches the current buffered row's join, so walk through the
            // buffered iterator to buffer the rest of the matching rows.
            assert(comp == 0)
            bufferMatchingRows()
            true
          }
        }
      }
    

      主要逻辑如下:

          如果流表为空直接返回,

          如何流表的行可以和当前的缓存matchJoinKey对应上,则返回true;

          如果构建表为空,直接返回false;

          之后具体逻辑在do while中,首先还是校验;之后对流表和构建表数据的key进行比对,如果大于0;则重新拿构建表的数据,如果小于0,就拿流表的数据,如果不是就循环,直到俩个key相同,或者俩个表为空;之后会一直添加bufferedMatches(相当于对拥有同一个key的构建表数据进行append操作,加入bufferedMatches中);

          其次在bufferMatchingRows方法中记录了matchJoinKey,之后再调用findNextInnerJoinRows的时候,如果发现新的流表key和matchJoinKey相同直接返回true,进行join操作。

          关于LeftOuter和RightOuter主要实现是基于LeftOuterIterator和RightOuterIterator,这俩个是OneSideOuterIterator的具体实现,其实依赖SortMergeJoinScanner.findNextOuterJoinRows来判断流表和构建表的key,然后进行相应的处理;这俩个主要实现setBufferedSideOutput和setStreamSideOutput这俩个方法,之后的逻辑都在advanceStream中。

          对于FullOuter主要实现就是FullOuterIterator,这里:

    private class FullOuterIterator(
        smjScanner: SortMergeFullOuterJoinScanner,
        resultProj: InternalRow => InternalRow,
        numRows: SQLMetric) extends RowIterator {
      private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow()
    
      override def advanceNext(): Boolean = {
        val r = smjScanner.advanceNext()
        if (r) numRows += 1
        r
      }
    
      override def getRow: InternalRow = resultProj(joinedRow)
    }
    

      这么看FullOuter的实现倒是最简单的;

          因为返回的是一个迭代器,所以在查看源码的时候,主要关注advanceNext方法的实现,根据这个可以追溯到整个的join的过程。

          总结,这里主要简单总结了下spark join的实现思想。具体的实现细节还是要深入代码去了解,比如SortMergeJoinExec中,他的溢出是基于什么的?这个其实在SortMergeJoinScanner

    中的ExternalAppendOnlyUnsafeRowArray,他基于UnsafeExternalSorter来实现对应的溢写操作。

         

          

      

  • 相关阅读:
    Java面向对象知识点总结
    JAVA编程必学必会单词集(1)
    Linux 帮助命令
    学习笔记
    day4
    复习
    day5
    day04
    day3
    day02
  • 原文地址:https://www.cnblogs.com/ldsggv/p/13504558.html
Copyright © 2011-2022 走看看