All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.catalyst.expressions.Projection.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
/*
 * Changes for SnappyData data platform.
 *
 * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you
 * may not use this file except in compliance with the License. You
 * may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied. See the License for the specific language governing
 * permissions and limitations under the License. See accompanying
 * LICENSE file.
 */

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.types.{DataType, Decimal, StructType, _}
import org.apache.spark.sql.types.{DataType, StructType}

/**
 * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
 * @param expressions a sequence of expressions that determine the value of each column of the
 *                    output row.
 */
class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
  def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
    this(expressions.map(BindReferences.bindReference(_, inputSchema)))

  expressions.foreach(_.foreach {
    case n: Nondeterministic => n.setInitialValues()
    case _ =>
  })

  // null check is required for when Kryo invokes the no-arg constructor.
  protected val exprArray = if (expressions != null) expressions.toArray else null

  def apply(input: InternalRow): InternalRow = {
    val outputArray = new Array[Any](exprArray.length)
    var i = 0
    while (i < exprArray.length) {
      outputArray(i) = exprArray(i).eval(input)
      i += 1
    }
    new GenericInternalRow(outputArray)
  }

  override def toString(): String = s"Row => [${exprArray.mkString(",")}]"
}

/**
 * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified
 * expressions.
 * @param expressions a sequence of expressions that determine the value of each column of the
 *                    output row.
 */
case class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection {
  def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
    this(expressions.map(BindReferences.bindReference(_, inputSchema)))

  private[this] val buffer = new Array[Any](expressions.size)

  expressions.foreach(_.foreach {
    case n: Nondeterministic => n.setInitialValues()
    case _ =>
  })
  private var targetUnsafe = false
  type UnsafeSetter = (UnsafeRow, Any) => Unit
  private var setters: Array[UnsafeSetter] = _
  private[this] val exprArray = expressions.toArray
  private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.length)
  def currentValue: InternalRow = mutableRow


  override def target(row: MutableRow): MutableProjection = {
    mutableRow = row
    targetUnsafe = row match {
      case _: UnsafeRow =>
        if (setters == null) {
          setters = Array.ofDim[UnsafeSetter](exprArray.length)
          for (i <- exprArray.indices) {
            setters(i) = exprArray(i).dataType match {
              case IntegerType => (target: UnsafeRow, value: Any) =>
                target.setInt(i, value.asInstanceOf[Int])
              case LongType => (target: UnsafeRow, value: Any) =>
                target.setLong(i, value.asInstanceOf[Long])
              case DoubleType => (target: UnsafeRow, value: Any) =>
                target.setDouble(i, value.asInstanceOf[Double])
              case FloatType => (target: UnsafeRow, value: Any) =>
                target.setFloat(i, value.asInstanceOf[Float])
              case NullType => (target: UnsafeRow, value: Any) =>
                target.setNullAt(i)
              case BooleanType => (target: UnsafeRow, value: Any) =>
                target.setBoolean(i, value.asInstanceOf[Boolean])
              case ByteType => (target: UnsafeRow, value: Any) =>
                target.setByte(i, value.asInstanceOf[Byte])
              case ShortType => (target: UnsafeRow, value: Any) =>
                target.setShort(i, value.asInstanceOf[Short])
            }
          }
        }
        true
      case _ => false
    }

    this
  }

  override def apply(input: InternalRow): InternalRow = {
    var i = 0
    while (i < exprArray.length) {
      // Store the result into buffer first, to make the projection atomic (needed by aggregation)
      buffer(i) = exprArray(i).eval(input)
      i += 1
    }
    i = 0
    while (i < exprArray.length) {
      if (targetUnsafe) {
        setters(i)(mutableRow.asInstanceOf[UnsafeRow], buffer(i))
      } else {
        mutableRow(i) = buffer(i)
      }
      i += 1
    }
    mutableRow
  }
}

/**
 * A projection that returns UnsafeRow.
 */
abstract class UnsafeProjection extends Projection {
  override def apply(row: InternalRow): UnsafeRow
}

object UnsafeProjection {

  /**
   * Returns an UnsafeProjection for given StructType.
   */
  def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType))

  /**
   * Returns an UnsafeProjection for given Array of DataTypes.
   */
  def create(fields: Array[DataType]): UnsafeProjection = {
    create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)))
  }

  /**
   * Returns an UnsafeProjection for given sequence of Expressions (bounded).
   */
  def create(exprs: Seq[Expression]): UnsafeProjection = {
    val unsafeExprs = exprs.map(_ transform {
      case CreateStruct(children) => CreateStructUnsafe(children)
      case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
    })
    GenerateUnsafeProjection.generate(unsafeExprs)
  }

  def create(expr: Expression): UnsafeProjection = create(Seq(expr))

  /**
   * Returns an UnsafeProjection for given sequence of Expressions, which will be bound to
   * `inputSchema`.
   */
  def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
    create(exprs.map(BindReferences.bindReference(_, inputSchema)))
  }

  /**
    * Same as other create()'s but allowing enabling/disabling subexpression elimination.
    * TODO: refactor the plumbing and clean this up.
    */
  def create(
      exprs: Seq[Expression],
      inputSchema: Seq[Attribute],
      subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
    val e = exprs.map(BindReferences.bindReference(_, inputSchema))
      .map(_ transform {
        case CreateStruct(children) => CreateStructUnsafe(children)
        case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
    })
    GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled)
  }
}

/**
 * A projection that could turn UnsafeRow into GenericInternalRow
 */
object FromUnsafeProjection {

  /**
   * Returns an Projection for given StructType.
   */
  def apply(schema: StructType): Projection = {
    apply(schema.fields.map(_.dataType))
  }

  /**
   * Returns an UnsafeProjection for given Array of DataTypes.
   */
  def apply(fields: Seq[DataType]): Projection = {
    create(fields.zipWithIndex.map(x => {
      new BoundReference(x._2, x._1, true)
    }))
  }

  /**
   * Returns an Projection for given sequence of Expressions (bounded).
   */
  private def create(exprs: Seq[Expression]): Projection = {
    GenerateSafeProjection.generate(exprs)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy