All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.catalyst.DeserializerBuildHelper.scala Maven / Gradle / Ivy

There is a newer version: 3.5.3
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, MapKeys, MapValues, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.types._

object DeserializerBuildHelper {
  /** Returns the current path with a sub-field extracted. */
  def addToPath(
      path: Expression,
      part: String,
      dataType: DataType,
      walkedTypePath: WalkedTypePath): Expression = {
    val newPath = UnresolvedExtractValue(path, expressions.Literal(part))
    upCastToExpectedType(newPath, dataType, walkedTypePath)
  }

  /** Returns the current path with a field at ordinal extracted. */
  def addToPathOrdinal(
      path: Expression,
      ordinal: Int,
      dataType: DataType,
      walkedTypePath: WalkedTypePath): Expression = {
    val newPath = GetStructField(path, ordinal)
    upCastToExpectedType(newPath, dataType, walkedTypePath)
  }

  def deserializerForWithNullSafetyAndUpcast(
      expr: Expression,
      dataType: DataType,
      nullable: Boolean,
      walkedTypePath: WalkedTypePath,
      funcForCreatingDeserializer: Expression => Expression): Expression = {
    val casted = upCastToExpectedType(expr, dataType, walkedTypePath)
    expressionWithNullSafety(funcForCreatingDeserializer(casted), nullable, walkedTypePath)
  }

  def expressionWithNullSafety(
      expr: Expression,
      nullable: Boolean,
      walkedTypePath: WalkedTypePath): Expression = {
    if (nullable) {
      expr
    } else {
      AssertNotNull(expr, walkedTypePath.getPaths)
    }
  }

  def createDeserializerForTypesSupportValueOf(
      path: Expression,
      clazz: Class[_]): Expression = {
    StaticInvoke(
      clazz,
      ObjectType(clazz),
      "valueOf",
      path :: Nil,
      returnNullable = false)
  }

  def createDeserializerForString(path: Expression, returnNullable: Boolean): Expression = {
    Invoke(path, "toString", ObjectType(classOf[java.lang.String]),
      returnNullable = returnNullable)
  }

  def createDeserializerForSqlDate(path: Expression): Expression = {
    StaticInvoke(
      DateTimeUtils.getClass,
      ObjectType(classOf[java.sql.Date]),
      "toJavaDate",
      path :: Nil,
      returnNullable = false)
  }

  def createDeserializerForLocalDate(path: Expression): Expression = {
    StaticInvoke(
      DateTimeUtils.getClass,
      ObjectType(classOf[java.time.LocalDate]),
      "daysToLocalDate",
      path :: Nil,
      returnNullable = false)
  }

  def createDeserializerForInstant(path: Expression): Expression = {
    StaticInvoke(
      DateTimeUtils.getClass,
      ObjectType(classOf[java.time.Instant]),
      "microsToInstant",
      path :: Nil,
      returnNullable = false)
  }

  def createDeserializerForSqlTimestamp(path: Expression): Expression = {
    StaticInvoke(
      DateTimeUtils.getClass,
      ObjectType(classOf[java.sql.Timestamp]),
      "toJavaTimestamp",
      path :: Nil,
      returnNullable = false)
  }

  def createDeserializerForLocalDateTime(path: Expression): Expression = {
    StaticInvoke(
      DateTimeUtils.getClass,
      ObjectType(classOf[java.time.LocalDateTime]),
      "microsToLocalDateTime",
      path :: Nil,
      returnNullable = false)
  }

  def createDeserializerForJavaBigDecimal(
      path: Expression,
      returnNullable: Boolean): Expression = {
    Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
      returnNullable = returnNullable)
  }

  def createDeserializerForScalaBigDecimal(
      path: Expression,
      returnNullable: Boolean): Expression = {
    Invoke(path, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = returnNullable)
  }

  def createDeserializerForJavaBigInteger(
      path: Expression,
      returnNullable: Boolean): Expression = {
    Invoke(path, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]),
      returnNullable = returnNullable)
  }

  def createDeserializerForScalaBigInt(path: Expression): Expression = {
    Invoke(path, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]),
      returnNullable = false)
  }

  def createDeserializerForDuration(path: Expression): Expression = {
    StaticInvoke(
      IntervalUtils.getClass,
      ObjectType(classOf[java.time.Duration]),
      "microsToDuration",
      path :: Nil,
      returnNullable = false)
  }

  def createDeserializerForPeriod(path: Expression): Expression = {
    StaticInvoke(
      IntervalUtils.getClass,
      ObjectType(classOf[java.time.Period]),
      "monthsToPeriod",
      path :: Nil,
      returnNullable = false)
  }

  /**
   * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
   * and lost the required data type, which may lead to runtime error if the real type doesn't
   * match the encoder's schema.
   * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type
   * is [a: int, b: long], then we will hit runtime error and say that we can't construct class
   * `Data` with int and long, because we lost the information that `b` should be a string.
   *
   * This method help us "remember" the required data type by adding a `UpCast`. Note that we
   * only need to do this for leaf nodes.
   */
  private[catalyst] def upCastToExpectedType(
      expr: Expression,
      expected: DataType,
      walkedTypePath: WalkedTypePath): Expression = expected match {
    case _: StructType => expr
    case _: ArrayType => expr
    case _: MapType => expr
    case _: DecimalType =>
      // For Scala/Java `BigDecimal`, we accept decimal types of any valid precision/scale.
      // Here we use the `DecimalType` object to indicate it.
      UpCast(expr, DecimalType, walkedTypePath.getPaths)
    case _ => UpCast(expr, expected, walkedTypePath.getPaths)
  }

  /**
   * Returns an expression for deserializing the Spark SQL representation of an object into its
   * external form. The mapping between the internal and external representations is
   * described by encoder `enc`. The Spark SQL representation is located at ordinal 0 of
   * a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed using
   * `UnresolvedExtractValue`.
   *
   * The returned expression is used by `ExpressionEncoder`. The encoder will resolve and bind this
   * deserializer expression when using it.
   *
   * @param enc encoder that describes the mapping between the Spark SQL representation and the
   *            external representation.
   */
  def createDeserializer[T](enc: AgnosticEncoder[T]): Expression = {
    val walkedTypePath = WalkedTypePath().recordRoot(enc.clsTag.runtimeClass.getName)
    // Assumes we are deserializing the first column of a row.
    val input = GetColumnByOrdinal(0, enc.dataType)
    enc match {
      case AgnosticEncoders.RowEncoder(fields) =>
        val children = fields.zipWithIndex.map { case (f, i) =>
          createDeserializer(f.enc, GetStructField(input, i), walkedTypePath)
        }
        CreateExternalRow(children, enc.schema)
      case _ =>
        val deserializer = createDeserializer(
          enc,
          upCastToExpectedType(input, enc.dataType, walkedTypePath),
          walkedTypePath)
        expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath)
    }
  }

  /**
   * Returns an expression for deserializing the value of an input expression into its external
   * representation. The mapping between the internal and external representations is
   * described by encoder `enc`.
   *
   * @param enc encoder that describes the mapping between the Spark SQL representation and the
   *            external representation.
   * @param path The expression which can be used to extract serialized value.
   * @param walkedTypePath The paths from top to bottom to access current field when deserializing.
   */
  private def createDeserializer(
      enc: AgnosticEncoder[_],
      path: Expression,
      walkedTypePath: WalkedTypePath): Expression = enc match {
    case _ if isNativeEncoder(enc) =>
      path
    case _: BoxedLeafEncoder[_, _] =>
      createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
    case JavaEnumEncoder(tag) =>
      val toString = createDeserializerForString(path, returnNullable = false)
      createDeserializerForTypesSupportValueOf(toString, tag.runtimeClass)
    case ScalaEnumEncoder(parent, tag) =>
      StaticInvoke(
        parent,
        ObjectType(tag.runtimeClass),
        "withName",
        createDeserializerForString(path, returnNullable = false) :: Nil,
        returnNullable = false)
    case StringEncoder =>
      createDeserializerForString(path, returnNullable = false)
    case _: ScalaDecimalEncoder =>
      createDeserializerForScalaBigDecimal(path, returnNullable = false)
    case _: JavaDecimalEncoder =>
      createDeserializerForJavaBigDecimal(path, returnNullable = false)
    case ScalaBigIntEncoder =>
      createDeserializerForScalaBigInt(path)
    case JavaBigIntEncoder =>
      createDeserializerForJavaBigInteger(path, returnNullable = false)
    case DayTimeIntervalEncoder =>
      createDeserializerForDuration(path)
    case YearMonthIntervalEncoder =>
      createDeserializerForPeriod(path)
    case _: DateEncoder =>
      createDeserializerForSqlDate(path)
    case _: LocalDateEncoder =>
      createDeserializerForLocalDate(path)
    case _: TimestampEncoder =>
      createDeserializerForSqlTimestamp(path)
    case _: InstantEncoder =>
      createDeserializerForInstant(path)
    case LocalDateTimeEncoder =>
      createDeserializerForLocalDateTime(path)
    case UDTEncoder(udt, udtClass) =>
      val obj = NewInstance(udtClass, Nil, ObjectType(udtClass))
      Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
    case OptionEncoder(valueEnc) =>
      val newTypePath = walkedTypePath.recordOption(valueEnc.clsTag.runtimeClass.getName)
      val deserializer = createDeserializer(valueEnc, path, newTypePath)
      WrapOption(deserializer, externalDataTypeFor(valueEnc))

    case ArrayEncoder(elementEnc: AgnosticEncoder[_], containsNull) =>
      Invoke(
        deserializeArray(
          path,
          elementEnc,
          containsNull,
          None,
          walkedTypePath),
        toArrayMethodName(elementEnc),
        ObjectType(enc.clsTag.runtimeClass),
        returnNullable = false)

    case IterableEncoder(clsTag, elementEnc, containsNull, _) =>
      deserializeArray(
        path,
        elementEnc,
        containsNull,
        Option(clsTag.runtimeClass),
        walkedTypePath)

    case MapEncoder(tag, keyEncoder, valueEncoder, _)
      if classOf[java.util.Map[_, _]].isAssignableFrom(tag.runtimeClass) =>
      // TODO (hvanhovell) this is can be improved.
      val newTypePath = walkedTypePath.recordMap(
        keyEncoder.clsTag.runtimeClass.getName,
        valueEncoder.clsTag.runtimeClass.getName)

      val keyData =
        Invoke(
          UnresolvedMapObjects(
            p => createDeserializer(keyEncoder, p, newTypePath),
            MapKeys(path)),
          "array",
          ObjectType(classOf[Array[Any]]))

      val valueData =
        Invoke(
          UnresolvedMapObjects(
            p => createDeserializer(valueEncoder, p, newTypePath),
            MapValues(path)),
          "array",
          ObjectType(classOf[Array[Any]]))

      StaticInvoke(
        ArrayBasedMapData.getClass,
        ObjectType(classOf[java.util.Map[_, _]]),
        "toJavaMap",
        keyData :: valueData :: Nil,
        returnNullable = false)

    case MapEncoder(tag, keyEncoder, valueEncoder, _) =>
      val newTypePath = walkedTypePath.recordMap(
        keyEncoder.clsTag.runtimeClass.getName,
        valueEncoder.clsTag.runtimeClass.getName)
      UnresolvedCatalystToExternalMap(
        path,
        createDeserializer(keyEncoder, _, newTypePath),
        createDeserializer(valueEncoder, _, newTypePath),
        tag.runtimeClass)

    case ProductEncoder(tag, fields, outerPointerGetter) =>
      val cls = tag.runtimeClass
      val dt = ObjectType(cls)
      val isTuple = cls.getName.startsWith("scala.Tuple")
      val arguments = fields.zipWithIndex.map {
        case (field, i) =>
          val newTypePath = walkedTypePath.recordField(
            field.enc.clsTag.runtimeClass.getName,
            field.name)
          // For tuples, we grab the inner fields by ordinal instead of name.
          val getter = if (isTuple) {
            addToPathOrdinal(path, i, field.enc.dataType, newTypePath)
          } else {
            addToPath(path, field.name, field.enc.dataType, newTypePath)
          }
          expressionWithNullSafety(
            createDeserializer(field.enc, getter, newTypePath),
            field.enc.nullable,
            newTypePath)
      }
      exprs.If(
        IsNull(path),
        exprs.Literal.create(null, dt),
        NewInstance(cls, arguments, Nil, propagateNull = false, dt, outerPointerGetter))

    case AgnosticEncoders.RowEncoder(fields) =>
      val convertedFields = fields.zipWithIndex.map { case (f, i) =>
        val newTypePath = walkedTypePath.recordField(
          f.enc.clsTag.runtimeClass.getName,
          f.name)
        exprs.If(
          Invoke(path, "isNullAt", BooleanType, exprs.Literal(i) :: Nil),
          exprs.Literal.create(null, externalDataTypeFor(f.enc)),
          createDeserializer(f.enc, GetStructField(path, i), newTypePath))
      }
      exprs.If(IsNull(path),
        exprs.Literal.create(null, externalDataTypeFor(enc)),
        CreateExternalRow(convertedFields, enc.schema))

    case JavaBeanEncoder(tag, fields) =>
      val setters = fields
        .filter(_.writeMethod.isDefined)
        .map { f =>
        val newTypePath = walkedTypePath.recordField(
          f.enc.clsTag.runtimeClass.getName,
          f.name)
        val setter = expressionWithNullSafety(
          createDeserializer(
            f.enc,
            addToPath(path, f.name, f.enc.dataType, newTypePath),
            newTypePath),
          nullable = f.nullable,
          newTypePath)
        f.writeMethod.get -> setter
      }

      val cls = tag.runtimeClass
      val newInstance = NewInstance(cls, Nil, ObjectType(cls), propagateNull = false)
      val result = InitializeJavaBean(newInstance, setters.toMap)
      exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result)
  }

  private def deserializeArray(
      path: Expression,
      elementEnc: AgnosticEncoder[_],
      containsNull: Boolean,
      cls: Option[Class[_]],
      walkedTypePath: WalkedTypePath): Expression = {
    val newTypePath = walkedTypePath.recordArray(elementEnc.clsTag.runtimeClass.getName)
    val mapFunction: Expression => Expression = element => {
      // upcast the array element to the data type the encoder expects.
      deserializerForWithNullSafetyAndUpcast(
        element,
        elementEnc.dataType,
        nullable = containsNull,
        newTypePath,
        createDeserializer(elementEnc, _, newTypePath))
    }
    UnresolvedMapObjects(mapFunction, path, cls)
  }

  private def toArrayMethodName(enc: AgnosticEncoder[_]): String = enc match {
    case PrimitiveBooleanEncoder => "toBooleanArray"
    case PrimitiveByteEncoder => "toByteArray"
    case PrimitiveShortEncoder => "toShortArray"
    case PrimitiveIntEncoder => "toIntArray"
    case PrimitiveLongEncoder => "toLongArray"
    case PrimitiveFloatEncoder => "toFloatArray"
    case PrimitiveDoubleEncoder => "toDoubleArray"
    case _ => "array"
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy