All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.execution.joins.HashedRelation.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.execution.joins

import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.nio.ByteOrder
import java.util.{HashMap => JavaHashMap}

import org.apache.spark.memory.{TaskMemoryManager, StaticMemoryManager}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.execution.local.LocalNode
import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.unsafe.memory.MemoryLocation
import org.apache.spark.util.{SizeEstimator, KnownSizeEstimation, Utils}
import org.apache.spark.util.collection.CompactBuffer
import org.apache.spark.{SparkConf, SparkEnv}


/**
 * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete
 * object.
 */
private[execution] sealed trait HashedRelation {
  def get(key: InternalRow): Seq[InternalRow]

  // This is a helper method to implement Externalizable, and is used by
  // GeneralHashedRelation and UniqueKeyHashedRelation
  protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = {
    out.writeInt(serialized.length) // Write the length of serialized bytes first
    out.write(serialized)
  }

  // This is a helper method to implement Externalizable, and is used by
  // GeneralHashedRelation and UniqueKeyHashedRelation
  protected def readBytes(in: ObjectInput): Array[Byte] = {
    val serializedSize = in.readInt() // Read the length of serialized bytes first
    val bytes = new Array[Byte](serializedSize)
    in.readFully(bytes)
    bytes
  }
}


/**
 * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values.
 */
private[joins] final class GeneralHashedRelation(
    private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]])
  extends HashedRelation with Externalizable {

  // Needed for serialization (it is public to make Java serialization work)
  def this() = this(null)

  override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key)

  override def writeExternal(out: ObjectOutput): Unit = {
    writeBytes(out, SparkSqlSerializer.serialize(hashTable))
  }

  override def readExternal(in: ObjectInput): Unit = {
    hashTable = SparkSqlSerializer.deserialize(readBytes(in))
  }
}


/**
 * A specialized [[HashedRelation]] that maps key into a single value. This implementation
 * assumes the key is unique.
 */
private[joins]
final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow])
  extends HashedRelation with Externalizable {

  // Needed for serialization (it is public to make Java serialization work)
  def this() = this(null)

  override def get(key: InternalRow): Seq[InternalRow] = {
    val v = hashTable.get(key)
    if (v eq null) null else CompactBuffer(v)
  }

  def getValue(key: InternalRow): InternalRow = hashTable.get(key)

  override def writeExternal(out: ObjectOutput): Unit = {
    writeBytes(out, SparkSqlSerializer.serialize(hashTable))
  }

  override def readExternal(in: ObjectInput): Unit = {
    hashTable = SparkSqlSerializer.deserialize(readBytes(in))
  }
}

// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys.


private[execution] object HashedRelation {

  def apply(localNode: LocalNode, keyGenerator: Projection): HashedRelation = {
    apply(localNode.asIterator, SQLMetrics.nullLongMetric, keyGenerator)
  }

  def apply(
      input: Iterator[InternalRow],
      numInputRows: LongSQLMetric,
      keyGenerator: Projection,
      sizeEstimate: Int = 64): HashedRelation = {

    if (keyGenerator.isInstanceOf[UnsafeProjection]) {
      return UnsafeHashedRelation(
        input, numInputRows, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
    }

    // TODO: Use Spark's HashMap implementation.
    val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]](sizeEstimate)
    var currentRow: InternalRow = null

    // Whether the join key is unique. If the key is unique, we can convert the underlying
    // hash map into one specialized for this.
    var keyIsUnique = true

    // Create a mapping of buildKeys -> rows
    while (input.hasNext) {
      currentRow = input.next()
      numInputRows += 1
      val rowKey = keyGenerator(currentRow)
      if (!rowKey.anyNull) {
        val existingMatchList = hashTable.get(rowKey)
        val matchList = if (existingMatchList == null) {
          val newMatchList = new CompactBuffer[InternalRow]()
          hashTable.put(rowKey.copy(), newMatchList)
          newMatchList
        } else {
          keyIsUnique = false
          existingMatchList
        }
        matchList += currentRow.copy()
      }
    }

    if (keyIsUnique) {
      val uniqHashTable = new JavaHashMap[InternalRow, InternalRow](hashTable.size)
      val iter = hashTable.entrySet().iterator()
      while (iter.hasNext) {
        val entry = iter.next()
        uniqHashTable.put(entry.getKey, entry.getValue()(0))
      }
      new UniqueKeyHashedRelation(uniqHashTable)
    } else {
      new GeneralHashedRelation(hashTable)
    }
  }
}

/**
 * A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key
 * into a sequence of values.
 *
 * When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use
 * BytesToBytesMap for better memory performance (multiple values for the same are stored as a
 * continuous byte array.
 *
 * It's serialized in the following format:
 *  [number of keys]
 *  [size of key] [size of all values in bytes] [key bytes] [bytes for all values]
 *  ...
 *
 * All the values are serialized as following:
 *   [number of fields] [number of bytes] [underlying bytes of UnsafeRow]
 *   ...
 */
private[joins] final class UnsafeHashedRelation(
    private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
  extends HashedRelation
  with KnownSizeEstimation
  with Externalizable {

  private[joins] def this() = this(null)  // Needed for serialization

  // Use BytesToBytesMap in executor for better performance (it's created when deserialization)
  // This is used in broadcast joins and distributed mode only
  @transient private[this] var binaryMap: BytesToBytesMap = _

  /**
   * Return the size of the unsafe map on the executors.
   *
   * For broadcast joins, this hashed relation is bigger on the driver because it is
   * represented as a Java hash map there. While serializing the map to the executors,
   * however, we rehash the contents in a binary map to reduce the memory footprint on
   * the executors.
   *
   * For non-broadcast joins or in local mode, return 0.
   */
  def getUnsafeSize: Long = {
    if (binaryMap != null) {
      binaryMap.getTotalMemoryConsumption
    } else {
      0
    }
  }

  override def estimatedSize: Long = {
    if (binaryMap != null) {
      binaryMap.getTotalMemoryConsumption
    } else {
      SizeEstimator.estimate(hashTable)
    }
  }

  override def get(key: InternalRow): Seq[InternalRow] = {
    val unsafeKey = key.asInstanceOf[UnsafeRow]

    if (binaryMap != null) {
      // Used in Broadcast join
      val map = binaryMap  // avoid the compiler error
      val loc = new map.Location  // this could be allocated in stack
      binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
        unsafeKey.getSizeInBytes, loc)
      if (loc.isDefined) {
        val buffer = CompactBuffer[UnsafeRow]()

        val base = loc.getValueAddress.getBaseObject
        var offset = loc.getValueAddress.getBaseOffset
        val last = loc.getValueAddress.getBaseOffset + loc.getValueLength
        while (offset < last) {
          val numFields = Platform.getInt(base, offset)
          val sizeInBytes = Platform.getInt(base, offset + 4)
          offset += 8

          val row = new UnsafeRow
          row.pointTo(base, offset, numFields, sizeInBytes)
          buffer += row
          offset += sizeInBytes
        }
        buffer
      } else {
        null
      }

    } else {
      // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin)
      hashTable.get(unsafeKey)
    }
  }

  override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
    if (binaryMap != null) {
      // This could happen when a cached broadcast object need to be dumped into disk to free memory
      out.writeInt(binaryMap.numElements())

      var buffer = new Array[Byte](64)
      def write(addr: MemoryLocation, length: Int): Unit = {
        if (buffer.length < length) {
          buffer = new Array[Byte](length)
        }
        Platform.copyMemory(addr.getBaseObject, addr.getBaseOffset,
          buffer, Platform.BYTE_ARRAY_OFFSET, length)
        out.write(buffer, 0, length)
      }

      val iter = binaryMap.iterator()
      while (iter.hasNext) {
        val loc = iter.next()
        // [key size] [values size] [key bytes] [values bytes]
        out.writeInt(loc.getKeyLength)
        out.writeInt(loc.getValueLength)
        write(loc.getKeyAddress, loc.getKeyLength)
        write(loc.getValueAddress, loc.getValueLength)
      }

    } else {
      assert(hashTable != null)
      out.writeInt(hashTable.size())

      val iter = hashTable.entrySet().iterator()
      while (iter.hasNext) {
        val entry = iter.next()
        val key = entry.getKey
        val values = entry.getValue

        // write all the values as single byte array
        var totalSize = 0L
        var i = 0
        while (i < values.length) {
          totalSize += values(i).getSizeInBytes + 4 + 4
          i += 1
        }
        assert(totalSize < Integer.MAX_VALUE, "values are too big")

        // [key size] [values size] [key bytes] [values bytes]
        out.writeInt(key.getSizeInBytes)
        out.writeInt(totalSize.toInt)
        out.write(key.getBytes)
        i = 0
        while (i < values.length) {
          // [num of fields] [num of bytes] [row bytes]
          // write the integer in native order, so they can be read by UNSAFE.getInt()
          if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
            out.writeInt(values(i).numFields())
            out.writeInt(values(i).getSizeInBytes)
          } else {
            out.writeInt(Integer.reverseBytes(values(i).numFields()))
            out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
          }
          out.write(values(i).getBytes)
          i += 1
        }
      }
    }
  }

  override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
    val nKeys = in.readInt()
    // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
    // TODO(josh): This needs to be revisited before we merge this patch; making this change now
    // so that tests compile:
    val taskMemoryManager = new TaskMemoryManager(
      new StaticMemoryManager(
        new SparkConf().set("spark.memory.offHeap.enabled", "false"),
        Long.MaxValue,
        Long.MaxValue,
        1),
      0)

    val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes)
      .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m"))

    // TODO(josh): We won't need this dummy memory manager after future refactorings; revisit
    // during code review

    binaryMap = new BytesToBytesMap(
      taskMemoryManager,
      (nKeys * 1.5 + 1).toInt, // reduce hash collision
      pageSizeBytes)

    var i = 0
    var keyBuffer = new Array[Byte](1024)
    var valuesBuffer = new Array[Byte](1024)
    while (i < nKeys) {
      val keySize = in.readInt()
      val valuesSize = in.readInt()
      if (keySize > keyBuffer.length) {
        keyBuffer = new Array[Byte](keySize)
      }
      in.readFully(keyBuffer, 0, keySize)
      if (valuesSize > valuesBuffer.length) {
        valuesBuffer = new Array[Byte](valuesSize)
      }
      in.readFully(valuesBuffer, 0, valuesSize)

      // put it into binary map
      val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize)
      assert(!loc.isDefined, "Duplicated key found!")
      val putSuceeded = loc.putNewKey(
        keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize,
        valuesBuffer, Platform.BYTE_ARRAY_OFFSET, valuesSize)
      if (!putSuceeded) {
        throw new IOException("Could not allocate memory to grow BytesToBytesMap")
      }
      i += 1
    }
  }
}

private[joins] object UnsafeHashedRelation {

  def apply(
      input: Iterator[InternalRow],
      numInputRows: LongSQLMetric,
      keyGenerator: UnsafeProjection,
      sizeEstimate: Int): HashedRelation = {

    // Use a Java hash table here because unsafe maps expect fixed size records
    val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)

    // Create a mapping of buildKeys -> rows
    while (input.hasNext) {
      val unsafeRow = input.next().asInstanceOf[UnsafeRow]
      numInputRows += 1
      val rowKey = keyGenerator(unsafeRow)
      if (!rowKey.anyNull) {
        val existingMatchList = hashTable.get(rowKey)
        val matchList = if (existingMatchList == null) {
          val newMatchList = new CompactBuffer[UnsafeRow]()
          hashTable.put(rowKey.copy(), newMatchList)
          newMatchList
        } else {
          existingMatchList
        }
        matchList += unsafeRow.copy()
      }
    }

    new UnsafeHashedRelation(hashTable)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy