Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.aggregate
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
/**
* The base class of [[SortBasedAggregationIterator]] and [[TungstenAggregationIterator]].
* It mainly contains two parts:
* 1. It initializes aggregate functions.
* 2. It creates two functions, `processRow` and `generateOutput` based on [[AggregateMode]] of
* its aggregate functions. `processRow` is the function to handle an input. `generateOutput`
* is used to generate result.
*/
abstract class AggregationIterator(
partIndex: Int,
groupingExpressions: Seq[NamedExpression],
inputAttributes: Seq[Attribute],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection)
extends Iterator[UnsafeRow] with Logging {
///////////////////////////////////////////////////////////////////////////
// Initializing functions.
///////////////////////////////////////////////////////////////////////////
/**
* The following combinations of AggregationMode are supported:
* - Partial
* - PartialMerge (for single distinct)
* - Partial and PartialMerge (for single distinct)
* - Final
* - Complete (for SortAggregate with functions that does not support Partial)
* - Final and Complete (currently not used)
*
* TODO: AggregateMode should have only two modes: Update and Merge, AggregateExpression
* could have a flag to tell it's final or not.
*/
{
val modes = aggregateExpressions.map(_.mode).distinct.toSet
require(modes.size <= 2,
s"$aggregateExpressions are not supported because they have more than 2 distinct modes.")
require(modes.subsetOf(Set(Partial, PartialMerge)) || modes.subsetOf(Set(Final, Complete)),
s"$aggregateExpressions can't have Partial/PartialMerge and Final/Complete in the same time.")
}
// Initialize all AggregateFunctions by binding references if necessary,
// and set inputBufferOffset and mutableBufferOffset.
protected def initializeAggregateFunctions(
expressions: Seq[AggregateExpression],
startingInputBufferOffset: Int): Array[AggregateFunction] = {
var mutableBufferOffset = 0
var inputBufferOffset: Int = startingInputBufferOffset
val expressionsLength = expressions.length
val functions = new Array[AggregateFunction](expressionsLength)
var i = 0
while (i < expressionsLength) {
val func = expressions(i).aggregateFunction
val funcWithBoundReferences: AggregateFunction = expressions(i).mode match {
case Partial | Complete if func.isInstanceOf[ImperativeAggregate] =>
// We need to create BoundReferences if the function is not an
// expression-based aggregate function (it does not support code-gen) and the mode of
// this function is Partial or Complete because we will call eval of this
// function's children in the update method of this aggregate function.
// Those eval calls require BoundReferences to work.
BindReferences.bindReference(func, inputAttributes)
case _ =>
// We only need to set inputBufferOffset for aggregate functions with mode
// PartialMerge and Final.
val updatedFunc = func match {
case function: ImperativeAggregate =>
function.withNewInputAggBufferOffset(inputBufferOffset)
case function => function
}
inputBufferOffset += func.aggBufferSchema.length
updatedFunc
}
val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match {
case function: ImperativeAggregate =>
// Set mutableBufferOffset for this function. It is important that setting
// mutableBufferOffset happens after all potential bindReference operations
// because bindReference will create a new instance of the function.
function.withNewMutableAggBufferOffset(mutableBufferOffset)
case function => function
}
mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length
functions(i) = funcWithUpdatedAggBufferOffset
i += 1
}
functions
}
protected val aggregateFunctions: Array[AggregateFunction] =
initializeAggregateFunctions(aggregateExpressions, initialInputBufferOffset)
// Positions of those imperative aggregate functions in allAggregateFunctions.
// For example, we have func1, func2, func3, func4 in aggregateFunctions, and
// func2 and func3 are imperative aggregate functions.
// ImperativeAggregateFunctionPositions will be [1, 2].
protected[this] val allImperativeAggregateFunctionPositions: Array[Int] = {
val positions = new ArrayBuffer[Int]()
var i = 0
while (i < aggregateFunctions.length) {
aggregateFunctions(i) match {
case agg: DeclarativeAggregate =>
case _ => positions += i
}
i += 1
}
positions.toArray
}
// The projection used to initialize buffer values for all expression-based aggregates.
protected[this] val expressionAggInitialProjection = {
val initExpressions = aggregateFunctions.flatMap {
case ae: DeclarativeAggregate => ae.initialValues
// For the positions corresponding to imperative aggregate functions, we'll use special
// no-op expressions which are ignored during projection code-generation.
case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp)
}
newMutableProjection(initExpressions, Nil)
}
// All imperative AggregateFunctions.
protected[this] val allImperativeAggregateFunctions: Array[ImperativeAggregate] =
allImperativeAggregateFunctionPositions
.map(aggregateFunctions)
.map(_.asInstanceOf[ImperativeAggregate])
// Initializing functions used to process a row.
protected def generateProcessRow(
expressions: Seq[AggregateExpression],
functions: Seq[AggregateFunction],
inputAttributes: Seq[Attribute]): (InternalRow, InternalRow) => Unit = {
val joinedRow = new JoinedRow
if (expressions.nonEmpty) {
val mergeExpressions = functions.zipWithIndex.flatMap {
case (ae: DeclarativeAggregate, i) =>
expressions(i).mode match {
case Partial | Complete => ae.updateExpressions
case PartialMerge | Final => ae.mergeExpressions
}
case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
}
val updateFunctions = functions.zipWithIndex.collect {
case (ae: ImperativeAggregate, i) =>
expressions(i).mode match {
case Partial | Complete =>
(buffer: InternalRow, row: InternalRow) => ae.update(buffer, row)
case PartialMerge | Final =>
(buffer: InternalRow, row: InternalRow) => ae.merge(buffer, row)
}
}.toArray
// This projection is used to merge buffer values for all expression-based aggregates.
val aggregationBufferSchema = functions.flatMap(_.aggBufferAttributes)
val updateProjection =
newMutableProjection(mergeExpressions, aggregationBufferSchema ++ inputAttributes)
(currentBuffer: InternalRow, row: InternalRow) => {
// Process all expression-based aggregate functions.
updateProjection.target(currentBuffer)(joinedRow(currentBuffer, row))
// Process all imperative aggregate functions.
var i = 0
while (i < updateFunctions.length) {
updateFunctions(i)(currentBuffer, row)
i += 1
}
}
} else {
// Grouping only.
(currentBuffer: InternalRow, row: InternalRow) => {}
}
}
protected val processRow: (InternalRow, InternalRow) => Unit =
generateProcessRow(aggregateExpressions, aggregateFunctions, inputAttributes)
protected val groupingProjection: UnsafeProjection =
UnsafeProjection.create(groupingExpressions, inputAttributes)
protected val groupingAttributes = groupingExpressions.map(_.toAttribute)
// Initializing the function used to generate the output row.
protected def generateResultProjection(): (UnsafeRow, InternalRow) => UnsafeRow = {
val joinedRow = new JoinedRow
val modes = aggregateExpressions.map(_.mode).distinct
val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes)
if (modes.contains(Final) || modes.contains(Complete)) {
val evalExpressions = aggregateFunctions.map {
case ae: DeclarativeAggregate => ae.evaluateExpression
case agg: AggregateFunction => NoOp
}
val aggregateResult = new SpecificInternalRow(aggregateAttributes.map(_.dataType))
val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)
expressionAggEvalProjection.target(aggregateResult)
val resultProjection =
UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateAttributes)
resultProjection.initialize(partIndex)
(currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => {
// Generate results for all expression-based aggregate functions.
expressionAggEvalProjection(currentBuffer)
// Generate results for all imperative aggregate functions.
var i = 0
while (i < allImperativeAggregateFunctions.length) {
aggregateResult.update(
allImperativeAggregateFunctionPositions(i),
allImperativeAggregateFunctions(i).eval(currentBuffer))
i += 1
}
resultProjection(joinedRow(currentGroupingKey, aggregateResult))
}
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
val resultProjection = UnsafeProjection.create(
groupingAttributes ++ bufferAttributes,
groupingAttributes ++ bufferAttributes)
resultProjection.initialize(partIndex)
// TypedImperativeAggregate stores generic object in aggregation buffer, and requires
// calling serialization before shuffling. See [[TypedImperativeAggregate]] for more info.
val typedImperativeAggregates: Array[TypedImperativeAggregate[_]] = {
aggregateFunctions.collect {
case (ag: TypedImperativeAggregate[_]) => ag
}
}
(currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => {
// Serializes the generic object stored in aggregation buffer
var i = 0
while (i < typedImperativeAggregates.length) {
typedImperativeAggregates(i).serializeAggregateBufferInPlace(currentBuffer)
i += 1
}
resultProjection(joinedRow(currentGroupingKey, currentBuffer))
}
} else {
// Grouping-only: we only output values based on grouping expressions.
val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes)
resultProjection.initialize(partIndex)
(currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => {
resultProjection(currentGroupingKey)
}
}
}
protected val generateOutput: (UnsafeRow, InternalRow) => UnsafeRow =
generateResultProjection()
/** Initializes buffer values for all aggregate functions. */
protected def initializeBuffer(buffer: InternalRow): Unit = {
expressionAggInitialProjection.target(buffer)(EmptyRow)
var i = 0
while (i < allImperativeAggregateFunctions.length) {
allImperativeAggregateFunctions(i).initialize(buffer)
i += 1
}
}
}