org.apache.spark.sql.expressions.UserDefinedFunction.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.expressions
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.ScalaUDF
import org.apache.spark.sql.types.DataType
/**
* A user-defined function. To create one, use the `udf` functions in `functions`.
*
* As an example:
* {{{
* // Define a UDF that returns true or false based on some numeric score.
* val predict = udf((score: Double) => score > 0.5)
*
* // Projects a column that adds a prediction column based on the score column.
* df.select( predict(df("score")) )
* }}}
*
* @since 1.3.0
*/
@InterfaceStability.Stable
case class UserDefinedFunction protected[sql] (
f: AnyRef,
dataType: DataType,
inputTypes: Option[Seq[DataType]]) {
private var _nameOption: Option[String] = None
private var _nullable: Boolean = true
private var _deterministic: Boolean = true
// This is a `var` instead of in the constructor for backward compatibility of this case class.
// TODO: revisit this case class in Spark 3.0, and narrow down the public surface.
private[sql] var nullableTypes: Option[Seq[Boolean]] = None
/**
* Returns true when the UDF can return a nullable value.
*
* @since 2.3.0
*/
def nullable: Boolean = _nullable
/**
* Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the same
* input.
*
* @since 2.3.0
*/
def deterministic: Boolean = _deterministic
/**
* Returns an expression that invokes the UDF, using the given arguments.
*
* @since 1.3.0
*/
@scala.annotation.varargs
def apply(exprs: Column*): Column = {
// TODO: make sure this class is only instantiated through `SparkUserDefinedFunction.create()`
// and `nullableTypes` is always set.
if (nullableTypes.isEmpty) {
nullableTypes = Some(ScalaReflection.getParameterTypeNullability(f))
}
if (inputTypes.isDefined) {
assert(inputTypes.get.length == nullableTypes.get.length)
}
Column(ScalaUDF(
f,
dataType,
exprs.map(_.expr),
nullableTypes.get,
inputTypes.getOrElse(Nil),
udfName = _nameOption,
nullable = _nullable,
udfDeterministic = _deterministic))
}
private def copyAll(): UserDefinedFunction = {
val udf = copy()
udf._nameOption = _nameOption
udf._nullable = _nullable
udf._deterministic = _deterministic
udf.nullableTypes = nullableTypes
udf
}
/**
* Updates UserDefinedFunction with a given name.
*
* @since 2.3.0
*/
def withName(name: String): UserDefinedFunction = {
val udf = copyAll()
udf._nameOption = Option(name)
udf
}
/**
* Updates UserDefinedFunction to non-nullable.
*
* @since 2.3.0
*/
def asNonNullable(): UserDefinedFunction = {
if (!nullable) {
this
} else {
val udf = copyAll()
udf._nullable = false
udf
}
}
/**
* Updates UserDefinedFunction to nondeterministic.
*
* @since 2.3.0
*/
def asNondeterministic(): UserDefinedFunction = {
if (!_deterministic) {
this
} else {
val udf = copyAll()
udf._deterministic = false
udf
}
}
}
// We have to use a name different than `UserDefinedFunction` here, to avoid breaking the binary
// compatibility of the auto-generate UserDefinedFunction object.
private[sql] object SparkUserDefinedFunction {
def create(
f: AnyRef,
dataType: DataType,
inputSchemas: Seq[Option[ScalaReflection.Schema]]): UserDefinedFunction = {
val inputTypes = if (inputSchemas.contains(None)) {
None
} else {
Some(inputSchemas.map(_.get.dataType))
}
val udf = new UserDefinedFunction(f, dataType, inputTypes)
udf.nullableTypes = Some(inputSchemas.map(_.map(_.nullable).getOrElse(true)))
udf
}
}