All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.UDFRegistration.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql

import java.lang.reflect.ParameterizedType

import org.apache.spark.annotation.Stable
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.internal.Logging
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.JavaTypeInference
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF}
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction}
import org.apache.spark.sql.internal.UserDefinedFunctionUtils.toScalaUDF
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils

/**
 * Functions for registering user-defined functions. Use `SparkSession.udf` to access this:
 *
 * {{{
 *   spark.udf
 * }}}
 *
 * @since 1.3.0
 */
@Stable
class UDFRegistration private[sql] (functionRegistry: FunctionRegistry)
  extends api.UDFRegistration
  with Logging {
  protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = {
    log.debug(
      s"""
        | Registering new PythonUDF:
        | name: $name
        | command: ${udf.func.command}
        | envVars: ${udf.func.envVars}
        | pythonIncludes: ${udf.func.pythonIncludes}
        | pythonExec: ${udf.func.pythonExec}
        | dataType: ${udf.dataType}
        | pythonEvalType: ${PythonEvalType.toString(udf.pythonEvalType)}
        | udfDeterministic: ${udf.udfDeterministic}
      """.stripMargin)

    functionRegistry.createOrReplaceTempFunction(name, udf.builder, "python_udf")
  }

  /**
   * Registers a user-defined aggregate function (UDAF).
   *
   * @param name the name of the UDAF.
   * @param udaf the UDAF needs to be registered.
   * @return the registered UDAF.
   * @since 1.5.0
   * @deprecated this method and the use of UserDefinedAggregateFunction are deprecated.
   *             Aggregator[IN, BUF, OUT] should now be registered as a UDF via the
   *             functions.udaf(agg) method.
   */
  @deprecated("Aggregator[IN, BUF, OUT] should now be registered as a UDF" +
    " via the functions.udaf(agg) method.", "3.0.0")
  def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
    def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf, udafName = Some(name))
    functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf")
    udaf
  }

  override protected def register(
      name: String,
      udf: UserDefinedFunction,
      source: String,
      validateParameterCount: Boolean): UserDefinedFunction = {
    val named = udf.withName(name)
    val builder: Seq[Expression] => Expression = named match {
      case udaf: UserDefinedAggregator[_, _, _] =>
        ScalaAggregator(udaf, _)
      case udf: SparkUserDefinedFunction if validateParameterCount =>
        val expectedParameterCount = udf.inputEncoders.size
        children => {
          val actualParameterCount = children.length
          if (expectedParameterCount == actualParameterCount) {
            toScalaUDF(udf, children)
          } else {
            throw QueryCompilationErrors.wrongNumArgsError(
              name,
              expectedParameterCount.toString,
              actualParameterCount)
          }
        }
      case udf: SparkUserDefinedFunction =>
        toScalaUDF(udf, _)
    }
    functionRegistry.createOrReplaceTempFunction(name, builder, source)
    named
  }


  /**
   * Register a Java UDAF class using reflection, for use from pyspark
   *
   * @param name      UDAF name
   * @param className fully qualified class name of UDAF
   */
  private[sql] def registerJavaUDAF(name: String, className: String): Unit = {
    try {
      val clazz = Utils.classForName[AnyRef](className)
      if (!classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) {
        throw QueryCompilationErrors
          .classDoesNotImplementUserDefinedAggregateFunctionError(className)
      }
      val udaf = clazz.getConstructor().newInstance().asInstanceOf[UserDefinedAggregateFunction]
      register(name, udaf)
    } catch {
      case _: ClassNotFoundException =>
        throw QueryCompilationErrors.cannotLoadClassNotOnClassPathError(className)
      case _: InstantiationException | _: IllegalArgumentException =>
        throw QueryCompilationErrors.classWithoutPublicNonArgumentConstructorError(className)
    }
  }

  // scalastyle:off line.size.limit
  /**
   * Register a Java UDF class using reflection, for use from pyspark
   *
   * @param name           udf name
   * @param className      fully qualified class name of udf
   * @param returnDataType return type of udf. If it is null, spark would try to infer
   *                       via reflection.
   */
  private[sql] def registerJava(name: String, className: String, returnDataType: DataType): Unit = {
    try {
      val clazz = Utils.classForName[AnyRef](className)
      val udfInterfaces = clazz.getGenericInterfaces
        .filter(_.isInstanceOf[ParameterizedType])
        .map(_.asInstanceOf[ParameterizedType])
        .filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF"))
      if (udfInterfaces.length == 0) {
        throw QueryCompilationErrors.udfClassDoesNotImplementAnyUDFInterfaceError(className)
      } else if (udfInterfaces.length > 1) {
        throw QueryCompilationErrors.udfClassImplementMultiUDFInterfacesError(className)
      } else {
        try {
          val udf = clazz.getConstructor().newInstance()
          val udfReturnType = udfInterfaces(0).getActualTypeArguments.last
          var returnType = returnDataType
          if (returnType == null) {
            returnType = JavaTypeInference.inferDataType(udfReturnType)._1
          }

          udfInterfaces(0).getActualTypeArguments.length match {
            case 1 => register(name, udf.asInstanceOf[UDF0[_]], returnType)
            case 2 => register(name, udf.asInstanceOf[UDF1[_, _]], returnType)
            case 3 => register(name, udf.asInstanceOf[UDF2[_, _, _]], returnType)
            case 4 => register(name, udf.asInstanceOf[UDF3[_, _, _, _]], returnType)
            case 5 => register(name, udf.asInstanceOf[UDF4[_, _, _, _, _]], returnType)
            case 6 => register(name, udf.asInstanceOf[UDF5[_, _, _, _, _, _]], returnType)
            case 7 => register(name, udf.asInstanceOf[UDF6[_, _, _, _, _, _, _]], returnType)
            case 8 => register(name, udf.asInstanceOf[UDF7[_, _, _, _, _, _, _, _]], returnType)
            case 9 => register(name, udf.asInstanceOf[UDF8[_, _, _, _, _, _, _, _, _]], returnType)
            case 10 => register(name, udf.asInstanceOf[UDF9[_, _, _, _, _, _, _, _, _, _]], returnType)
            case 11 => register(name, udf.asInstanceOf[UDF10[_, _, _, _, _, _, _, _, _, _, _]], returnType)
            case 12 => register(name, udf.asInstanceOf[UDF11[_, _, _, _, _, _, _, _, _, _, _, _]], returnType)
            case 13 => register(name, udf.asInstanceOf[UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
            case 14 => register(name, udf.asInstanceOf[UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
            case 15 => register(name, udf.asInstanceOf[UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
            case 16 => register(name, udf.asInstanceOf[UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
            case 17 => register(name, udf.asInstanceOf[UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
            case 18 => register(name, udf.asInstanceOf[UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
            case 19 => register(name, udf.asInstanceOf[UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
            case 20 => register(name, udf.asInstanceOf[UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
            case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
            case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
            case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
            case n =>
              throw QueryCompilationErrors.udfClassWithTooManyTypeArgumentsError(n)
          }
        } catch {
          case _: InstantiationException | _: IllegalArgumentException =>
            throw QueryCompilationErrors.classWithoutPublicNonArgumentConstructorError(className)
        }
      }
    } catch {
      case _: ClassNotFoundException => throw QueryCompilationErrors.cannotLoadClassNotOnClassPathError(className)
    }
  }
  // scalastyle:on line.size.limit
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy