Spark SQL(9)-Spark SQL JOIN操作源码总结
本文主要总结下spark sql join操作的实现,本文会根据spark sql 的源码来总结其具体的实现;大体流程还是从sql语句到逻辑算子树再到analyzed-> optimized -> 物理计划及其处理逻辑进行大致的总结。
override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) => val right = plan(relation.relationPrimary) val join = right.optionalMap(left)(Join(_, _, Inner, None)) withJoinRelations(join, relation) } ctx.lateralView.asScala.foldLeft(from)(withGenerate) }
private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = { val pp = ctx.joinRelation pp.asScala.foldLeft(base) { (left, join) => withOrigin(join) { val baseJoinType = join.joinType match { case null => Inner case jt if jt.CROSS != null => Cross case jt if jt.FULL != null => FullOuter case jt if jt.SEMI != null => LeftSemi case jt if jt.ANTI != null => LeftAnti case jt if jt.LEFT != null => LeftOuter case jt if jt.RIGHT != null => RightOuter case _ => Inner } // Resolve the join type and join condition val (joinType, condition) = Option(join.joinCriteria) match { case Some(c) if c.USING != null => (UsingJoin(baseJoinType,, None) case Some(c) if c.booleanExpression != null => (baseJoinType, Option(expression(c.booleanExpression))) case None if join.NATURAL != null => if (baseJoinType == Cross) { throw new ParseException("NATURAL CROSS JOIN is not supported", ctx) } (NaturalJoin(baseJoinType), None) case None => (baseJoinType, None) } Join(left, plan(join.right), joinType, condition) } } }
之后就是Join Analyzed 以及optimized 操作,在这里俩步主要操作就是添加子查询别名等操作,之后在优化阶段算子下推、消除子查询别名等优化;这里面涉及的规则比较多,感兴趣的同学可以查看源码多研究研究;
这一步主要涉及到 SparkPlanner 中配置的各种strategies,在这些策略中主要关注JoinSelection部分就行,他的apply方如下:
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // --- BroadcastHashJoin -------------------------------------------------------------------- // broadcast hints were specified case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if canBroadcastByHints(joinType, left, right) => val buildSide = broadcastSideByHints(joinType, left, right) Seq(joins.BroadcastHashJoinExec( leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) // broadcast hints were not specified, so need to infer it from size and configuration. case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if canBroadcastBySizes(joinType, left, right) => val buildSide = broadcastSideBySizes(joinType, left, right) Seq(joins.BroadcastHashJoinExec( leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) // --- ShuffledHashJoin --------------------------------------------------------------------- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right) && muchSmaller(right, left) || !RowOrdering.isOrderable(leftKeys) => Seq(joins.ShuffledHashJoinExec( leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left) && muchSmaller(left, right) || !RowOrdering.isOrderable(leftKeys) => Seq(joins.ShuffledHashJoinExec( leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) // --- SortMergeJoin ------------------------------------------------------------ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => joins.SortMergeJoinExec( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil // --- Without joining keys ------------------------------------------------------------ // Pick BroadcastNestedLoopJoin if one side could be broadcast case j @ logical.Join(left, right, joinType, condition) if canBroadcastByHints(joinType, left, right) => val buildSide = broadcastSideByHints(joinType, left, right) joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case j @ logical.Join(left, right, joinType, condition) if canBroadcastBySizes(joinType, left, right) => val buildSide = broadcastSideBySizes(joinType, left, right) joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil // Pick CartesianProduct for InnerJoin case logical.Join(left, right, _: InnerLike, condition) => joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil case logical.Join(left, right, joinType, condition) => val buildSide = broadcastSide( left.stats.hints.broadcast, right.stats.hints.broadcast, left, right) // This join could be very slow or OOM joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil // --- Cases where this strategy does not apply --------------------------------------------- case _ => Nil } }
2、ShuffledHashJoinExec 通过shuffle之后在内存中保存join构建表来实现join操作;其生成的条件是:可以构建左表或者右表,其次表的大小小于分区数和配置的广播参数的乘积(保证可以加载到本地内存进行计算),并且打开了优先考虑基于hash join的开关、其次需要保证构建表足够小(构建表*3小于流表);其主要思想就是对流表进行迭代,之后和内存中的构建表数据匹配生成join之后的行数据。
3、SortMergeJoinExec 通过shuffle操作之后进行排序,再然后进行基于排序的join操作;如果上述俩个都不满足的情况就会进行就排序的join(前提是可以排序);排序的join就是先对数据进行shuffle分区,保证相同的key分到相同的分区,之后进行排序操作,保证数据有序,之后进行merge join操作,同时读取流表和构建表;因为数据有序,所以只要顺序遍历流表和构建表;匹配生成join行数据就行
4、BroadcastNestedLoopJoinExec 主要针对的是没有join条件的连接操作;暂时不做研究;
private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { val buildDataSize = longMetric("buildDataSize") val buildTime = longMetric("buildTime") val start = System.nanoTime() val context = TaskContext.get() val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager()) buildTime += (System.nanoTime() - start) / 1000000 buildDataSize += relation.estimatedSize // This relation is usually used until the end of task. context.addTaskCompletionListener(_ => relation.close()) relation } protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val avgHashProbe = longMetric("avgHashProbe") streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => val hashed = buildHashedRelation(buildIter) join(streamIter, hashed, numOutputRows, avgHashProbe) } }
private[execution] object HashedRelation { /** * Create a HashedRelation from an Iterator of InternalRow. */ def apply( input: Iterator[InternalRow], key: Seq[Expression], sizeEstimate: Int = 64, taskMemoryManager: TaskMemoryManager = null): HashedRelation = { val mm = Option(taskMemoryManager).getOrElse { new TaskMemoryManager( new StaticMemoryManager( new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), Long.MaxValue, Long.MaxValue, 1), 0) } if (key.length == 1 && key.head.dataType == LongType) { LongHashedRelation(input, key, sizeEstimate, mm) } else { UnsafeHashedRelation(input, key, sizeEstimate, mm) } } }
private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], key: Seq[Expression], sizeEstimate: Int, taskMemoryManager: TaskMemoryManager): HashedRelation = { val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) val binaryMap = new BytesToBytesMap( taskMemoryManager, // Only 70% of the slots can be used before growing, more capacity help to reduce collision (sizeEstimate * 1.5 + 1).toInt, pageSizeBytes, true) // Create a mapping of buildKeys -> rows val keyGenerator = UnsafeProjection.create(key) var numFields = 0 while (input.hasNext) { val row =[UnsafeRow] numFields = row.numFields() val key = keyGenerator(row) if (!key.anyNull) { val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) val success = loc.append( key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) if (!success) { throw new SparkException("There is no enough memory to build hash map") } } } new UnsafeHashedRelation(numFields, binaryMap) }
protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val avgHashProbe = longMetric("avgHashProbe") streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => val hashed = buildHashedRelation(buildIter) join(streamIter, hashed, numOutputRows, avgHashProbe) } }
protected def join( streamedIter: Iterator[InternalRow], hashed: HashedRelation, numOutputRows: SQLMetric, avgHashProbe: SQLMetric): Iterator[InternalRow] = { val joinedIter = joinType match { case _: InnerLike => innerJoin(streamedIter, hashed) case LeftOuter | RightOuter => outerJoin(streamedIter, hashed) case LeftSemi => semiJoin(streamedIter, hashed) case LeftAnti => antiJoin(streamedIter, hashed) case j: ExistenceJoin => existenceJoin(streamedIter, hashed) case x => throw new IllegalArgumentException( s"BroadcastHashJoin should not take $x as the JoinType") } // At the end of the task, we update the avg hash probe. TaskContext.get().addTaskCompletionListener(_ => avgHashProbe.set(hashed.getAverageProbesPerLookup)) val resultProj = createResultProjection { r => numOutputRows += 1 resultProj(r) } }
private def innerJoin( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation): Iterator[InternalRow] = { val joinRow = new JoinedRow val joinKeys = streamSideKeyGenerator() streamIter.flatMap { srow => joinRow.withLeft(srow) val matches = hashedRelation.get(joinKeys(srow)) if (matches != null) { } else { Seq.empty } } }
可以看出,遍历流表,从构建表获取相同的key,如果不为空就构建joinRow,并应用join的条件进行筛选。到这里整个hash join的实现就算是完成了。对于其他类型的join可以自己跟代码阅读。
protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val spillThreshold = getSpillThreshold val inMemoryThreshold = getInMemoryThreshold left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { { cond => newPredicate(cond, left.output ++ right.output).eval _ }.getOrElse { (r: InternalRow) => true } } // An ordering that can be used to compare keys from both sides. val keyOrdering = newNaturalAscendingOrdering( val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output) joinType match { case _: InnerLike => new RowIterator { private[this] var currentLeftRow: InternalRow = _ private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _ private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null private[this] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold ) private[this] val joinRow = new JoinedRow if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow rightMatchesIterator = currentRightMatches.generateIterator() } override def advanceNext(): Boolean = { while (rightMatchesIterator != null) { if (!rightMatchesIterator.hasNext) { if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow rightMatchesIterator = currentRightMatches.generateIterator() } else { currentRightMatches = null currentLeftRow = null rightMatchesIterator = null return false } } joinRow(currentLeftRow, if (boundCondition(joinRow)) { numOutputRows += 1 return true } } false } override def getRow: InternalRow = resultProj(joinRow) }.toScala case LeftOuter => val smjScanner = new SortMergeJoinScanner( streamedKeyGenerator = createLeftKeyGenerator(), bufferedKeyGenerator = createRightKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(leftIter), bufferedIter = RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold ) val rightNullRow = new GenericInternalRow(right.output.length) new LeftOuterIterator( smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala case RightOuter => val smjScanner = new SortMergeJoinScanner( streamedKeyGenerator = createRightKeyGenerator(), bufferedKeyGenerator = createLeftKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(rightIter), bufferedIter = RowIterator.fromScala(leftIter), inMemoryThreshold, spillThreshold ) val leftNullRow = new GenericInternalRow(left.output.length) new RightOuterIterator( smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala case FullOuter => val leftNullRow = new GenericInternalRow(left.output.length) val rightNullRow = new GenericInternalRow(right.output.length) val smjScanner = new SortMergeFullOuterJoinScanner( leftKeyGenerator = createLeftKeyGenerator(), rightKeyGenerator = createRightKeyGenerator(), keyOrdering, leftIter = RowIterator.fromScala(leftIter), rightIter = RowIterator.fromScala(rightIter), boundCondition, leftNullRow, rightNullRow) new FullOuterIterator( smjScanner, resultProj, numOutputRows).toScala case LeftSemi => new RowIterator { private[this] var currentLeftRow: InternalRow = _ private[this] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold ) private[this] val joinRow = new JoinedRow override def advanceNext(): Boolean = { while (smjScanner.findNextInnerJoinRows()) { val currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow if (currentRightMatches != null && currentRightMatches.length > 0) { val rightMatchesIterator = currentRightMatches.generateIterator() while (rightMatchesIterator.hasNext) { joinRow(currentLeftRow, if (boundCondition(joinRow)) { numOutputRows += 1 return true } } } } false } override def getRow: InternalRow = currentLeftRow }.toScala case LeftAnti => new RowIterator { private[this] var currentLeftRow: InternalRow = _ private[this] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold ) private[this] val joinRow = new JoinedRow override def advanceNext(): Boolean = { while (smjScanner.findNextOuterJoinRows()) { currentLeftRow = smjScanner.getStreamedRow val currentRightMatches = smjScanner.getBufferedMatches if (currentRightMatches == null || currentRightMatches.length == 0) { numOutputRows += 1 return true } var found = false val rightMatchesIterator = currentRightMatches.generateIterator() while (!found && rightMatchesIterator.hasNext) { joinRow(currentLeftRow, if (boundCondition(joinRow)) { found = true } } if (!found) { numOutputRows += 1 return true } } false } override def getRow: InternalRow = currentLeftRow }.toScala case j: ExistenceJoin => new RowIterator { private[this] var currentLeftRow: InternalRow = _ private[this] val result: InternalRow = new GenericInternalRow(Array[Any](null)) private[this] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold ) private[this] val joinRow = new JoinedRow override def advanceNext(): Boolean = { while (smjScanner.findNextOuterJoinRows()) { currentLeftRow = smjScanner.getStreamedRow val currentRightMatches = smjScanner.getBufferedMatches var found = false if (currentRightMatches != null && currentRightMatches.length > 0) { val rightMatchesIterator = currentRightMatches.generateIterator() while (!found && rightMatchesIterator.hasNext) { joinRow(currentLeftRow, if (boundCondition(joinRow)) { found = true } } } result.setBoolean(0, found) numOutputRows += 1 return true } false } override def getRow: InternalRow = resultProj(joinRow(currentLeftRow, result)) }.toScala case x => throw new IllegalArgumentException( s"SortMergeJoin should not take $x as the JoinType") } } }
final def findNextInnerJoinRows(): Boolean = { while (advancedStreamed() && streamedRowKey.anyNull) { // Advance the streamed side of the join until we find the next row whose join key contains // no nulls or we hit the end of the streamed iterator. } if (streamedRow == null) { // We have consumed the entire streamed iterator, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() false } else if (matchJoinKey != null &&, matchJoinKey) == 0) { // The new streamed row has the same join key as the previous row, so return the same matches. true } else if (bufferedRow == null) { // The streamed row's join key does not match the current batch of buffered rows and there are // no more rows to read from the buffered iterator, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() false } else { // Advance both the streamed and buffered iterators to find the next pair of matching rows. var comp =, bufferedRowKey) do { if (streamedRowKey.anyNull) { advancedStreamed() } else { assert(!bufferedRowKey.anyNull) comp =, bufferedRowKey) if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey() else if (comp < 0) advancedStreamed() } } while (streamedRow != null && bufferedRow != null && comp != 0) if (streamedRow == null || bufferedRow == null) { // We have either hit the end of one of the iterators, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() false } else { // The streamed row's join key matches the current buffered row's join, so walk through the // buffered iterator to buffer the rest of the matching rows. assert(comp == 0) bufferMatchingRows() true } } }
之后具体逻辑在do while中,首先还是校验;之后对流表和构建表数据的key进行比对,如果大于0;则重新拿构建表的数据,如果小于0,就拿流表的数据,如果不是就循环,直到俩个key相同,或者俩个表为空;之后会一直添加bufferedMatches(相当于对拥有同一个key的构建表数据进行append操作,加入bufferedMatches中);
private class FullOuterIterator( smjScanner: SortMergeFullOuterJoinScanner, resultProj: InternalRow => InternalRow, numRows: SQLMetric) extends RowIterator { private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow() override def advanceNext(): Boolean = { val r = smjScanner.advanceNext() if (r) numRows += 1 r } override def getRow: InternalRow = resultProj(joinedRow) }
总结,这里主要简单总结了下spark join的实现思想。具体的实现细节还是要深入代码去了解,比如SortMergeJoinExec中,他的溢出是基于什么的?这个其实在SortMergeJoinScanner