All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.internal.UdfToProtoUtils.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.spark.sql.internal

import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

import com.google.protobuf.ByteString

import org.apache.spark.SparkException
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter.toConnectProtoType
import org.apache.spark.sql.connect.common.UdfPacket
import org.apache.spark.sql.encoderFor
import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction}
import org.apache.spark.util.{ClosureCleaner, SparkClassUtils, SparkSerDeUtils}

/**
 * Utility for converting a `UserDefinedFunction` into a Connect Protobuf message.
 */
private[sql] object UdfToProtoUtils {
  private val LAMBDA_DESERIALIZATION_ERR_MSG: String =
    "cannot assign instance of java.lang.invoke.SerializedLambda to field"

  private def checkDeserializable(bytes: Array[Byte]): Unit = {
    try {
      SparkSerDeUtils.deserialize(bytes, SparkClassUtils.getContextOrSparkClassLoader)
    } catch {
      case e: ClassCastException if e.getMessage.contains(LAMBDA_DESERIALIZATION_ERR_MSG) =>
        throw new SparkException(
          "UDF cannot be executed on a Spark cluster: it cannot be deserialized. " +
            "This is very likely to be caused by the lambda function (the UDF) having a " +
            "self-reference. This is not supported by java serialization.")
      case NonFatal(e) =>
        throw new SparkException(
          "UDF cannot be executed on a Spark cluster: it cannot be deserialized.",
          e)
    }
  }

  private[sql] def toUdfPacketBytes(
      function: AnyRef,
      inputEncoders: Seq[AgnosticEncoder[_]],
      outputEncoder: AgnosticEncoder[_]): ByteString = {
    ClosureCleaner.clean(function, cleanTransitively = true, mutable.Map.empty)
    val bytes = SparkSerDeUtils.serialize(UdfPacket(function, inputEncoders, outputEncoder))
    checkDeserializable(bytes)
    ByteString.copyFrom(bytes)
  }

  /**
   * Convert a [[UserDefinedFunction]] to a [[proto.CommonInlineUserDefinedFunction]].
   */
  private[sql] def toProto(
      udf: UserDefinedFunction,
      arguments: Seq[proto.Expression] = Nil): proto.CommonInlineUserDefinedFunction = {
    val invokeUdf = proto.CommonInlineUserDefinedFunction
      .newBuilder()
      .setDeterministic(udf.deterministic)
      .addAllArguments(arguments.asJava)
    val protoUdf = invokeUdf.getScalarScalaUdfBuilder
      .setNullable(udf.nullable)
    udf match {
      case f: SparkUserDefinedFunction =>
        val outputEncoder = f.outputEncoder
          .map(e => encoderFor(e))
          .getOrElse(RowEncoder.encoderForDataType(f.dataType, lenient = false))
        val inputEncoders = if (f.inputEncoders.forall(_.isEmpty)) {
          Nil // Java UDFs have no bindings for their inputs.
        } else {
          f.inputEncoders.map(e => encoderFor(e.get)) // TODO support Any and UnboundRow.
        }
        inputEncoders.foreach(e => protoUdf.addInputTypes(toConnectProtoType(e.dataType)))
        protoUdf
          .setPayload(toUdfPacketBytes(f.f, inputEncoders, outputEncoder))
          .setOutputType(toConnectProtoType(outputEncoder.dataType))
          .setAggregate(false)
        f.givenName.foreach(invokeUdf.setFunctionName)
      case f: UserDefinedAggregator[_, _, _] =>
        val outputEncoder = encoderFor(f.aggregator.outputEncoder)
        val inputEncoder = encoderFor(f.inputEncoder)
        protoUdf
          .setPayload(toUdfPacketBytes(f.aggregator, inputEncoder :: Nil, outputEncoder))
          .addInputTypes(toConnectProtoType(inputEncoder.dataType))
          .setOutputType(toConnectProtoType(outputEncoder.dataType))
          .setAggregate(true)
        f.givenName.foreach(invokeUdf.setFunctionName)
    }
    invokeUdf.build()
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy