All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.dimajix.spark.sql.expressions.CreateNullableStruct.scala Maven / Gradle / Ivy

There is a newer version: 1.2.0-synapse3.3-spark3.3-hadoop3.3
Show newest version
/*
 * Copyright (C) 2019 The Flowman Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.dimajix.spark.sql.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct
import org.apache.spark.sql.catalyst.expressions.EmptyRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.ExpressionDescription
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.NamePlaceholder
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.Metadata
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType


/**
  * Returns a Row containing the evaluation of all children expressions.
  */
object CreateNullableStruct extends FunctionBuilder {
    def apply(children: Seq[Expression]): CreateNullableNamedStruct = {
        CreateNullableNamedStruct(children.zipWithIndex.flatMap {
            //case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e)
            case (e: NamedExpression, _) if e.name.nonEmpty => Seq(Literal(e.name), e)
            case (e: NamedExpression, _) => Seq(NamePlaceholder, e)
            case (e, index) => Seq(Literal(s"col${index + 1}"), e)
        })
    }
}

/**
  * Creates a struct with the given field names and values
  *
  * @param children Seq(name1, val1, name2, val2, ...)
  */
// scalastyle:off line.size.limit
@ExpressionDescription(
    usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.",
    examples = """
    Examples:
      > SELECT _FUNC_("a", 1, "b", 2, "c", 3);
       {"a":1,"b":2,"c":3}
  """)
// scalastyle:on line.size.limit
case class CreateNullableNamedStruct(override val children: Seq[Expression]) extends Expression {
    lazy val (nameExprs, valExprs) = children.grouped(2).map {
        case Seq(name, value) => (name, value)
    }.toList.unzip

    lazy val names = nameExprs.map(_.eval(EmptyRow))

    override def nullable: Boolean = true

    override def foldable: Boolean = valExprs.forall(_.foldable)

    override lazy val dataType: StructType = {
        val fields = names.zip(valExprs).map {
            case (name, expr) =>
                val metadata = expr match {
                    case ne: NamedExpression => ne.metadata
                    case _ => Metadata.empty
                }
                StructField(name.toString, expr.dataType, expr.nullable, metadata)
        }
        StructType(fields)
    }

    override def checkInputDataTypes(): TypeCheckResult = {
        if (children.size % 2 != 0) {
            TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.")
        } else {
            val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType)
            if (invalidNames.nonEmpty) {
                TypeCheckResult.TypeCheckFailure(
                    s"Only foldable ${StringType.catalogString} expressions are allowed to appear at odd" +
                        s" position, got: ${invalidNames.mkString(",")}")
            } else if (!names.contains(null)) {
                TypeCheckResult.TypeCheckSuccess
            } else {
                TypeCheckResult.TypeCheckFailure("Field name should not be null")
            }
        }
    }

    /**
     * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this
     * StructType.
     */
    def flatten: Seq[NamedExpression] = valExprs.zip(names).map {
        case (v, n) => Alias(v, n.toString)()
    }

    override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
        val rowClass = classOf[GenericInternalRow].getName
        val values = ctx.freshName("values")
        val nonnull = ctx.freshName("nonnull")
        val evals = valExprs.zipWithIndex.map { case (e, i) =>
            val eval = e.genCode(ctx)
            s"""
               |${eval.code}
               |if (${eval.isNull}) {
               |  $values[$i] = null;
               |} else {
               |  $values[$i] = ${eval.value};
               |  $nonnull = true;
               |}
       """.stripMargin
        }
        val codes = ctx.splitExpressionsWithCurrentInputs(
            expressions = evals,
            funcName = "nullable_struct",
            returnType = "boolean",
            extraArguments = "Object[]" -> values :: "boolean" -> nonnull :: Nil,
            makeSplitFunction = body =>
                s"""
                   |do {
                   |  $body
                   |} while (false);
                   |return $nonnull;
                 """.stripMargin,
            foldFunctions = _.map { funcCall =>
                s"""
                   |$nonnull = $funcCall;
                 """.stripMargin
            }.mkString
        )

        import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
        val code =
            code"""
               |Object[] $values = new Object[${valExprs.size}];
               |boolean $nonnull = false;
               |do {
               |  $codes
               |} while (false);
               |
               |boolean ${ev.isNull} = !$nonnull;
               |InternalRow ${ev.value} = null;
               |if (!${ev.isNull}) {
               |  ${ev.value} = new $rowClass($values);
               |}
               |$values = null;
            """.stripMargin
        ev.copy(code = code)
    }

    override def prettyName: String = "named_nullable_struct"

    override def eval(input: InternalRow): Any = {
        val values = valExprs.map(_.eval(input))
        if (values.forall(_ == null))
            null
        else
            InternalRow(values: _*)
    }

    override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy