All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.catalyst.expressions.InterpretedUnsafeProjection.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.{UserDefinedType, _}
import org.apache.spark.unsafe.Platform

/**
 * An interpreted unsafe projection. This class reuses the [[UnsafeRow]] it produces, a consumer
 * should copy the row if it is being buffered. This class is not thread safe.
 *
 * @param expressions that produces the resulting fields. These expressions must be bound
 *                    to a schema.
 */
class InterpretedUnsafeProjection(expressions: Array[Expression]) extends UnsafeProjection {
  import InterpretedUnsafeProjection._

  /** Number of (top level) fields in the resulting row. */
  private[this] val numFields = expressions.length

  /** Array that expression results. */
  private[this] val values = new Array[Any](numFields)

  /** The row representing the expression results. */
  private[this] val intermediate = new GenericInternalRow(values)

  /* The row writer for UnsafeRow result */
  private[this] val rowWriter = new UnsafeRowWriter(numFields, numFields * 32)

  /** The writer that writes the intermediate result to the result row. */
  private[this] val writer: InternalRow => Unit = {
    val baseWriter = generateStructWriter(
      rowWriter,
      expressions.map(e => StructField("", e.dataType, e.nullable)))
    if (!expressions.exists(_.nullable)) {
      // No nullable fields. The top-level null bit mask will always be zeroed out.
      baseWriter
    } else {
      // Zero out the null bit mask before we write the row.
      row => {
        rowWriter.zeroOutNullBytes()
        baseWriter(row)
      }
    }
  }

  override def initialize(partitionIndex: Int): Unit = {
    expressions.foreach(_.foreach {
      case n: Nondeterministic => n.initialize(partitionIndex)
      case _ =>
    })
  }

  override def apply(row: InternalRow): UnsafeRow = {
    // Put the expression results in the intermediate row.
    var i = 0
    while (i < numFields) {
      values(i) = expressions(i).eval(row)
      i += 1
    }

    // Write the intermediate row to an unsafe row.
    rowWriter.reset()
    writer(intermediate)
    rowWriter.getRow()
  }
}

/**
 * Helper functions for creating an [[InterpretedUnsafeProjection]].
 */
object InterpretedUnsafeProjection {
  /**
   * Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
   */
  def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
    // We need to make sure that we do not reuse stateful expressions.
    val cleanedExpressions = exprs.map(_.transform {
      case s: Stateful => s.freshCopy()
    })
    new InterpretedUnsafeProjection(cleanedExpressions.toArray)
  }

  /**
   * Generate a struct writer function. The generated function writes an [[InternalRow]] to the
   * given buffer using the given [[UnsafeRowWriter]].
   */
  private def generateStructWriter(
      rowWriter: UnsafeRowWriter,
      fields: Array[StructField]): InternalRow => Unit = {
    val numFields = fields.length

    // Create field writers.
    val fieldWriters = fields.map { field =>
      generateFieldWriter(rowWriter, field.dataType, field.nullable)
    }
    // Create basic writer.
    row => {
      var i = 0
      while (i < numFields) {
        fieldWriters(i).apply(row, i)
        i += 1
      }
    }
  }

  /**
   * Generate a writer function for a struct field, array element, map key or map value. The
   * generated function writes the element at an index in a [[SpecializedGetters]] object (row
   * or array) to the given buffer using the given [[UnsafeWriter]].
   */
  private def generateFieldWriter(
      writer: UnsafeWriter,
      dt: DataType,
      nullable: Boolean): (SpecializedGetters, Int) => Unit = {

    // Create the the basic writer.
    val unsafeWriter: (SpecializedGetters, Int) => Unit = dt match {
      case BooleanType =>
        (v, i) => writer.write(i, v.getBoolean(i))

      case ByteType =>
        (v, i) => writer.write(i, v.getByte(i))

      case ShortType =>
        (v, i) => writer.write(i, v.getShort(i))

      case IntegerType | DateType =>
        (v, i) => writer.write(i, v.getInt(i))

      case LongType | TimestampType =>
        (v, i) => writer.write(i, v.getLong(i))

      case FloatType =>
        (v, i) => writer.write(i, v.getFloat(i))

      case DoubleType =>
        (v, i) => writer.write(i, v.getDouble(i))

      case DecimalType.Fixed(precision, scale) =>
        (v, i) => writer.write(i, v.getDecimal(i, precision, scale), precision, scale)

      case CalendarIntervalType =>
        (v, i) => writer.write(i, v.getInterval(i))

      case BinaryType =>
        (v, i) => writer.write(i, v.getBinary(i))

      case StringType =>
        (v, i) => writer.write(i, v.getUTF8String(i))

      case StructType(fields) =>
        val numFields = fields.length
        val rowWriter = new UnsafeRowWriter(writer, numFields)
        val structWriter = generateStructWriter(rowWriter, fields)
        (v, i) => {
          v.getStruct(i, fields.length) match {
            case row: UnsafeRow =>
              writer.write(i, row)
            case row =>
              val previousCursor = writer.cursor()
              // Nested struct. We don't know where this will start because a row can be
              // variable length, so we need to update the offsets and zero out the bit mask.
              rowWriter.resetRowWriter()
              structWriter.apply(row)
              writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
          }
        }

      case ArrayType(elementType, containsNull) =>
        val arrayWriter = new UnsafeArrayWriter(writer, getElementSize(elementType))
        val elementWriter = generateFieldWriter(
          arrayWriter,
          elementType,
          containsNull)
        (v, i) => {
          val previousCursor = writer.cursor()
          writeArray(arrayWriter, elementWriter, v.getArray(i))
          writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
        }

      case MapType(keyType, valueType, valueContainsNull) =>
        val keyArrayWriter = new UnsafeArrayWriter(writer, getElementSize(keyType))
        val keyWriter = generateFieldWriter(
          keyArrayWriter,
          keyType,
          nullable = false)
        val valueArrayWriter = new UnsafeArrayWriter(writer, getElementSize(valueType))
        val valueWriter = generateFieldWriter(
          valueArrayWriter,
          valueType,
          valueContainsNull)
        (v, i) => {
          v.getMap(i) match {
            case map: UnsafeMapData =>
              writer.write(i, map)
            case map =>
              val previousCursor = writer.cursor()

              // preserve 8 bytes to write the key array numBytes later.
              valueArrayWriter.grow(8)
              valueArrayWriter.increaseCursor(8)

              // Write the keys and write the numBytes of key array into the first 8 bytes.
              writeArray(keyArrayWriter, keyWriter, map.keyArray())
              Platform.putLong(
                valueArrayWriter.getBuffer,
                previousCursor,
                valueArrayWriter.cursor - previousCursor - 8
              )

              // Write the values.
              writeArray(valueArrayWriter, valueWriter, map.valueArray())
              writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
          }
        }

      case udt: UserDefinedType[_] =>
        generateFieldWriter(writer, udt.sqlType, nullable)

      case NullType =>
        (_, _) => {}

      case _ =>
        throw new SparkException(s"Unsupported data type $dt")
    }

    // Always wrap the writer with a null safe version.
    dt match {
      case _: UserDefinedType[_] =>
        // The null wrapper depends on the sql type and not on the UDT.
        unsafeWriter
      case DecimalType.Fixed(precision, _) if precision > Decimal.MAX_LONG_DIGITS =>
        // We can't call setNullAt() for DecimalType with precision larger than 18, we call write
        // directly. We can use the unwrapped writer directly.
        unsafeWriter
      case BooleanType | ByteType =>
        (v, i) => {
          if (!v.isNullAt(i)) {
            unsafeWriter(v, i)
          } else {
            writer.setNull1Bytes(i)
          }
        }
      case ShortType =>
        (v, i) => {
          if (!v.isNullAt(i)) {
            unsafeWriter(v, i)
          } else {
            writer.setNull2Bytes(i)
          }
        }
      case IntegerType | DateType | FloatType =>
        (v, i) => {
          if (!v.isNullAt(i)) {
            unsafeWriter(v, i)
          } else {
            writer.setNull4Bytes(i)
          }
        }
      case _ =>
        (v, i) => {
          if (!v.isNullAt(i)) {
            unsafeWriter(v, i)
          } else {
            writer.setNull8Bytes(i)
          }
        }
    }
  }

  /**
   * Get the number of bytes elements of a data type will occupy in the fixed part of an
   * [[UnsafeArrayData]] object. Reference types are stored as an 8 byte combination of an
   * offset (upper 4 bytes) and a length (lower 4 bytes), these point to the variable length
   * portion of the array object. Primitives take up to 8 bytes, depending on the size of the
   * underlying data type.
   */
  private def getElementSize(dataType: DataType): Int = dataType match {
    case NullType | StringType | BinaryType | CalendarIntervalType |
         _: DecimalType | _: StructType | _: ArrayType | _: MapType => 8
    case _ => dataType.defaultSize
  }

  /**
   * Write an array to the buffer. If the array is already in serialized form (an instance of
   * [[UnsafeArrayData]]) then we copy the bytes directly, otherwise we do an element-by-element
   * copy.
   */
  private def writeArray(
      arrayWriter: UnsafeArrayWriter,
      elementWriter: (SpecializedGetters, Int) => Unit,
      array: ArrayData): Unit = array match {
    case unsafe: UnsafeArrayData =>
      arrayWriter.write(unsafe)
    case _ =>
      val numElements = array.numElements()
      arrayWriter.initialize(numElements)
      var i = 0
      while (i < numElements) {
        elementWriter.apply(array, i)
        i += 1
      }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy