All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.catalyst.expressions.objects.objects.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst.expressions.objects

import java.lang.reflect.{Method, Modifier}

import scala.collection.JavaConverters._
import scala.collection.mutable.Builder
import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.Try

import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.serializer._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.util.Utils

/**
 * Common base class for [[StaticInvoke]], [[Invoke]], and [[NewInstance]].
 */
trait InvokeLike extends Expression with NonSQLExpression {

  def arguments: Seq[Expression]

  def propagateNull: Boolean

  protected lazy val needNullCheck: Boolean = propagateNull && arguments.exists(_.nullable)

  /**
   * Prepares codes for arguments.
   *
   * - generate codes for argument.
   * - use ctx.splitExpressions() to not exceed 64kb JVM limit while preparing arguments.
   * - avoid some of nullability checking which are not needed because the expression is not
   *   nullable.
   * - when needNullCheck == true, short circuit if we found one of arguments is null because
   *   preparing rest of arguments can be skipped in the case.
   *
   * @param ctx a [[CodegenContext]]
   * @return (code to prepare arguments, argument string, result of argument null check)
   */
  def prepareArguments(ctx: CodegenContext): (String, String, ExprValue) = {

    val resultIsNull = if (needNullCheck) {
      val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull")
      JavaCode.isNullGlobal(resultIsNull)
    } else {
      FalseLiteral
    }
    val argValues = arguments.map { e =>
      val argValue = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "argValue")
      argValue
    }

    val argCodes = if (needNullCheck) {
      val reset = s"$resultIsNull = false;"
      val argCodes = arguments.zipWithIndex.map { case (e, i) =>
        val expr = e.genCode(ctx)
        val updateResultIsNull = if (e.nullable) {
          s"$resultIsNull = ${expr.isNull};"
        } else {
          ""
        }
        s"""
          if (!$resultIsNull) {
            ${expr.code}
            $updateResultIsNull
            ${argValues(i)} = ${expr.value};
          }
        """
      }
      reset +: argCodes
    } else {
      arguments.zipWithIndex.map { case (e, i) =>
        val expr = e.genCode(ctx)
        s"""
          ${expr.code}
          ${argValues(i)} = ${expr.value};
        """
      }
    }
    val argCode = ctx.splitExpressionsWithCurrentInputs(argCodes)

    (argCode, argValues.mkString(", "), resultIsNull)
  }

  /**
   * Evaluate each argument with a given row, invoke a method with a given object and arguments,
   * and cast a return value if the return type can be mapped to a Java Boxed type
   *
   * @param obj the object for the method to be called. If null, perform s static method call
   * @param method the method object to be called
   * @param arguments the arguments used for the method call
   * @param input the row used for evaluating arguments
   * @param dataType the data type of the return object
   * @return the return object of a method call
   */
  def invoke(
      obj: Any,
      method: Method,
      arguments: Seq[Expression],
      input: InternalRow,
      dataType: DataType): Any = {
    val args = arguments.map(e => e.eval(input).asInstanceOf[Object])
    if (needNullCheck && args.exists(_ == null)) {
      // return null if one of arguments is null
      null
    } else {
      val ret = method.invoke(obj, args: _*)
      val boxedClass = ScalaReflection.typeBoxedJavaMapping.get(dataType)
      if (boxedClass.isDefined) {
        boxedClass.get.cast(ret)
      } else {
        ret
      }
    }
  }

  final def findMethod(cls: Class[_], functionName: String, argClasses: Seq[Class[_]]): Method = {
    // Looking with function name + argument classes first.
    try {
      cls.getMethod(functionName, argClasses: _*)
    } catch {
      case _: NoSuchMethodException =>
        // For some cases, e.g. arg class is Object, `getMethod` cannot find the method.
        // We look at function name + argument length
        val m = cls.getMethods.filter { m =>
          m.getName == functionName && m.getParameterCount == arguments.length
        }
        if (m.isEmpty) {
          sys.error(s"Couldn't find $functionName on $cls")
        } else if (m.length > 1) {
          // More than one matched method signature. Exclude synthetic one, e.g. generic one.
          val realMethods = m.filter(!_.isSynthetic)
          if (realMethods.length > 1) {
            // Ambiguous case, we don't know which method to choose, just fail it.
            sys.error(s"Found ${realMethods.length} $functionName on $cls")
          } else {
            realMethods.head
          }
        } else {
          m.head
        }
    }
  }
}

/**
 * Common trait for [[DecodeUsingSerializer]] and [[EncodeUsingSerializer]]
 */
trait SerializerSupport {
  /**
   * If true, Kryo serialization is used, otherwise the Java one is used
   */
  val kryo: Boolean

  /**
   * The serializer instance to be used for serialization/deserialization in interpreted execution
   */
  lazy val serializerInstance: SerializerInstance = SerializerSupport.newSerializer(kryo)

  /**
   * Adds a immutable state to the generated class containing a reference to the serializer.
   * @return a string containing the name of the variable referencing the serializer
   */
  def addImmutableSerializerIfNeeded(ctx: CodegenContext): String = {
    val (serializerInstance, serializerInstanceClass) = {
      if (kryo) {
        ("kryoSerializer",
          classOf[KryoSerializerInstance].getName)
      } else {
        ("javaSerializer",
          classOf[JavaSerializerInstance].getName)
      }
    }
    val newSerializerMethod = s"${classOf[SerializerSupport].getName}$$.MODULE$$.newSerializer"
    // Code to initialize the serializer
    ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializerInstance, v =>
      s"""
         |$v = ($serializerInstanceClass) $newSerializerMethod($kryo);
       """.stripMargin)
    serializerInstance
  }
}

object SerializerSupport {
  /**
   * It creates a new `SerializerInstance` which is either a `KryoSerializerInstance` (is
   * `useKryo` is set to `true`) or a `JavaSerializerInstance`.
   */
  def newSerializer(useKryo: Boolean): SerializerInstance = {
    // try conf from env, otherwise create a new one
    val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
    val s = if (useKryo) {
      new KryoSerializer(conf)
    } else {
      new JavaSerializer(conf)
    }
    s.newInstance()
  }
}

/**
 * Invokes a static function, returning the result.  By default, any of the arguments being null
 * will result in returning null instead of calling the function.
 *
 * @param staticObject The target of the static call.  This can either be the object itself
 *                     (methods defined on scala objects), or the class object
 *                     (static methods defined in java).
 * @param dataType The expected return type of the function call
 * @param functionName The name of the method to call.
 * @param arguments An optional list of expressions to pass as arguments to the function.
 * @param propagateNull When true, and any of the arguments is null, null will be returned instead
 *                      of calling the function.
 * @param returnNullable When false, indicating the invoked method will always return
 *                       non-null value.
 */
case class StaticInvoke(
    staticObject: Class[_],
    dataType: DataType,
    functionName: String,
    arguments: Seq[Expression] = Nil,
    propagateNull: Boolean = true,
    returnNullable: Boolean = true) extends InvokeLike {

  val objectName = staticObject.getName.stripSuffix("$")
  val cls = if (staticObject.getName == objectName) {
    staticObject
  } else {
    Utils.classForName(objectName)
  }

  override def nullable: Boolean = needNullCheck || returnNullable
  override def children: Seq[Expression] = arguments

  lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
  @transient lazy val method = findMethod(cls, functionName, argClasses)

  override def eval(input: InternalRow): Any = {
    invoke(null, method, arguments, input, dataType)
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val javaType = CodeGenerator.javaType(dataType)

    val (argCode, argString, resultIsNull) = prepareArguments(ctx)

    val callFunc = s"$objectName.$functionName($argString)"

    val prepareIsNull = if (nullable) {
      s"boolean ${ev.isNull} = $resultIsNull;"
    } else {
      ev.isNull = FalseLiteral
      ""
    }

    val evaluate = if (returnNullable) {
      if (CodeGenerator.defaultValue(dataType) == "null") {
        s"""
          ${ev.value} = $callFunc;
          ${ev.isNull} = ${ev.value} == null;
        """
      } else {
        val boxedResult = ctx.freshName("boxedResult")
        s"""
          ${CodeGenerator.boxedType(dataType)} $boxedResult = $callFunc;
          ${ev.isNull} = $boxedResult == null;
          if (!${ev.isNull}) {
            ${ev.value} = $boxedResult;
          }
        """
      }
    } else {
      s"${ev.value} = $callFunc;"
    }

    val code = code"""
      $argCode
      $prepareIsNull
      $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
      if (!$resultIsNull) {
        $evaluate
      }
     """
    ev.copy(code = code)
  }
}

/**
 * Calls the specified function on an object, optionally passing arguments.  If the `targetObject`
 * expression evaluates to null then null will be returned.
 *
 * In some cases, due to erasure, the schema may expect a primitive type when in fact the method
 * is returning java.lang.Object.  In this case, we will generate code that attempts to unbox the
 * value automatically.
 *
 * @param targetObject An expression that will return the object to call the method on.
 * @param functionName The name of the method to call.
 * @param dataType The expected return type of the function.
 * @param arguments An optional list of expressions, whose evaluation will be passed to the
  *                 function.
 * @param propagateNull When true, and any of the arguments is null, null will be returned instead
 *                      of calling the function.
 * @param returnNullable When false, indicating the invoked method will always return
 *                       non-null value.
 */
case class Invoke(
    targetObject: Expression,
    functionName: String,
    dataType: DataType,
    arguments: Seq[Expression] = Nil,
    propagateNull: Boolean = true,
    returnNullable : Boolean = true) extends InvokeLike {

  lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)

  override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable
  override def children: Seq[Expression] = targetObject +: arguments

  private lazy val encodedFunctionName = TermName(functionName).encodedName.toString

  @transient lazy val method = targetObject.dataType match {
    case ObjectType(cls) =>
      Some(findMethod(cls, encodedFunctionName, argClasses))
    case _ => None
  }

  override def eval(input: InternalRow): Any = {
    val obj = targetObject.eval(input)
    if (obj == null) {
      // return null if obj is null
      null
    } else {
      val invokeMethod = if (method.isDefined) {
        method.get
      } else {
        obj.getClass.getMethod(functionName, argClasses: _*)
      }
      invoke(obj, invokeMethod, arguments, input, dataType)
    }
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val javaType = CodeGenerator.javaType(dataType)
    val obj = targetObject.genCode(ctx)

    val (argCode, argString, resultIsNull) = prepareArguments(ctx)

    val returnPrimitive = method.isDefined && method.get.getReturnType.isPrimitive
    val needTryCatch = method.isDefined && method.get.getExceptionTypes.nonEmpty

    def getFuncResult(resultVal: String, funcCall: String): String = if (needTryCatch) {
      s"""
        try {
          $resultVal = $funcCall;
        } catch (Exception e) {
          org.apache.spark.unsafe.Platform.throwException(e);
        }
      """
    } else {
      s"$resultVal = $funcCall;"
    }

    val evaluate = if (returnPrimitive) {
      getFuncResult(ev.value, s"${obj.value}.$encodedFunctionName($argString)")
    } else {
      val funcResult = ctx.freshName("funcResult")
      // If the function can return null, we do an extra check to make sure our null bit is still
      // set correctly.
      val assignResult = if (!returnNullable) {
        s"${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;"
      } else {
        s"""
          if ($funcResult != null) {
            ${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;
          } else {
            ${ev.isNull} = true;
          }
        """
      }
      s"""
        Object $funcResult = null;
        ${getFuncResult(funcResult, s"${obj.value}.$encodedFunctionName($argString)")}
        $assignResult
      """
    }

    val code = obj.code + code"""
      boolean ${ev.isNull} = true;
      $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
      if (!${obj.isNull}) {
        $argCode
        ${ev.isNull} = $resultIsNull;
        if (!${ev.isNull}) {
          $evaluate
        }
      }
     """
    ev.copy(code = code)
  }

  override def toString: String = s"$targetObject.$functionName"
}

object NewInstance {
  def apply(
      cls: Class[_],
      arguments: Seq[Expression],
      dataType: DataType,
      propagateNull: Boolean = true): NewInstance =
    new NewInstance(cls, arguments, propagateNull, dataType, None)
}

/**
 * Constructs a new instance of the given class, using the result of evaluating the specified
 * expressions as arguments.
 *
 * @param cls The class to construct.
 * @param arguments A list of expression to use as arguments to the constructor.
 * @param propagateNull When true, if any of the arguments is null, then null will be returned
 *                      instead of trying to construct the object.
 * @param dataType The type of object being constructed, as a Spark SQL datatype.  This allows you
 *                 to manually specify the type when the object in question is a valid internal
 *                 representation (i.e. ArrayData) instead of an object.
 * @param outerPointer If the object being constructed is an inner class, the outerPointer for the
 *                     containing class must be specified. This parameter is defined as an optional
 *                     function, which allows us to get the outer pointer lazily,and it's useful if
 *                     the inner class is defined in REPL.
 */
case class NewInstance(
    cls: Class[_],
    arguments: Seq[Expression],
    propagateNull: Boolean,
    dataType: DataType,
    outerPointer: Option[() => AnyRef]) extends InvokeLike {
  private val className = cls.getName

  override def nullable: Boolean = needNullCheck

  override def children: Seq[Expression] = arguments

  override lazy val resolved: Boolean = {
    // If the class to construct is an inner class, we need to get its outer pointer, or this
    // expression should be regarded as unresolved.
    // Note that static inner classes (e.g., inner classes within Scala objects) don't need
    // outer pointer registration.
    val needOuterPointer =
      outerPointer.isEmpty && Utils.isMemberClass(cls) && !Modifier.isStatic(cls.getModifiers)
    childrenResolved && !needOuterPointer
  }

  @transient private lazy val constructor: (Seq[AnyRef]) => Any = {
    val paramTypes = ScalaReflection.expressionJavaClasses(arguments)
    val getConstructor = (paramClazz: Seq[Class[_]]) => {
      ScalaReflection.findConstructor(cls, paramClazz).getOrElse {
        sys.error(s"Couldn't find a valid constructor on $cls")
      }
    }
    outerPointer.map { p =>
      val outerObj = p()
      val c = getConstructor(outerObj.getClass +: paramTypes)
      (args: Seq[AnyRef]) => {
        c.newInstance(outerObj +: args: _*)
      }
    }.getOrElse {
      val c = getConstructor(paramTypes)
      (args: Seq[AnyRef]) => {
        c.newInstance(args: _*)
      }
    }
  }

  override def eval(input: InternalRow): Any = {
    val argValues = arguments.map(_.eval(input))
    constructor(argValues.map(_.asInstanceOf[AnyRef]))
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val javaType = CodeGenerator.javaType(dataType)

    val (argCode, argString, resultIsNull) = prepareArguments(ctx)

    val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx))

    ev.isNull = resultIsNull

    val constructorCall = outer.map { gen =>
      s"${gen.value}.new ${Utils.getSimpleName(cls)}($argString)"
    }.getOrElse {
      s"new $className($argString)"
    }

    val code = code"""
      $argCode
      ${outer.map(_.code).getOrElse("")}
      final $javaType ${ev.value} = ${ev.isNull} ?
        ${CodeGenerator.defaultValue(dataType)} : $constructorCall;
    """
    ev.copy(code = code)
  }

  override def toString: String = s"newInstance($cls)"
}

/**
 * Given an expression that returns on object of type `Option[_]`, this expression unwraps the
 * option into the specified Spark SQL datatype.  In the case of `None`, the nullbit is set instead.
 *
 * @param dataType The expected unwrapped option type.
 * @param child An expression that returns an `Option`
 */
case class UnwrapOption(
    dataType: DataType,
    child: Expression) extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {

  override def nullable: Boolean = true

  override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil

  override def eval(input: InternalRow): Any = {
    val inputObject = child.eval(input)
    if (inputObject == null) {
      null
    } else {
      inputObject.asInstanceOf[Option[_]].orNull
    }
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val javaType = CodeGenerator.javaType(dataType)
    val inputObject = child.genCode(ctx)

    val code = inputObject.code + code"""
      final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty();
      $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} :
        (${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get();
    """
    ev.copy(code = code)
  }
}

/**
 * Converts the result of evaluating `child` into an option, checking both the isNull bit and
 * (in the case of reference types) equality with null.
 *
 * @param child The expression to evaluate and wrap.
 * @param optType The type of this option.
 */
case class WrapOption(child: Expression, optType: DataType)
  extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {

  override def dataType: DataType = ObjectType(classOf[Option[_]])

  override def nullable: Boolean = false

  override def inputTypes: Seq[AbstractDataType] = optType :: Nil

  override def eval(input: InternalRow): Any = Option(child.eval(input))

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val inputObject = child.genCode(ctx)

    val code = inputObject.code + code"""
      scala.Option ${ev.value} =
        ${inputObject.isNull} ?
        scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value});
    """
    ev.copy(code = code, isNull = FalseLiteral)
  }
}

/**
 * A placeholder for the loop variable used in [[MapObjects]].  This should never be constructed
 * manually, but will instead be passed into the provided lambda function.
 */
case class LambdaVariable(
    value: String,
    isNull: String,
    dataType: DataType,
    nullable: Boolean = true) extends LeafExpression with NonSQLExpression {

  private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType)

  // Interpreted execution of `LambdaVariable` always get the 0-index element from input row.
  override def eval(input: InternalRow): Any = {
    assert(input.numFields == 1,
      "The input row of interpreted LambdaVariable should have only 1 field.")
    if (nullable && input.isNullAt(0)) {
      null
    } else {
      accessor(input, 0)
    }
  }

  override def genCode(ctx: CodegenContext): ExprCode = {
    val isNullValue = if (nullable) {
      JavaCode.isNullVariable(isNull)
    } else {
      FalseLiteral
    }
    ExprCode(value = JavaCode.variable(value, dataType), isNull = isNullValue)
  }

  // This won't be called as `genCode` is overrided, just overriding it to make
  // `LambdaVariable` non-abstract.
  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev
}

/**
 * When constructing [[MapObjects]], the element type must be given, which may not be available
 * before analysis. This class acts like a placeholder for [[MapObjects]], and will be replaced by
 * [[MapObjects]] during analysis after the input data is resolved.
 * Note that, ideally we should not serialize and send unresolved expressions to executors, but
 * users may accidentally do this(e.g. mistakenly reference an encoder instance when implementing
 * Aggregator). Here we mark `function` as transient because it may reference scala Type, which is
 * not serializable. Then even users mistakenly reference unresolved expression and serialize it,
 * it's just a performance issue(more network traffic), and will not fail.
 */
case class UnresolvedMapObjects(
    @transient function: Expression => Expression,
    child: Expression,
    customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with Unevaluable {
  override lazy val resolved = false

  override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse {
    throw new UnsupportedOperationException("not resolved")
  }
}

object MapObjects {
  private val curId = new java.util.concurrent.atomic.AtomicInteger()

  /**
   * Construct an instance of MapObjects case class.
   *
   * @param function The function applied on the collection elements.
   * @param inputData An expression that when evaluated returns a collection object.
   * @param elementType The data type of elements in the collection.
   * @param elementNullable When false, indicating elements in the collection are always
   *                        non-null value.
   * @param customCollectionCls Class of the resulting collection (returning ObjectType)
   *                            or None (returning ArrayType)
   */
  def apply(
      function: Expression => Expression,
      inputData: Expression,
      elementType: DataType,
      elementNullable: Boolean = true,
      customCollectionCls: Option[Class[_]] = None): MapObjects = {
    val id = curId.getAndIncrement()
    val loopValue = s"MapObjects_loopValue$id"
    val loopIsNull = if (elementNullable) {
      s"MapObjects_loopIsNull$id"
    } else {
      "false"
    }
    val loopVar = LambdaVariable(loopValue, loopIsNull, elementType, elementNullable)
    MapObjects(
      loopValue, loopIsNull, elementType, function(loopVar), inputData, customCollectionCls)
  }
}

/**
 * Applies the given expression to every element of a collection of items, returning the result
 * as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda
 * function is expressed using catalyst expressions.
 *
 * The type of the result is determined as follows:
 * - ArrayType - when customCollectionCls is None
 * - ObjectType(collection) - when customCollectionCls contains a collection class
 *
 * The following collection ObjectTypes are currently supported on input:
 *   Seq, Array, ArrayData, java.util.List
 *
 * @param loopValue the name of the loop variable that used when iterate the collection, and used
 *                  as input for the `lambdaFunction`
 * @param loopIsNull the nullity of the loop variable that used when iterate the collection, and
 *                   used as input for the `lambdaFunction`
 * @param loopVarDataType the data type of the loop variable that used when iterate the collection,
 *                        and used as input for the `lambdaFunction`
 * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
 *                       to handle collection elements.
 * @param inputData An expression that when evaluated returns a collection object.
 * @param customCollectionCls Class of the resulting collection (returning ObjectType)
 *                            or None (returning ArrayType)
 */
case class MapObjects private(
    loopValue: String,
    loopIsNull: String,
    loopVarDataType: DataType,
    lambdaFunction: Expression,
    inputData: Expression,
    customCollectionCls: Option[Class[_]]) extends Expression with NonSQLExpression {

  override def nullable: Boolean = inputData.nullable

  override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil

  // The data with UserDefinedType are actually stored with the data type of its sqlType.
  // When we want to apply MapObjects on it, we have to use it.
  lazy private val inputDataType = inputData.dataType match {
    case u: UserDefinedType[_] => u.sqlType
    case _ => inputData.dataType
  }

  private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = {
    val row = new GenericInternalRow(1)
    inputCollection.toIterator.map { element =>
      row.update(0, element)
      lambdaFunction.eval(row)
    }
  }

  private lazy val convertToSeq: Any => Seq[_] = inputDataType match {
    case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
      _.asInstanceOf[Seq[_]]
    case ObjectType(cls) if cls.isArray =>
      _.asInstanceOf[Array[_]].toSeq
    case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
      _.asInstanceOf[java.util.List[_]].asScala
    case ObjectType(cls) if cls == classOf[Object] =>
      (inputCollection) => {
        if (inputCollection.getClass.isArray) {
          inputCollection.asInstanceOf[Array[_]].toSeq
        } else {
          inputCollection.asInstanceOf[Seq[_]]
        }
      }
    case ArrayType(et, _) =>
      _.asInstanceOf[ArrayData].toSeq[Any](et)
  }

  private lazy val mapElements: Seq[_] => Any = customCollectionCls match {
    case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
      // Scala sequence
      executeFuncOnCollection(_).toSeq
    case Some(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
      // Scala set
      executeFuncOnCollection(_).toSet
    case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
      // Java list
      if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] ||
          cls == classOf[java.util.AbstractSequentialList[_]]) {
        // Specifying non concrete implementations of `java.util.List`
        executeFuncOnCollection(_).toSeq.asJava
      } else {
        val constructors = cls.getConstructors()
        val intParamConstructor = constructors.find { constructor =>
          constructor.getParameterCount == 1 && constructor.getParameterTypes()(0) == classOf[Int]
        }
        val noParamConstructor = constructors.find { constructor =>
          constructor.getParameterCount == 0
        }

        val constructor = intParamConstructor.map { intConstructor =>
          (len: Int) => intConstructor.newInstance(len.asInstanceOf[Object])
        }.getOrElse {
          (_: Int) => noParamConstructor.get.newInstance()
        }

        // Specifying concrete implementations of `java.util.List`
        (inputs) => {
          val results = executeFuncOnCollection(inputs)
          val builder = constructor(inputs.length).asInstanceOf[java.util.List[Any]]
          results.foreach(builder.add(_))
          builder
        }
      }
    case None =>
      // array
      x => new GenericArrayData(executeFuncOnCollection(x).toArray)
    case Some(cls) =>
      throw new RuntimeException(s"class `${cls.getName}` is not supported by `MapObjects` as " +
        "resulting collection.")
  }

  override def eval(input: InternalRow): Any = {
    val inputCollection = inputData.eval(input)

    if (inputCollection == null) {
      return null
    }
    mapElements(convertToSeq(inputCollection))
  }

  override def dataType: DataType =
    customCollectionCls.map(ObjectType.apply).getOrElse(
      ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable))

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val elementJavaType = CodeGenerator.javaType(loopVarDataType)
    ctx.addMutableState(elementJavaType, loopValue, forceInline = true, useFreshName = false)
    val genInputData = inputData.genCode(ctx)
    val genFunction = lambdaFunction.genCode(ctx)
    val dataLength = ctx.freshName("dataLength")
    val convertedArray = ctx.freshName("convertedArray")
    val loopIndex = ctx.freshName("loopIndex")

    val convertedType = CodeGenerator.boxedType(lambdaFunction.dataType)

    // Because of the way Java defines nested arrays, we have to handle the syntax specially.
    // Specifically, we have to insert the [$dataLength] in between the type and any extra nested
    // array declarations (i.e. new String[1][]).
    val arrayConstructor = if (convertedType contains "[]") {
      val rawType = convertedType.takeWhile(_ != '[')
      val arrayPart = convertedType.reverse.takeWhile(c => c == '[' || c == ']').reverse
      s"new $rawType[$dataLength]$arrayPart"
    } else {
      s"new $convertedType[$dataLength]"
    }

    // In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type
    // of input collection at runtime for this case.
    val seq = ctx.freshName("seq")
    val array = ctx.freshName("array")
    val determineCollectionType = inputData.dataType match {
      case ObjectType(cls) if cls == classOf[Object] =>
        val seqClass = classOf[Seq[_]].getName
        s"""
          $seqClass $seq = null;
          $elementJavaType[] $array = null;
          if (${genInputData.value}.getClass().isArray()) {
            $array = ($elementJavaType[]) ${genInputData.value};
          } else {
            $seq = ($seqClass) ${genInputData.value};
          }
         """
      case _ => ""
    }

    // `MapObjects` generates a while loop to traverse the elements of the input collection. We
    // need to take care of Seq and List because they may have O(n) complexity for indexed accessing
    // like `list.get(1)`. Here we use Iterator to traverse Seq and List.
    val (getLength, prepareLoop, getLoopVar) = inputDataType match {
      case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
        val it = ctx.freshName("it")
        (
          s"${genInputData.value}.size()",
          s"scala.collection.Iterator $it = ${genInputData.value}.toIterator();",
          s"$it.next()"
        )
      case ObjectType(cls) if cls.isArray =>
        (
          s"${genInputData.value}.length",
          "",
          s"${genInputData.value}[$loopIndex]"
        )
      case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
        val it = ctx.freshName("it")
        (
          s"${genInputData.value}.size()",
          s"java.util.Iterator $it = ${genInputData.value}.iterator();",
          s"$it.next()"
        )
      case ArrayType(et, _) =>
        (
          s"${genInputData.value}.numElements()",
          "",
          CodeGenerator.getValue(genInputData.value, et, loopIndex)
        )
      case ObjectType(cls) if cls == classOf[Object] =>
        val it = ctx.freshName("it")
        (
          s"$seq == null ? $array.length : $seq.size()",
          s"scala.collection.Iterator $it = $seq == null ? null : $seq.toIterator();",
          s"$it == null ? $array[$loopIndex] : $it.next()"
        )
    }

    // Make a copy of the data if it's unsafe-backed
    def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
      s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value"
    val genFunctionValue: String = lambdaFunction.dataType match {
      case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value)
      case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value)
      case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value)
      case _ => genFunction.value
    }

    val loopNullCheck = if (loopIsNull != "false") {
      ctx.addMutableState(
        CodeGenerator.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false)
      inputDataType match {
        case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
        case _ => s"$loopIsNull = $loopValue == null;"
      }
    } else {
      ""
    }

    val (initCollection, addElement, getResult): (String, String => String, String) =
      customCollectionCls match {
        case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) ||
          classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
          // Scala sequence or set
          val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()"
          val builder = ctx.freshName("collectionBuilder")
          (
            s"""
               ${classOf[Builder[_, _]].getName} $builder = $getBuilder;
               $builder.sizeHint($dataLength);
             """,
            genValue => s"$builder.$$plus$$eq($genValue);",
            s"(${cls.getName}) $builder.result();"
          )
        case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
          // Java list
          val builder = ctx.freshName("collectionBuilder")
          (
            if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] ||
              cls == classOf[java.util.AbstractSequentialList[_]]) {
              s"${cls.getName} $builder = new java.util.ArrayList($dataLength);"
            } else {
              val param = Try(cls.getConstructor(Integer.TYPE)).map(_ => dataLength).getOrElse("")
              s"${cls.getName} $builder = new ${cls.getName}($param);"
            },
            genValue => s"$builder.add($genValue);",
            s"$builder;"
          )
        case None =>
          // array
          (
            s"""
               $convertedType[] $convertedArray = null;
               $convertedArray = $arrayConstructor;
             """,
            genValue => s"$convertedArray[$loopIndex] = $genValue;",
            s"new ${classOf[GenericArrayData].getName}($convertedArray);"
          )
      }

    val code = genInputData.code + code"""
      ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};

      if (!${genInputData.isNull}) {
        $determineCollectionType
        int $dataLength = $getLength;
        $initCollection

        int $loopIndex = 0;
        $prepareLoop
        while ($loopIndex < $dataLength) {
          $loopValue = ($elementJavaType) ($getLoopVar);
          $loopNullCheck

          ${genFunction.code}
          if (${genFunction.isNull}) {
            ${addElement("null")}
          } else {
            ${addElement(genFunctionValue)}
          }

          $loopIndex += 1;
        }

        ${ev.value} = $getResult
      }
    """
    ev.copy(code = code, isNull = genInputData.isNull)
  }
}

object CatalystToExternalMap {
  private val curId = new java.util.concurrent.atomic.AtomicInteger()

  /**
   * Construct an instance of CatalystToExternalMap case class.
   *
   * @param keyFunction The function applied on the key collection elements.
   * @param valueFunction The function applied on the value collection elements.
   * @param inputData An expression that when evaluated returns a map object.
   * @param collClass The type of the resulting collection.
   */
  def apply(
      keyFunction: Expression => Expression,
      valueFunction: Expression => Expression,
      inputData: Expression,
      collClass: Class[_]): CatalystToExternalMap = {
    val id = curId.getAndIncrement()
    val keyLoopValue = s"CatalystToExternalMap_keyLoopValue$id"
    val mapType = inputData.dataType.asInstanceOf[MapType]
    val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false)
    val valueLoopValue = s"CatalystToExternalMap_valueLoopValue$id"
    val valueLoopIsNull = if (mapType.valueContainsNull) {
      s"CatalystToExternalMap_valueLoopIsNull$id"
    } else {
      "false"
    }
    val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType)
    CatalystToExternalMap(
      keyLoopValue, keyFunction(keyLoopVar),
      valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar),
      inputData, collClass)
  }
}

/**
 * Expression used to convert a Catalyst Map to an external Scala Map.
 * The collection is constructed using the associated builder, obtained by calling `newBuilder`
 * on the collection's companion object.
 *
 * @param keyLoopValue the name of the loop variable that is used when iterating over the key
 *                     collection, and which is used as input for the `keyLambdaFunction`
 * @param keyLambdaFunction A function that takes the `keyLoopVar` as input, and is used as
 *                          a lambda function to handle collection elements.
 * @param valueLoopValue the name of the loop variable that is used when iterating over the value
 *                       collection, and which is used as input for the `valueLambdaFunction`
 * @param valueLoopIsNull the nullability of the loop variable that is used when iterating over
 *                        the value collection, and which is used as input for the
 *                        `valueLambdaFunction`
 * @param valueLambdaFunction A function that takes the `valueLoopVar` as input, and is used as
 *                            a lambda function to handle collection elements.
 * @param inputData An expression that when evaluated returns a map object.
 * @param collClass The type of the resulting collection.
 */
case class CatalystToExternalMap private(
    keyLoopValue: String,
    keyLambdaFunction: Expression,
    valueLoopValue: String,
    valueLoopIsNull: String,
    valueLambdaFunction: Expression,
    inputData: Expression,
    collClass: Class[_]) extends Expression with NonSQLExpression {

  override def nullable: Boolean = inputData.nullable

  override def children: Seq[Expression] =
    keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil

  private lazy val inputMapType = inputData.dataType.asInstanceOf[MapType]

  private lazy val keyConverter =
    CatalystTypeConverters.createToScalaConverter(inputMapType.keyType)
  private lazy val valueConverter =
    CatalystTypeConverters.createToScalaConverter(inputMapType.valueType)

  private lazy val (newMapBuilderMethod, moduleField) = {
    val clazz = Utils.classForName(collClass.getCanonicalName + "$")
    (clazz.getMethod("newBuilder"), clazz.getField("MODULE$").get(null))
  }

  private def newMapBuilder(): Builder[AnyRef, AnyRef] = {
    newMapBuilderMethod.invoke(moduleField).asInstanceOf[Builder[AnyRef, AnyRef]]
  }

  override def eval(input: InternalRow): Any = {
    val result = inputData.eval(input).asInstanceOf[MapData]
    if (result != null) {
      val builder = newMapBuilder()
      builder.sizeHint(result.numElements())
      val keyArray = result.keyArray()
      val valueArray = result.valueArray()
      var i = 0
      while (i < result.numElements()) {
        val key = keyConverter(keyArray.get(i, inputMapType.keyType))
        val value = valueConverter(valueArray.get(i, inputMapType.valueType))
        builder += Tuple2(key, value)
        i += 1
      }
      builder.result()
    } else {
      null
    }
  }

  override def dataType: DataType = ObjectType(collClass)

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    // The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
    // When we want to apply MapObjects on it, we have to use it.
    def inputDataType(dataType: DataType) = dataType match {
      case p: PythonUserDefinedType => p.sqlType
      case _ => dataType
    }

    val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType]
    val keyElementJavaType = CodeGenerator.javaType(mapType.keyType)
    ctx.addMutableState(keyElementJavaType, keyLoopValue, forceInline = true, useFreshName = false)
    val genKeyFunction = keyLambdaFunction.genCode(ctx)
    val valueElementJavaType = CodeGenerator.javaType(mapType.valueType)
    ctx.addMutableState(valueElementJavaType, valueLoopValue, forceInline = true,
      useFreshName = false)
    val genValueFunction = valueLambdaFunction.genCode(ctx)
    val genInputData = inputData.genCode(ctx)
    val dataLength = ctx.freshName("dataLength")
    val loopIndex = ctx.freshName("loopIndex")
    val tupleLoopValue = ctx.freshName("tupleLoopValue")
    val builderValue = ctx.freshName("builderValue")

    val getLength = s"${genInputData.value}.numElements()"

    val keyArray = ctx.freshName("keyArray")
    val valueArray = ctx.freshName("valueArray")
    val getKeyArray =
      s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();"
    val getKeyLoopVar = CodeGenerator.getValue(keyArray, inputDataType(mapType.keyType), loopIndex)
    val getValueArray =
      s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();"
    val getValueLoopVar = CodeGenerator.getValue(
      valueArray, inputDataType(mapType.valueType), loopIndex)

    // Make a copy of the data if it's unsafe-backed
    def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
      s"$value instanceof ${clazz.getSimpleName}? $value.copy() : $value"
    def genFunctionValue(lambdaFunction: Expression, genFunction: ExprCode) =
      lambdaFunction.dataType match {
        case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value)
        case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value)
        case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value)
        case _ => genFunction.value
      }
    val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction)
    val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction)

    val valueLoopNullCheck = if (valueLoopIsNull != "false") {
      ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true,
        useFreshName = false)
      s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);"
    } else {
      ""
    }

    val builderClass = classOf[Builder[_, _]].getName
    val constructBuilder = s"""
      $builderClass $builderValue = ${collClass.getName}$$.MODULE$$.newBuilder();
      $builderValue.sizeHint($dataLength);
    """

    val tupleClass = classOf[(_, _)].getName
    val appendToBuilder = s"""
      $tupleClass $tupleLoopValue;

      if (${genValueFunction.isNull}) {
        $tupleLoopValue = new $tupleClass($genKeyFunctionValue, null);
      } else {
        $tupleLoopValue = new $tupleClass($genKeyFunctionValue, $genValueFunctionValue);
      }

      $builderValue.$$plus$$eq($tupleLoopValue);
     """
    val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();"

    val code = genInputData.code + code"""
      ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};

      if (!${genInputData.isNull}) {
        int $dataLength = $getLength;
        $constructBuilder
        $getKeyArray
        $getValueArray

        int $loopIndex = 0;
        while ($loopIndex < $dataLength) {
          $keyLoopValue = ($keyElementJavaType) ($getKeyLoopVar);
          $valueLoopValue = ($valueElementJavaType) ($getValueLoopVar);
          $valueLoopNullCheck

          ${genKeyFunction.code}
          ${genValueFunction.code}

          $appendToBuilder

          $loopIndex += 1;
        }

        $getBuilderResult
      }
    """
    ev.copy(code = code, isNull = genInputData.isNull)
  }
}

object ExternalMapToCatalyst {
  private val curId = new java.util.concurrent.atomic.AtomicInteger()

  def apply(
      inputMap: Expression,
      keyType: DataType,
      keyConverter: Expression => Expression,
      keyNullable: Boolean,
      valueType: DataType,
      valueConverter: Expression => Expression,
      valueNullable: Boolean): ExternalMapToCatalyst = {
    val id = curId.getAndIncrement()
    val keyName = "ExternalMapToCatalyst_key" + id
    val keyIsNull = if (keyNullable) {
      "ExternalMapToCatalyst_key_isNull" + id
    } else {
      "false"
    }
    val valueName = "ExternalMapToCatalyst_value" + id
    val valueIsNull = if (valueNullable) {
      "ExternalMapToCatalyst_value_isNull" + id
    } else {
      "false"
    }

    ExternalMapToCatalyst(
      keyName,
      keyIsNull,
      keyType,
      keyConverter(LambdaVariable(keyName, keyIsNull, keyType, keyNullable)),
      valueName,
      valueIsNull,
      valueType,
      valueConverter(LambdaVariable(valueName, valueIsNull, valueType, valueNullable)),
      inputMap
    )
  }
}

/**
 * Converts a Scala/Java map object into catalyst format, by applying the key/value converter when
 * iterate the map.
 *
 * @param key the name of the map key variable that used when iterate the map, and used as input for
 *            the `keyConverter`
 * @param keyIsNull the nullability of the map key variable that used when iterate the map, and
 *                  used as input for the `keyConverter`
 * @param keyType the data type of the map key variable that used when iterate the map, and used as
 *                input for the `keyConverter`
 * @param keyConverter A function that take the `key` as input, and converts it to catalyst format.
 * @param value the name of the map value variable that used when iterate the map, and used as input
 *              for the `valueConverter`
 * @param valueIsNull the nullability of the map value variable that used when iterate the map, and
 *                    used as input for the `valueConverter`
 * @param valueType the data type of the map value variable that used when iterate the map, and
 *                  used as input for the `valueConverter`
 * @param valueConverter A function that take the `value` as input, and converts it to catalyst
 *                       format.
 * @param child An expression that when evaluated returns the input map object.
 */
case class ExternalMapToCatalyst private(
    key: String,
    keyIsNull: String,
    keyType: DataType,
    keyConverter: Expression,
    value: String,
    valueIsNull: String,
    valueType: DataType,
    valueConverter: Expression,
    child: Expression)
  extends UnaryExpression with NonSQLExpression {

  override def foldable: Boolean = false

  override def dataType: MapType = MapType(
    keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable)

  private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = {
    val rowBuffer = InternalRow.fromSeq(Array[Any](1))
    def rowWrapper(data: Any): InternalRow = {
      rowBuffer.update(0, data)
      rowBuffer
    }

    child.dataType match {
      case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) =>
        (input: Any) => {
          val data = input.asInstanceOf[java.util.Map[Any, Any]]
          val keys = new Array[Any](data.size)
          val values = new Array[Any](data.size)
          val iter = data.entrySet().iterator()
          var i = 0
          while (iter.hasNext) {
            val entry = iter.next()
            val (key, value) = (entry.getKey, entry.getValue)
            keys(i) = if (key != null) {
              keyConverter.eval(rowWrapper(key))
            } else {
              throw new RuntimeException("Cannot use null as map key!")
            }
            values(i) = if (value != null) {
              valueConverter.eval(rowWrapper(value))
            } else {
              null
            }
            i += 1
          }
          (keys, values)
        }

      case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) =>
        (input: Any) => {
          val data = input.asInstanceOf[scala.collection.Map[Any, Any]]
          val keys = new Array[Any](data.size)
          val values = new Array[Any](data.size)
          var i = 0
          for ((key, value) <- data) {
            keys(i) = if (key != null) {
              keyConverter.eval(rowWrapper(key))
            } else {
              throw new RuntimeException("Cannot use null as map key!")
            }
            values(i) = if (value != null) {
              valueConverter.eval(rowWrapper(value))
            } else {
              null
            }
            i += 1
          }
          (keys, values)
        }
    }
  }

  override def eval(input: InternalRow): Any = {
    val result = child.eval(input)
    if (result != null) {
      val (keys, values) = mapCatalystConverter(result)
      new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
    } else {
      null
    }
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val inputMap = child.genCode(ctx)
    val genKeyConverter = keyConverter.genCode(ctx)
    val genValueConverter = valueConverter.genCode(ctx)
    val length = ctx.freshName("length")
    val index = ctx.freshName("index")
    val convertedKeys = ctx.freshName("convertedKeys")
    val convertedValues = ctx.freshName("convertedValues")
    val entry = ctx.freshName("entry")
    val entries = ctx.freshName("entries")

    val keyElementJavaType = CodeGenerator.javaType(keyType)
    val valueElementJavaType = CodeGenerator.javaType(valueType)
    ctx.addMutableState(keyElementJavaType, key, forceInline = true, useFreshName = false)
    ctx.addMutableState(valueElementJavaType, value, forceInline = true, useFreshName = false)

    val (defineEntries, defineKeyValue) = child.dataType match {
      case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) =>
        val javaIteratorCls = classOf[java.util.Iterator[_]].getName
        val javaMapEntryCls = classOf[java.util.Map.Entry[_, _]].getName

        val defineEntries =
          s"final $javaIteratorCls $entries = ${inputMap.value}.entrySet().iterator();"

        val defineKeyValue =
          s"""
            final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next();
            $key = (${CodeGenerator.boxedType(keyType)}) $entry.getKey();
            $value = (${CodeGenerator.boxedType(valueType)}) $entry.getValue();
          """

        defineEntries -> defineKeyValue

      case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) =>
        val scalaIteratorCls = classOf[Iterator[_]].getName
        val scalaMapEntryCls = classOf[Tuple2[_, _]].getName

        val defineEntries = s"final $scalaIteratorCls $entries = ${inputMap.value}.iterator();"

        val defineKeyValue =
          s"""
            final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next();
            $key = (${CodeGenerator.boxedType(keyType)}) $entry._1();
            $value = (${CodeGenerator.boxedType(valueType)}) $entry._2();
          """

        defineEntries -> defineKeyValue
    }

    val keyNullCheck = if (keyIsNull != "false") {
      ctx.addMutableState(
        CodeGenerator.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false)
      s"$keyIsNull = $key == null;"
    } else {
      ""
    }

    val valueNullCheck = if (valueIsNull != "false") {
      ctx.addMutableState(
        CodeGenerator.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false)
      s"$valueIsNull = $value == null;"
    } else {
      ""
    }

    val arrayCls = classOf[GenericArrayData].getName
    val mapCls = classOf[ArrayBasedMapData].getName
    val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType)
    val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType)
    val code = inputMap.code +
      code"""
        ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
        if (!${inputMap.isNull}) {
          final int $length = ${inputMap.value}.size();
          final Object[] $convertedKeys = new Object[$length];
          final Object[] $convertedValues = new Object[$length];
          int $index = 0;
          $defineEntries
          while($entries.hasNext()) {
            $defineKeyValue
            $keyNullCheck
            $valueNullCheck

            ${genKeyConverter.code}
            if (${genKeyConverter.isNull}) {
              throw new RuntimeException("Cannot use null as map key!");
            } else {
              $convertedKeys[$index] = ($convertedKeyType) ${genKeyConverter.value};
            }

            ${genValueConverter.code}
            if (${genValueConverter.isNull}) {
              $convertedValues[$index] = null;
            } else {
              $convertedValues[$index] = ($convertedValueType) ${genValueConverter.value};
            }

            $index++;
          }

          ${ev.value} = new $mapCls(new $arrayCls($convertedKeys), new $arrayCls($convertedValues));
        }
      """
    ev.copy(code = code, isNull = inputMap.isNull)
  }
}

/**
 * Constructs a new external row, using the result of evaluating the specified expressions
 * as content.
 *
 * @param children A list of expression to use as content of the external row.
 */
case class CreateExternalRow(children: Seq[Expression], schema: StructType)
  extends Expression with NonSQLExpression {

  override def dataType: DataType = ObjectType(classOf[Row])

  override def nullable: Boolean = false

  override def eval(input: InternalRow): Any = {
    val values = children.map(_.eval(input)).toArray
    new GenericRowWithSchema(values, schema)
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val rowClass = classOf[GenericRowWithSchema].getName
    val values = ctx.freshName("values")

    val childrenCodes = children.zipWithIndex.map { case (e, i) =>
      val eval = e.genCode(ctx)
      s"""
         |${eval.code}
         |if (${eval.isNull}) {
         |  $values[$i] = null;
         |} else {
         |  $values[$i] = ${eval.value};
         |}
       """.stripMargin
    }

    val childrenCode = ctx.splitExpressionsWithCurrentInputs(
      expressions = childrenCodes,
      funcName = "createExternalRow",
      extraArguments = "Object[]" -> values :: Nil)
    val schemaField = ctx.addReferenceObj("schema", schema)

    val code =
      code"""
         |Object[] $values = new Object[${children.size}];
         |$childrenCode
         |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField);
       """.stripMargin
    ev.copy(code = code, isNull = FalseLiteral)
  }
}

/**
 * Serializes an input object using a generic serializer (Kryo or Java).
 *
 * @param kryo if true, use Kryo. Otherwise, use Java.
 */
case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
  extends UnaryExpression with NonSQLExpression with SerializerSupport {

  override def nullSafeEval(input: Any): Any = {
    serializerInstance.serialize(input).array()
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val serializer = addImmutableSerializerIfNeeded(ctx)
    // Code to serialize.
    val input = child.genCode(ctx)
    val javaType = CodeGenerator.javaType(dataType)
    val serialize = s"$serializer.serialize(${input.value}, null).array()"

    val code = input.code + code"""
      final $javaType ${ev.value} =
        ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize;
     """
    ev.copy(code = code, isNull = input.isNull)
  }

  override def dataType: DataType = BinaryType
}

/**
 * Serializes an input object using a generic serializer (Kryo or Java).  Note that the ClassTag
 * is not an implicit parameter because TreeNode cannot copy implicit parameters.
 *
 * @param kryo if true, use Kryo. Otherwise, use Java.
 */
case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean)
  extends UnaryExpression with NonSQLExpression with SerializerSupport {

  override def nullSafeEval(input: Any): Any = {
    val inputBytes = java.nio.ByteBuffer.wrap(input.asInstanceOf[Array[Byte]])
    serializerInstance.deserialize(inputBytes)
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val serializer = addImmutableSerializerIfNeeded(ctx)
    // Code to deserialize.
    val input = child.genCode(ctx)
    val javaType = CodeGenerator.javaType(dataType)
    val deserialize =
      s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)"

    val code = input.code + code"""
      final $javaType ${ev.value} =
         ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize;
     """
    ev.copy(code = code, isNull = input.isNull)
  }

  override def dataType: DataType = ObjectType(tag.runtimeClass)
}

/**
 * Initialize a Java Bean instance by setting its field values via setters.
 */
case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Expression])
  extends Expression with NonSQLExpression {

  override def nullable: Boolean = beanInstance.nullable
  override def children: Seq[Expression] = beanInstance +: setters.values.toSeq
  override def dataType: DataType = beanInstance.dataType

  private lazy val resolvedSetters = {
    assert(beanInstance.dataType.isInstanceOf[ObjectType])

    val ObjectType(beanClass) = beanInstance.dataType
    setters.map {
      case (name, expr) =>
        // Looking for known type mapping.
        // But also looking for general `Object`-type parameter for generic methods.
        val paramTypes = ScalaReflection.expressionJavaClasses(Seq(expr)) ++ Seq(classOf[Object])
        val methods = paramTypes.flatMap { fieldClass =>
          try {
            Some(beanClass.getDeclaredMethod(name, fieldClass))
          } catch {
            case e: NoSuchMethodException => None
          }
        }
        if (methods.isEmpty) {
          throw new NoSuchMethodException(s"""A method named "$name" is not declared """ +
            "in any enclosing class nor any supertype")
        }
        methods.head -> expr
    }
  }

  override def eval(input: InternalRow): Any = {
    val instance = beanInstance.eval(input)
    if (instance != null) {
      val bean = instance.asInstanceOf[Object]
      resolvedSetters.foreach {
        case (setter, expr) =>
          val paramVal = expr.eval(input)
          // We don't call setter if input value is null.
          if (paramVal != null) {
            setter.invoke(bean, paramVal.asInstanceOf[AnyRef])
          }
      }
    }
    instance
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val instanceGen = beanInstance.genCode(ctx)

    val javaBeanInstance = ctx.freshName("javaBean")
    val beanInstanceJavaType = CodeGenerator.javaType(beanInstance.dataType)

    val initialize = setters.map {
      case (setterMethod, fieldValue) =>
        val fieldGen = fieldValue.genCode(ctx)
        s"""
           |${fieldGen.code}
           |if (!${fieldGen.isNull}) {
           |  $javaBeanInstance.$setterMethod(${fieldGen.value});
           |}
         """.stripMargin
    }
    val initializeCode = ctx.splitExpressionsWithCurrentInputs(
      expressions = initialize.toSeq,
      funcName = "initializeJavaBean",
      extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil)

    val code = instanceGen.code +
      code"""
         |$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value};
         |if (!${instanceGen.isNull}) {
         |  $initializeCode
         |}
       """.stripMargin
    ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value)
  }
}

/**
 * Asserts that input values of a non-nullable child expression are not null.
 *
 * Note that there are cases where `child.nullable == true`, while we still need to add this
 * assertion.  Consider a nullable column `s` whose data type is a struct containing a non-nullable
 * `Int` field named `i`.  Expression `s.i` is nullable because `s` can be null.  However, for all
 * non-null `s`, `s.i` can't be null.
 */
case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)
  extends UnaryExpression with NonSQLExpression {

  override def dataType: DataType = child.dataType
  override def foldable: Boolean = false
  override def nullable: Boolean = false

  override def flatArguments: Iterator[Any] = Iterator(child)

  private val errMsg = "Null value appeared in non-nullable field:" +
    walkedTypePath.mkString("\n", "\n", "\n") +
    "If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
    "please try to use scala.Option[_] or other nullable types " +
    "(e.g. java.lang.Integer instead of int/scala.Int)."

  override def eval(input: InternalRow): Any = {
    val result = child.eval(input)
    if (result == null) {
      throw new NullPointerException(errMsg)
    }
    result
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val childGen = child.genCode(ctx)

    // Use unnamed reference that doesn't create a local field here to reduce the number of fields
    // because errMsgField is used only when the value is null.
    val errMsgField = ctx.addReferenceObj("errMsg", errMsg)

    val code = childGen.code + code"""
      if (${childGen.isNull}) {
        throw new NullPointerException($errMsgField);
      }
     """
    ev.copy(code = code, isNull = FalseLiteral, value = childGen.value)
  }
}

/**
 * Returns the value of field at index `index` from the external row `child`.
 * This class can be viewed as [[GetStructField]] for [[Row]]s instead of [[InternalRow]]s.
 *
 * Note that the input row and the field we try to get are both guaranteed to be not null, if they
 * are null, a runtime exception will be thrown.
 */
case class GetExternalRowField(
    child: Expression,
    index: Int,
    fieldName: String) extends UnaryExpression with NonSQLExpression {

  override def nullable: Boolean = false

  override def dataType: DataType = ObjectType(classOf[Object])

  private val errMsg = s"The ${index}th field '$fieldName' of input row cannot be null."

  override def eval(input: InternalRow): Any = {
    val inputRow = child.eval(input).asInstanceOf[Row]
    if (inputRow == null) {
      throw new RuntimeException("The input external row cannot be null.")
    }
    if (inputRow.isNullAt(index)) {
      throw new RuntimeException(errMsg)
    }
    inputRow.get(index)
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    // Use unnamed reference that doesn't create a local field here to reduce the number of fields
    // because errMsgField is used only when the field is null.
    val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
    val row = child.genCode(ctx)
    val code = code"""
      ${row.code}

      if (${row.isNull}) {
        throw new RuntimeException("The input external row cannot be null.");
      }

      if (${row.value}.isNullAt($index)) {
        throw new RuntimeException($errMsgField);
      }

      final Object ${ev.value} = ${row.value}.get($index);
     """
    ev.copy(code = code, isNull = FalseLiteral)
  }
}

/**
 * Validates the actual data type of input expression at runtime.  If it doesn't match the
 * expectation, throw an exception.
 */
case class ValidateExternalType(child: Expression, expected: DataType)
  extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {

  override def inputTypes: Seq[AbstractDataType] = Seq(ObjectType(classOf[Object]))

  override def nullable: Boolean = child.nullable

  override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected)

  private val errMsg = s" is not a valid external type for schema of ${expected.catalogString}"

  private lazy val checkType: (Any) => Boolean = expected match {
    case _: DecimalType =>
      (value: Any) => {
        value.isInstanceOf[java.math.BigDecimal] || value.isInstanceOf[scala.math.BigDecimal] ||
          value.isInstanceOf[Decimal]
      }
    case _: ArrayType =>
      (value: Any) => {
        value.getClass.isArray || value.isInstanceOf[Seq[_]]
      }
    case _ =>
      val dataTypeClazz = ScalaReflection.javaBoxedType(dataType)
      (value: Any) => {
        dataTypeClazz.isInstance(value)
      }
  }

  override def eval(input: InternalRow): Any = {
    val result = child.eval(input)
    if (checkType(result)) {
      result
    } else {
      throw new RuntimeException(s"${result.getClass.getName}$errMsg")
    }
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    // Use unnamed reference that doesn't create a local field here to reduce the number of fields
    // because errMsgField is used only when the type doesn't match.
    val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
    val input = child.genCode(ctx)
    val obj = input.value

    val typeCheck = expected match {
      case _: DecimalType =>
        Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal])
          .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ")
      case _: ArrayType =>
        s"$obj.getClass().isArray() || $obj instanceof ${classOf[Seq[_]].getName}"
      case _ =>
        s"$obj instanceof ${CodeGenerator.boxedType(dataType)}"
    }

    val code = code"""
      ${input.code}
      ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
      if (!${input.isNull}) {
        if ($typeCheck) {
          ${ev.value} = (${CodeGenerator.boxedType(dataType)}) $obj;
        } else {
          throw new RuntimeException($obj.getClass().getName() + $errMsgField);
        }
      }

    """
    ev.copy(code = code, isNull = input.isNull)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy