All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.execution.joins.HashOuterJoin.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.LongSQLMetric
import org.apache.spark.util.collection.CompactBuffer


trait HashOuterJoin {
  self: SparkPlan =>

  val leftKeys: Seq[Expression]
  val rightKeys: Seq[Expression]
  val joinType: JoinType
  val condition: Option[Expression]
  val left: SparkPlan
  val right: SparkPlan

  override def output: Seq[Attribute] = {
    joinType match {
      case LeftOuter =>
        left.output ++ right.output.map(_.withNullability(true))
      case RightOuter =>
        left.output.map(_.withNullability(true)) ++ right.output
      case FullOuter =>
        left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
      case x =>
        throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
    }
  }

  protected[this] lazy val (buildPlan, streamedPlan) = joinType match {
    case RightOuter => (left, right)
    case LeftOuter => (right, left)
    case x =>
      throw new IllegalArgumentException(
        s"HashOuterJoin should not take $x as the JoinType")
  }

  protected[this] lazy val (buildKeys, streamedKeys) = joinType match {
    case RightOuter => (leftKeys, rightKeys)
    case LeftOuter => (rightKeys, leftKeys)
    case x =>
      throw new IllegalArgumentException(
        s"HashOuterJoin should not take $x as the JoinType")
  }

  override def outputsUnsafeRows: Boolean = true
  override def canProcessUnsafeRows: Boolean = true
  override def canProcessSafeRows: Boolean = false

  protected def buildKeyGenerator: Projection =
    UnsafeProjection.create(buildKeys, buildPlan.output)

  protected[this] def streamedKeyGenerator: Projection =
    UnsafeProjection.create(streamedKeys, streamedPlan.output)

  protected[this] def resultProjection: InternalRow => InternalRow =
    UnsafeProjection.create(self.schema)

  @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null)
  @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()

  @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
  @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
  @transient private[this] lazy val boundCondition =
    newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)

  // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
  // iterator for performance purpose.

  protected[this] def leftOuterIterator(
      key: InternalRow,
      joinedRow: JoinedRow,
      rightIter: Iterable[InternalRow],
      resultProjection: InternalRow => InternalRow,
      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
    val ret: Iterable[InternalRow] = {
      if (!key.anyNull) {
        val temp = if (rightIter != null) {
          rightIter.collect {
            case r if boundCondition(joinedRow.withRight(r)) => {
              numOutputRows += 1
              resultProjection(joinedRow).copy()
            }
          }
        } else {
          List.empty
        }
        if (temp.isEmpty) {
          numOutputRows += 1
          resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
        } else {
          temp
        }
      } else {
        numOutputRows += 1
        resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
      }
    }
    ret.iterator
  }

  protected[this] def rightOuterIterator(
      key: InternalRow,
      leftIter: Iterable[InternalRow],
      joinedRow: JoinedRow,
      resultProjection: InternalRow => InternalRow,
      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
    val ret: Iterable[InternalRow] = {
      if (!key.anyNull) {
        val temp = if (leftIter != null) {
          leftIter.collect {
            case l if boundCondition(joinedRow.withLeft(l)) => {
              numOutputRows += 1
              resultProjection(joinedRow).copy()
            }
          }
        } else {
          List.empty
        }
        if (temp.isEmpty) {
          numOutputRows += 1
          resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
        } else {
          temp
        }
      } else {
        numOutputRows += 1
        resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
      }
    }
    ret.iterator
  }

  protected[this] def fullOuterIterator(
      key: InternalRow,
      leftIter: Iterable[InternalRow],
      rightIter: Iterable[InternalRow],
      joinedRow: JoinedRow,
      resultProjection: InternalRow => InternalRow,
      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
    if (!key.anyNull) {
      // Store the positions of records in right, if one of its associated row satisfy
      // the join condition.
      val rightMatchedSet = scala.collection.mutable.Set[Int]()
      leftIter.iterator.flatMap[InternalRow] { l =>
        joinedRow.withLeft(l)
        var matched = false
        rightIter.zipWithIndex.collect {
          // 1. For those matched (satisfy the join condition) records with both sides filled,
          //    append them directly

          case (r, idx) if boundCondition(joinedRow.withRight(r)) =>
            numOutputRows += 1
            matched = true
            // if the row satisfy the join condition, add its index into the matched set
            rightMatchedSet.add(idx)
            resultProjection(joinedRow)

        } ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
          // 2. For those unmatched records in left, append additional records with empty right.

          // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
          // as we don't know whether we need to append it until finish iterating all
          // of the records in right side.
          // If we didn't get any proper row, then append a single row with empty right.
          numOutputRows += 1
          resultProjection(joinedRow.withRight(rightNullRow))
        })
      } ++ rightIter.zipWithIndex.collect {
        // 3. For those unmatched records in right, append additional records with empty left.

        // Re-visiting the records in right, and append additional row with empty left, if its not
        // in the matched set.
        case (r, idx) if !rightMatchedSet.contains(idx) =>
          numOutputRows += 1
          resultProjection(joinedRow(leftNullRow, r))
      }
    } else {
      leftIter.iterator.map[InternalRow] { l =>
        numOutputRows += 1
        resultProjection(joinedRow(l, rightNullRow))
      } ++ rightIter.iterator.map[InternalRow] { r =>
        numOutputRows += 1
        resultProjection(joinedRow(leftNullRow, r))
      }
    }
  }

  // This is only used by FullOuter
  protected[this] def buildHashTable(
      iter: Iterator[InternalRow],
      numIterRows: LongSQLMetric,
      keyGenerator: Projection): java.util.HashMap[InternalRow, CompactBuffer[InternalRow]] = {
    val hashTable = new java.util.HashMap[InternalRow, CompactBuffer[InternalRow]]()
    while (iter.hasNext) {
      val currentRow = iter.next()
      numIterRows += 1
      val rowKey = keyGenerator(currentRow)

      var existingMatchList = hashTable.get(rowKey)
      if (existingMatchList == null) {
        existingMatchList = new CompactBuffer[InternalRow]()
        hashTable.put(rowKey.copy(), existingMatchList)
      }

      existingMatchList += currentRow.copy()
    }

    hashTable
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy