All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.execution.columnar.GenerateColumnAccessor.scala Maven / Gradle / Ivy

There is a newer version: 2.2.3
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.execution.columnar

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator, UnsafeRowWriter}
import org.apache.spark.sql.types._

/**
 * An Iterator to walk through the InternalRows from a CachedBatch
 */
abstract class ColumnarIterator extends Iterator[InternalRow] {
  def initialize(input: Iterator[CachedBatch], columnTypes: Array[DataType],
    columnIndexes: Array[Int]): Unit
}

/**
 * An helper class to update the fields of UnsafeRow, used by ColumnAccessor
 *
 * WARNING: These setter MUST be called in increasing order of ordinals.
 */
class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(null) {

  override def isNullAt(i: Int): Boolean = writer.isNullAt(i)
  override def setNullAt(i: Int): Unit = writer.setNullAt(i)

  override def setBoolean(i: Int, v: Boolean): Unit = writer.write(i, v)
  override def setByte(i: Int, v: Byte): Unit = writer.write(i, v)
  override def setShort(i: Int, v: Short): Unit = writer.write(i, v)
  override def setInt(i: Int, v: Int): Unit = writer.write(i, v)
  override def setLong(i: Int, v: Long): Unit = writer.write(i, v)
  override def setFloat(i: Int, v: Float): Unit = writer.write(i, v)
  override def setDouble(i: Int, v: Double): Unit = writer.write(i, v)

  // the writer will be used directly to avoid creating wrapper objects
  override def setDecimal(i: Int, v: Decimal, precision: Int): Unit =
    throw new UnsupportedOperationException
  override def update(i: Int, v: Any): Unit = throw new UnsupportedOperationException

  // all other methods inherited from GenericMutableRow are not need
}

/**
 * Generates bytecode for an [[ColumnarIterator]] for columnar cache.
 */
object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarIterator] with Logging {

  protected def canonicalize(in: Seq[DataType]): Seq[DataType] = in
  protected def bind(in: Seq[DataType], inputSchema: Seq[Attribute]): Seq[DataType] = in

  protected def create(columnTypes: Seq[DataType]): ColumnarIterator = {
    val ctx = newCodeGenContext()
    val numFields = columnTypes.size
    val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) =>
      val accessorName = ctx.freshName("accessor")
      val accessorCls = dt match {
        case NullType => classOf[NullColumnAccessor].getName
        case BooleanType => classOf[BooleanColumnAccessor].getName
        case ByteType => classOf[ByteColumnAccessor].getName
        case ShortType => classOf[ShortColumnAccessor].getName
        case IntegerType | DateType => classOf[IntColumnAccessor].getName
        case LongType | TimestampType => classOf[LongColumnAccessor].getName
        case FloatType => classOf[FloatColumnAccessor].getName
        case DoubleType => classOf[DoubleColumnAccessor].getName
        case StringType => classOf[StringColumnAccessor].getName
        case BinaryType => classOf[BinaryColumnAccessor].getName
        case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
          classOf[CompactDecimalColumnAccessor].getName
        case dt: DecimalType => classOf[DecimalColumnAccessor].getName
        case struct: StructType => classOf[StructColumnAccessor].getName
        case array: ArrayType => classOf[ArrayColumnAccessor].getName
        case t: MapType => classOf[MapColumnAccessor].getName
      }
      ctx.addMutableState(accessorCls, accessorName, "")

      val createCode = dt match {
        case t if ctx.isPrimitiveType(dt) =>
          s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
        case NullType | StringType | BinaryType =>
          s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
        case other =>
          s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder),
             (${dt.getClass.getName}) columnTypes[$index]);"""
      }

      val extract = s"$accessorName.extractTo(mutableRow, $index);"
      val patch = dt match {
        case DecimalType.Fixed(p, s) if p > Decimal.MAX_LONG_DIGITS =>
          // For large Decimal, it should have 16 bytes for future update even it's null now.
          s"""
            if (mutableRow.isNullAt($index)) {
              rowWriter.write($index, (Decimal) null, $p, $s);
            }
           """
        case other => ""
      }
      (createCode, extract + patch)
    }.unzip

    /*
     * 200 = 6000 bytes / 30 (up to 30 bytes per one call))
     * the maximum byte code size to be compiled for HotSpot is 8000.
     * We should keep less than 8000
     */
    val numberOfStatementsThreshold = 200
    val (initializerAccessorCalls, extractorCalls) =
      if (initializeAccessors.length <= numberOfStatementsThreshold) {
        (initializeAccessors.mkString("\n"), extractors.mkString("\n"))
      } else {
        val groupedAccessorsItr = initializeAccessors.grouped(numberOfStatementsThreshold)
        val groupedExtractorsItr = extractors.grouped(numberOfStatementsThreshold)
        var groupedAccessorsLength = 0
        groupedAccessorsItr.zipWithIndex.map { case (body, i) =>
          groupedAccessorsLength += 1
          val funcName = s"accessors$i"
          val funcCode = s"""
             |private void $funcName() {
             |  ${body.mkString("\n")}
             |}
           """.stripMargin
          ctx.addNewFunction(funcName, funcCode)
        }
        groupedExtractorsItr.zipWithIndex.map { case (body, i) =>
          val funcName = s"extractors$i"
          val funcCode = s"""
             |private void $funcName() {
             |  ${body.mkString("\n")}
             |}
           """.stripMargin
          ctx.addNewFunction(funcName, funcCode)
        }
        ((0 to groupedAccessorsLength - 1).map { i => s"accessors$i();" }.mkString("\n"),
         (0 to groupedAccessorsLength - 1).map { i => s"extractors$i();" }.mkString("\n"))
      }

    val code = s"""
      import java.nio.ByteBuffer;
      import java.nio.ByteOrder;
      import scala.collection.Iterator;
      import org.apache.spark.sql.types.DataType;
      import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder;
      import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
      import org.apache.spark.sql.execution.columnar.MutableUnsafeRow;

      public SpecificColumnarIterator generate(Object[] references) {
        return new SpecificColumnarIterator();
      }

      class SpecificColumnarIterator extends ${classOf[ColumnarIterator].getName} {

        private ByteOrder nativeOrder = null;
        private byte[][] buffers = null;
        private UnsafeRow unsafeRow = new UnsafeRow($numFields);
        private BufferHolder bufferHolder = new BufferHolder(unsafeRow);
        private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields);
        private MutableUnsafeRow mutableRow = null;

        private int currentRow = 0;
        private int numRowsInBatch = 0;

        private scala.collection.Iterator input = null;
        private DataType[] columnTypes = null;
        private int[] columnIndexes = null;

        ${ctx.declareMutableStates()}

        public SpecificColumnarIterator() {
          this.nativeOrder = ByteOrder.nativeOrder();
          this.buffers = new byte[${columnTypes.length}][];
          this.mutableRow = new MutableUnsafeRow(rowWriter);
        }

        public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) {
          this.input = input;
          this.columnTypes = columnTypes;
          this.columnIndexes = columnIndexes;
        }

        ${ctx.declareAddedFunctions()}

        public boolean hasNext() {
          if (currentRow < numRowsInBatch) {
            return true;
          }
          if (!input.hasNext()) {
            return false;
          }

          ${classOf[CachedBatch].getName} batch = (${classOf[CachedBatch].getName}) input.next();
          currentRow = 0;
          numRowsInBatch = batch.numRows();
          for (int i = 0; i < columnIndexes.length; i ++) {
            buffers[i] = batch.buffers()[columnIndexes[i]];
          }
          ${initializerAccessorCalls}

          return hasNext();
        }

        public InternalRow next() {
          currentRow += 1;
          bufferHolder.reset();
          rowWriter.zeroOutNullBytes();
          ${extractorCalls}
          unsafeRow.setTotalSize(bufferHolder.totalSize());
          return unsafeRow;
        }
      }"""

    logDebug(s"Generated ColumnarIterator:\n${CodeFormatter.format(code)}")

    CodeGenerator.compile(code).generate(Array.empty).asInstanceOf[ColumnarIterator]
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy