All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.sparkutils.quality.impl.aggregates.ExpressionAggregates.scala Maven / Gradle / Ivy

package com.sparkutils.quality.impl.aggregates

import com.sparkutils.quality.QualityException.qualityException
import eu.timepit.refined.boolean.False
import org.apache.spark.sql.QualitySparkUtils
import org.apache.spark.sql.ShimUtils.cast
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, If, LambdaFunction, Literal, NamedLambdaVariable}
import org.apache.spark.sql.qualityFunctions._
import org.apache.spark.sql.types.{DataType, DecimalType, MapType}
import org.apache.spark.sql.ShimUtils.{cast => castf}

object AggregateExpressions {

  def transformSumType(function: Expression, param: NamedLambdaVariable, newDT: DataType) =
    function.transform {
      case n: NamedLambdaVariable if n.exprId == param.exprId =>
        n.copy(dataType = newDT)
      // spark for decimals already gets the type wrong for 0.7.1 syntax type and double wraps it (only checking single breaks deprecated syntax), so
      // ugly workarounds for types not matching.. see above comment
      case cast: Cast if cast.child.isInstanceOf[Cast] && cast.child.asInstanceOf[Cast].child.isInstanceOf[NamedLambdaVariable] =>
        val nvl = cast.child.asInstanceOf[Cast].child.asInstanceOf[NamedLambdaVariable]
        if (nvl.exprId == param.exprId)
          castf(castf(nvl.copy(dataType = newDT), newDT), cast.dataType)
        else
          cast
      // for dbr > 11.2, the cast is on the variable not the expression
      case cast: Cast if cast.child.isInstanceOf[NamedLambdaVariable] =>
        val nvl = cast.child.asInstanceOf[NamedLambdaVariable]
        if (nvl.exprId == param.exprId)
          castf( child = cast.child.asInstanceOf[NamedLambdaVariable].
            copy(dataType = newDT), newDT)
        else
          cast

    }

  /**
   * Should only be created from within the RuleRunnerObject functions, not doing so will lead to unresolved functions and lambda variables - use sql strings only to construct.
   *
   * @param sumType when non-null it triggers a rewrite of both sum and evaluate expressions to set their parameter types to sumType, otherwise it takes those specified in the expressions already
   * @param ifExpr count if true/1 etc.
   * @param sum called lambda expression when ifExpr is true, with the current sum value and returns the new sum value (e.g. current -> current + col2 would total all col2's)
   * @param evaluate combines the count and overall result of combinations in lambda to provide the overall results (e.g. (current, count) -> current / count)
   * @param zero     lookup function to get the zero value for the sum type
   * @param notYetResolved additional casts are needed for the dsl due to decimal precision handling.
   *                 adding them in at the dsl level, however, prevents resolving the lambdas, using Column cast doesn't work either, so they are added in the expression itself
   *                 As the lambda's won't be resolved when calling this the existing matching for the expr sql variant cannot work.
   */
  def apply(sumType: DataType, ifExpr: Expression, sum: Expression, evaluate: Expression, zero: DataType => Option[Any], add: DataType => Option[( Expression, Expression ) => Expression], notYetResolved: Boolean = false ): Expression = {
    /*
     * in the case of decimal's being used the DecimalPrecision analysis can change the types such that the
     * precision is ignored e.g.
     * sumWith('DECIMAL(38,18)', entry -> entry + field)
     * may end up with 38,17 for the lambda despite 38,17 being the input as such this must be casted.
     */

    lazy val correctedSum: Expression =
      if (notYetResolved)
        sum
      else
        correctSum(sumType, sum)

    lazy val correctedEvaluate: Expression =
      if (notYetResolved)
        evaluate
      else
        correctEvaluate(sumType, evaluate)

    val (SeqArgs(sum1 +: _, _), SeqArgs(Seq(sum2, count), _), useSum, useEvaluate) =
      if (sumType eq null)
        (sum, evaluate, sum, evaluate)
      else {
        val temp = (correctedSum, correctedEvaluate, correctedSum, correctedEvaluate)
        temp
      }

    ExpressionAggregates(Seq(count, sum1, sum2, ifExpr, useSum, useEvaluate), zero, add, sumType).toAggregateExpression()
  }

  def correctEvaluate(sumType: DataType, evaluate: Expression, wrapCastOnly: Boolean = false) =
    evaluate match {
      case FunN(Seq(RefExpression(_, nullable, _), cref),
      LambdaFunction(function, Seq(sparam: NamedLambdaVariable, cparam: NamedLambdaVariable), hidden),
      name, _, _) =>

        val correctedFunction =
          if (wrapCastOnly)
            function
          else
            transformSumType(function, sparam, sumType)

        FunN(Seq(RefExpression(sumType, nullable), cref),
          LambdaFunction(correctedFunction,
            Seq(sparam.copy(dataType = sumType), cparam), hidden), name)

      // when we partially apply the lambda is further down
      case FunForward(Seq(RefExpression(_, nullable, paramIndex), cref) :+ (
        i@FunN(funparams, LambdaFunction(function, params, hidden), _, _, _)
        )) =>
        // params can be longer and not 1:1 due to placeholders
        val sparam = params(paramIndex).asInstanceOf[NamedLambdaVariable]
        val correctedFunction =
          if (wrapCastOnly)
            function
          else
            transformSumType(function, sparam, sumType)

        FunForward(Seq(RefExpression(sumType, nullable, paramIndex), cref) :+
          i.copy(function = LambdaFunction(correctedFunction,
            sparam.copy(dataType = sumType) +: params.drop(1), hidden),
            arguments = funparams.updated(paramIndex,
              funparams(paramIndex).asInstanceOf[RefExpression].copy(dataType = sumType)
            )))

      case _ => evaluate // shouldn't happen but better from spark than a match error
    }

  def correctSum(sumType: DataType, sum: Expression, wrapCastOnly: Boolean = false) =
    sum match {
      // re-pack with new type

      // for sumWith
      case FunN(Seq(RefExpression(_, nullable, _)),
      LambdaFunction(fun, Seq(param: NamedLambdaVariable), hidden),
      name, _, _) =>

        val correctedFunction =
          if (wrapCastOnly)
            fun
          else
            transformSumType(fun, param, sumType)

        FunN(Seq(RefExpression(sumType, nullable)),
          LambdaFunction(cast(correctedFunction, sumType), // see above comment for cast justification
            Seq(param.copy(dataType = sumType)), hidden), name)

      // for mapWith
      case MapTransform(r: RefExpression, key, LambdaFunction(function, Seq(param: NamedLambdaVariable), hidden), zeroF) =>

        val MapType(_, valueType, _) = sumType

        val correctedFunction =
          if (wrapCastOnly)
            function
          else
            transformSumType(function, param, valueType)

        MapTransform(RefExpression(sumType, r.nullable), key,
          LambdaFunction(cast(correctedFunction, valueType), // see above comment for cast justification
            Seq(param.copy(dataType = valueType)), hidden), zeroF)

      // when we partially apply the lambda is further down
      case FunForward(Seq(RefExpression(_, nullable, paramIndex)) :+ (
        i@FunN(funparams, LambdaFunction(function, params, hidden), _, _, _)
        )) =>
        // params can be longer and not 1:1 due to placeholders
        val sparam = params(paramIndex).asInstanceOf[NamedLambdaVariable]
        val correctedFunction =
          if (wrapCastOnly)
            function
          else
            transformSumType(function, sparam, sumType)

        FunForward(Seq(RefExpression(sumType, nullable, paramIndex)) :+
          i.copy(function = LambdaFunction(cast(correctedFunction, sumType), // see above comment for cast justification
            sparam.copy(dataType = sumType) +: params.drop(1), hidden),
            arguments = funparams.updated(paramIndex,
              funparams(paramIndex).asInstanceOf[RefExpression].copy(dataType = sumType)
            )))

      case _ => sum // not expected but better errors will come form Spark
    }
}

/**
 * Represents an aggregation expression built from a filter function, a sum lambda and an evaluate lambda which uses the
 * count of filter hits and the sum as parameters.
 * @param children
 */
case class ExpressionAggregates(override val children: Seq[Expression], zero: DataType => Option[Any], addF: DataType => Option[( Expression, Expression ) => Expression], sumType: DataType) extends DeclarativeAggregate {
  // extending higherorder fun is needed otherwise case other => other.failAnalysis( is thrown in analysis/higherOrderFunctions
  lazy val Seq(countLeaf: RefExpression, sumLeaf: RefExpression, evalSumLeaf: RefExpression, ifExpr, sumWith, evaluate) =
    if (children.forall(_.resolved))
      rewriteChildren(children)
    else
      children

  private def rewriteChildren(children: Seq[Expression]) = {
    import AggregateExpressions.{correctEvaluate, correctSum}
    val Seq(_, _, _, ifExpr, sumWith, evaluate) = children

    val correctedEvaluate = correctEvaluate(sumType, evaluate, wrapCastOnly = true)
    val correctedSum = correctSum(sumType, sumWith, wrapCastOnly = true)

    val (SeqArgs(sum1 +: _, _), SeqArgs(Seq(sum2, count), _), useSum, useEvaluate) =
      if (sumType eq null)
        (sumWith, evaluate, sumWith, evaluate)
      else {
        val temp = (correctedSum, correctedEvaluate, correctedSum, correctedEvaluate)
        temp
      }

    Seq(count, sum1, sum2, ifExpr, useSum, useEvaluate)
  }

  lazy val sumRef = AttributeReference("sum", sumLeaf.dataType, true)()
  lazy val countRef = AttributeReference("count", countLeaf.dataType)()

  lazy val sum = RefSetterExpression(Seq(sumLeaf, sumRef))
  lazy val sumEval = RefSetterExpression(Seq(evalSumLeaf, sumRef))
  lazy val count = RefSetterExpression(Seq(countLeaf, countRef))

  lazy val runEvaluate = RunAllReturnLast(Seq(sumEval, count, evaluate))
  lazy val sumWithEvaluate = RunAllReturnLast(Seq(sum, sumWith))

  lazy val add = addF(sumWith.dataType).getOrElse(qualityException(s"Cannot find the monoidal add for type ${sumWith.dataType}"))

  override lazy val initialValues: Seq[Expression] = Seq(
    Literal(0L), // count
    {
      val dt = sumLeaf.dataType
      val zeroE = zero(dt).getOrElse(qualityException(s"Could not find zero for type ${dt}"))
      Literal(zeroE, dt)
    }
  )

  override lazy val updateExpressions: Seq[Expression] = Seq(
    If(ifExpr, count + 1L, count),
    If(ifExpr, sumWithEvaluate, sum)
  )

  override lazy val mergeExpressions: Seq[Expression] = Seq(
    count.left + count.right,
    sumLeaf.dataType match {
      case _: DecimalType => cast(add(sum.left, sum.right), sumLeaf.dataType) // extra protection against SPARK-39316
      case _ => add(sum.left, sum.right)
    }
  )

  override lazy val evaluateExpression: Expression = runEvaluate

  override lazy val aggBufferAttributes: Seq[AttributeReference] = Seq(countRef, sumRef)

  override def dataType: DataType = evaluate.dataType

  override def nullable: Boolean = false

  implicit class RichAttribute(a: RefSetterExpression) {
    /** Represents this attribute at the mutable buffer side. */
    def left: RefSetterExpression = a

    /** Represents this attribute at the input buffer side (the data value is read-only). */
    def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a.from))
  }

  protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
    copy(children = newChildren)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy