
io.projectglow.sql.expressions.glueExpressions.scala Maven / Gradle / Ivy
/*
* Copyright 2019 The Glow Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.projectglow.sql.expressions
import io.projectglow.sql.util.{Rewrite, RewriteAfterResolution}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.SQLUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{UnresolvedException, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateNamedStruct, ExpectsInputTypes, Expression, Generator, GenericInternalRow, GetStructField, ImplicitCastInputTypes, Literal, NamedExpression, UnaryExpression, Unevaluable}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
/**
* Expands all the fields of a potentially unnamed struct.
*/
case class ExpandStruct(struct: Expression) extends Expression with Unevaluable {
override def children: Seq[Expression] = Seq(struct)
override lazy val resolved: Boolean = false
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
def expand(): Seq[NamedExpression] = {
if (!struct.dataType.isInstanceOf[StructType]) {
throw SQLUtils.newAnalysisException("Only structs can be expanded.")
}
struct.dataType.asInstanceOf[StructType].zipWithIndex.map {
case (f, i) =>
Alias(GetStructField(struct, i), f.name)()
}
}
}
case class SubsetStruct(struct: Expression, fields: Seq[Expression]) extends Rewrite {
override def children: Seq[Expression] = Seq(struct) ++ fields
override def rewrite: Expression = {
CreateNamedStruct(fields.flatMap(f => Seq(f, UnresolvedExtractValue(struct, f))))
}
}
/**
* Expression that adds fields to an existing struct.
*
* At optimization time, this expression is rewritten as the creation of new struct with all the
* fields of the existing struct as well as the new fields. See [[io.projectglow.sql.optimizer.ReplaceExpressionsRule]]
* for more details.
*/
case class AddStructFields(struct: Expression, newFields: Seq[Expression])
extends RewriteAfterResolution {
override def children: Seq[Expression] = struct +: newFields
override def rewrite: Expression = {
val baseType = struct.dataType.asInstanceOf[StructType]
val baseFields = baseType.indices.flatMap { idx =>
Seq(Literal(baseType(idx).name), GetStructField(struct, idx))
}
CreateNamedStruct(baseFields ++ newFields)
}
}
/**
* Explodes a matrix by row. Each row of the input matrix will be output as an array of doubles.
*
* If the input expression is null or has 0 rows, the output will be empty.
* @param matrixExpr The matrix to explode. May be dense or sparse.
*/
case class ExplodeMatrix(matrixExpr: Expression)
extends Generator
with CodegenFallback
with ExpectsInputTypes {
private val matrixUdt = SQLUtils.newMatrixUDT()
override def children: Seq[Expression] = Seq(matrixExpr)
override def elementSchema: StructType = {
new StructType()
.add("row", ArrayType(DoubleType, containsNull = false), nullable = false)
}
override def inputTypes = Seq(matrixUdt) // scalastyle:ignore
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
val matrixStruct = matrixExpr.eval(input)
if (matrixStruct == null) {
return Iterator.empty
}
val matrix = matrixUdt.deserialize(matrixStruct).toDenseRowMajor
var rowIdx = 0
new Iterator[InternalRow] {
override def hasNext: Boolean = rowIdx < matrix.numRows
override def next(): InternalRow = {
var colIdx = 0
val arr = new Array[Any](matrix.numCols)
while (colIdx < matrix.numCols) {
arr(colIdx) = matrix.values(rowIdx * matrix.numCols + colIdx)
colIdx += 1
}
rowIdx += 1
new GenericInternalRow(Array[Any](new GenericArrayData(arr)))
}
}
}
}
case class ArrayToSparseVector(child: Expression)
extends UnaryExpression
with ImplicitCastInputTypes {
override def inputTypes: Seq[SQLUtils.ADT] = Seq(ArrayType(DoubleType))
override def dataType: DataType = ArrayToSparseVector.vectorType
override def nullSafeEval(input: Any): Any = ArrayToSparseVector.fromDoubleArray(input)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(
ctx,
ev,
c => {
s"""
|${ev.value} =
|io.projectglow.sql.expressions.ArrayToSparseVector.fromDoubleArray($c);
""".stripMargin
}
)
}
}
object ArrayToSparseVector {
lazy val vectorType = SQLUtils.newVectorUDT()
def fromDoubleArray(input: Any): InternalRow = {
val vector = Vectors.dense(input.asInstanceOf[ArrayData].toDoubleArray())
vectorType.serialize(vector.toSparse)
}
}
case class ArrayToDenseVector(child: Expression)
extends UnaryExpression
with ImplicitCastInputTypes {
override def inputTypes: Seq[SQLUtils.ADT] = Seq(ArrayType(DoubleType))
override def dataType: DataType = ArrayToDenseVector.vectorType
override def nullSafeEval(input: Any): Any = ArrayToDenseVector.fromDoubleArray(input)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(
ctx,
ev,
c => {
s"""
|${ev.value} =
|io.projectglow.sql.expressions.ArrayToDenseVector.fromDoubleArray($c);
""".stripMargin
}
)
}
}
object ArrayToDenseVector {
private lazy val vectorType = SQLUtils.newVectorUDT()
def fromDoubleArray(input: Any): InternalRow = {
val vector = Vectors.dense(input.asInstanceOf[ArrayData].toDoubleArray())
vectorType.serialize(vector)
}
}
case class VectorToArray(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[SQLUtils.ADT] = Seq(VectorToArray.vectorType)
override def dataType: DataType = ArrayType(DoubleType)
override def nullSafeEval(input: Any): Any = VectorToArray.toDoubleArray(input)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
s"""
|${ev.value} =
|io.projectglow.sql.expressions.VectorToArray.toDoubleArray($c);
""".stripMargin
})
}
}
object VectorToArray {
lazy val vectorType = SQLUtils.newVectorUDT()
def toDoubleArray(input: Any): ArrayData = {
new GenericArrayData(vectorType.deserialize(input).toArray)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy