All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.expressions.UserDefinedFunction.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.expressions

import scala.reflect.runtime.universe.TypeTag
import scala.util.Try

import org.apache.spark.annotation.Stable
import org.apache.spark.sql.{Column, Encoder}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.internal.{InvokeInlineUserDefinedFunction, UserDefinedFunctionLike}
import org.apache.spark.sql.types.DataType

/**
 * A user-defined function. To create one, use the `udf` functions in `functions`.
 *
 * As an example:
 * {{{
 *   // Define a UDF that returns true or false based on some numeric score.
 *   val predict = udf((score: Double) => score > 0.5)
 *
 *   // Projects a column that adds a prediction column based on the score column.
 *   df.select( predict(df("score")) )
 * }}}
 *
 * @since 1.3.0
 */
@Stable
sealed abstract class UserDefinedFunction extends UserDefinedFunctionLike {

  /**
   * Returns true when the UDF can return a nullable value.
   *
   * @since 2.3.0
   */
  def nullable: Boolean

  /**
   * Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the
   * same input.
   *
   * @since 2.3.0
   */
  def deterministic: Boolean

  /**
   * Returns an expression that invokes the UDF, using the given arguments.
   *
   * @since 1.3.0
   */
  @scala.annotation.varargs
  def apply(exprs: Column*): Column = {
    Column(InvokeInlineUserDefinedFunction(this, exprs.map(_.node)))
  }

  /**
   * Updates UserDefinedFunction with a given name.
   *
   * @since 2.3.0
   */
  def withName(name: String): UserDefinedFunction

  /**
   * Updates UserDefinedFunction to non-nullable.
   *
   * @since 2.3.0
   */
  def asNonNullable(): UserDefinedFunction

  /**
   * Updates UserDefinedFunction to nondeterministic.
   *
   * @since 2.3.0
   */
  def asNondeterministic(): UserDefinedFunction
}

private[spark] case class SparkUserDefinedFunction(
    f: AnyRef,
    dataType: DataType,
    inputEncoders: Seq[Option[Encoder[_]]] = Nil,
    outputEncoder: Option[Encoder[_]] = None,
    givenName: Option[String] = None,
    nullable: Boolean = true,
    deterministic: Boolean = true)
    extends UserDefinedFunction {

  override def withName(name: String): SparkUserDefinedFunction = {
    copy(givenName = Option(name))
  }

  override def asNonNullable(): SparkUserDefinedFunction = {
    if (!nullable) {
      this
    } else {
      copy(nullable = false)
    }
  }

  override def asNondeterministic(): SparkUserDefinedFunction = {
    if (!deterministic) {
      this
    } else {
      copy(deterministic = false)
    }
  }

  override def name: String = givenName.getOrElse("UDF")
}

object SparkUserDefinedFunction {
  private[sql] def apply(
      function: AnyRef,
      returnTypeTag: TypeTag[_],
      inputTypeTags: TypeTag[_]*): SparkUserDefinedFunction = {
    val outputEncoder = ScalaReflection.encoderFor(returnTypeTag)
    val inputEncoders = inputTypeTags.map { tag =>
      Try(ScalaReflection.encoderFor(tag)).toOption
    }
    SparkUserDefinedFunction(
      f = function,
      inputEncoders = inputEncoders,
      dataType = outputEncoder.dataType,
      outputEncoder = Option(outputEncoder),
      nullable = outputEncoder.nullable)
  }

  private[sql] def apply(
      function: AnyRef,
      inputEncoders: Seq[AgnosticEncoder[_]],
      outputEncoder: AgnosticEncoder[_]): SparkUserDefinedFunction = {
    SparkUserDefinedFunction(
      f = function,
      inputEncoders = inputEncoders.map(Option.apply),
      outputEncoder = Option(outputEncoder),
      dataType = outputEncoder.dataType,
      nullable = outputEncoder.nullable)
  }

  private[sql] def apply(
      function: AnyRef,
      returnType: DataType,
      cardinality: Int): SparkUserDefinedFunction = {
    SparkUserDefinedFunction(
      function,
      returnType,
      inputEncoders = Seq.fill(cardinality)(None),
      None)
  }
}

private[sql] case class UserDefinedAggregator[IN, BUF, OUT](
    aggregator: Aggregator[IN, BUF, OUT],
    inputEncoder: Encoder[IN],
    givenName: Option[String] = None,
    nullable: Boolean = true,
    deterministic: Boolean = true)
    extends UserDefinedFunction {

  override def withName(name: String): UserDefinedAggregator[IN, BUF, OUT] = {
    copy(givenName = Option(name))
  }

  override def asNonNullable(): UserDefinedAggregator[IN, BUF, OUT] = {
    if (!nullable) {
      this
    } else {
      copy(nullable = false)
    }
  }

  override def asNondeterministic(): UserDefinedAggregator[IN, BUF, OUT] = {
    if (!deterministic) {
      this
    } else {
      copy(deterministic = false)
    }
  }

  override def name: String = givenName.getOrElse(aggregator.name)
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy