All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scala3encoders.derivation.Deserializer.scala Maven / Gradle / Ivy

There is a newer version: 0.2.6
Show newest version
package scala3encoders.derivation

import scala.compiletime.{constValue, erasedValue, summonInline}
import scala.deriving.Mirror
import scala.reflect.{ClassTag, Enum}
import org.apache.spark.sql.catalyst.expressions.{Expression, If, IsNull, Literal}
import org.apache.spark.sql.catalyst.DeserializerBuildHelper.*
import org.apache.spark.sql.catalyst.WalkedTypePath
import org.apache.spark.sql.catalyst.expressions.objects.*
import org.apache.spark.sql.helper.Helper
import org.apache.spark.sql.types.*
import scala.concurrent.duration.FiniteDuration
import scala.jdk.javaapi.DurationConverters

trait Deserializer[T]:
  def inputType: DataType
  def deserialize(path: Expression): Expression
  def nullable: Boolean = true

object Deserializer:
  // See DeserializerBuildHelper.createDeserializer
  inline given deriveOpt[T](using
      d: Deserializer[T],
      ct: ClassTag[T]
  ): Deserializer[Option[T]] =
    new Deserializer[Option[T]]:
      override def inputType: DataType = d.inputType
      override def deserialize(path: Expression): Expression =
        val tpe = Helper.typeBoxedJavaMapping.getOrElse(ct.runtimeClass, ct.runtimeClass)
        WrapOption(d.deserialize(path), ObjectType(tpe))

  given Deserializer[Int] with
    def inputType: DataType = IntegerType
    def deserialize(path: Expression): Expression =
      createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Integer])
    override def nullable: Boolean = false

  given Deserializer[java.lang.Integer] with
    def inputType: DataType = IntegerType
    def deserialize(path: Expression): Expression =
      createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Integer])
    override def nullable: Boolean = false

  given Deserializer[Long] with
    def inputType: DataType = LongType
    def deserialize(path: Expression): Expression =
      createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Long])
    override def nullable: Boolean = false

  given Deserializer[Double] with
    def inputType: DataType = DoubleType
    def deserialize(path: Expression): Expression =
      createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Double])
    override def nullable: Boolean = false

  given given_Deserializer_JavaDouble: Deserializer[java.lang.Double] with
    def inputType: DataType = DoubleType
    def deserialize(path: Expression): Expression =
      createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Double])

  given Deserializer[Float] with
    def inputType: DataType = FloatType
    def deserialize(path: Expression): Expression =
      createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Float])
    override def nullable: Boolean = false

  given Deserializer[Short] with
    def inputType: DataType = ShortType
    def deserialize(path: Expression): Expression =
      createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Short])
    override def nullable: Boolean = false

  given Deserializer[Byte] with
    def inputType: DataType = ByteType
    def deserialize(path: Expression): Expression =
      createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Byte])
    override def nullable: Boolean = false

  given Deserializer[Boolean] with
    def inputType: DataType = BooleanType
    def deserialize(path: Expression): Expression =
      createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Boolean])
    override def nullable: Boolean = false

  given Deserializer[java.time.LocalDate] with
    def inputType: DataType = DateType
    def deserialize(path: Expression): Expression =
      createDeserializerForLocalDate(path)

  given Deserializer[java.sql.Date] with
    def inputType: DataType = DateType
    def deserialize(path: Expression): Expression =
      createDeserializerForSqlDate(path)

  given Deserializer[java.time.Instant] with
    def inputType: DataType = TimestampType
    def deserialize(path: Expression): Expression =
      createDeserializerForInstant(path)

  given Deserializer[java.sql.Timestamp] with
    def inputType: DataType = TimestampType
    def deserialize(path: Expression): Expression =
      createDeserializerForSqlTimestamp(path)

  given Deserializer[java.time.Duration] with
    def inputType: DataType = DayTimeIntervalType()
    def deserialize(path: Expression): Expression =
      createDeserializerForDuration(path)

  given Deserializer[FiniteDuration] with
    def inputType: DataType = DayTimeIntervalType()
    def deserialize(path: Expression): Expression =
      val javaDuration = summon[Deserializer[java.time.Duration]].deserialize(path)
      StaticInvoke(
        DurationConverters.getClass,
        ObjectType(classOf[FiniteDuration]),
        "toScala",
        javaDuration :: Nil,
        returnNullable = false
      )

  given Deserializer[java.time.Period] with
    def inputType: DataType = YearMonthIntervalType()
    def deserialize(path: Expression): Expression =
      createDeserializerForPeriod(path)

  given Deserializer[String] with
    def inputType: DataType = StringType
    def deserialize(path: Expression): Expression =
      createDeserializerForString(path, false)

  given Deserializer[BigDecimal] with
    def inputType: DataType =
      DecimalType.SYSTEM_DEFAULT
    def deserialize(path: Expression): Expression =
      createDeserializerForJavaBigDecimal(path, returnNullable = false)

  given Deserializer[java.math.BigInteger] with
    def inputType: DataType =
      DecimalType(38, 0) // .BigIntDecimal is private
    def deserialize(path: Expression): Expression =
      createDeserializerForJavaBigInteger(path, returnNullable = false)

  given Deserializer[scala.math.BigInt] with
    def inputType: DataType =
      DecimalType(38, 0) // .BigIntDecimal is private
    def deserialize(path: Expression): Expression =
      createDeserializerForScalaBigInt(path)

  given[E <: Enum : ClassTag]: Deserializer[E] with
    def inputType: DataType = StringType

    def deserialize(path: Expression): Expression =
      val string = summon[Deserializer[String]].deserialize(path)
      StaticInvoke(
        summon[ClassTag[E]].runtimeClass,
        ObjectType(summon[ClassTag[E]].runtimeClass),
        "valueOf",
        string :: Nil,
        returnNullable = false
      )

  inline given deriveArray[T](using
      d: Deserializer[T],
      ct: ClassTag[T]
  ): Deserializer[Array[T]] =
    // TODO: nullable. walked
    new Deserializer[Array[T]]:
      override def inputType: DataType = ArrayType(d.inputType)
      override def deserialize(path: Expression): Expression =
        val mapFunction: Expression => Expression = el =>
          Helper.deserializerForWithNullSafetyAndUpcast(
            el,
            d.inputType,
            true,
            WalkedTypePath(Nil),
            d.deserialize
          )
        val arrayClass = ObjectType(ct.newArray(0).getClass)
        val arrayData = UnresolvedMapObjects(mapFunction, path)

        val methodName = d.inputType match
          case IntegerType => "toIntArray"
          case LongType    => "toLongArray"
          case DoubleType  => "toDoubleArray"
          case FloatType   => "toFloatArray"
          case ShortType   => "toShortArray"
          case ByteType    => "toByteArray"
          case BooleanType => "toBooleanArray"
          // non-primitive
          case _ => "array"

        Invoke(arrayData, methodName, arrayClass, returnNullable = true)

  inline given deriveSeq[F[_], T](using d: Deserializer[T], ct: ClassTag[T])(
      using F[T] <:< Seq[T]
  ): Deserializer[F[T]] =
    // TODO: Nullable
    new Deserializer[F[T]]:
      override def inputType: DataType = ArrayType(d.inputType)
      override def deserialize(path: Expression): Expression =
        val mapFunction: Expression => Expression = element =>
          Helper.deserializerForWithNullSafetyAndUpcast(
            element,
            d.inputType,
            nullable = true,
            WalkedTypePath(Nil),
            d.deserialize
          )
        UnresolvedMapObjects(mapFunction, path, Some(classOf[Seq[T]]))

  inline given derivedSet[T: Deserializer: ClassTag]: Deserializer[Set[T]] =
    val forSeq = deriveSeq[List, T]
    new Deserializer[Set[T]]:
      override def inputType: DataType = forSeq.inputType
      override def deserialize(path: Expression): Expression =
        val res = forSeq.deserialize(path).asInstanceOf[UnresolvedMapObjects]
        UnresolvedMapObjects(res.function, res.child, Some(classOf[Set[T]]))

  inline given derivedMap[K, V](using
      kd: Deserializer[K],
      vd: Deserializer[V],
      ct: ClassTag[Map[K, V]]
  ): Deserializer[Map[K, V]] =
    new Deserializer[Map[K, V]]:
      override def inputType: DataType = MapType(kd.inputType, vd.inputType)
      override def deserialize(path: Expression): Expression =
        UnresolvedCatalystToExternalMap(
          path,
          kd.deserialize(_),
          vd.deserialize(_),
          ct.runtimeClass
        )

// inspired by https://github.com/apache/spark/blob/39542bb81f8570219770bb6533c077f44f6cbd2a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala#L356-L390
  inline given derivedProduct[T](using
      mirror: Mirror.ProductOf[T],
      classTag: ClassTag[T]
  ): Deserializer[T] =
    val elems = summonAll[mirror.MirroredElemLabels, mirror.MirroredElemTypes]
    lazy val fields = elems
      .map((label, deserializer) =>
        StructField(label, deserializer.inputType, deserializer.nullable)
      )
    def cls = classTag.runtimeClass
    def isTuple = cls.getName.startsWith("scala.Tuple")
    val walkedTypePath = WalkedTypePath()

    new Deserializer[T]:
      override def inputType: DataType = StructType(fields)
      override def deserialize(path: Expression): Expression =
        // code is partly taken from `DeserializerBuildHelper.createDeserializer`: `case ProductEncoder`
        val arguments = elems.zipWithIndex.map {
          case ((label, deserializer), i) =>
            val newTypePath = walkedTypePath.recordField(cls.getName, label)
            // For tuples, we grab the inner fields by ordinal instead of name.
            val getter = if (isTuple) {
              addToPathOrdinal(
                path,
                i,
                deserializer.inputType,
                newTypePath
              )
            } else {
              addToPath(
                path,
                label,
                deserializer.inputType,
                newTypePath
              )
            }
            expressionWithNullSafety(
              deserializer.deserialize(getter),
              deserializer.nullable,
              newTypePath
            )
        }

        val newInstance =
          NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
        If(IsNull(path), Literal.create(null, ObjectType(cls)), newInstance)

  private inline def summonAll[T <: Tuple, U <: Tuple]
      : List[(String, Deserializer[?])] =
    inline (erasedValue[T], erasedValue[U]) match
      // same bulk processing as in Serializer to prevent stackoverflow on summoning decoders for large case classes
      case _: (
              t1 *: t2 *: t3 *: t4 *: t5 *: t6 *: t7 *: t8 *: t9 *: t10 *:
                t11 *: t12 *: t13 *: t14 *: t15 *: t16 *: ts,
              u1 *: u2 *: u3 *: u4 *: u5 *: u6 *: u7 *: u8 *: u9 *: u10 *:
                u11 *: u12 *: u13 *: u14 *: u15 *: u16 *: us
          ) =>
        List(
          (constValue[t1].toString, summonInline[Deserializer[u1]]),
          (constValue[t2].toString, summonInline[Deserializer[u2]]),
          (constValue[t3].toString, summonInline[Deserializer[u3]]),
          (constValue[t4].toString, summonInline[Deserializer[u4]]),
          (constValue[t5].toString, summonInline[Deserializer[u5]]),
          (constValue[t6].toString, summonInline[Deserializer[u6]]),
          (constValue[t7].toString, summonInline[Deserializer[u7]]),
          (constValue[t8].toString, summonInline[Deserializer[u8]]),
          (constValue[t9].toString, summonInline[Deserializer[u9]]),
          (constValue[t10].toString, summonInline[Deserializer[u10]]),
          (constValue[t11].toString, summonInline[Deserializer[u11]]),
          (constValue[t12].toString, summonInline[Deserializer[u12]]),
          (constValue[t13].toString, summonInline[Deserializer[u13]]),
          (constValue[t14].toString, summonInline[Deserializer[u14]]),
          (constValue[t15].toString, summonInline[Deserializer[u15]]),
          (constValue[t16].toString, summonInline[Deserializer[u16]])
        )
          ::: summonAll[ts, us]
      case _: ((t *: ts), (u *: us)) =>
        (constValue[t].toString, summonInline[Deserializer[u]]) :: summonAll[
          ts,
          us
        ]
      case _ => Nil




© 2015 - 2024 Weber Informatics LLC | Privacy Policy