All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.execution.local.HashJoinNode.scala Maven / Gradle / Ivy

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*    http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.joins._

/**
 * An abstract node for sharing common functionality among different implementations of
 * inner hash equi-join, notably [[BinaryHashJoinNode]] and [[BroadcastHashJoinNode]].
 *
 * Much of this code is similar to [[org.apache.spark.sql.execution.joins.HashJoin]].
 */
trait HashJoinNode {

  self: LocalNode =>

  protected def streamedKeys: Seq[Expression]
  protected def streamedNode: LocalNode
  protected def buildSide: BuildSide

  private[this] var currentStreamedRow: InternalRow = _
  private[this] var currentHashMatches: Seq[InternalRow] = _
  private[this] var currentMatchPosition: Int = -1

  private[this] var joinRow: JoinedRow = _
  private[this] var resultProjection: (InternalRow) => InternalRow = _

  private[this] var hashed: HashedRelation = _
  private[this] var joinKeys: Projection = _

  private def streamSideKeyGenerator: Projection =
    UnsafeProjection.create(streamedKeys, streamedNode.output)

  /**
   * Sets the HashedRelation used by this node. This method needs to be called after
   * before the first `next` gets called.
   */
  protected def withHashedRelation(hashedRelation: HashedRelation): Unit = {
    hashed = hashedRelation
  }

  /**
   * Custom open implementation to be overridden by subclasses.
   */
  protected def doOpen(): Unit

  override def open(): Unit = {
    doOpen()
    joinRow = new JoinedRow
    resultProjection = UnsafeProjection.create(schema)
    joinKeys = streamSideKeyGenerator
  }

  override def next(): Boolean = {
    currentMatchPosition += 1
    if (currentHashMatches == null || currentMatchPosition >= currentHashMatches.size) {
      fetchNextMatch()
    } else {
      true
    }
  }

  /**
   * Populate `currentHashMatches` with build-side rows matching the next streamed row.
   * @return whether matches are found such that subsequent calls to `fetch` are valid.
   */
  private def fetchNextMatch(): Boolean = {
    currentHashMatches = null
    currentMatchPosition = -1

    while (currentHashMatches == null && streamedNode.next()) {
      currentStreamedRow = streamedNode.fetch()
      val key = joinKeys(currentStreamedRow)
      if (!key.anyNull) {
        currentHashMatches = hashed.get(key)
      }
    }

    if (currentHashMatches == null) {
      false
    } else {
      currentMatchPosition = 0
      true
    }
  }

  override def fetch(): InternalRow = {
    val ret = buildSide match {
      case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
      case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
    }
    resultProjection(ret)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy