org.apache.spark.sql.catalyst.expressions.collectionOperations.scala Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import java.util.Comparator
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayData}
import org.apache.spark.sql.types._
/**
* Given an array or map, returns its size.
*/
case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType))
override def nullSafeEval(value: Any): Int = child.dataType match {
case _: ArrayType => value.asInstanceOf[ArrayData].numElements()
case _: MapType => value.asInstanceOf[MapData].numElements()
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).numElements();")
}
}
/**
* Sorts the input array in ascending / descending order according to the natural ordering of
* the array elements and returns it.
*/
case class SortArray(base: Expression, ascendingOrder: Expression)
extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
def this(e: Expression) = this(e, Literal(true))
override def left: Expression = base
override def right: Expression = ascendingOrder
override def dataType: DataType = base.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType)
override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
TypeCheckResult.TypeCheckSuccess
case ArrayType(dt, _) =>
TypeCheckResult.TypeCheckFailure(
s"$prettyName does not support sorting array of type ${dt.simpleString}")
case _ =>
TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
}
@transient
private lazy val lt: Comparator[Any] = {
val ordering = base.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
}
new Comparator[Any]() {
override def compare(o1: Any, o2: Any): Int = {
if (o1 == null && o2 == null) {
0
} else if (o1 == null) {
-1
} else if (o2 == null) {
1
} else {
ordering.compare(o1, o2)
}
}
}
}
@transient
private lazy val gt: Comparator[Any] = {
val ordering = base.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
}
new Comparator[Any]() {
override def compare(o1: Any, o2: Any): Int = {
if (o1 == null && o2 == null) {
0
} else if (o1 == null) {
1
} else if (o2 == null) {
-1
} else {
-ordering.compare(o1, o2)
}
}
}
}
override def nullSafeEval(array: Any, ascending: Any): Any = {
val elementType = base.dataType.asInstanceOf[ArrayType].elementType
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
if (elementType != NullType) {
java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt)
}
new GenericArrayData(data.asInstanceOf[Array[Any]])
}
override def prettyName: String = "sort_array"
}
/**
* Checks if the array (left) has the element (right)
*/
case class ArrayContains(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = BooleanType
override def inputTypes: Seq[AbstractDataType] = right.dataType match {
case NullType => Seq()
case _ => left.dataType match {
case n @ ArrayType(element, _) => Seq(n, element)
case _ => Seq()
}
}
override def checkInputDataTypes(): TypeCheckResult = {
if (right.dataType == NullType) {
TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments")
} else if (!left.dataType.isInstanceOf[ArrayType]
|| left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) {
TypeCheckResult.TypeCheckFailure(
"Arguments must be an array followed by a value of same type as the array members")
} else {
TypeCheckResult.TypeCheckSuccess
}
}
override def nullable: Boolean = {
left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull
}
override def nullSafeEval(arr: Any, value: Any): Any = {
var hasNull = false
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
if (v == null) {
hasNull = true
} else if (v == value) {
return true
}
)
if (hasNull) {
null
} else {
false
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (arr, value) => {
val i = ctx.freshName("i")
val getValue = ctx.getValue(arr, right.dataType, i)
s"""
for (int $i = 0; $i < $arr.numElements(); $i ++) {
if ($arr.isNullAt($i)) {
${ev.isNull} = true;
} else if (${ctx.genEqual(right.dataType, value, getValue)}) {
${ev.isNull} = false;
${ev.value} = true;
break;
}
}
"""
})
}
override def prettyName: String = "array_contains"
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy