All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.projectglow.sql.expressions.AggregateByIndex.scala Maven / Gradle / Ivy

/*
 * Copyright 2019 The Glow Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.projectglow.sql.expressions

import org.apache.spark.sql.SQLUtils
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate}
import org.apache.spark.sql.types._

/**
 * An expression that allows users to aggregate over all array elements at a specific index in an
 * array column. For example, this expression can be used to compute per-sample summary statistics
 * from a genotypes column.
 *
 * The user must provide the following arguments:
 * - The array for aggregation
 * - The initialValue for each element in the per-index buffer
 * - An update function to update the buffer with a new element
 * - A merge function to combine two buffers
 *
 * The user may optionally provide an evaluate function. If it's not provided, the identity function
 * is used.
 *
 * Example usage to calculate average depth across all sites for a sample:
 * aggregate_by_index(
 *   genotypes,
 *   named_struct('sum', 0l, 'count', 0l),
 *   (buf, genotype) -> named_struct('sum', buf.sum + genotype.depth, 'count', buf.count + 1),
 *   (buf1, buf2) -> named_struct('sum', buf1.sum + buf2.sum, 'count', buf1.count + buf2.count),
 *   buf -> buf.sum / buf.count)
 */
trait AggregateByIndex extends DeclarativeAggregate with HigherOrderFunction {

  // Fields specific to AggregateByIndex that must be overridden by subclasses
  /** The array over which we're aggregating */
  def arr: Expression

  /** The initial value for each element in the aggregation buffer */
  def initialValue: Expression

  /** Function to update an element of the buffer with a new input element */
  def update: Expression

  /** Function to merge two buffers' elements */
  def merge: Expression

  /** Function to turn a buffer into the actual output */
  def evaluate: Expression

  def withBoundExprs(
      newUpdate: Expression,
      newMerge: Expression,
      newEvaluate: Expression): AggregateByIndex

  protected val buffer: AttributeReference = {
    AttributeReference("buffer", ArrayType(initialValue.dataType))()
  }

  // Overrides from [[Expression]]
  override def prettyName: String = "aggregate_by_index"
  override def nullable: Boolean = children.exists(_.nullable)
  override def dataType: DataType = ArrayType(evaluate.dataType)

  // Overrides from [[HigherOrderFunction]]
  override def functions: Seq[Expression] = Seq(update, merge, evaluate)
  override def functionTypes: Seq[SQLUtils.ADT] = {
    Seq(SQLUtils.anyDataType, SQLUtils.anyDataType, SQLUtils.anyDataType)
  }

  override def arguments: Seq[Expression] = Seq(arr)
  override def argumentTypes: Seq[SQLUtils.ADT] = Seq(arr.dataType)

  /**
   * To create bound lambda functions, we bind each of the child higher order functions,
   * extract their bound lambda functions, and then copy.
   */
  override def bind(
      f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): AggregateByIndex = {
    withBoundExprs(
      updateExpr.bind(f).function,
      mergeExpr.bind(f).function,
      evaluateExpr.bind(f).function
    )
  }

  // Overrides from [[DeclarativeAggregate]]
  override def aggBufferAttributes: Seq[AttributeReference] = Seq(buffer)

  private lazy val updateExpr = {
    ZipWith(bufferNotNull(buffer, arr), arr, update)
  }

  private lazy val mergeExpr = {
    val leftBuffer = bufferNotNull(buffer.left, buffer.right)
    val rightBuffer = bufferNotNull(buffer.right, buffer.left)

    ZipWith(leftBuffer, rightBuffer, merge)
  }

  private lazy val evaluateExpr = ArrayTransform(buffer, evaluate)

  override lazy val evaluateExpression: Expression = evaluateExpr

  override lazy val initialValues: Seq[Expression] = {
    Seq(Literal(null, ArrayType(initialValue.dataType)))
  }

  override lazy val mergeExpressions: Seq[Expression] = Seq(mergeExpr)

  override lazy val updateExpressions: Seq[Expression] = Seq(updateExpr)

  /**
   * A helper for accessing a buffer that may not yet be initialized.
   *
   * @param buf The maybe initialized buffer to return
   * @param array An input array that has the same length as the aggregation buffer
   * @return `buf` if it is not null, otherwise [[initialValue]] repeated size(`array`) times
   */
  private def bufferNotNull(buf: Expression, array: Expression): Expression = {
    If(buf.isNull, ArrayRepeat(initialValue, Size(array)), buf)
  }
}

/**
 * A hack to make Spark SQL recognize [[AggregateByIndex]] as an aggregate expression.
 *
 * See [[io.projectglow.sql.optimizer.ResolveAggregateFunctionsRule]] for details.
 */
trait UnwrappedAggregateFunction extends AggregateFunction {
  def asWrapped: AggregateFunction
}

case class UnwrappedAggregateByIndex(
    arr: Expression,
    initialValue: Expression,
    update: Expression,
    merge: Expression,
    evaluate: Expression)
    extends AggregateByIndex
    with UnwrappedAggregateFunction {

  def this(arr: Expression, initialValue: Expression, update: Expression, merge: Expression) = {
    this(arr, initialValue, update, merge, LambdaFunction.identity)
  }

  override def prettyName: String = "unwrapped_aggregate_by_index"

  override def withBoundExprs(
      newUpdate: Expression,
      newMerge: Expression,
      newEvaluate: Expression): AggregateByIndex = {

    copy(update = newUpdate, merge = newMerge, evaluate = newEvaluate)
  }

  override def asWrapped: AggregateFunction = {
    WrappedAggregateByIndex(arr, initialValue, update, merge, evaluate)
  }
}

case class WrappedAggregateByIndex(
    arr: Expression,
    initialValue: Expression,
    update: Expression,
    merge: Expression,
    evaluate: Expression = LambdaFunction.identity)
    extends AggregateByIndex {

  override def prettyName: String = "wrapped_agg_by"

  override def withBoundExprs(
      newUpdate: Expression,
      newMerge: Expression,
      newEvaluate: Expression): AggregateByIndex = {

    copy(update = newUpdate, merge = newMerge, evaluate = newEvaluate)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy