org.apache.spark.sql.catalyst.analysis.FunctionRegistry.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.analysis
import java.util.Locale
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable
import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.xml._
import org.apache.spark.sql.types._
/**
* A catalog for looking up user defined functions, used by an [[Analyzer]].
*
* Note:
* 1) The implementation should be thread-safe to allow concurrent access.
* 2) the database name is always case-sensitive here, callers are responsible to
* format the database name w.r.t. case-sensitive config.
*/
trait FunctionRegistry {
final def registerFunction(name: FunctionIdentifier, builder: FunctionBuilder): Unit = {
val info = new ExpressionInfo(
builder.getClass.getCanonicalName, name.database.orNull, name.funcName)
registerFunction(name, info, builder)
}
def registerFunction(
name: FunctionIdentifier,
info: ExpressionInfo,
builder: FunctionBuilder): Unit
/* Create or replace a temporary function. */
final def createOrReplaceTempFunction(name: String, builder: FunctionBuilder): Unit = {
registerFunction(
FunctionIdentifier(name),
builder)
}
@throws[AnalysisException]("If function does not exist")
def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression
/* List all of the registered function names. */
def listFunction(): Seq[FunctionIdentifier]
/* Get the class of the registered function by specified name. */
def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo]
/* Get the builder of the registered function by specified name. */
def lookupFunctionBuilder(name: FunctionIdentifier): Option[FunctionBuilder]
/** Drop a function and return whether the function existed. */
def dropFunction(name: FunctionIdentifier): Boolean
/** Checks if a function with a given name exists. */
def functionExists(name: FunctionIdentifier): Boolean = lookupFunction(name).isDefined
/** Clear all registered functions. */
def clear(): Unit
/** Create a copy of this registry with identical functions as this registry. */
override def clone(): FunctionRegistry = throw new CloneNotSupportedException()
}
class SimpleFunctionRegistry extends FunctionRegistry {
@GuardedBy("this")
private val functionBuilders =
new mutable.HashMap[FunctionIdentifier, (ExpressionInfo, FunctionBuilder)]
// Resolution of the function name is always case insensitive, but the database name
// depends on the caller
private def normalizeFuncName(name: FunctionIdentifier): FunctionIdentifier = {
FunctionIdentifier(name.funcName.toLowerCase(Locale.ROOT), name.database)
}
override def registerFunction(
name: FunctionIdentifier,
info: ExpressionInfo,
builder: FunctionBuilder): Unit = synchronized {
functionBuilders.put(normalizeFuncName(name), (info, builder))
}
override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
val func = synchronized {
functionBuilders.get(normalizeFuncName(name)).map(_._2).getOrElse {
throw new AnalysisException(s"undefined function $name")
}
}
func(children)
}
override def listFunction(): Seq[FunctionIdentifier] = synchronized {
functionBuilders.iterator.map(_._1).toList
}
override def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo] = synchronized {
functionBuilders.get(normalizeFuncName(name)).map(_._1)
}
override def lookupFunctionBuilder(
name: FunctionIdentifier): Option[FunctionBuilder] = synchronized {
functionBuilders.get(normalizeFuncName(name)).map(_._2)
}
override def dropFunction(name: FunctionIdentifier): Boolean = synchronized {
functionBuilders.remove(normalizeFuncName(name)).isDefined
}
override def clear(): Unit = synchronized {
functionBuilders.clear()
}
override def clone(): SimpleFunctionRegistry = synchronized {
val registry = new SimpleFunctionRegistry
functionBuilders.iterator.foreach { case (name, (info, builder)) =>
registry.registerFunction(name, info, builder)
}
registry
}
}
/**
* A trivial catalog that returns an error when a function is requested. Used for testing when all
* functions are already filled in and the analyzer needs only to resolve attribute references.
*/
object EmptyFunctionRegistry extends FunctionRegistry {
override def registerFunction(
name: FunctionIdentifier, info: ExpressionInfo, builder: FunctionBuilder): Unit = {
throw new UnsupportedOperationException
}
override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
throw new UnsupportedOperationException
}
override def listFunction(): Seq[FunctionIdentifier] = {
throw new UnsupportedOperationException
}
override def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo] = {
throw new UnsupportedOperationException
}
override def lookupFunctionBuilder(name: FunctionIdentifier): Option[FunctionBuilder] = {
throw new UnsupportedOperationException
}
override def dropFunction(name: FunctionIdentifier): Boolean = {
throw new UnsupportedOperationException
}
override def clear(): Unit = {
throw new UnsupportedOperationException
}
override def clone(): FunctionRegistry = this
}
object FunctionRegistry {
type FunctionBuilder = Seq[Expression] => Expression
// Note: Whenever we add a new entry here, make sure we also update ExpressionToSQLSuite
val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map(
// misc non-aggregate functions
expression[Abs]("abs"),
expression[Coalesce]("coalesce"),
expression[Explode]("explode"),
expressionGeneratorOuter[Explode]("explode_outer"),
expression[Greatest]("greatest"),
expression[If]("if"),
expression[Inline]("inline"),
expressionGeneratorOuter[Inline]("inline_outer"),
expression[IsNaN]("isnan"),
expression[IfNull]("ifnull"),
expression[IsNull]("isnull"),
expression[IsNotNull]("isnotnull"),
expression[Least]("least"),
expression[NaNvl]("nanvl"),
expression[NullIf]("nullif"),
expression[Nvl]("nvl"),
expression[Nvl2]("nvl2"),
expression[PosExplode]("posexplode"),
expressionGeneratorOuter[PosExplode]("posexplode_outer"),
expression[Rand]("rand"),
expression[Randn]("randn"),
expression[Stack]("stack"),
expression[CaseWhen]("when"),
// math functions
expression[Acos]("acos"),
expression[Asin]("asin"),
expression[Atan]("atan"),
expression[Atan2]("atan2"),
expression[Bin]("bin"),
expression[BRound]("bround"),
expression[Cbrt]("cbrt"),
expression[Ceil]("ceil"),
expression[Ceil]("ceiling"),
expression[Cos]("cos"),
expression[Cosh]("cosh"),
expression[Conv]("conv"),
expression[ToDegrees]("degrees"),
expression[EulerNumber]("e"),
expression[Exp]("exp"),
expression[Expm1]("expm1"),
expression[Floor]("floor"),
expression[Factorial]("factorial"),
expression[Hex]("hex"),
expression[Hypot]("hypot"),
expression[Logarithm]("log"),
expression[Log10]("log10"),
expression[Log1p]("log1p"),
expression[Log2]("log2"),
expression[Log]("ln"),
expression[Remainder]("mod"),
expression[UnaryMinus]("negative"),
expression[Pi]("pi"),
expression[Pmod]("pmod"),
expression[UnaryPositive]("positive"),
expression[Pow]("pow"),
expression[Pow]("power"),
expression[ToRadians]("radians"),
expression[Rint]("rint"),
expression[Round]("round"),
expression[ShiftLeft]("shiftleft"),
expression[ShiftRight]("shiftright"),
expression[ShiftRightUnsigned]("shiftrightunsigned"),
expression[Signum]("sign"),
expression[Signum]("signum"),
expression[Sin]("sin"),
expression[Sinh]("sinh"),
expression[StringToMap]("str_to_map"),
expression[Sqrt]("sqrt"),
expression[Tan]("tan"),
expression[Cot]("cot"),
expression[Tanh]("tanh"),
expression[Add]("+"),
expression[Subtract]("-"),
expression[Multiply]("*"),
expression[Divide]("/"),
expression[Remainder]("%"),
// aggregate functions
expression[HyperLogLogPlusPlus]("approx_count_distinct"),
expression[Average]("avg"),
expression[Corr]("corr"),
expression[Count]("count"),
expression[CovPopulation]("covar_pop"),
expression[CovSample]("covar_samp"),
expression[First]("first"),
expression[First]("first_value"),
expression[Kurtosis]("kurtosis"),
expression[Last]("last"),
expression[Last]("last_value"),
expression[Max]("max"),
expression[Average]("mean"),
expression[Min]("min"),
expression[Percentile]("percentile"),
expression[Skewness]("skewness"),
expression[ApproximatePercentile]("percentile_approx"),
expression[ApproximatePercentile]("approx_percentile"),
expression[StddevSamp]("std"),
expression[StddevSamp]("stddev"),
expression[StddevPop]("stddev_pop"),
expression[StddevSamp]("stddev_samp"),
expression[Sum]("sum"),
expression[VarianceSamp]("variance"),
expression[VariancePop]("var_pop"),
expression[VarianceSamp]("var_samp"),
expression[CollectList]("collect_list"),
expression[CollectSet]("collect_set"),
expression[CountMinSketchAgg]("count_min_sketch"),
// string functions
expression[Ascii]("ascii"),
expression[Chr]("char"),
expression[Chr]("chr"),
expression[Base64]("base64"),
expression[BitLength]("bit_length"),
expression[Length]("char_length"),
expression[Length]("character_length"),
expression[ConcatWs]("concat_ws"),
expression[Decode]("decode"),
expression[Elt]("elt"),
expression[Encode]("encode"),
expression[FindInSet]("find_in_set"),
expression[FormatNumber]("format_number"),
expression[FormatString]("format_string"),
expression[GetJsonObject]("get_json_object"),
expression[InitCap]("initcap"),
expression[StringInstr]("instr"),
expression[Lower]("lcase"),
expression[Length]("length"),
expression[Levenshtein]("levenshtein"),
expression[Like]("like"),
expression[Lower]("lower"),
expression[OctetLength]("octet_length"),
expression[StringLocate]("locate"),
expression[StringLPad]("lpad"),
expression[StringTrimLeft]("ltrim"),
expression[JsonTuple]("json_tuple"),
expression[ParseUrl]("parse_url"),
expression[StringLocate]("position"),
expression[FormatString]("printf"),
expression[RegExpExtract]("regexp_extract"),
expression[RegExpReplace]("regexp_replace"),
expression[StringRepeat]("repeat"),
expression[StringReplace]("replace"),
expression[RLike]("rlike"),
expression[StringRPad]("rpad"),
expression[StringTrimRight]("rtrim"),
expression[Sentences]("sentences"),
expression[SoundEx]("soundex"),
expression[StringSpace]("space"),
expression[StringSplit]("split"),
expression[Substring]("substr"),
expression[Substring]("substring"),
expression[Left]("left"),
expression[Right]("right"),
expression[SubstringIndex]("substring_index"),
expression[StringTranslate]("translate"),
expression[StringTrim]("trim"),
expression[Upper]("ucase"),
expression[UnBase64]("unbase64"),
expression[Unhex]("unhex"),
expression[Upper]("upper"),
expression[XPathList]("xpath"),
expression[XPathBoolean]("xpath_boolean"),
expression[XPathDouble]("xpath_double"),
expression[XPathDouble]("xpath_number"),
expression[XPathFloat]("xpath_float"),
expression[XPathInt]("xpath_int"),
expression[XPathLong]("xpath_long"),
expression[XPathShort]("xpath_short"),
expression[XPathString]("xpath_string"),
// datetime functions
expression[AddMonths]("add_months"),
expression[CurrentDate]("current_date"),
expression[CurrentTimestamp]("current_timestamp"),
expression[DateDiff]("datediff"),
expression[DateAdd]("date_add"),
expression[DateFormatClass]("date_format"),
expression[DateSub]("date_sub"),
expression[DayOfMonth]("day"),
expression[DayOfYear]("dayofyear"),
expression[DayOfMonth]("dayofmonth"),
expression[FromUnixTime]("from_unixtime"),
expression[FromUTCTimestamp]("from_utc_timestamp"),
expression[Hour]("hour"),
expression[LastDay]("last_day"),
expression[Minute]("minute"),
expression[Month]("month"),
expression[MonthsBetween]("months_between"),
expression[NextDay]("next_day"),
expression[CurrentTimestamp]("now"),
expression[Quarter]("quarter"),
expression[Second]("second"),
expression[ParseToTimestamp]("to_timestamp"),
expression[ParseToDate]("to_date"),
expression[ToUnixTimestamp]("to_unix_timestamp"),
expression[ToUTCTimestamp]("to_utc_timestamp"),
expression[TruncDate]("trunc"),
expression[TruncTimestamp]("date_trunc"),
expression[UnixTimestamp]("unix_timestamp"),
expression[DayOfWeek]("dayofweek"),
expression[WeekDay]("weekday"),
expression[WeekOfYear]("weekofyear"),
expression[Year]("year"),
expression[TimeWindow]("window"),
// collection functions
expression[CreateArray]("array"),
expression[ArrayContains]("array_contains"),
expression[ArraysOverlap]("arrays_overlap"),
expression[ArrayIntersect]("array_intersect"),
expression[ArrayJoin]("array_join"),
expression[ArrayPosition]("array_position"),
expression[ArraySort]("array_sort"),
expression[ArrayExcept]("array_except"),
expression[ArrayUnion]("array_union"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
expression[ElementAt]("element_at"),
expression[MapFromArrays]("map_from_arrays"),
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
expression[MapFromEntries]("map_from_entries"),
expression[MapConcat]("map_concat"),
expression[Size]("size"),
expression[Slice]("slice"),
expression[Size]("cardinality"),
expression[ArraysZip]("arrays_zip"),
expression[SortArray]("sort_array"),
expression[Shuffle]("shuffle"),
expression[ArrayMin]("array_min"),
expression[ArrayMax]("array_max"),
expression[Reverse]("reverse"),
expression[Concat]("concat"),
expression[Flatten]("flatten"),
expression[Sequence]("sequence"),
expression[ArrayRepeat]("array_repeat"),
expression[ArrayRemove]("array_remove"),
expression[ArrayDistinct]("array_distinct"),
expression[ArrayTransform]("transform"),
expression[ArrayFilter]("filter"),
expression[ArrayExists]("exists"),
expression[ArrayAggregate]("aggregate"),
expression[ZipWith]("zip_with"),
CreateStruct.registryEntry,
// misc functions
expression[AssertTrue]("assert_true"),
expression[Crc32]("crc32"),
expression[Md5]("md5"),
expression[Uuid]("uuid"),
expression[Murmur3Hash]("hash"),
expression[Sha1]("sha"),
expression[Sha1]("sha1"),
expression[Sha2]("sha2"),
expression[SparkPartitionID]("spark_partition_id"),
expression[InputFileName]("input_file_name"),
expression[InputFileBlockStart]("input_file_block_start"),
expression[InputFileBlockLength]("input_file_block_length"),
expression[MonotonicallyIncreasingID]("monotonically_increasing_id"),
expression[CurrentDatabase]("current_database"),
expression[CallMethodViaReflection]("reflect"),
expression[CallMethodViaReflection]("java_method"),
// grouping sets
expression[Cube]("cube"),
expression[Rollup]("rollup"),
expression[Grouping]("grouping"),
expression[GroupingID]("grouping_id"),
// window functions
expression[Lead]("lead"),
expression[Lag]("lag"),
expression[RowNumber]("row_number"),
expression[CumeDist]("cume_dist"),
expression[NTile]("ntile"),
expression[Rank]("rank"),
expression[DenseRank]("dense_rank"),
expression[PercentRank]("percent_rank"),
// predicates
expression[And]("and"),
expression[In]("in"),
expression[Not]("not"),
expression[Or]("or"),
// comparison operators
expression[EqualNullSafe]("<=>"),
expression[EqualTo]("="),
expression[EqualTo]("=="),
expression[GreaterThan](">"),
expression[GreaterThanOrEqual](">="),
expression[LessThan]("<"),
expression[LessThanOrEqual]("<="),
expression[Not]("!"),
// bitwise
expression[BitwiseAnd]("&"),
expression[BitwiseNot]("~"),
expression[BitwiseOr]("|"),
expression[BitwiseXor]("^"),
// json
expression[StructsToJson]("to_json"),
expression[JsonToStructs]("from_json"),
expression[SchemaOfJson]("schema_of_json"),
// cast
expression[Cast]("cast"),
// Cast aliases (SPARK-16730)
castAlias("boolean", BooleanType),
castAlias("tinyint", ByteType),
castAlias("smallint", ShortType),
castAlias("int", IntegerType),
castAlias("bigint", LongType),
castAlias("float", FloatType),
castAlias("double", DoubleType),
castAlias("decimal", DecimalType.USER_DEFAULT),
castAlias("date", DateType),
castAlias("timestamp", TimestampType),
castAlias("binary", BinaryType),
castAlias("string", StringType)
)
val builtin: SimpleFunctionRegistry = {
val fr = new SimpleFunctionRegistry
expressions.foreach {
case (name, (info, builder)) => fr.registerFunction(FunctionIdentifier(name), info, builder)
}
fr
}
val functionSet: Set[FunctionIdentifier] = builtin.listFunction().toSet
/** See usage above. */
private def expression[T <: Expression](name: String)
(implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = {
// For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main
// constructor and contains non-parameter `child` and should not be used as function builder.
val constructors = if (classOf[RuntimeReplaceable].isAssignableFrom(tag.runtimeClass)) {
val all = tag.runtimeClass.getConstructors
val maxNumArgs = all.map(_.getParameterCount).max
all.filterNot(_.getParameterCount == maxNumArgs)
} else {
tag.runtimeClass.getConstructors
}
// See if we can find a constructor that accepts Seq[Expression]
val varargCtor = constructors.find(_.getParameterTypes.toSeq == Seq(classOf[Seq[_]]))
val builder = (expressions: Seq[Expression]) => {
if (varargCtor.isDefined) {
// If there is an apply method that accepts Seq[Expression], use that one.
Try(varargCtor.get.newInstance(expressions).asInstanceOf[Expression]) match {
case Success(e) => e
case Failure(e) =>
// the exception is an invocation exception. To get a meaningful message, we need the
// cause.
throw new AnalysisException(e.getCause.getMessage)
}
} else {
// Otherwise, find a constructor method that matches the number of arguments, and use that.
val params = Seq.fill(expressions.size)(classOf[Expression])
val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse {
val validParametersCount = constructors
.filter(_.getParameterTypes.forall(_ == classOf[Expression]))
.map(_.getParameterCount).distinct.sorted
val expectedNumberOfParameters = if (validParametersCount.length == 1) {
validParametersCount.head.toString
} else {
validParametersCount.init.mkString("one of ", ", ", " and ") +
validParametersCount.last
}
throw new AnalysisException(s"Invalid number of arguments for function $name. " +
s"Expected: $expectedNumberOfParameters; Found: ${params.length}")
}
Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match {
case Success(e) => e
case Failure(e) =>
// the exception is an invocation exception. To get a meaningful message, we need the
// cause.
throw new AnalysisException(e.getCause.getMessage)
}
}
}
(name, (expressionInfo[T](name), builder))
}
/**
* Creates a function registry lookup entry for cast aliases (SPARK-16730).
* For example, if name is "int", and dataType is IntegerType, this means int(x) would become
* an alias for cast(x as IntegerType).
* See usage above.
*/
private def castAlias(
name: String,
dataType: DataType): (String, (ExpressionInfo, FunctionBuilder)) = {
val builder = (args: Seq[Expression]) => {
if (args.size != 1) {
throw new AnalysisException(s"Function $name accepts only one argument")
}
Cast(args.head, dataType)
}
val clazz = scala.reflect.classTag[Cast].runtimeClass
val usage = "_FUNC_(expr) - Casts the value `expr` to the target data type `_FUNC_`."
val expressionInfo =
new ExpressionInfo(clazz.getCanonicalName, null, name, usage, "", "", "", "")
(name, (expressionInfo, builder))
}
/**
* Creates an [[ExpressionInfo]] for the function as defined by expression T using the given name.
*/
private def expressionInfo[T <: Expression : ClassTag](name: String): ExpressionInfo = {
val clazz = scala.reflect.classTag[T].runtimeClass
val df = clazz.getAnnotation(classOf[ExpressionDescription])
if (df != null) {
if (df.extended().isEmpty) {
new ExpressionInfo(
clazz.getCanonicalName,
null,
name,
df.usage(),
df.arguments(),
df.examples(),
df.note(),
df.since())
} else {
// This exists for the backward compatibility with old `ExpressionDescription`s defining
// the extended description in `extended()`.
new ExpressionInfo(clazz.getCanonicalName, null, name, df.usage(), df.extended())
}
} else {
new ExpressionInfo(clazz.getCanonicalName, name)
}
}
private def expressionGeneratorOuter[T <: Generator : ClassTag](name: String)
: (String, (ExpressionInfo, FunctionBuilder)) = {
val (_, (info, generatorBuilder)) = expression[T](name)
val outerBuilder = (args: Seq[Expression]) => {
GeneratorOuter(generatorBuilder(args).asInstanceOf[Generator])
}
(name, (info, outerBuilder))
}
}