
org.apache.spark.sql.catalyst.expressions.higherOrderFunctions.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import java.util.Comparator
import java.util.concurrent.atomic.AtomicReference
import scala.collection.mutable
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedException}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods
/**
* A placeholder of lambda variables to prevent unexpected resolution of [[LambdaFunction]].
*/
case class UnresolvedNamedLambdaVariable(nameParts: Seq[String])
extends LeafExpression with NamedExpression with Unevaluable {
override def name: String =
nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier")
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")
override lazy val resolved = false
override def toString: String = s"lambda '$name"
override def sql: String = name
}
/**
* A named lambda variable.
*/
case class NamedLambdaVariable(
name: String,
dataType: DataType,
nullable: Boolean,
exprId: ExprId = NamedExpression.newExprId,
value: AtomicReference[Any] = new AtomicReference())
extends LeafExpression
with NamedExpression
with CodegenFallback {
override def qualifier: Seq[String] = Seq.empty
override def newInstance(): NamedExpression =
copy(exprId = NamedExpression.newExprId, value = new AtomicReference())
override def toAttribute: Attribute = {
AttributeReference(name, dataType, nullable, Metadata.empty)(exprId, Seq.empty)
}
override def eval(input: InternalRow): Any = value.get
override def toString: String = s"lambda $name#${exprId.id}$typeSuffix"
override def simpleString(maxFields: Int): String = {
s"lambda $name#${exprId.id}: ${dataType.simpleString(maxFields)}"
}
}
/**
* A lambda function and its arguments. A lambda function can be hidden when a user wants to
* process an completely independent expression in a [[HigherOrderFunction]], the lambda function
* and its variables are then only used for internal bookkeeping within the higher order function.
*/
case class LambdaFunction(
function: Expression,
arguments: Seq[NamedExpression],
hidden: Boolean = false)
extends Expression with CodegenFallback {
override def children: Seq[Expression] = function +: arguments
override def dataType: DataType = function.dataType
override def nullable: Boolean = function.nullable
lazy val bound: Boolean = arguments.forall(_.resolved)
override def eval(input: InternalRow): Any = function.eval(input)
}
object LambdaFunction {
val identity: LambdaFunction = {
val id = UnresolvedNamedLambdaVariable(Seq("id"))
LambdaFunction(id, Seq(id))
}
}
/**
* A higher order function takes one or more (lambda) functions and applies these to some objects.
* The function produces a number of variables which can be consumed by some lambda function.
*/
trait HigherOrderFunction extends Expression with ExpectsInputTypes {
override def nullable: Boolean = arguments.exists(_.nullable)
override def children: Seq[Expression] = arguments ++ functions
/**
* Arguments of the higher ordered function.
*/
def arguments: Seq[Expression]
def argumentTypes: Seq[AbstractDataType]
/**
* All arguments have been resolved. This means that the types and nullabilty of (most of) the
* lambda function arguments is known, and that we can start binding the lambda functions.
*/
lazy val argumentsResolved: Boolean = arguments.forall(_.resolved)
/**
* Checks the argument data types, returns `TypeCheckResult.success` if it's valid,
* or returns a `TypeCheckResult` with an error message if invalid.
* Note: it's not valid to call this method until `argumentsResolved == true`.
*/
def checkArgumentDataTypes(): TypeCheckResult = {
ExpectsInputTypes.checkInputDataTypes(arguments, argumentTypes)
}
/**
* Functions applied by the higher order function.
*/
def functions: Seq[Expression]
def functionTypes: Seq[AbstractDataType]
override def inputTypes: Seq[AbstractDataType] = argumentTypes ++ functionTypes
/**
* All inputs must be resolved and all functions must be resolved lambda functions.
*/
override lazy val resolved: Boolean = argumentsResolved && functions.forall {
case l: LambdaFunction => l.resolved
case _ => false
}
/**
* Bind the lambda functions to the [[HigherOrderFunction]] using the given bind function. The
* bind function takes the potential lambda and it's (partial) arguments and converts this into
* a bound lambda function.
*/
def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): HigherOrderFunction
// Make sure the lambda variables refer the same instances as of arguments for case that the
// variables in instantiated separately during serialization or for some reason.
@transient lazy val functionsForEval: Seq[Expression] = functions.map {
case LambdaFunction(function, arguments, hidden) =>
val argumentMap = arguments.map { arg => arg.exprId -> arg }.toMap
function.transformUp {
case variable: NamedLambdaVariable if argumentMap.contains(variable.exprId) =>
argumentMap(variable.exprId)
}
}
}
/**
* Trait for functions having as input one argument and one function.
*/
trait SimpleHigherOrderFunction extends HigherOrderFunction {
def argument: Expression
override def arguments: Seq[Expression] = argument :: Nil
def argumentType: AbstractDataType
override def argumentTypes(): Seq[AbstractDataType] = argumentType :: Nil
def function: Expression
override def functions: Seq[Expression] = function :: Nil
def functionType: AbstractDataType = AnyDataType
override def functionTypes: Seq[AbstractDataType] = functionType :: Nil
def functionForEval: Expression = functionsForEval.head
/**
* Called by [[eval]]. If a subclass keeps the default nullability, it can override this method
* in order to save null-check code.
*/
protected def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any =
sys.error(s"UnaryHigherOrderFunction must override either eval or nullSafeEval")
override def eval(inputRow: InternalRow): Any = {
val value = argument.eval(inputRow)
if (value == null) {
null
} else {
nullSafeEval(inputRow, value)
}
}
}
trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
override def argumentType: AbstractDataType = ArrayType
}
trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
override def argumentType: AbstractDataType = MapType
}
/**
* Transform elements in an array using the transform function. This is similar to
* a `map` in functional programming.
*/
@ExpressionDescription(
usage = "_FUNC_(expr, func) - Transforms elements in an array using the function.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), x -> x + 1);
[2,3,4]
> SELECT _FUNC_(array(1, 2, 3), (x, i) -> x + i);
[1,3,5]
""",
since = "2.4.0")
case class ArrayTransform(
argument: Expression,
function: Expression)
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
override def dataType: ArrayType = ArrayType(function.dataType, function.nullable)
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = {
val ArrayType(elementType, containsNull) = argument.dataType
function match {
case LambdaFunction(_, arguments, _) if arguments.size == 2 =>
copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil))
case _ =>
copy(function = f(function, (elementType, containsNull) :: Nil))
}
}
@transient lazy val (elementVar, indexVar) = {
val LambdaFunction(_, (elementVar: NamedLambdaVariable) +: tail, _) = function
val indexVar = if (tail.nonEmpty) {
Some(tail.head.asInstanceOf[NamedLambdaVariable])
} else {
None
}
(elementVar, indexVar)
}
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val arr = argumentValue.asInstanceOf[ArrayData]
val f = functionForEval
val result = new GenericArrayData(new Array[Any](arr.numElements))
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (indexVar.isDefined) {
indexVar.get.value.set(i)
}
result.update(i, f.eval(inputRow))
i += 1
}
result
}
override def prettyName: String = "transform"
}
/**
* Sorts elements in an array using a comparator function.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """_FUNC_(expr, func) - Sorts the input array. If func is omitted, sort
in ascending order. The elements of the input array must be orderable. Null elements
will be placed at the end of the returned array. Since 3.0.0 this function also sorts
and returns the array based on the given comparator function. The comparator will
take two arguments representing two elements of the array.
It returns -1, 0, or 1 as the first element is less than, equal to, or greater
than the second element. If the comparator function returns other
values (including null), the function will fail and raise an error.
""",
examples = """
Examples:
> SELECT _FUNC_(array(5, 6, 1), (left, right) -> case when left < right then -1 when left > right then 1 else 0 end);
[1,5,6]
> SELECT _FUNC_(array('bc', 'ab', 'dc'), (left, right) -> case when left is null and right is null then 0 when left is null then -1 when right is null then 1 when left < right then 1 when left > right then -1 else 0 end);
["dc","bc","ab"]
> SELECT _FUNC_(array('b', 'd', null, 'c', 'a'));
["a","b","c","d",null]
""",
since = "2.4.0")
// scalastyle:on line.size.limit
case class ArraySort(
argument: Expression,
function: Expression)
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
def this(argument: Expression) = this(argument, ArraySort.defaultComparator)
@transient lazy val elementType: DataType =
argument.dataType.asInstanceOf[ArrayType].elementType
override def dataType: ArrayType = argument.dataType.asInstanceOf[ArrayType]
override def checkInputDataTypes(): TypeCheckResult = {
checkArgumentDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
argument.dataType match {
case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
if (function.dataType == IntegerType) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure("Return type of the given function has to be " +
"IntegerType")
}
case ArrayType(dt, _) =>
val dtSimple = dt.catalogString
TypeCheckResult.TypeCheckFailure(
s"$prettyName does not support sorting array of type $dtSimple which is not " +
"orderable")
case _ =>
TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
}
case failure => failure
}
}
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArraySort = {
val ArrayType(elementType, containsNull) = argument.dataType
copy(function =
f(function, (elementType, containsNull) :: (elementType, containsNull) :: Nil))
}
@transient lazy val LambdaFunction(_,
Seq(firstElemVar: NamedLambdaVariable, secondElemVar: NamedLambdaVariable), _) = function
def comparator(inputRow: InternalRow): Comparator[Any] = {
val f = functionForEval
(o1: Any, o2: Any) => {
firstElemVar.value.set(o1)
secondElemVar.value.set(o2)
f.eval(inputRow).asInstanceOf[Int]
}
}
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val arr = argumentValue.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
if (elementType != NullType) {
java.util.Arrays.sort(arr, comparator(inputRow))
}
new GenericArrayData(arr.asInstanceOf[Array[Any]])
}
override def prettyName: String = "array_sort"
}
object ArraySort {
def comparator(left: Expression, right: Expression): Expression = {
val lit0 = Literal(0)
val lit1 = Literal(1)
val litm1 = Literal(-1)
If(And(IsNull(left), IsNull(right)), lit0,
If(IsNull(left), lit1, If(IsNull(right), litm1,
If(LessThan(left, right), litm1, If(GreaterThan(left, right), lit1, lit0)))))
}
val defaultComparator: LambdaFunction = {
val left = UnresolvedNamedLambdaVariable(Seq("left"))
val right = UnresolvedNamedLambdaVariable(Seq("right"))
LambdaFunction(comparator(left, right), Seq(left, right))
}
}
/**
* Filters entries in a map using the provided function.
*/
@ExpressionDescription(
usage = "_FUNC_(expr, func) - Filters entries in a map using the function.",
examples = """
Examples:
> SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v);
{1:0,3:-1}
""",
since = "3.0.0")
case class MapFilter(
argument: Expression,
function: Expression)
extends MapBasedSimpleHigherOrderFunction with CodegenFallback {
@transient lazy val (keyVar, valueVar) = {
val args = function.asInstanceOf[LambdaFunction].arguments
(args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable])
}
@transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = {
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
}
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val m = argumentValue.asInstanceOf[MapData]
val f = functionForEval
val retKeys = new mutable.ListBuffer[Any]
val retValues = new mutable.ListBuffer[Any]
m.foreach(keyType, valueType, (k, v) => {
keyVar.value.set(k)
valueVar.value.set(v)
if (f.eval(inputRow).asInstanceOf[Boolean]) {
retKeys += k
retValues += v
}
})
ArrayBasedMapData(retKeys.toArray, retValues.toArray)
}
override def dataType: DataType = argument.dataType
override def functionType: AbstractDataType = BooleanType
override def prettyName: String = "map_filter"
}
/**
* Filters the input array using the given lambda function.
*/
@ExpressionDescription(
usage = "_FUNC_(expr, func) - Filters the input array using the given predicate.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1);
[1,3]
> SELECT _FUNC_(array(0, 2, 3), (x, i) -> x > i);
[2,3]
""",
since = "2.4.0",
note = """
The inner function may use the index argument since 3.0.0.
""")
case class ArrayFilter(
argument: Expression,
function: Expression)
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
override def dataType: DataType = argument.dataType
override def functionType: AbstractDataType = BooleanType
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = {
val ArrayType(elementType, containsNull) = argument.dataType
function match {
case LambdaFunction(_, arguments, _) if arguments.size == 2 =>
copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil))
case _ =>
copy(function = f(function, (elementType, containsNull) :: Nil))
}
}
@transient lazy val (elementVar, indexVar) = {
val LambdaFunction(_, (elementVar: NamedLambdaVariable) +: tail, _) = function
val indexVar = tail.headOption.map(_.asInstanceOf[NamedLambdaVariable])
(elementVar, indexVar)
}
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val arr = argumentValue.asInstanceOf[ArrayData]
val f = functionForEval
val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (indexVar.isDefined) {
indexVar.get.value.set(i)
}
if (f.eval(inputRow).asInstanceOf[Boolean]) {
buffer += elementVar.value.get
}
i += 1
}
new GenericArrayData(buffer)
}
override def prettyName: String = "filter"
}
/**
* Tests whether a predicate holds for one or more elements in the array.
*/
@ExpressionDescription(usage =
"_FUNC_(expr, pred) - Tests whether a predicate holds for one or more elements in the array.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 0);
true
> SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 10);
false
> SELECT _FUNC_(array(1, null, 3), x -> x % 2 == 0);
NULL
""",
since = "2.4.0")
case class ArrayExists(
argument: Expression,
function: Expression,
followThreeValuedLogic: Boolean)
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
def this(argument: Expression, function: Expression) = {
this(
argument,
function,
SQLConf.get.getConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC))
}
override def stringArgs: Iterator[Any] = super.stringArgs.take(2)
override def nullable: Boolean =
if (followThreeValuedLogic) {
super.nullable || function.nullable
} else {
super.nullable
}
override def dataType: DataType = BooleanType
override def functionType: AbstractDataType = BooleanType
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayExists = {
val ArrayType(elementType, containsNull) = argument.dataType
copy(function = f(function, (elementType, containsNull) :: Nil))
}
@transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val arr = argumentValue.asInstanceOf[ArrayData]
val f = functionForEval
var exists = false
var foundNull = false
var i = 0
while (i < arr.numElements && !exists) {
elementVar.value.set(arr.get(i, elementVar.dataType))
val ret = f.eval(inputRow)
if (ret == null) {
foundNull = true
} else if (ret.asInstanceOf[Boolean]) {
exists = true
}
i += 1
}
if (exists) {
true
} else if (followThreeValuedLogic && foundNull) {
null
} else {
false
}
}
override def prettyName: String = "exists"
}
object ArrayExists {
def apply(argument: Expression, function: Expression): ArrayExists = {
new ArrayExists(argument, function)
}
}
/**
* Tests whether a predicate holds for all elements in the array.
*/
@ExpressionDescription(usage =
"_FUNC_(expr, pred) - Tests whether a predicate holds for all elements in the array.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 0);
false
> SELECT _FUNC_(array(2, 4, 8), x -> x % 2 == 0);
true
> SELECT _FUNC_(array(1, null, 3), x -> x % 2 == 0);
false
> SELECT _FUNC_(array(2, null, 8), x -> x % 2 == 0);
NULL
""",
since = "3.0.0")
case class ArrayForAll(
argument: Expression,
function: Expression)
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
override def nullable: Boolean =
super.nullable || function.nullable
override def dataType: DataType = BooleanType
override def functionType: AbstractDataType = BooleanType
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayForAll = {
val ArrayType(elementType, containsNull) = argument.dataType
copy(function = f(function, (elementType, containsNull) :: Nil))
}
@transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function
/*
* true for all non null elements foundNull result
* F F F
* F T F
* T F T
* T T N
*/
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val arr = argumentValue.asInstanceOf[ArrayData]
val f = functionForEval
var forall = true
var foundNull = false
var i = 0
while (i < arr.numElements && forall) {
elementVar.value.set(arr.get(i, elementVar.dataType))
val ret = f.eval(inputRow)
if (ret == null) {
foundNull = true
} else if (!ret.asInstanceOf[Boolean]) {
forall = false
}
i += 1
}
if (foundNull && forall) {
null
} else {
forall
}
}
override def prettyName: String = "forall"
}
/**
* Applies a binary operator to a start value and all elements in the array.
*/
@ExpressionDescription(
usage =
"""
_FUNC_(expr, start, merge, finish) - Applies a binary operator to an initial state and all
elements in the array, and reduces this to a single state. The final state is converted
into the final result by applying a finish function.
""",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), 0, (acc, x) -> acc + x);
6
> SELECT _FUNC_(array(1, 2, 3), 0, (acc, x) -> acc + x, acc -> acc * 10);
60
""",
since = "2.4.0")
case class ArrayAggregate(
argument: Expression,
zero: Expression,
merge: Expression,
finish: Expression)
extends HigherOrderFunction with CodegenFallback {
def this(argument: Expression, zero: Expression, merge: Expression) = {
this(argument, zero, merge, LambdaFunction.identity)
}
override def arguments: Seq[Expression] = argument :: zero :: Nil
override def argumentTypes: Seq[AbstractDataType] = ArrayType :: AnyDataType :: Nil
override def functions: Seq[Expression] = merge :: finish :: Nil
override def functionTypes: Seq[AbstractDataType] = zero.dataType :: AnyDataType :: Nil
override def nullable: Boolean = argument.nullable || finish.nullable
override def dataType: DataType = finish.dataType
override def checkInputDataTypes(): TypeCheckResult = {
checkArgumentDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
if (!DataType.equalsStructurally(
zero.dataType, merge.dataType, ignoreNullability = true)) {
TypeCheckResult.TypeCheckFailure(
s"argument 3 requires ${zero.dataType.simpleString} type, " +
s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.")
} else {
TypeCheckResult.TypeCheckSuccess
}
case failure => failure
}
}
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayAggregate = {
// Be very conservative with nullable. We cannot be sure that the accumulator does not
// evaluate to null. So we always set nullable to true here.
val ArrayType(elementType, containsNull) = argument.dataType
val acc = zero.dataType -> true
val newMerge = f(merge, acc :: (elementType, containsNull) :: Nil)
val newFinish = f(finish, acc :: Nil)
copy(merge = newMerge, finish = newFinish)
}
@transient lazy val LambdaFunction(_,
Seq(accForMergeVar: NamedLambdaVariable, elementVar: NamedLambdaVariable), _) = merge
@transient lazy val LambdaFunction(_, Seq(accForFinishVar: NamedLambdaVariable), _) = finish
override def eval(input: InternalRow): Any = {
val arr = argument.eval(input).asInstanceOf[ArrayData]
if (arr == null) {
null
} else {
val Seq(mergeForEval, finishForEval) = functionsForEval
accForMergeVar.value.set(zero.eval(input))
var i = 0
while (i < arr.numElements()) {
elementVar.value.set(arr.get(i, elementVar.dataType))
accForMergeVar.value.set(mergeForEval.eval(input))
i += 1
}
accForFinishVar.value.set(accForMergeVar.value.get)
finishForEval.eval(input)
}
}
override def prettyName: String = "aggregate"
}
/**
* Transform Keys for every entry of the map by applying the transform_keys function.
* Returns map with transformed key entries
*/
@ExpressionDescription(
usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.",
examples = """
Examples:
> SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + 1);
{2:1,3:2,4:3}
> SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v);
{2:1,4:2,6:3}
""",
since = "3.0.0")
case class TransformKeys(
argument: Expression,
function: Expression)
extends MapBasedSimpleHigherOrderFunction with CodegenFallback {
@transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType
override def dataType: MapType = MapType(function.dataType, valueType, valueContainsNull)
override def checkInputDataTypes(): TypeCheckResult = {
TypeUtils.checkForMapKeyType(function.dataType)
}
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): TransformKeys = {
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
}
@transient lazy val LambdaFunction(
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function
private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType)
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val map = argumentValue.asInstanceOf[MapData]
val resultKeys = new GenericArrayData(new Array[Any](map.numElements))
var i = 0
while (i < map.numElements) {
keyVar.value.set(map.keyArray().get(i, keyVar.dataType))
valueVar.value.set(map.valueArray().get(i, valueVar.dataType))
val result = functionForEval.eval(inputRow)
resultKeys.update(i, result)
i += 1
}
mapBuilder.from(resultKeys, map.valueArray())
}
override def prettyName: String = "transform_keys"
}
/**
* Returns a map that applies the function to each value of the map.
*/
@ExpressionDescription(
usage = "_FUNC_(expr, func) - Transforms values in the map using the function.",
examples = """
Examples:
> SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1);
{1:2,2:3,3:4}
> SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v);
{1:2,2:4,3:6}
""",
since = "3.0.0")
case class TransformValues(
argument: Expression,
function: Expression)
extends MapBasedSimpleHigherOrderFunction with CodegenFallback {
@transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType
override def dataType: DataType = MapType(keyType, function.dataType, function.nullable)
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction)
: TransformValues = {
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
}
@transient lazy val LambdaFunction(
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val map = argumentValue.asInstanceOf[MapData]
val resultValues = new GenericArrayData(new Array[Any](map.numElements))
var i = 0
while (i < map.numElements) {
keyVar.value.set(map.keyArray().get(i, keyVar.dataType))
valueVar.value.set(map.valueArray().get(i, valueVar.dataType))
resultValues.update(i, functionForEval.eval(inputRow))
i += 1
}
new ArrayBasedMapData(map.keyArray(), resultValues)
}
override def prettyName: String = "transform_values"
}
/**
* Merges two given maps into a single map by applying function to the pair of values with
* the same key.
*/
@ExpressionDescription(
usage =
"""
_FUNC_(map1, map2, function) - Merges two given maps into a single map by applying
function to the pair of values with the same key. For keys only presented in one map,
NULL will be passed as the value for the missing key. If an input map contains duplicated
keys, only the first entry of the duplicated key is passed into the lambda function.
""",
examples = """
Examples:
> SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2));
{1:"ax",2:"by"}
""",
since = "3.0.0")
case class MapZipWith(left: Expression, right: Expression, function: Expression)
extends HigherOrderFunction with CodegenFallback {
def functionForEval: Expression = functionsForEval.head
@transient lazy val MapType(leftKeyType, leftValueType, leftValueContainsNull) = left.dataType
@transient lazy val MapType(rightKeyType, rightValueType, rightValueContainsNull) = right.dataType
@transient lazy val keyType =
TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(leftKeyType, rightKeyType).get
@transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType)
override def arguments: Seq[Expression] = left :: right :: Nil
override def argumentTypes: Seq[AbstractDataType] = MapType :: MapType :: Nil
override def functions: Seq[Expression] = function :: Nil
override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil
override def dataType: DataType = MapType(keyType, function.dataType, function.nullable)
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapZipWith = {
val arguments = Seq((keyType, false), (leftValueType, true), (rightValueType, true))
copy(function = f(function, arguments))
}
override def checkArgumentDataTypes(): TypeCheckResult = {
super.checkArgumentDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
if (leftKeyType.sameType(rightKeyType)) {
TypeUtils.checkForOrderingExpr(leftKeyType, s"function $prettyName")
} else {
TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " +
s"been two ${MapType.simpleString}s with compatible key types, but the key types are " +
s"[${leftKeyType.catalogString}, ${rightKeyType.catalogString}].")
}
case failure => failure
}
}
override def checkInputDataTypes(): TypeCheckResult = checkArgumentDataTypes()
override def eval(input: InternalRow): Any = {
val value1 = left.eval(input)
if (value1 == null) {
null
} else {
val value2 = right.eval(input)
if (value2 == null) {
null
} else {
nullSafeEval(input, value1, value2)
}
}
}
@transient lazy val LambdaFunction(_, Seq(
keyVar: NamedLambdaVariable,
value1Var: NamedLambdaVariable,
value2Var: NamedLambdaVariable),
_) = function
/**
* The function accepts two key arrays and returns a collection of keys with indexes
* to value arrays. Indexes are represented as an array of two items. This is a small
* optimization leveraging mutability of arrays.
*/
@transient private lazy val getKeysWithValueIndexes:
(ArrayData, ArrayData) => mutable.Iterable[(Any, Array[Option[Int]])] = {
if (TypeUtils.typeWithProperEquals(keyType)) {
getKeysWithIndexesFast
} else {
getKeysWithIndexesBruteForce
}
}
private def assertSizeOfArrayBuffer(size: Int): Unit = {
if (size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful try to zip maps with $size " +
s"unique keys due to exceeding the array size limit " +
s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
}
}
private def getKeysWithIndexesFast(keys1: ArrayData, keys2: ArrayData) = {
val hashMap = new mutable.LinkedHashMap[Any, Array[Option[Int]]]
for((z, array) <- Array((0, keys1), (1, keys2))) {
var i = 0
while (i < array.numElements()) {
val key = array.get(i, keyType)
hashMap.get(key) match {
case Some(indexes) =>
if (indexes(z).isEmpty) {
indexes(z) = Some(i)
}
case None =>
val indexes = Array[Option[Int]](None, None)
indexes(z) = Some(i)
hashMap.put(key, indexes)
}
i += 1
}
}
hashMap
}
private def getKeysWithIndexesBruteForce(keys1: ArrayData, keys2: ArrayData) = {
val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])]
for((z, array) <- Array((0, keys1), (1, keys2))) {
var i = 0
while (i < array.numElements()) {
val key = array.get(i, keyType)
var found = false
var j = 0
while (!found && j < arrayBuffer.size) {
val (bufferKey, indexes) = arrayBuffer(j)
if (ordering.equiv(bufferKey, key)) {
found = true
if(indexes(z).isEmpty) {
indexes(z) = Some(i)
}
}
j += 1
}
if (!found) {
assertSizeOfArrayBuffer(arrayBuffer.size)
val indexes = Array[Option[Int]](None, None)
indexes(z) = Some(i)
arrayBuffer += Tuple2(key, indexes)
}
i += 1
}
}
arrayBuffer
}
private def nullSafeEval(inputRow: InternalRow, value1: Any, value2: Any): Any = {
val mapData1 = value1.asInstanceOf[MapData]
val mapData2 = value2.asInstanceOf[MapData]
val keysWithIndexes = getKeysWithValueIndexes(mapData1.keyArray(), mapData2.keyArray())
val size = keysWithIndexes.size
val keys = new GenericArrayData(new Array[Any](size))
val values = new GenericArrayData(new Array[Any](size))
val valueData1 = mapData1.valueArray()
val valueData2 = mapData2.valueArray()
var i = 0
for ((key, Array(index1, index2)) <- keysWithIndexes) {
val v1 = index1.map(valueData1.get(_, leftValueType)).getOrElse(null)
val v2 = index2.map(valueData2.get(_, rightValueType)).getOrElse(null)
keyVar.value.set(key)
value1Var.value.set(v1)
value2Var.value.set(v2)
keys.update(i, key)
values.update(i, functionForEval.eval(inputRow))
i += 1
}
new ArrayBasedMapData(keys, values)
}
override def prettyName: String = "map_zip_with"
}
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(left, right, func) - Merges the two given arrays, element-wise, into a single array using function. If one array is shorter, nulls are appended at the end to match the length of the longer array, before applying function.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), array('a', 'b', 'c'), (x, y) -> (y, x));
[{"y":"a","x":1},{"y":"b","x":2},{"y":"c","x":3}]
> SELECT _FUNC_(array(1, 2), array(3, 4), (x, y) -> x + y);
[4,6]
> SELECT _FUNC_(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y));
["ad","be","cf"]
""",
since = "2.4.0")
// scalastyle:on line.size.limit
case class ZipWith(left: Expression, right: Expression, function: Expression)
extends HigherOrderFunction with CodegenFallback {
def functionForEval: Expression = functionsForEval.head
override def arguments: Seq[Expression] = left :: right :: Nil
override def argumentTypes: Seq[AbstractDataType] = ArrayType :: ArrayType :: Nil
override def functions: Seq[Expression] = List(function)
override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil
override def dataType: ArrayType = ArrayType(function.dataType, function.nullable)
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ZipWith = {
val ArrayType(leftElementType, _) = left.dataType
val ArrayType(rightElementType, _) = right.dataType
copy(function = f(function,
(leftElementType, true) :: (rightElementType, true) :: Nil))
}
@transient lazy val LambdaFunction(_,
Seq(leftElemVar: NamedLambdaVariable, rightElemVar: NamedLambdaVariable), _) = function
override def eval(input: InternalRow): Any = {
val leftArr = left.eval(input).asInstanceOf[ArrayData]
if (leftArr == null) {
null
} else {
val rightArr = right.eval(input).asInstanceOf[ArrayData]
if (rightArr == null) {
null
} else {
val resultLength = math.max(leftArr.numElements(), rightArr.numElements())
val f = functionForEval
val result = new GenericArrayData(new Array[Any](resultLength))
var i = 0
while (i < resultLength) {
if (i < leftArr.numElements()) {
leftElemVar.value.set(leftArr.get(i, leftElemVar.dataType))
} else {
leftElemVar.value.set(null)
}
if (i < rightArr.numElements()) {
rightElemVar.value.set(rightArr.get(i, rightElemVar.dataType))
} else {
rightElemVar.value.set(null)
}
result.update(i, f.eval(input))
i += 1
}
result
}
}
}
override def prettyName: String = "zip_with"
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy