All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.expressions.UserDefinedFunction.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.spark.sql.expressions

import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal

import com.google.protobuf.ByteString

import org.apache.spark.SparkException
import org.apache.spark.connect.proto
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfPacket}
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.{SparkClassUtils, SparkSerDeUtils}

/**
 * A user-defined function. To create one, use the `udf` functions in `functions`.
 *
 * As an example:
 * {{{
 *   // Define a UDF that returns true or false based on some numeric score.
 *   val predict = udf((score: Double) => score > 0.5)
 *
 *   // Projects a column that adds a prediction column based on the score column.
 *   df.select( predict(df("score")) )
 * }}}
 *
 * @since 3.4.0
 */
sealed abstract class UserDefinedFunction {

  /**
   * Returns true when the UDF can return a nullable value.
   *
   * @since 3.4.0
   */
  def nullable: Boolean

  /**
   * Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the
   * same input.
   *
   * @since 3.4.0
   */
  def deterministic: Boolean

  /**
   * Returns an expression that invokes the UDF, using the given arguments.
   *
   * @since 3.4.0
   */
  @scala.annotation.varargs
  def apply(exprs: Column*): Column

  /**
   * Updates UserDefinedFunction with a given name.
   *
   * @since 3.4.0
   */
  def withName(name: String): UserDefinedFunction

  /**
   * Updates UserDefinedFunction to non-nullable.
   *
   * @since 3.4.0
   */
  def asNonNullable(): UserDefinedFunction

  /**
   * Updates UserDefinedFunction to nondeterministic.
   *
   * @since 3.4.0
   */
  def asNondeterministic(): UserDefinedFunction
}

/**
 * Holder class for a scalar user-defined function and it's input/output encoder(s).
 */
case class ScalarUserDefinedFunction private[sql] (
    // SPARK-43198: Eagerly serialize to prevent the UDF from containing a reference to this class.
    serializedUdfPacket: Array[Byte],
    inputTypes: Seq[proto.DataType],
    outputType: proto.DataType,
    name: Option[String],
    override val nullable: Boolean,
    override val deterministic: Boolean)
    extends UserDefinedFunction {

  private[this] lazy val udf = {
    val scalaUdfBuilder = proto.ScalarScalaUDF
      .newBuilder()
      .setPayload(ByteString.copyFrom(serializedUdfPacket))
      // Send the real inputs and return types to obtain the types without deser the udf bytes.
      .addAllInputTypes(inputTypes.asJava)
      .setOutputType(outputType)
      .setNullable(nullable)

    scalaUdfBuilder.build()
  }

  @scala.annotation.varargs
  override def apply(exprs: Column*): Column = Column { builder =>
    val udfBuilder = builder.getCommonInlineUserDefinedFunctionBuilder
    udfBuilder
      .setDeterministic(deterministic)
      .setScalarScalaUdf(udf)
      .addAllArguments(exprs.map(_.expr).asJava)

    name.foreach(udfBuilder.setFunctionName)
  }

  override def withName(name: String): ScalarUserDefinedFunction = copy(name = Option(name))

  override def asNonNullable(): ScalarUserDefinedFunction = copy(nullable = false)

  override def asNondeterministic(): ScalarUserDefinedFunction = copy(deterministic = false)

  def toProto: proto.CommonInlineUserDefinedFunction = {
    val builder = proto.CommonInlineUserDefinedFunction.newBuilder()
    builder
      .setDeterministic(deterministic)
      .setScalarScalaUdf(udf)

    name.foreach(builder.setFunctionName)
    builder.build()
  }
}

object ScalarUserDefinedFunction {
  private val LAMBDA_DESERIALIZATION_ERR_MSG: String =
    "cannot assign instance of java.lang.invoke.SerializedLambda to field"

  private def checkDeserializable(bytes: Array[Byte]): Unit = {
    try {
      SparkSerDeUtils.deserialize(bytes, SparkClassUtils.getContextOrSparkClassLoader)
    } catch {
      case e: ClassCastException if e.getMessage.contains(LAMBDA_DESERIALIZATION_ERR_MSG) =>
        throw new SparkException(
          "UDF cannot be executed on a Spark cluster: it cannot be deserialized. " +
            "This is very likely to be caused by the lambda function (the UDF) having a " +
            "self-reference. This is not supported by java serialization.")
      case NonFatal(e) =>
        throw new SparkException(
          "UDF cannot be executed on a Spark cluster: it cannot be deserialized.",
          e)
    }
  }

  private[sql] def apply(
      function: AnyRef,
      returnType: TypeTag[_],
      parameterTypes: TypeTag[_]*): ScalarUserDefinedFunction = {

    ScalarUserDefinedFunction(
      function = function,
      // Input can be a row because the input data schema can be found from the plan.
      inputEncoders =
        parameterTypes.map(tag => ScalaReflection.encoderForWithRowEncoderSupport(tag)),
      // Output cannot be a row as there is no good way to get the return data type.
      outputEncoder = ScalaReflection.encoderFor(returnType))
  }

  private[sql] def apply(
      function: AnyRef,
      inputEncoders: Seq[AgnosticEncoder[_]],
      outputEncoder: AgnosticEncoder[_]): ScalarUserDefinedFunction = {
    val udfPacketBytes =
      SparkSerDeUtils.serialize(UdfPacket(function, inputEncoders, outputEncoder))
    checkDeserializable(udfPacketBytes)
    ScalarUserDefinedFunction(
      serializedUdfPacket = udfPacketBytes,
      inputTypes = inputEncoders.map(_.dataType).map(DataTypeProtoConverter.toConnectProtoType),
      outputType = DataTypeProtoConverter.toConnectProtoType(outputEncoder.dataType),
      name = None,
      nullable = true,
      deterministic = true)
  }

  private[sql] def apply(function: AnyRef, returnType: DataType): ScalarUserDefinedFunction = {
    ScalarUserDefinedFunction(
      function = function,
      inputEncoders = Seq.empty[AgnosticEncoder[_]],
      outputEncoder = RowEncoder.encoderForDataType(returnType, lenient = false))
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy