All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.catalyst.SerializerBuildHelper.scala Maven / Gradle / Ivy

There is a newer version: 3.5.3
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst

import scala.language.existentials

import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, IntervalUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

object SerializerBuildHelper {

  private def nullOnOverflow: Boolean = !SQLConf.get.ansiEnabled

  def createSerializerForBoolean(inputObject: Expression): Expression = {
    Invoke(inputObject, "booleanValue", BooleanType)
  }

  def createSerializerForByte(inputObject: Expression): Expression = {
    Invoke(inputObject, "byteValue", ByteType)
  }

  def createSerializerForShort(inputObject: Expression): Expression = {
    Invoke(inputObject, "shortValue", ShortType)
  }

  def createSerializerForInteger(inputObject: Expression): Expression = {
    Invoke(inputObject, "intValue", IntegerType)
  }

  def createSerializerForLong(inputObject: Expression): Expression = {
    Invoke(inputObject, "longValue", LongType)
  }

  def createSerializerForFloat(inputObject: Expression): Expression = {
    Invoke(inputObject, "floatValue", FloatType)
  }

  def createSerializerForDouble(inputObject: Expression): Expression = {
    Invoke(inputObject, "doubleValue", DoubleType)
  }

  def createSerializerForString(inputObject: Expression): Expression = {
    StaticInvoke(
      classOf[UTF8String],
      StringType,
      "fromString",
      inputObject :: Nil,
      returnNullable = false)
  }

  def createSerializerForJavaInstant(inputObject: Expression): Expression = {
    StaticInvoke(
      DateTimeUtils.getClass,
      TimestampType,
      "instantToMicros",
      inputObject :: Nil,
      returnNullable = false)
  }

  def createSerializerForScalaEnum(inputObject: Expression): Expression = {
    createSerializerForString(
      Invoke(
        inputObject,
        "toString",
        ObjectType(classOf[String]),
        returnNullable = false))
  }

  def createSerializerForJavaEnum(inputObject: Expression): Expression =
    createSerializerForString(Invoke(inputObject, "name", ObjectType(classOf[String])))

  def createSerializerForSqlTimestamp(inputObject: Expression): Expression = {
    StaticInvoke(
      DateTimeUtils.getClass,
      TimestampType,
      "fromJavaTimestamp",
      inputObject :: Nil,
      returnNullable = false)
  }

  def createSerializerForAnyTimestamp(inputObject: Expression): Expression = {
    StaticInvoke(
      DateTimeUtils.getClass,
      TimestampType,
      "anyToMicros",
      inputObject :: Nil,
      returnNullable = false)
  }

  def createSerializerForLocalDateTime(inputObject: Expression): Expression = {
    StaticInvoke(
      DateTimeUtils.getClass,
      TimestampNTZType,
      "localDateTimeToMicros",
      inputObject :: Nil,
      returnNullable = false)
  }

  def createSerializerForJavaLocalDate(inputObject: Expression): Expression = {
    StaticInvoke(
      DateTimeUtils.getClass,
      DateType,
      "localDateToDays",
      inputObject :: Nil,
      returnNullable = false)
  }

  def createSerializerForSqlDate(inputObject: Expression): Expression = {
    StaticInvoke(
      DateTimeUtils.getClass,
      DateType,
      "fromJavaDate",
      inputObject :: Nil,
      returnNullable = false)
  }

  def createSerializerForAnyDate(inputObject: Expression): Expression = {
    StaticInvoke(
      DateTimeUtils.getClass,
      DateType,
      "anyToDays",
      inputObject :: Nil,
      returnNullable = false)
  }

  def createSerializerForJavaDuration(inputObject: Expression): Expression = {
    StaticInvoke(
      IntervalUtils.getClass,
      DayTimeIntervalType(),
      "durationToMicros",
      inputObject :: Nil,
      returnNullable = false)
  }

  def createSerializerForJavaPeriod(inputObject: Expression): Expression = {
    StaticInvoke(
      IntervalUtils.getClass,
      YearMonthIntervalType(),
      "periodToMonths",
      inputObject :: Nil,
      returnNullable = false)
  }

  def createSerializerForBigDecimal(inputObject: Expression): Expression = {
    createSerializerForBigDecimal(inputObject, DecimalType.SYSTEM_DEFAULT)
  }

  def createSerializerForBigDecimal(inputObject: Expression, dt: DecimalType): Expression = {
    CheckOverflow(StaticInvoke(
      Decimal.getClass,
      dt,
      "apply",
      inputObject :: Nil,
      returnNullable = false), dt, nullOnOverflow)
  }

  def createSerializerForAnyDecimal(inputObject: Expression, dt: DecimalType): Expression = {
    CheckOverflow(StaticInvoke(
      Decimal.getClass,
      dt,
      "fromDecimal",
      inputObject :: Nil,
      returnNullable = false), dt, nullOnOverflow)
  }

  def createSerializerForBigInteger(inputObject: Expression): Expression = {
    CheckOverflow(StaticInvoke(
      Decimal.getClass,
      DecimalType.BigIntDecimal,
      "apply",
      inputObject :: Nil,
      returnNullable = false), DecimalType.BigIntDecimal, nullOnOverflow)
  }

  def createSerializerForPrimitiveArray(
      inputObject: Expression,
      dataType: DataType): Expression = {
    StaticInvoke(
      classOf[UnsafeArrayData],
      ArrayType(dataType, containsNull = false),
      "fromPrimitiveArray",
      inputObject :: Nil,
      returnNullable = false)
  }

  def createSerializerForGenericArray(
      inputObject: Expression,
      dataType: DataType,
      nullable: Boolean): Expression = {
    NewInstance(
      classOf[GenericArrayData],
      inputObject :: Nil,
      dataType = ArrayType(dataType, nullable))
  }

  def createSerializerForMapObjects(
      inputObject: Expression,
      dataType: ObjectType,
      funcForNewExpr: Expression => Expression): Expression = {
    MapObjects(funcForNewExpr, inputObject, dataType)
  }

  case class MapElementInformation(
      dataType: DataType,
      nullable: Boolean,
      funcForNewExpr: Expression => Expression)

  def createSerializerForMap(
      inputObject: Expression,
      keyInformation: MapElementInformation,
      valueInformation: MapElementInformation): Expression = {
    ExternalMapToCatalyst(
      inputObject,
      keyInformation.dataType,
      keyInformation.funcForNewExpr,
      keyNullable = keyInformation.nullable,
      valueInformation.dataType,
      valueInformation.funcForNewExpr,
      valueNullable = valueInformation.nullable
    )
  }

  private def argumentsForFieldSerializer(
      fieldName: String,
      serializerForFieldValue: Expression): Seq[Expression] = {
    expressions.Literal(fieldName) :: serializerForFieldValue :: Nil
  }

  def createSerializerForObject(
      inputObject: Expression,
      fields: Seq[(String, Expression)]): Expression = {
    val nonNullOutput = CreateNamedStruct(fields.flatMap { case(fieldName, fieldExpr) =>
      argumentsForFieldSerializer(fieldName, fieldExpr)
    })
    if (inputObject.nullable) {
      val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
      expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
    } else {
      nonNullOutput
    }
  }

  def createSerializerForUserDefinedType(
      inputObject: Expression,
      udt: UserDefinedType[_],
      udtClass: Class[_]): Expression = {
    val obj = NewInstance(udtClass, Nil, dataType = ObjectType(udtClass))
    Invoke(obj, "serialize", udt, inputObject :: Nil)
  }

  /**
   * Returns an expression for serializing an object into its Spark SQL form. The mapping
   * between the external and internal representations is described by encoder `enc`. The
   * input object is located at ordinal 0 of a row, i.e., `BoundReference(0, _)`.
   */
  def createSerializer(enc: AgnosticEncoder[_]): Expression = {
    val input = BoundReference(0, lenientExternalDataTypeFor(enc), nullable = enc.nullable)
    createSerializer(enc, input)
  }

  /**
   * Returns an expression for serializing the value of an input expression into its Spark SQL
   * representation. The mapping between the external and internal representations is described
   * by encoder `enc`.
   */
  private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression = enc match {
    case _ if isNativeEncoder(enc) => input
    case BoxedBooleanEncoder => createSerializerForBoolean(input)
    case BoxedByteEncoder => createSerializerForByte(input)
    case BoxedShortEncoder => createSerializerForShort(input)
    case BoxedIntEncoder => createSerializerForInteger(input)
    case BoxedLongEncoder => createSerializerForLong(input)
    case BoxedFloatEncoder => createSerializerForFloat(input)
    case BoxedDoubleEncoder => createSerializerForDouble(input)
    case JavaEnumEncoder(_) => createSerializerForJavaEnum(input)
    case ScalaEnumEncoder(_, _) => createSerializerForScalaEnum(input)
    case StringEncoder => createSerializerForString(input)
    case ScalaDecimalEncoder(dt) => createSerializerForBigDecimal(input, dt)
    case JavaDecimalEncoder(dt, false) => createSerializerForBigDecimal(input, dt)
    case JavaDecimalEncoder(dt, true) => createSerializerForAnyDecimal(input, dt)
    case ScalaBigIntEncoder => createSerializerForBigInteger(input)
    case JavaBigIntEncoder => createSerializerForBigInteger(input)
    case DayTimeIntervalEncoder => createSerializerForJavaDuration(input)
    case YearMonthIntervalEncoder => createSerializerForJavaPeriod(input)
    case DateEncoder(true) | LocalDateEncoder(true) => createSerializerForAnyDate(input)
    case DateEncoder(false) => createSerializerForSqlDate(input)
    case LocalDateEncoder(false) => createSerializerForJavaLocalDate(input)
    case TimestampEncoder(true) | InstantEncoder(true) => createSerializerForAnyTimestamp(input)
    case TimestampEncoder(false) => createSerializerForSqlTimestamp(input)
    case InstantEncoder(false) => createSerializerForJavaInstant(input)
    case LocalDateTimeEncoder => createSerializerForLocalDateTime(input)
    case UDTEncoder(udt, udtClass) => createSerializerForUserDefinedType(input, udt, udtClass)
    case OptionEncoder(valueEnc) =>
      createSerializer(valueEnc, UnwrapOption(externalDataTypeFor(valueEnc), input))

    case ArrayEncoder(elementEncoder, containsNull) =>
      if (elementEncoder.isPrimitive) {
        createSerializerForPrimitiveArray(input, elementEncoder.dataType)
      } else {
        serializerForArray(elementEncoder, containsNull, input, lenientSerialization = false)
      }

    case IterableEncoder(ctag, elementEncoder, containsNull, lenientSerialization) =>
      val getter = if (classOf[scala.collection.Set[_]].isAssignableFrom(ctag.runtimeClass)) {
        // There's no corresponding Catalyst type for `Set`, we serialize a `Set` to Catalyst array.
        // Note that the property of `Set` is only kept when manipulating the data as domain object.
        Invoke(input, "toSeq", ObjectType(classOf[scala.collection.Seq[_]]))
      } else {
        input
      }
      serializerForArray(elementEncoder, containsNull, getter, lenientSerialization)

    case MapEncoder(_, keyEncoder, valueEncoder, valueContainsNull) =>
      createSerializerForMap(
        input,
        MapElementInformation(
          ObjectType(classOf[AnyRef]),
          nullable = keyEncoder.nullable,
          validateAndSerializeElement(keyEncoder, keyEncoder.nullable)),
        MapElementInformation(
          ObjectType(classOf[AnyRef]),
          nullable = valueContainsNull,
          validateAndSerializeElement(valueEncoder, valueContainsNull))
      )

    case ProductEncoder(_, fields, _) =>
      val serializedFields = fields.map { field =>
        // SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul
        // is necessary here. Because for a nullable nested inputObject with struct data
        // type, e.g. StructType(IntegerType, StringType), it will return nullable=true
        // for IntegerType without KnownNotNull. And that's what we do not expect to.
        val getter = Invoke(
          KnownNotNull(input),
          field.name,
          externalDataTypeFor(field.enc),
          returnNullable = field.nullable)
        field.name -> createSerializer(field.enc, getter)
      }
      createSerializerForObject(input, serializedFields)

    case AgnosticEncoders.RowEncoder(fields) =>
      val serializedFields = fields.zipWithIndex.map { case (field, index) =>
        val fieldValue = createSerializer(
          field.enc,
          ValidateExternalType(
            GetExternalRowField(input, index, field.name),
            field.enc.dataType,
            lenientExternalDataTypeFor(field.enc)))

        val convertedField = if (field.nullable) {
          exprs.If(
            Invoke(input, "isNullAt", BooleanType, exprs.Literal(index) :: Nil),
            // Because we strip UDTs, `field.dataType` can be different from `fieldValue.dataType`.
            // We should use `fieldValue.dataType` here.
            exprs.Literal.create(null, fieldValue.dataType),
            fieldValue
          )
        } else {
          AssertNotNull(fieldValue)
        }
        field.name -> convertedField
      }
      createSerializerForObject(input, serializedFields)

    case JavaBeanEncoder(_, fields) =>
      val serializedFields = fields.map { f =>
        val fieldValue = Invoke(
          KnownNotNull(input),
          f.readMethod.get,
          externalDataTypeFor(f.enc),
          propagateNull = f.nullable,
          returnNullable = f.nullable)
        f.name -> createSerializer(f.enc, fieldValue)
      }
      createSerializerForObject(input, serializedFields)
  }

  private def serializerForArray(
      elementEnc: AgnosticEncoder[_],
      elementNullable: Boolean,
      input: Expression,
      lenientSerialization: Boolean): Expression = {
    // Default serializer for Seq and generic Arrays. This does not work for primitive arrays.
    val genericSerializer = createSerializerForMapObjects(
      input,
      ObjectType(classOf[AnyRef]),
      validateAndSerializeElement(elementEnc, elementNullable))

    // Check if it is possible the user can pass a primitive array. This is the only case when it
    // is safe to directly convert to an array (for generic arrays and Seqs the type and the
    // nullability can be violated). If the user has passed a primitive array we create a special
    // code path to deal with these.
    val primitiveEncoderOption = elementEnc match {
      case _ if !lenientSerialization => None
      case enc: PrimitiveLeafEncoder[_] => Option(enc)
      case enc: BoxedLeafEncoder[_, _] => Option(enc.primitive)
      case _ => None
    }
    primitiveEncoderOption match {
      case Some(primitiveEncoder) =>
        val primitiveArrayClass = primitiveEncoder.clsTag.wrap.runtimeClass
        val check = Invoke(
          targetObject = exprs.Literal.fromObject(primitiveArrayClass),
          functionName = "isInstance",
          BooleanType,
          arguments = input :: Nil,
          propagateNull = false,
          returnNullable = false)
        exprs.If(
          check,
          // TODO replace this with `createSerializerForPrimitiveArray` as
          //  soon as Cast support ObjectType casts.
          StaticInvoke(
            classOf[ArrayData],
            ArrayType(elementEnc.dataType, containsNull = false),
            "toArrayData",
            input :: Nil,
            propagateNull = false,
            returnNullable = false),
          genericSerializer)
      case None =>
        genericSerializer
    }
  }

  private def validateAndSerializeElement(
      enc: AgnosticEncoder[_],
      nullable: Boolean): Expression => Expression = { input =>
    val expected = enc match {
      case OptionEncoder(_) => lenientExternalDataTypeFor(enc)
      case _ => enc.dataType
    }

    expressionWithNullSafety(
      createSerializer(
        enc,
        ValidateExternalType(input, expected, lenientExternalDataTypeFor(enc))),
      nullable,
      WalkedTypePath())
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy