All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.execution.AggregatingAccumulator.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.spark.sql.execution

import scala.collection.mutable

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, BindReferences, Expression, InterpretedMutableProjection, InterpretedUnsafeProjection, JoinedRow, MutableProjection, NamedExpression, Projection, SpecificInternalRow}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate, ImperativeAggregate, NoOp, TypedImperativeAggregate}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.AccumulatorV2

/**
 * Accumulator that computes a global aggregate.
 */
class AggregatingAccumulator private(
    bufferSchema: Seq[DataType],
    initialValues: Seq[Expression],
    updateExpressions: Seq[Expression],
    @transient private val mergeExpressions: Seq[Expression],
    @transient private val resultExpressions: Seq[Expression],
    imperatives: Array[ImperativeAggregate],
    typedImperatives: Array[TypedImperativeAggregate[_]],
    @transient private val conf: SQLConf)
  extends AccumulatorV2[InternalRow, InternalRow] {
  assert(bufferSchema.size == initialValues.size)
  assert(bufferSchema.size == updateExpressions.size)
  assert(mergeExpressions == null || bufferSchema.size == mergeExpressions.size)

  @transient
  private var joinedRow: JoinedRow = _

  private var buffer: SpecificInternalRow = _

  private def createBuffer(): SpecificInternalRow = {
    val buffer = new SpecificInternalRow(bufferSchema)

    // Initialize the buffer. Note that we do not use a code generated projection here because
    // generating and compiling a projection is probably more expensive than using an interpreted
    // projection.
    InterpretedMutableProjection.createProjection(initialValues)
      .target(buffer)
      .apply(InternalRow.empty)
    imperatives.foreach(_.initialize(buffer))
    typedImperatives.foreach(_.initialize(buffer))
    buffer
  }

  private def getOrCreateBuffer(): SpecificInternalRow = {
    if (buffer == null) {
      buffer = createBuffer()

      // Create the joined row and set the buffer as its 'left' row.
      joinedRow = new JoinedRow()
      joinedRow.withLeft(buffer)
    }
    buffer
  }

  private def initializeProjection[T <: Projection](projection: T): T = {
    projection.initialize(TaskContext.getPartitionId())
    projection
  }

  @transient
  private[this] lazy val updateProjection = initializeProjection {
    MutableProjection.create(updateExpressions)
  }

  @transient
  private[this] lazy val mergeProjection = initializeProjection {
    InterpretedMutableProjection.createProjection(mergeExpressions)
  }

  @transient
  private[this] lazy val resultProjection = initializeProjection {
    InterpretedUnsafeProjection.createProjection(resultExpressions)
  }

  /**
   * Driver side operations like `merge` and `value` are executed in the DAGScheduler thread. This
   * thread does not have a SQL configuration so we attach our own here. Note that we can't (and
   * shouldn't) call `merge` or `value` on an accumulator originating from an executor so we just
   * return a default value here.
   */
  private[this] def withSQLConf[T](default: => T)(body: => T): T = {
    if (conf != null) {
      SQLConf.withExistingConf(conf)(body)
    } else {
      default
    }
  }

  override def reset(): Unit = {
    buffer = null
    joinedRow = null
  }

  override def isZero: Boolean = buffer == null

  override def copyAndReset(): AggregatingAccumulator = {
    new AggregatingAccumulator(
      bufferSchema,
      initialValues,
      updateExpressions,
      mergeExpressions,
      resultExpressions,
      imperatives,
      typedImperatives,
      conf)
  }

  override def copy(): AggregatingAccumulator = {
    val copy = copyAndReset()
    copy.merge(this)
    copy
  }

  override def add(v: InternalRow): Unit = {
    val buffer = getOrCreateBuffer()
    updateProjection.target(buffer)(joinedRow.withRight(v))
    var i = 0
    while (i < imperatives.length) {
      imperatives(i).update(buffer, v)
      i += 1
    }
    i = 0
    while (i < typedImperatives.length) {
      typedImperatives(i).update(buffer, v)
      i += 1
    }
  }

  override def merge(other: AccumulatorV2[InternalRow, InternalRow]): Unit = withSQLConf(()) {
    if (!other.isZero) {
      other match {
        case agg: AggregatingAccumulator =>
          val buffer = getOrCreateBuffer()
          val otherBuffer = agg.buffer
          mergeProjection.target(buffer)(joinedRow.withRight(otherBuffer))
          var i = 0
          while (i < imperatives.length) {
            imperatives(i).merge(buffer, otherBuffer)
            i += 1
          }
          i = 0
          while (i < typedImperatives.length) {
            typedImperatives(i).mergeBuffersObjects(buffer, otherBuffer)
            i += 1
          }
        case _ =>
          throw new UnsupportedOperationException(
            s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
      }
    }
  }

  override def value: InternalRow = withSQLConf(InternalRow.empty) {
    // Either use the existing buffer or create a temporary one.
    val input = if (!isZero) {
      buffer
    } else {
      // Create a temporary buffer because we want to avoid changing the state of the accumulator
      // here, which would happen if we called getOrCreateBuffer(). This is relatively expensive to
      // do but it should be no problem since this method is supposed to be called rarely (once per
      // query execution).
      createBuffer()
    }
    resultProjection(input)
  }

  /**
   * Get the output schema of the aggregating accumulator.
   */
  lazy val schema: StructType = {
    StructType(resultExpressions.zipWithIndex.map {
      case (e: NamedExpression, _) => StructField(e.name, e.dataType, e.nullable, e.metadata)
      case (e, i) => StructField(s"c_$i", e.dataType, e.nullable)
    })
  }

  /**
   * Set the state of the accumulator to the state of another accumulator. This is used in cases
   * where we only want to publish the state of the accumulator when the task completes, see
   * [[CollectMetricsExec]] for an example.
   */
  private[execution] def setState(other: AggregatingAccumulator): Unit = {
    assert(buffer == null || (buffer eq other.buffer))
    buffer = other.buffer
    joinedRow = other.joinedRow
  }
}

object AggregatingAccumulator {
  /**
   * Create an aggregating accumulator for the given functions and input schema.
   */
  def apply(functions: Seq[Expression], inputAttributes: Seq[Attribute]): AggregatingAccumulator = {
    // There are a couple of things happening here:
    // - Collect the schema's of the aggregate and input aggregate buffers. These are needed to bind
    //   the expressions which will be done when we create the accumulator.
    // - Collect the initialValues, update and merge expressions for declarative aggregate
    //   functions.
    // - Bind and Collect the imperative aggregate functions. Note that we insert NoOps into the
    //   (declarative) initialValues, update and merge expression buffers to keep these aligned with
    //   the aggregate buffer.
    // - Build the result expressions.
    val aggBufferAttributes = mutable.Buffer.empty[AttributeReference]
    val inputAggBufferAttributes = mutable.Buffer.empty[AttributeReference]
    val initialValues = mutable.Buffer.empty[Expression]
    val updateExpressions = mutable.Buffer.empty[Expression]
    val mergeExpressions = mutable.Buffer.empty[Expression]
    val imperatives = mutable.Buffer.empty[ImperativeAggregate]
    val typedImperatives = mutable.Buffer.empty[TypedImperativeAggregate[_]]
    val inputAttributeSeq: AttributeSeq = inputAttributes
    val resultExpressions = functions.map(_.transform {
      case AggregateExpression(agg: DeclarativeAggregate, _, _, _, _) =>
        aggBufferAttributes ++= agg.aggBufferAttributes
        inputAggBufferAttributes ++= agg.inputAggBufferAttributes
        initialValues ++= agg.initialValues
        updateExpressions ++= agg.updateExpressions
        mergeExpressions ++= agg.mergeExpressions
        agg.evaluateExpression
      case AggregateExpression(agg: ImperativeAggregate, _, _, _, _) =>
        val imperative = BindReferences.bindReference(agg
          .withNewMutableAggBufferOffset(aggBufferAttributes.size)
          .withNewInputAggBufferOffset(inputAggBufferAttributes.size),
          inputAttributeSeq)
        imperative match {
          case typedImperative: TypedImperativeAggregate[_] =>
            typedImperatives += typedImperative
          case _ =>
            imperatives += imperative
        }
        aggBufferAttributes ++= imperative.aggBufferAttributes
        inputAggBufferAttributes ++= agg.inputAggBufferAttributes
        val noOps = Seq.fill(imperative.aggBufferAttributes.size)(NoOp)
        initialValues ++= noOps
        updateExpressions ++= noOps
        mergeExpressions ++= noOps
        imperative
    })

    val updateAttrSeq: AttributeSeq = aggBufferAttributes ++ inputAttributes
    val mergeAttrSeq: AttributeSeq = aggBufferAttributes ++ inputAggBufferAttributes
    val aggBufferAttributesSeq: AttributeSeq = aggBufferAttributes

    // Create the accumulator.
    new AggregatingAccumulator(
      aggBufferAttributes.map(_.dataType),
      initialValues,
      updateExpressions.map(BindReferences.bindReference(_, updateAttrSeq)),
      mergeExpressions.map(BindReferences.bindReference(_, mergeAttrSeq)),
      resultExpressions.map(BindReferences.bindReference(_, aggBufferAttributesSeq)),
      imperatives.toArray,
      typedImperatives.toArray,
      SQLConf.get)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy