org.apache.spark.sql.execution.joins.ShuffledHashJoinTopKExec.scala Maven / Gradle / Ivy
The newest version!
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
package org.apache.spark.sql.execution.joins
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.utils.InternalRowPriorityQueue
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric._
import org.apache.spark.sql.types._
abstract class PriorityQueueShim {
def insert(score: Any, row: InternalRow): Unit
def get(): Iterator[InternalRow]
def clear(): Unit
case class ShuffledHashJoinTopKExec(
k: Int,
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan)(
scoreExpr: NamedExpression,
rankAttr: Seq[Attribute])
extends BinaryExecNode with TopKHelper with HashJoin with CodegenSupport {
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override val scoreType: DataType = scoreExpr.dataType
override val joinType: JoinType = Inner
override val buildSide: BuildSide = BuildRight // Only support `BuildRight`
private lazy val scoreProjection: UnsafeProjection =
UnsafeProjection.create(scoreExpr :: Nil, left.output ++ right.output)
private lazy val boundCondition = if (condition.isDefined) {
(r: InternalRow) => newPredicate(condition.get, streamedPlan.output ++ buildPlan.output).eval(r)
} else {
(r: InternalRow) => true
private lazy val topKAttr = rankAttr :+ scoreExpr.toAttribute
private lazy val _priorityQueue = new PriorityQueueShim {
private val q: InternalRowPriorityQueue = queue
private val joinedRow = new JoinedRow
override def insert(score: Any, row: InternalRow): Unit = {
q += Tuple2(score, row)
override def get(): Iterator[InternalRow] = {
val outputRows = queue.iterator.toSeq.reverse
val (headScore, _) = outputRows.head
val rankNum = outputRows.scanLeft((1, headScore)) { case ((rank, prevScore), (score, _)) =>
if (prevScore == score) (rank, score) else (rank + 1, score)
val topKRow = new UnsafeRow(2)
val bufferHolder = new BufferHolder(topKRow)
val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 2)
val scoreWriter = ScoreWriter(unsafeRowWriter, 1) { case ((score, row), index) =>
// Writes to an UnsafeRow directly
unsafeRowWriter.write(0, index)
joinedRow.apply(topKRow, row)
override def clear(): Unit = q.clear()
override def output: Seq[Attribute] = joinType match {
case Inner => topKAttr ++ left.output ++ right.output
override protected final def otherCopyArgs: Seq[AnyRef] = {
scoreExpr :: rankAttr :: Nil
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
val context = TaskContext.get()
val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
context.addTaskCompletionListener(_ => relation.close())
override protected def createResultProjection(): (InternalRow) => InternalRow = joinType match {
case Inner =>
// Always put the stream side on left to simplify implementation
// both of left and right side could be null
output, (topKAttr ++ streamedPlan.output ++ buildPlan.output).map(_.withNullability(true)))
protected def InnerJoin(
streamedIter: Iterator[InternalRow],
hashedRelation: HashedRelation,
numOutputRows: SQLMetric): Iterator[InternalRow] = {
val joinRow = new JoinedRow
val joinKeysProj = streamSideKeyGenerator()
val joinedIter = streamedIter.flatMap { srow =>
val joinKeys = joinKeysProj(srow) // `joinKeys` is also a grouping key
val matches = hashedRelation.get(joinKeys)
if (matches != null) { { resultRow =>
_priorityQueue.insert(scoreProjection(resultRow).get(0, scoreType), resultRow)
val iter = _priorityQueue.get()
} else {
val resultProj = createResultProjection()
(joinedIter ++ queue.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering)
.map(_._2)).map { r =>
override protected def doExecute(): RDD[InternalRow] = {
streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
val hashed = buildHashedRelation(buildIter)
InnerJoin(streamIter, hashed, null)
override def inputRDDs(): Seq[RDD[InternalRow]] = {
left.execute() :: right.execute() :: Nil
// Accessor for generated code
def priorityQueue(): PriorityQueueShim = _priorityQueue
* Add a state of HashedRelation and return the variable name for it.
private def prepareHashedRelation(ctx: CodegenContext): String = {
// create a name for HashedRelation
val joinExec = ctx.addReferenceObj("joinExec", this)
val relationTerm = ctx.freshName("relation")
val clsName = HashedRelation.getClass.getName.replace("$", "")
ctx.addMutableState(clsName, relationTerm,
| $relationTerm = ($clsName) $joinExec.buildHashedRelation(inputs[1]);
| incPeakExecutionMemory($relationTerm.estimatedSize());
* Creates variables for left part of result row.
* In order to defer the access after condition and also only access once in the loop,
* the variables should be declared separately from accessing the columns, we can't use the
* codegen of BoundReference here.
private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = {
ctx.INPUT_ROW = leftRow { case (a, i) =>
val value = ctx.freshName("value")
val valueCode = ctx.getValue(leftRow, a.dataType, i.toString)
// declare it as class member, so we can access the column before or in the loop.
ctx.addMutableState(ctx.javaType(a.dataType), value, "")
if (a.nullable) {
val isNull = ctx.freshName("isNull")
ctx.addMutableState("boolean", isNull, "")
val code =
|$isNull = $leftRow.isNullAt($i);
|$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode);
ExprCode(code, isNull, value)
} else {
ExprCode(s"$value = $valueCode;", "false", value)
* Creates the variables for right part of result row, using BoundReference, since the right
* part are accessed inside the loop.
private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = {
ctx.INPUT_ROW = rightRow { case (a, i) =>
BoundReference(i, a.dataType, a.nullable).genCode(ctx)
* Returns the code for generating join key for stream side, and expression of whether the key
* has any null in it or not.
private def genStreamSideJoinKey(ctx: CodegenContext, leftRow: String): (ExprCode, String) = {
ctx.INPUT_ROW = leftRow
if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) {
// generate the join key as Long
val ev = streamedKeys.head.genCode(ctx)
(ev, ev.isNull)
} else {
// generate the join key as UnsafeRow
val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys)
(ev, s"${ev.value}.anyNull()")
private def createScoreVar(ctx: CodegenContext, row: String): ExprCode = {
ctx.INPUT_ROW = row
BindReferences.bindReference(scoreExpr, left.output ++ right.output).genCode(ctx)
private def createResultVars(ctx: CodegenContext, resultRow: String): Seq[ExprCode] = {
ctx.INPUT_ROW = resultRow { case (a, i) =>
val value = ctx.freshName("value")
val valueCode = ctx.getValue(resultRow, a.dataType, i.toString)
// declare it as class member, so we can access the column before or in the loop.
ctx.addMutableState(ctx.javaType(a.dataType), value, "")
if (a.nullable) {
val isNull = ctx.freshName("isNull")
ctx.addMutableState("boolean", isNull, "")
val code =
|$isNull = $resultRow.isNullAt($i);
|$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode);
ExprCode(code, isNull, value)
} else {
ExprCode(s"$value = $valueCode;", "false", value)
* Splits variables based on whether it's used by condition or not, returns the code to create
* these variables before the condition and after the condition.
* Only a few columns are used by condition, then we can skip the accessing of those columns
* that are not used by condition also filtered out by condition.
private def splitVarsByCondition(
attributes: Seq[Attribute],
variables: Seq[ExprCode]): (String, String) = {
if (condition.isDefined) {
val condRefs = condition.get.references
val (used, notUsed) ={ case (a, ev) =>
val beforeCond = evaluateVariables(
val afterCond = evaluateVariables(
(beforeCond, afterCond)
} else {
(evaluateVariables(variables), "")
override def doProduce(ctx: CodegenContext): String = {
ctx.copyResult = true
val topKJoin = ctx.addReferenceObj("topKJoin", this)
// Prepare a priority queue for top-K computing
val pQueue = ctx.freshName("queue")
ctx.addMutableState(classOf[PriorityQueueShim].getName, pQueue,
s"$pQueue = $topKJoin.priorityQueue();")
// Prepare variables for a left side
val leftIter = ctx.freshName("leftIter")
ctx.addMutableState("scala.collection.Iterator", leftIter, s"$leftIter = inputs[0];")
val leftRow = ctx.freshName("leftRow")
ctx.addMutableState("InternalRow", leftRow, "")
val leftVars = createLeftVars(ctx, leftRow)
// Prepare variables for a right side
val rightRow = ctx.freshName("rightRow")
val rightVars = createRightVar(ctx, rightRow)
// Build a hashed relation from a right side
val buildRelation = prepareHashedRelation(ctx)
// Project join keys from a left side
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, leftRow)
// Prepare variables for joined rows
val joinedRow = ctx.freshName("joinedRow")
val joinedRowCls = classOf[JoinedRow].getName
ctx.addMutableState(joinedRowCls, joinedRow, s"$joinedRow = new $joinedRowCls();")
// Project score values from joined rows
val scoreVar = createScoreVar(ctx, joinedRow)
// Prepare variables for output rows
val resultRow = ctx.freshName("resultRow")
val resultVars = createResultVars(ctx, resultRow)
val (beforeLoop, condCheck) = if (condition.isDefined) {
// Split the code of creating variables based on whether it's used by condition or not.
val loaded = ctx.freshName("loaded")
val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
// Generate code for condition
ctx.currentVars = leftVars ++ rightVars
val cond = BindReferences.bindReference(condition.get, output).genCode(ctx)
// evaluate the columns those used by condition before loop
val before = s"""
|boolean $loaded = false;
val checking = s"""
|if (${cond.isNull} || !${cond.value}) continue;
|if (!$loaded) {
| $loaded = true;
| $leftAfter
(before, checking)
} else {
(evaluateVariables(leftVars), "")
val numOutput = metricTerm(ctx, "numOutputRows")
val matches = ctx.freshName("matches")
val topKRows = ctx.freshName("topKRows")
val iteratorCls = classOf[Iterator[UnsafeRow]].getName
|$leftRow = null;
|while ($leftIter.hasNext()) {
| $leftRow = (InternalRow) $;
| // Generate join key for stream side
| ${keyEv.code}
| // Find matches from HashedRelation
| $iteratorCls $matches = $anyNull? null : ($iteratorCls)$buildRelation.get(${keyEv.value});
| if ($matches == null) continue;
| // Join top-K right rows
| while ($matches.hasNext()) {
| ${beforeLoop.trim}
| InternalRow $rightRow = (InternalRow) $;
| ${condCheck.trim}
| InternalRow row = $joinedRow.apply($leftRow, $rightRow);
| // Compute a score for the `row`
| ${scoreVar.code}
| $pQueue.insert(${scoreVar.value}, row);
| }
| // Get top-K rows
| $iteratorCls $topKRows = $pQueue.get();
| $pQueue.clear();
| // Output top-K rows
| while ($topKRows.hasNext()) {
| InternalRow $resultRow = (InternalRow) $;
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
| }
| if (shouldStop()) return;
© 2015 - 2025 Weber Informatics LLC | Privacy Policy