All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst.encoders

import java.util.concurrent.ConcurrentMap

import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{typeTag, TypeTag}

import org.apache.spark.util.Utils
import org.apache.spark.sql.{AnalysisException, Encoder}
import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
import org.apache.spark.sql.catalyst.{JavaTypeInference, InternalRow, ScalaReflection}
import org.apache.spark.sql.types.{StructField, ObjectType, StructType}

/**
 * A factory for constructing encoders that convert objects and primitives to and from the
 * internal row format using catalyst expressions and code generation.  By default, the
 * expressions used to retrieve values from an input row when producing an object will be created as
 * follows:
 *  - Classes will have their sub fields extracted by name using [[UnresolvedAttribute]] expressions
 *    and [[UnresolvedExtractValue]] expressions.
 *  - Tuples will have their subfields extracted by position using [[BoundReference]] expressions.
 *  - Primitives will have their values extracted from the first ordinal with a schema that defaults
 *    to the name `value`.
 */
object ExpressionEncoder {
  def apply[T : TypeTag](): ExpressionEncoder[T] = {
    // We convert the not-serializable TypeTag into StructType and ClassTag.
    val mirror = typeTag[T].mirror
    val cls = mirror.runtimeClass(typeTag[T].tpe)
    val flat = !classOf[Product].isAssignableFrom(cls)

    val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true)
    val toRowExpression = ScalaReflection.extractorsFor[T](inputObject)
    val fromRowExpression = ScalaReflection.constructorFor[T]

    val schema = ScalaReflection.schemaFor[T] match {
      case ScalaReflection.Schema(s: StructType, _) => s
      case ScalaReflection.Schema(dt, nullable) => new StructType().add("value", dt, nullable)
    }

    new ExpressionEncoder[T](
      schema,
      flat,
      toRowExpression.flatten,
      fromRowExpression,
      ClassTag[T](cls))
  }

  // TODO: improve error message for java bean encoder.
  def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = {
    val schema = JavaTypeInference.inferDataType(beanClass)._1
    assert(schema.isInstanceOf[StructType])

    val toRowExpression = JavaTypeInference.extractorsFor(beanClass)
    val fromRowExpression = JavaTypeInference.constructorFor(beanClass)

    new ExpressionEncoder[T](
      schema.asInstanceOf[StructType],
      flat = false,
      toRowExpression.flatten,
      fromRowExpression,
      ClassTag[T](beanClass))
  }

  /**
   * Given a set of N encoders, constructs a new encoder that produce objects as items in an
   * N-tuple.  Note that these encoders should be unresolved so that information about
   * name/positional binding is preserved.
   */
  def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
    encoders.foreach(_.assertUnresolved())

    val schema = StructType(encoders.zipWithIndex.map {
      case (e, i) =>
        val (dataType, nullable) = if (e.flat) {
          e.schema.head.dataType -> e.schema.head.nullable
        } else {
          e.schema -> true
        }
        StructField(s"_${i + 1}", dataType, nullable)
    })

    val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")

    val toRowExpressions = encoders.map {
      case e if e.flat => e.toRowExpressions.head
      case other => CreateStruct(other.toRowExpressions)
    }.zipWithIndex.map { case (expr, index) =>
      expr.transformUp {
        case BoundReference(0, t, _) =>
          Invoke(
            BoundReference(0, ObjectType(cls), nullable = true),
            s"_${index + 1}",
            t)
      }
    }

    val fromRowExpressions = encoders.zipWithIndex.map { case (enc, index) =>
      if (enc.flat) {
        enc.fromRowExpression.transform {
          case b: BoundReference => b.copy(ordinal = index)
        }
      } else {
        val input = BoundReference(index, enc.schema, nullable = true)
        enc.fromRowExpression.transformUp {
          case UnresolvedAttribute(nameParts) =>
            assert(nameParts.length == 1)
            UnresolvedExtractValue(input, Literal(nameParts.head))
          case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal)
        }
      }
    }

    val fromRowExpression =
      NewInstance(cls, fromRowExpressions, ObjectType(cls), propagateNull = false)

    new ExpressionEncoder[Any](
      schema,
      flat = false,
      toRowExpressions,
      fromRowExpression,
      ClassTag(cls))
  }

  def tuple[T1, T2](
      e1: ExpressionEncoder[T1],
      e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
    tuple(Seq(e1, e2)).asInstanceOf[ExpressionEncoder[(T1, T2)]]

  def tuple[T1, T2, T3](
      e1: ExpressionEncoder[T1],
      e2: ExpressionEncoder[T2],
      e3: ExpressionEncoder[T3]): ExpressionEncoder[(T1, T2, T3)] =
    tuple(Seq(e1, e2, e3)).asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]

  def tuple[T1, T2, T3, T4](
      e1: ExpressionEncoder[T1],
      e2: ExpressionEncoder[T2],
      e3: ExpressionEncoder[T3],
      e4: ExpressionEncoder[T4]): ExpressionEncoder[(T1, T2, T3, T4)] =
    tuple(Seq(e1, e2, e3, e4)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]

  def tuple[T1, T2, T3, T4, T5](
      e1: ExpressionEncoder[T1],
      e2: ExpressionEncoder[T2],
      e3: ExpressionEncoder[T3],
      e4: ExpressionEncoder[T4],
      e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] =
    tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
}

/**
 * A generic encoder for JVM objects.
 *
 * @param schema The schema after converting `T` to a Spark SQL row.
 * @param toRowExpressions A set of expressions, one for each top-level field that can be used to
 *                           extract the values from a raw object into an [[InternalRow]].
 * @param fromRowExpression An expression that will construct an object given an [[InternalRow]].
 * @param clsTag A classtag for `T`.
 */
case class ExpressionEncoder[T](
    schema: StructType,
    flat: Boolean,
    toRowExpressions: Seq[Expression],
    fromRowExpression: Expression,
    clsTag: ClassTag[T])
  extends Encoder[T] {

  if (flat) require(toRowExpressions.size == 1)

  @transient
  private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions)

  @transient
  private lazy val inputRow = new GenericMutableRow(1)

  @transient
  private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil)

  /**
   * Returns this encoder where it has been bound to its own output (i.e. no remaping of columns
   * is performed).
   */
  def defaultBinding: ExpressionEncoder[T] = {
    val attrs = schema.toAttributes
    resolve(attrs, OuterScopes.outerScopes).bind(attrs)
  }

  /**
   * Returns an encoded version of `t` as a Spark SQL row.  Note that multiple calls to
   * toRow are allowed to return the same actual [[InternalRow]] object.  Thus, the caller should
   * copy the result before making another call if required.
   */
  def toRow(t: T): InternalRow = try {
    inputRow(0) = t
    extractProjection(inputRow)
  } catch {
    case e: Exception =>
      throw new RuntimeException(
        s"Error while encoding: $e\n${toRowExpressions.map(_.treeString).mkString("\n")}", e)
  }

  /**
   * Returns an object of type `T`, extracting the required values from the provided row.  Note that
   * you must `resolve` and `bind` an encoder to a specific schema before you can call this
   * function.
   */
  def fromRow(row: InternalRow): T = try {
    constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
  } catch {
    case e: Exception =>
      throw new RuntimeException(s"Error while decoding: $e\n${fromRowExpression.treeString}", e)
  }

  /**
   * The process of resolution to a given schema throws away information about where a given field
   * is being bound by ordinal instead of by name.  This method checks to make sure this process
   * has not been done already in places where we plan to do later composition of encoders.
   */
  def assertUnresolved(): Unit = {
    (fromRowExpression +:  toRowExpressions).foreach(_.foreach {
      case a: AttributeReference if a.name != "loopVar" =>
        sys.error(s"Unresolved encoder expected, but $a was found.")
      case _ =>
    })
  }

  /**
   * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
   * given schema.
   */
  def resolve(
      schema: Seq[Attribute],
      outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
    def fail(st: StructType, maxOrdinal: Int): Unit = {
      throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " +
        "but failed as the number of fields does not line up.\n" +
        " - Input schema: " + StructType.fromAttributes(schema).simpleString + "\n" +
        " - Target schema: " + this.schema.simpleString)
    }

    var maxOrdinal = -1
    fromRowExpression.foreach {
      case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal
      case _ =>
    }
    if (maxOrdinal >= 0 && maxOrdinal != schema.length - 1) {
      fail(StructType.fromAttributes(schema), maxOrdinal)
    }

    val unbound = fromRowExpression transform {
      case b: BoundReference => schema(b.ordinal)
    }

    val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int]
    unbound.foreach {
      case g: GetStructField =>
        val maxOrdinal = exprToMaxOrdinal.getOrElse(g.child, -1)
        if (maxOrdinal < g.ordinal) {
          exprToMaxOrdinal.update(g.child, g.ordinal)
        }
      case _ =>
    }
    exprToMaxOrdinal.foreach {
      case (expr, maxOrdinal) =>
        val schema = expr.dataType.asInstanceOf[StructType]
        if (maxOrdinal != schema.length - 1) {
          fail(schema, maxOrdinal)
        }
    }

    val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
    val analyzedPlan = SimpleAnalyzer.execute(plan)
    SimpleAnalyzer.checkAnalysis(analyzedPlan)
    val optimizedPlan = SimplifyCasts(analyzedPlan)

    // In order to construct instances of inner classes (for example those declared in a REPL cell),
    // we need an instance of the outer scope.  This rule substitues those outer objects into
    // expressions that are missing them by looking up the name in the SQLContexts `outerScopes`
    // registry.
    copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform {
      case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass =>
        val outer = outerScopes.get(n.cls.getDeclaringClass.getName)
        if (outer == null) {
          throw new AnalysisException(
            s"Unable to generate an encoder for inner class `${n.cls.getName}` without access " +
            s"to the scope that this class was defined in. " + "" +
             "Try moving this class out of its parent class.")
        }

        n.copy(outerPointer = Some(Literal.fromObject(outer)))
    })
  }

  /**
   * Returns a copy of this encoder where the expressions used to construct an object from an input
   * row have been bound to the ordinals of the given schema.  Note that you need to first call
   * resolve before bind.
   */
  def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
    copy(fromRowExpression = BindReferences.bindReference(fromRowExpression, schema))
  }

  /**
   * Returns a new encoder with input columns shifted by `delta` ordinals
   */
  def shift(delta: Int): ExpressionEncoder[T] = {
    copy(fromRowExpression = fromRowExpression transform {
      case r: BoundReference => r.copy(ordinal = r.ordinal + delta)
    })
  }

  protected val attrs = toRowExpressions.flatMap(_.collect {
    case _: UnresolvedAttribute => ""
    case a: Attribute => s"#${a.exprId}"
    case b: BoundReference => s"[${b.ordinal}]"
  })

  protected val schemaString =
    schema
      .zip(attrs)
      .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")

  override def toString: String = s"class[$schemaString]"
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy