org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of flink-table-planner-blink_2.11 Show documentation
Show all versions of flink-table-planner-blink_2.11 Show documentation
This module bridges Table/SQL API and runtime. It contains
all resources that are required during pre-flight and runtime
phase. The content of this module is work-in-progress. It will
replace flink-table-planner once it is stable. See FLINK-11439
and FLIP-32 for more details.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.codegen.agg
import org.apache.flink.api.common.typeutils.TypeSerializer
import org.apache.flink.table.api.TableException
import org.apache.flink.table.data.GenericRowData
import org.apache.flink.table.expressions._
import org.apache.flink.table.functions.ImperativeAggregateFunction
import org.apache.flink.table.planner.JLong
import org.apache.flink.table.planner.codegen.CodeGenUtils.{ROW_DATA, _}
import org.apache.flink.table.planner.codegen.Indenter.toISC
import org.apache.flink.table.planner.codegen._
import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator._
import org.apache.flink.table.planner.expressions.DeclarativeExpressionResolver.toRexInputRef
import org.apache.flink.table.planner.expressions._
import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction
import org.apache.flink.table.planner.plan.utils.AggregateInfoList
import org.apache.flink.table.planner.typeutils.DataViewUtils.{DataViewSpec, ListViewSpec, MapViewSpec}
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala
import org.apache.flink.table.runtime.dataview.{StateListView, StateMapView}
import org.apache.flink.table.runtime.generated._
import org.apache.flink.table.runtime.operators.window.slicing.SliceAssigner
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType
import org.apache.flink.table.types.DataType
import org.apache.flink.table.types.logical.utils.LogicalTypeUtils
import org.apache.flink.table.types.logical.{BooleanType, IntType, LogicalType, RowType}
import org.apache.flink.util.Collector
import org.apache.calcite.rex.RexLiteral
import org.apache.calcite.tools.RelBuilder
import java.util.Optional
import java.time.ZoneId
/**
* A code generator for generating [[AggsHandleFunction]].
*
* @param copyInputField copy input field element if true (only mutable type will be copied),
* set to true if field will be buffered (such as local aggregate)
*/
class AggsHandlerCodeGenerator(
ctx: CodeGeneratorContext,
relBuilder: RelBuilder,
inputFieldTypes: Seq[LogicalType],
copyInputField: Boolean) {
private val inputType = RowType.of(inputFieldTypes: _*)
/** constant expressions that act like a second input in the parameter indices. */
private var constants: Seq[RexLiteral] = Seq()
private var constantExprs: Seq[GeneratedExpression] = Seq()
/** window properties like window_start and window_end, only used in window aggregates */
private var namespaceClassName: String = _
private var windowProperties: Seq[PlannerWindowProperty] = Seq()
private var hasNamespace: Boolean = false
private var sliceAssignerTerm: String = _
private var shiftTimeZone: ZoneId = _
/** Aggregates informations */
private var accTypeInfo: RowType = _
private var aggBufferSize: Int = _
private var mergedAccExternalTypes: Array[DataType] = _
private var mergedAccOffset: Int = 0
private var mergedAccOnHeap: Boolean = false
private var ignoreAggValues: Array[Int] = Array()
private var isAccumulateNeeded = false
private var isRetractNeeded = false
private var isMergeNeeded = false
var valueType: RowType = _
/**
* The [[aggBufferCodeGens]] and [[aggActionCodeGens]] will be both created when code generate
* an [[AggsHandleFunction]] or [[NamespaceAggsHandleFunction]]. They both contain all the
* same AggCodeGens, but are different in the organizational form. The [[aggBufferCodeGens]]
* flatten all the AggCodeGens in a flat format. The [[aggActionCodeGens]] organize all the
* AggCodeGens in a tree format. If there is no distinct aggregate, the [[aggBufferCodeGens]]
* and [[aggActionCodeGens]] are totally the same.
*
* When different aggregate distinct on the same field but on different filter conditions,
* they will share the same distinct state, see DistinctAggCodeGen.DistinctValueGenerator
* for more information.
*/
/**
* The aggBufferCodeGens is organized according to the agg buffer order, which is in a flat
* format, and is only used to generate the methods relative to accumulators, Such as
* [[genCreateAccumulators()]], [[genGetAccumulators()]], [[genSetAccumulators()]].
*
* For example if we have :
* count(*), count(distinct a), count(distinct a) filter d > 5, sum(a), sum(distinct a)
*
* then the members of aggBufferCodeGens are organized looks like this:
* +----------+-----------+-----------+---------+---------+----------------+
* | count(*) | count(a') | count(a') | sum(a) | sum(a') | distinct(a) a' |
* +----------+-----------+-----------+---------+---------+----------------+
* */
private var aggBufferCodeGens: Array[AggCodeGen] = _
/**
* The aggActionCodeGens is organized according to the aggregate calling order, which is in
* a tree format. Such as the aggregates distinct on the same fields should be accumulated
* together when distinct is satisfied. And this is only used to generate the methods relative
* to aggregate action. Such as [[genAccumulate()]], [[genRetract()]], [[genMerge()]].
*
* For example if we have :
* count(*), count(distinct a), count(distinct a) filter d > 5, sum(a), sum(distinct a)
*
* then the members of aggActionCodeGens are organized looks like this:
*
* +----------------------------------------------------+
* | count(*) | sum(a) | distinct(a) a' |
* | | | |-- count(a') |
* | | | |-- count(a') (filter d > 5) |
* | | | |-- sum(a') |
* +----------------------------------------------------+
*/
private var aggActionCodeGens: Array[AggCodeGen] = _
/**
* Adds constant expressions that act like a second input in the parameter indices.
*/
def withConstants(literals: Seq[RexLiteral]): AggsHandlerCodeGenerator = {
// create constants
this.constants = literals
val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL)
val exprs = literals.map(exprGenerator.generateExpression)
this.constantExprs = exprs.map(ctx.addReusableConstant(_, nullCheck = true))
this
}
/**
* Tells the generator to generate `accumulate(..)` method for the [[AggsHandleFunction]] and
* [[NamespaceAggsHandleFunction]]. Default not generate `accumulate(..)` method.
*/
def needAccumulate(): AggsHandlerCodeGenerator = {
this.isAccumulateNeeded = true
this
}
/**
* Tells the generator to generate `retract(..)` method for the [[AggsHandleFunction]] and
* [[NamespaceAggsHandleFunction]]. Default not generate `retract(..)` method.
*
* @return
*/
def needRetract(): AggsHandlerCodeGenerator = {
this.isRetractNeeded = true
this
}
/**
* Tells the generator to generate `merge(..)` method with the merged accumulator information
* for the [[AggsHandleFunction]] and [[NamespaceAggsHandleFunction]].
* Default not generate `merge(..)` method.
*
* @param mergedAccOffset the mergedAcc may come from local aggregate,
* this is the first buffer offset in the row
* @param mergedAccOnHeap true if the mergedAcc is on heap, otherwise
* @param mergedAccExternalTypes the merged acc types
*/
def needMerge(
mergedAccOffset: Int,
mergedAccOnHeap: Boolean,
mergedAccExternalTypes: Array[DataType] = null): AggsHandlerCodeGenerator = {
this.mergedAccOffset = mergedAccOffset
this.mergedAccOnHeap = mergedAccOnHeap
this.mergedAccExternalTypes = mergedAccExternalTypes
this.isMergeNeeded = true
this
}
/**
* Adds window properties such as window_start, window_end
*/
private def initialWindowProperties(
windowProperties: Seq[PlannerWindowProperty],
windowClass: Class[_],
shiftTimeZone: ZoneId): Unit = {
this.windowProperties = windowProperties
this.namespaceClassName = windowClass.getCanonicalName
this.hasNamespace = true
this.shiftTimeZone = shiftTimeZone
}
/**
* Adds aggregate infos into context
*/
private def initialAggregateInformation(aggInfoList: AggregateInfoList): Unit = {
this.accTypeInfo = RowType.of(
aggInfoList.getAccTypes.map(fromDataTypeToLogicalType): _*)
this.aggBufferSize = accTypeInfo.getFieldCount
var aggBufferOffset: Int = 0
if (mergedAccExternalTypes == null) {
mergedAccExternalTypes = aggInfoList.getAccTypes
}
val aggCodeGens = aggInfoList.aggInfos.map { aggInfo =>
val filterExpr = createFilterExpression(
aggInfo.agg.filterArg,
aggInfo.aggIndex,
aggInfo.agg.name)
val codegen = aggInfo.function match {
case _: DeclarativeAggregateFunction =>
new DeclarativeAggCodeGen(
ctx,
aggInfo,
filterExpr,
mergedAccOffset,
aggBufferOffset,
aggBufferSize,
inputFieldTypes,
constants,
relBuilder)
case _: ImperativeAggregateFunction[_, _] =>
new ImperativeAggCodeGen(
ctx,
aggInfo,
filterExpr,
mergedAccOffset,
aggBufferOffset,
aggBufferSize,
inputFieldTypes,
constantExprs,
relBuilder,
hasNamespace,
mergedAccOnHeap,
mergedAccExternalTypes(aggBufferOffset),
copyInputField)
}
aggBufferOffset = aggBufferOffset + aggInfo.externalAccTypes.length
codegen
}
val distinctCodeGens = aggInfoList.distinctInfos.zipWithIndex.map {
case (distinctInfo, index) =>
val innerCodeGens = distinctInfo.aggIndexes.map(aggCodeGens(_)).toArray
val distinctIndex = aggCodeGens.length + index
val filterExpr = distinctInfo.filterArgs.map(
createFilterExpression(_, distinctIndex, "distinct aggregate"))
val codegen = new DistinctAggCodeGen(
ctx,
distinctInfo,
index,
innerCodeGens,
filterExpr.toArray,
constantExprs,
mergedAccOffset,
aggBufferOffset,
aggBufferSize,
hasNamespace,
isMergeNeeded,
mergedAccOnHeap,
distinctInfo.consumeRetraction,
copyInputField,
relBuilder)
// distinct agg buffer occupies only one field
aggBufferOffset += 1
codegen
}
val distinctAggIndexes = aggInfoList.distinctInfos.flatMap(_.aggIndexes)
val nonDistinctAggIndexes = aggCodeGens.indices.filter(!distinctAggIndexes.contains(_)).toArray
this.aggBufferCodeGens = aggCodeGens ++ distinctCodeGens
this.aggActionCodeGens = nonDistinctAggIndexes.map(aggCodeGens(_)) ++ distinctCodeGens
// when input contains retractions, we inserted a count1 agg in the agg list
// the count1 agg value shouldn't be in the aggregate result
if (aggInfoList.indexOfCountStar.nonEmpty && aggInfoList.countStarInserted) {
ignoreAggValues ++= Array(aggInfoList.indexOfCountStar.get)
}
// the distinct value shouldn't be in the aggregate result
if (aggInfoList.distinctInfos.nonEmpty) {
ignoreAggValues ++= distinctCodeGens.indices.map(_ + aggCodeGens.length)
}
}
/**
* Creates filter argument access expression, none if no filter
*/
private def createFilterExpression(
filterArg: Int,
aggIndex: Int,
aggName: String): Option[Expression] = {
if (filterArg > 0) {
val name = s"agg_${aggIndex}_filter"
val filterType = inputFieldTypes(filterArg)
if (!filterType.isInstanceOf[BooleanType]) {
throw new TableException(s"filter arg must be boolean, but is $filterType, " +
s"the aggregate is $aggName.")
}
Some(toRexInputRef(relBuilder, filterArg, inputFieldTypes(filterArg)))
} else {
None
}
}
/**
* Generate [[GeneratedAggsHandleFunction]] with the given function name and aggregate infos.
*/
def generateAggsHandler(
name: String,
aggInfoList: AggregateInfoList): GeneratedAggsHandleFunction = {
initialAggregateInformation(aggInfoList)
// generates all methods body first to add necessary reuse code to context
val createAccumulatorsCode = genCreateAccumulators()
val getAccumulatorsCode = genGetAccumulators()
val setAccumulatorsCode = genSetAccumulators()
val resetAccumulatorsCode = genResetAccumulators()
val accumulateCode = genAccumulate()
val retractCode = genRetract()
val mergeCode = genMerge()
val getValueCode = genGetValue()
val functionName = newName(name)
val functionCode =
j"""
public final class $functionName implements $AGGS_HANDLER_FUNCTION {
${ctx.reuseMemberCode()}
private $STATE_DATA_VIEW_STORE store;
public $functionName(java.lang.Object[] references) throws Exception {
${ctx.reuseInitCode()}
}
private $RUNTIME_CONTEXT getRuntimeContext() {
return store.getRuntimeContext();
}
@Override
public void open($STATE_DATA_VIEW_STORE store) throws Exception {
this.store = store;
${ctx.reuseOpenCode()}
}
@Override
public void accumulate($ROW_DATA $ACCUMULATE_INPUT_TERM) throws Exception {
$accumulateCode
}
@Override
public void retract($ROW_DATA $RETRACT_INPUT_TERM) throws Exception {
$retractCode
}
@Override
public void merge($ROW_DATA $MERGED_ACC_TERM) throws Exception {
$mergeCode
}
@Override
public void setAccumulators($ROW_DATA $ACC_TERM) throws Exception {
$setAccumulatorsCode
}
@Override
public void resetAccumulators() throws Exception {
$resetAccumulatorsCode
}
@Override
public $ROW_DATA getAccumulators() throws Exception {
$getAccumulatorsCode
}
@Override
public $ROW_DATA createAccumulators() throws Exception {
$createAccumulatorsCode
}
@Override
public $ROW_DATA getValue() throws Exception {
$getValueCode
}
@Override
public void cleanup() throws Exception {
${ctx.reuseCleanupCode()}
}
@Override
public void close() throws Exception {
${ctx.reuseCloseCode()}
}
}
""".stripMargin
new GeneratedAggsHandleFunction(functionName, functionCode, ctx.references.toArray)
}
/**
* Generate [[GeneratedTableAggsHandleFunction]] with the given function name and aggregate
* infos.
*/
def generateTableAggsHandler(
name: String,
aggInfoList: AggregateInfoList): GeneratedTableAggsHandleFunction = {
initialAggregateInformation(aggInfoList)
// generates all methods body first to add necessary reuse code to context
val createAccumulatorsCode = genCreateAccumulators()
val getAccumulatorsCode = genGetAccumulators()
val setAccumulatorsCode = genSetAccumulators()
val resetAccumulatorsCode = genResetAccumulators()
val accumulateCode = genAccumulate()
val retractCode = genRetract()
val mergeCode = genMerge()
val emitValueCode = genEmitValue()
// gen converter
val aggExternalType = aggInfoList.getActualAggregateInfos(0).externalResultType
val recordInputName = newName("recordInput")
val recordToRowDataCode = genRecordToRowData(aggExternalType, recordInputName)
val functionName = newName(name)
val functionCode =
j"""
public final class $functionName implements ${className[TableAggsHandleFunction]} {
${ctx.reuseMemberCode()}
private $STATE_DATA_VIEW_STORE store;
private $CONVERT_COLLECTOR_TYPE_TERM $MEMBER_COLLECTOR_TERM;
public $functionName(java.lang.Object[] references) throws Exception {
${ctx.reuseInitCode()}
$MEMBER_COLLECTOR_TERM = new $CONVERT_COLLECTOR_TYPE_TERM(references);
}
private $RUNTIME_CONTEXT getRuntimeContext() {
return store.getRuntimeContext();
}
@Override
public void open($STATE_DATA_VIEW_STORE store) throws Exception {
this.store = store;
${ctx.reuseOpenCode()}
}
@Override
public void accumulate($ROW_DATA $ACCUMULATE_INPUT_TERM) throws Exception {
$accumulateCode
}
@Override
public void retract($ROW_DATA $RETRACT_INPUT_TERM) throws Exception {
$retractCode
}
@Override
public void merge($ROW_DATA $MERGED_ACC_TERM) throws Exception {
$mergeCode
}
@Override
public void setAccumulators($ROW_DATA $ACC_TERM) throws Exception {
$setAccumulatorsCode
}
@Override
public void resetAccumulators() throws Exception {
$resetAccumulatorsCode
}
@Override
public $ROW_DATA getAccumulators() throws Exception {
$getAccumulatorsCode
}
@Override
public $ROW_DATA createAccumulators() throws Exception {
$createAccumulatorsCode
}
@Override
public void emitValue(
$COLLECTOR<$ROW_DATA> $COLLECTOR_TERM, $ROW_DATA key, boolean isRetract)
throws Exception {
$MEMBER_COLLECTOR_TERM.reset(key, isRetract, $COLLECTOR_TERM);
$emitValueCode
}
@Override
public void cleanup() throws Exception {
${ctx.reuseCleanupCode()}
}
@Override
public void close() throws Exception {
${ctx.reuseCloseCode()}
}
private class $CONVERT_COLLECTOR_TYPE_TERM implements $COLLECTOR {
private $COLLECTOR<$ROW_DATA> $COLLECTOR_TERM;
private $ROW_DATA key;
private $JOINED_ROW result;
private boolean isRetract = false;
${ctx.reuseMemberCode()}
public $CONVERT_COLLECTOR_TYPE_TERM(java.lang.Object[] references) throws Exception {
${ctx.reuseInitCode()}
result = new $JOINED_ROW();
}
public void reset(
$ROW_DATA key, boolean isRetract, $COLLECTOR<$ROW_DATA> $COLLECTOR_TERM) {
this.key = key;
this.isRetract = isRetract;
this.$COLLECTOR_TERM = $COLLECTOR_TERM;
}
public $ROW_DATA convertToRowData(Object $recordInputName) throws Exception {
$recordToRowDataCode
}
@Override
public void collect(Object $recordInputName) throws Exception {
$ROW_DATA tempRowData = convertToRowData($recordInputName);
result.replace(key, tempRowData);
if (isRetract) {
result.setRowKind($ROW_KIND.DELETE);
} else {
result.setRowKind($ROW_KIND.INSERT);
}
$COLLECTOR_TERM.collect(result);
}
@Override
public void close() {
$COLLECTOR_TERM.close();
}
}
}
""".stripMargin
new GeneratedTableAggsHandleFunction(functionName, functionCode, ctx.references.toArray)
}
/**
* Generate [[NamespaceAggsHandleFunction]] with the given function name and aggregate infos
* and window properties.
*/
def generateNamespaceAggsHandler(
name: String,
aggInfoList: AggregateInfoList,
windowProperties: Seq[PlannerWindowProperty],
sliceAssigner: SliceAssigner,
shiftTimeZone: ZoneId): GeneratedNamespaceAggsHandleFunction[JLong] = {
this.sliceAssignerTerm = newName("sliceAssigner")
ctx.addReusableObjectWithName(sliceAssigner, sliceAssignerTerm)
// we use window end timestamp to indicate a window, see SliceAssigner
generateNamespaceAggsHandler(
name,
aggInfoList,
windowProperties,
classOf[JLong],
shiftTimeZone)
}
/**
* Generate [[NamespaceAggsHandleFunction]] with the given function name and aggregate infos
* and window properties.
*/
def generateNamespaceAggsHandler[N](
name: String,
aggInfoList: AggregateInfoList,
windowProperties: Seq[PlannerWindowProperty],
windowClass: Class[N],
shiftTimeZone: ZoneId): GeneratedNamespaceAggsHandleFunction[N] = {
initialWindowProperties(windowProperties, windowClass, shiftTimeZone)
initialAggregateInformation(aggInfoList)
// generates all methods body first to add necessary reuse code to context
val createAccumulatorsCode = genCreateAccumulators()
val getAccumulatorsCode = genGetAccumulators()
val setAccumulatorsCode = genSetAccumulators()
val accumulateCode = genAccumulate()
val retractCode = genRetract()
val mergeCode = genMerge()
val getValueCode = genGetValue()
val functionName = newName(name)
val functionCode =
j"""
public final class $functionName
implements $NAMESPACE_AGGS_HANDLER_FUNCTION<$namespaceClassName> {
${ctx.reuseMemberCode()}
private $STATE_DATA_VIEW_STORE store;
private $namespaceClassName $NAMESPACE_TERM;
public $functionName(Object[] references) throws Exception {
${ctx.reuseInitCode()}
}
private $RUNTIME_CONTEXT getRuntimeContext() {
return store.getRuntimeContext();
}
@Override
public void open($STATE_DATA_VIEW_STORE store) throws Exception {
this.store = store;
${ctx.reuseOpenCode()}
}
@Override
public void accumulate($ROW_DATA $ACCUMULATE_INPUT_TERM) throws Exception {
$accumulateCode
}
@Override
public void retract($ROW_DATA $RETRACT_INPUT_TERM) throws Exception {
$retractCode
}
@Override
public void merge(Object ns, $ROW_DATA $MERGED_ACC_TERM) throws Exception {
$NAMESPACE_TERM = ($namespaceClassName) ns;
$mergeCode
}
@Override
public void setAccumulators(Object ns, $ROW_DATA $ACC_TERM)
throws Exception {
$NAMESPACE_TERM = ($namespaceClassName) ns;
$setAccumulatorsCode
}
@Override
public $ROW_DATA getAccumulators() throws Exception {
$getAccumulatorsCode
}
@Override
public $ROW_DATA createAccumulators() throws Exception {
$createAccumulatorsCode
}
@Override
public $ROW_DATA getValue(Object ns) throws Exception {
$NAMESPACE_TERM = ($namespaceClassName) ns;
$getValueCode
}
@Override
public void cleanup(Object ns) throws Exception {
$NAMESPACE_TERM = ($namespaceClassName) ns;
${ctx.reuseCleanupCode()}
}
@Override
public void close() throws Exception {
${ctx.reuseCloseCode()}
}
}
""".stripMargin
new GeneratedNamespaceAggsHandleFunction[N](functionName, functionCode, ctx.references.toArray)
}
/**
* Generate [[NamespaceTableAggsHandleFunction]] with the given function name and aggregate infos
* and window properties.
*/
def generateNamespaceTableAggsHandler[N](
name: String,
aggInfoList: AggregateInfoList,
windowProperties: Seq[PlannerWindowProperty],
windowClass: Class[N],
shiftedTimeZone: ZoneId): GeneratedNamespaceTableAggsHandleFunction[N] = {
initialWindowProperties(windowProperties, windowClass, shiftedTimeZone)
initialAggregateInformation(aggInfoList)
// generates all methods body first to add necessary reuse code to context
val createAccumulatorsCode = genCreateAccumulators()
val getAccumulatorsCode = genGetAccumulators()
val setAccumulatorsCode = genSetAccumulators()
val accumulateCode = genAccumulate()
val retractCode = genRetract()
val mergeCode = genMerge()
val emitValueCode = genEmitValue(isWindow = true)
// gen converter
val aggExternalType = aggInfoList.getActualAggregateInfos(0).externalResultType
val recordInputName = newName("recordInput")
val recordToRowDataCode = genRecordToRowData(aggExternalType, recordInputName)
val functionName = newName(name)
val functionCode =
j"""
public final class $functionName
implements ${className[NamespaceTableAggsHandleFunction[_]]}<$namespaceClassName> {
${ctx.reuseMemberCode()}
private $STATE_DATA_VIEW_STORE store;
private $namespaceClassName $NAMESPACE_TERM;
private $CONVERT_COLLECTOR_TYPE_TERM $MEMBER_COLLECTOR_TERM;
public $functionName(Object[] references) throws Exception {
${ctx.reuseInitCode()}
$MEMBER_COLLECTOR_TERM = new $CONVERT_COLLECTOR_TYPE_TERM(references);
}
private $RUNTIME_CONTEXT getRuntimeContext() {
return store.getRuntimeContext();
}
@Override
public void open($STATE_DATA_VIEW_STORE store) throws Exception {
this.store = store;
${ctx.reuseOpenCode()}
}
@Override
public void accumulate($ROW_DATA $ACCUMULATE_INPUT_TERM) throws Exception {
$accumulateCode
}
@Override
public void retract($ROW_DATA $RETRACT_INPUT_TERM) throws Exception {
$retractCode
}
@Override
public void merge(Object ns, $ROW_DATA $MERGED_ACC_TERM) throws Exception {
$NAMESPACE_TERM = ($namespaceClassName) ns;
$mergeCode
}
@Override
public void setAccumulators(Object ns, $ROW_DATA $ACC_TERM)
throws Exception {
$NAMESPACE_TERM = ($namespaceClassName) ns;
$setAccumulatorsCode
}
@Override
public $ROW_DATA getAccumulators() throws Exception {
$getAccumulatorsCode
}
@Override
public $ROW_DATA createAccumulators() throws Exception {
$createAccumulatorsCode
}
@Override
public void emitValue(Object ns, $ROW_DATA $KEY_TERM,
$COLLECTOR<$ROW_DATA> $COLLECTOR_TERM) throws Exception {
$MEMBER_COLLECTOR_TERM.$COLLECTOR_TERM = $COLLECTOR_TERM;
$NAMESPACE_TERM = ($namespaceClassName) ns;
$emitValueCode
}
@Override
public void cleanup(Object ns) throws Exception {
$NAMESPACE_TERM = ($namespaceClassName) ns;
${ctx.reuseCleanupCode()}
}
@Override
public void close() throws Exception {
${ctx.reuseCloseCode()}
}
private class $CONVERT_COLLECTOR_TYPE_TERM implements $COLLECTOR {
public $COLLECTOR<$ROW_DATA> $COLLECTOR_TERM;
private $ROW_DATA timeProperties;
private $ROW_DATA key;
private $JOINED_ROW outerResult;
private $JOINED_ROW innerResult;
${ctx.reuseMemberCode()}
public $CONVERT_COLLECTOR_TYPE_TERM(java.lang.Object[] references) throws Exception {
${ctx.reuseInitCode()}
outerResult = new $JOINED_ROW();
innerResult = new $JOINED_ROW();
}
public void reset($ROW_DATA $KEY_TERM, $ROW_DATA timeProperties) {
this.timeProperties = timeProperties;
this.key = $KEY_TERM;
}
public $ROW_DATA convertToRowData(Object $recordInputName) throws Exception {
$recordToRowDataCode
}
@Override
public void collect(Object $recordInputName) throws Exception {
$ROW_DATA tempRowData = convertToRowData($recordInputName);
innerResult.replace(tempRowData, timeProperties);
outerResult.replace(key, innerResult);
$COLLECTOR_TERM.collect(outerResult);
}
@Override
public void close() {
$COLLECTOR_TERM.close();
}
}
}
""".stripMargin
new GeneratedNamespaceTableAggsHandleFunction[N](
functionName, functionCode, ctx.references.toArray)
}
private def genCreateAccumulators(): String = {
val methodName = "createAccumulators"
ctx.startNewLocalVariableStatement(methodName)
// not need to bind input for ExprCodeGenerator
val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL)
val initAccExprs = aggBufferCodeGens.flatMap(_.createAccumulator(exprGenerator))
val accTerm = newName("acc")
val resultExpr = exprGenerator.generateResultExpression(
initAccExprs,
accTypeInfo,
classOf[GenericRowData],
outRow = accTerm,
reusedOutRow = false,
allowSplit = true,
methodName = methodName)
s"""
|${ctx.reuseLocalVariableCode(methodName)}
|${resultExpr.code}
|return ${resultExpr.resultTerm};
""".stripMargin
}
private def genGetAccumulators(): String = {
val methodName = "getAccumulators"
ctx.startNewLocalVariableStatement(methodName)
// no need to bind input
val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL)
val accExprs = aggBufferCodeGens.flatMap(_.getAccumulator(exprGenerator))
val accTerm = newName("acc")
// always create a new accumulator row
val resultExpr = exprGenerator.generateResultExpression(
accExprs,
accTypeInfo,
classOf[GenericRowData],
outRow = accTerm,
reusedOutRow = false,
allowSplit = true,
methodName = methodName)
s"""
|${ctx.reuseLocalVariableCode(methodName)}
|${resultExpr.code}
|return ${resultExpr.resultTerm};
""".stripMargin
}
private def genSetAccumulators(): String = {
val methodName = "setAccumulators"
ctx.startNewLocalVariableStatement(methodName)
// bind input1 as accumulators
val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL)
.bindInput(accTypeInfo, inputTerm = ACC_TERM)
val body = splitExpressionsIfNecessary(
aggBufferCodeGens.map(_.setAccumulator(exprGenerator)), methodName)
s"""
|${ctx.reuseLocalVariableCode(methodName)}
|${ctx.reuseInputUnboxingCode(ACC_TERM)}
|$body
""".stripMargin
}
private def genResetAccumulators(): String = {
val methodName = "resetAccumulators"
ctx.startNewLocalVariableStatement(methodName)
val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL)
val body = splitExpressionsIfNecessary(aggBufferCodeGens.map(_.resetAccumulator(exprGenerator)),
methodName)
s"""
|${ctx.reuseLocalVariableCode(methodName)}
|$body
""".stripMargin
}
private def genAccumulate(): String = {
if (isAccumulateNeeded) {
// validation check
checkNeededMethods(needAccumulate = true)
val methodName = "accumulate"
ctx.startNewLocalVariableStatement(methodName)
// bind input1 as inputRow
val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL)
.bindInput(inputType, inputTerm = ACCUMULATE_INPUT_TERM)
val body = splitExpressionsIfNecessary(
aggActionCodeGens.map(_.accumulate(exprGenerator)), methodName)
s"""
|${ctx.reuseLocalVariableCode(methodName)}
|${ctx.reuseInputUnboxingCode(ACCUMULATE_INPUT_TERM)}
|$body
|""".stripMargin
} else {
genThrowException(
"This function not require accumulate method, but the accumulate method is called.")
}
}
private def genRetract(): String = {
if (isRetractNeeded) {
// validation check
checkNeededMethods(needRetract = true)
val methodName = "retract"
ctx.startNewLocalVariableStatement(methodName)
// bind input1 as inputRow
val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL)
.bindInput(inputType, inputTerm = RETRACT_INPUT_TERM)
val body = splitExpressionsIfNecessary(
aggActionCodeGens.map(_.retract(exprGenerator)), methodName)
s"""
|${ctx.reuseLocalVariableCode(methodName)}
|${ctx.reuseInputUnboxingCode(RETRACT_INPUT_TERM)}
|$body
""".stripMargin
} else {
genThrowException(
"This function not require retract method, but the retract method is called.")
}
}
private def genMerge(): String = {
if (isMergeNeeded) {
// validation check
checkNeededMethods(needMerge = true)
val methodName = "merge"
ctx.startNewLocalVariableStatement(methodName)
// the mergedAcc is partial of mergedInput, such as in local-global, ignore keys
val internalAccTypes = mergedAccExternalTypes.map(fromDataTypeToLogicalType)
val mergedAccType = if (mergedAccOffset > 0) {
// concat padding types and acc types, use int type as padding
// the padding types will be ignored
val padding = Array.range(0, mergedAccOffset).map(_ => new IntType())
RowType.of(padding ++ internalAccTypes: _*)
} else {
RowType.of(internalAccTypes: _*)
}
// bind input1 as otherAcc
val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL)
.bindInput(mergedAccType, inputTerm = MERGED_ACC_TERM)
val body = splitExpressionsIfNecessary(
aggActionCodeGens.map(_.merge(exprGenerator)), methodName)
s"""
|${ctx.reuseLocalVariableCode(methodName)}
|${ctx.reuseInputUnboxingCode(MERGED_ACC_TERM)}
|$body
""".stripMargin
} else {
genThrowException(
"This function not require merge method, but the merge method is called.")
}
}
private def splitExpressionsIfNecessary(exprs: Array[String], methodName: String): String = {
val totalLen = exprs.map(_.length).sum
val maxCodeLength = ctx.tableConfig.getMaxGeneratedCodeLength
if (totalLen > maxCodeLength) {
ctx.setCodeSplit(methodName)
exprs.map(expr => {
val splitMethodName = newName("split_" + methodName)
val method =
s"""
|private void $splitMethodName() throws Exception {
| $expr
|}
|""".stripMargin
ctx.addReusableMember(method)
s"$splitMethodName();"
}).mkString("\n")
} else {
exprs.mkString("\n")
}
}
private def getWindowExpressions(
windowProperties: Seq[PlannerWindowProperty]): Seq[GeneratedExpression] = {
if (namespaceClassName.equals(classOf[JLong].getCanonicalName)) {
// slicing optimization, we are using window end timestamp to indicate a window
windowProperties.map {
case w: PlannerWindowStart =>
// return a Timestamp(Internal is TimestampData)
GeneratedExpression(
s"$TIMESTAMP_DATA.fromEpochMillis($sliceAssignerTerm.getWindowStart($NAMESPACE_TERM))",
"false",
"",
w.getResultType)
case w: PlannerWindowEnd =>
// return a Timestamp(Internal is TimestampData)
GeneratedExpression(
s"$TIMESTAMP_DATA.fromEpochMillis($NAMESPACE_TERM)",
"false",
"",
w.getResultType)
case r: PlannerRowtimeAttribute =>
// return a rowtime, use TimestampData as internal type
GeneratedExpression(
s"""
|$TIMESTAMP_DATA.fromEpochMillis(
|${getShiftEpochMills(s"$NAMESPACE_TERM - 1")})
""".stripMargin,
"false",
"",
r.getResultType)
case p: PlannerProctimeAttribute =>
// ignore this property, it will be null at the position later
GeneratedExpression(s"$TIMESTAMP_DATA.fromEpochMillis(-1L)", "true", "", p.getResultType)
}
} else {
windowProperties.map {
case w: PlannerWindowStart =>
// return a Timestamp(Internal is TimestampData)
GeneratedExpression(
s"$TIMESTAMP_DATA.fromEpochMillis($NAMESPACE_TERM.getStart())",
"false",
"",
w.getResultType)
case w: PlannerWindowEnd =>
// return a Timestamp(Internal is TimestampData)
GeneratedExpression(
s"$TIMESTAMP_DATA.fromEpochMillis($NAMESPACE_TERM.getEnd())",
"false",
"",
w.getResultType)
case r: PlannerRowtimeAttribute =>
// return a rowtime, use TimestampData as internal type
GeneratedExpression(
s"""
|$TIMESTAMP_DATA.fromEpochMillis(
|${getShiftEpochMills(s"$NAMESPACE_TERM.getEnd() - 1")})
""".stripMargin,
"false",
"",
r.getResultType)
case p: PlannerProctimeAttribute =>
// ignore this property, it will be null at the position later
GeneratedExpression(s"$TIMESTAMP_DATA.fromEpochMillis(-1L)", "true", "", p.getResultType)
}
}
}
private def getShiftEpochMills(itemExpr: String): String = {
if ("UTC".equals(shiftTimeZone.getId)) {
itemExpr
} else {
val timeZoneId = ctx.addReusableShiftTimeZone(shiftTimeZone)
s"""
|$TIME_WINDOW_UTIL.toEpochMills($itemExpr, $timeZoneId)
""".stripMargin
}
}
private def genGetValue(): String = {
val methodName = "getValue"
ctx.startNewLocalVariableStatement(methodName)
// no need to bind input
val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL)
var valueExprs = aggBufferCodeGens.zipWithIndex.filter { case (_, index) =>
// ignore the count1 agg codegen and distinct agg codegen
ignoreAggValues.isEmpty || !ignoreAggValues.contains(index)
}.map { case (codegen, _) =>
codegen.getValue(exprGenerator)
}
if (hasNamespace) {
// append window property results
val windowExprs = getWindowExpressions(windowProperties)
valueExprs = valueExprs ++ windowExprs
}
val aggValueTerm = newName("aggValue")
valueType = RowType.of(valueExprs.map(_.resultType): _*)
// always create a new result row
val resultExpr = exprGenerator.generateResultExpression(
valueExprs,
valueType,
classOf[GenericRowData],
outRow = aggValueTerm,
reusedOutRow = false,
allowSplit = true,
methodName = methodName)
s"""
|${ctx.reuseLocalVariableCode(methodName)}
|${resultExpr.code}
|return ${resultExpr.resultTerm};
""".stripMargin
}
private def genEmitValue(isWindow: Boolean = false): String = {
// validation check
checkNeededMethods(needEmitValue = true)
val methodName = "emitValue"
ctx.startNewLocalVariableStatement(methodName)
val windowCode =
if (isWindow) {
// no need to bind input
val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL)
val valueExprs = getWindowExpressions(windowProperties)
val aggValueTerm = newName("windowProperties")
valueType = RowType.of(valueExprs.map(_.resultType): _*)
// always create a new result row
val resultExpr = exprGenerator.generateResultExpression(
valueExprs,
valueType,
classOf[GenericRowData],
outRow = aggValueTerm,
reusedOutRow = false)
s"""
|${ctx.reuseLocalVariableCode(methodName)}
|${resultExpr.code}
|$MEMBER_COLLECTOR_TERM.reset($KEY_TERM, ${resultExpr.resultTerm});
""".stripMargin
} else {
""
}
windowCode + aggBufferCodeGens(0).asInstanceOf[ImperativeAggCodeGen].emitValue
}
private def genRecordToRowData(aggExternalType: DataType, recordInputName: String): String = {
val resultType = fromDataTypeToLogicalType(aggExternalType)
val resultRowType = LogicalTypeUtils.toRowType(resultType)
val newCtx = CodeGeneratorContext(ctx.tableConfig)
val exprGenerator = new ExprCodeGenerator(newCtx, false).bindInput(resultType)
val resultExpr = exprGenerator.generateConverterResultExpression(
resultRowType, classOf[GenericRowData], "convertResult")
val converterCode = CodeGenUtils.genToInternalConverter(ctx, aggExternalType, recordInputName)
val resultTypeClass = boxedTypeTermForType(resultType)
s"""
|${newCtx.reuseMemberCode()}
|$resultTypeClass ${exprGenerator.input1Term} = ($resultTypeClass) $converterCode;
|${newCtx.reuseLocalVariableCode()}
|${newCtx.reuseInputUnboxingCode()}
|${resultExpr.code}
|return ${resultExpr.resultTerm};
""".stripMargin
}
private def checkNeededMethods(
needAccumulate: Boolean = false,
needRetract: Boolean = false,
needMerge: Boolean = false,
needReset: Boolean = false,
needEmitValue: Boolean = false): Unit = {
// check and validate the needed methods
aggBufferCodeGens.foreach(
_.checkNeededMethods(needAccumulate, needRetract, needMerge, needReset, needEmitValue))
}
private def genThrowException(msg: String): String = {
s"""
|throw new java.lang.RuntimeException("$msg");
""".stripMargin
}
}
object AggsHandlerCodeGenerator {
/** static terms **/
val ACC_TERM = "acc"
val MERGED_ACC_TERM = "otherAcc"
val ACCUMULATE_INPUT_TERM = "accInput"
val RETRACT_INPUT_TERM = "retractInput"
val DISTINCT_KEY_TERM = "distinctKey"
val NAMESPACE_TERM = "namespace"
val STORE_TERM = "store"
val COLLECTOR: String = className[Collector[_]]
val COLLECTOR_TERM = "out"
val MEMBER_COLLECTOR_TERM = "convertCollector"
val CONVERT_COLLECTOR_TYPE_TERM = "ConvertCollector"
val KEY_TERM = "groupKey"
val INPUT_NOT_NULL = false
/**
* Create DataView term, for example, acc1_map_dataview.
*
* @return term to access MapView or ListView
*/
def createDataViewTerm(spec: DataViewSpec): String = {
s"${spec.getStateId}_dataview"
}
/**
* Creates RawValueData term which wraps the specific DataView term.
*/
def createDataViewRawValueTerm(spec: DataViewSpec): String = {
s"${createDataViewTerm(spec)}_raw_value"
}
/**
* Create DataView backup term, for example, acc1_map_dataview_backup.
* The backup dataview term is used for merging two statebackend
* dataviews, e.g. session window.
*
* @return term to access backup MapView or ListView
*/
def createDataViewBackupTerm(spec: DataViewSpec): String = {
s"${spec.getStateId}_dataview_backup"
}
/**
* Creates RawValueData term which wraps the specific DataView backup term.
*/
def createDataViewBackupRawValueTerm(spec: DataViewSpec): String = {
s"${createDataViewBackupTerm(spec)}_raw_value"
}
def addReusableStateDataViews(
ctx: CodeGeneratorContext,
viewSpecs: Array[DataViewSpec],
hasNamespace: Boolean,
enableBackupDataView: Boolean): Unit = {
// add reusable dataviews to context
viewSpecs.foreach { spec =>
val stateId = '"' + spec.getStateId + '"'
val (viewTypeTerm, stateStoreCall) = spec match {
case spec: ListViewSpec =>
val viewTypeTerm = className[StateListView[_, _]]
val elementSerializerTerm = addReusableDataViewSerializer(
ctx,
spec.getElementSerializer,
() => spec.getElementDataType)
val stateStoreCall =
s"getStateListView($stateId, $elementSerializerTerm)"
(viewTypeTerm, stateStoreCall)
case spec: MapViewSpec =>
val viewTypeTerm = className[StateMapView[_, _, _]]
val withNullKey = spec.containsNullKey()
val keySerializerTerm = addReusableDataViewSerializer(
ctx,
spec.getKeySerializer,
() => spec.getKeyDataType
)
val valueSerializerTerm = addReusableDataViewSerializer(
ctx,
spec.getValueSerializer,
() => spec.getValueDataType)
val stateStoreCall =
s"getStateMapView($stateId, $withNullKey, $keySerializerTerm, $valueSerializerTerm)"
(viewTypeTerm, stateStoreCall)
}
val viewFieldTerm = createDataViewTerm(spec)
val viewFieldInternalTerm = createDataViewRawValueTerm(spec)
ctx.addReusableMember(s"private $viewTypeTerm $viewFieldTerm;")
ctx.addReusableMember(s"private $BINARY_RAW_VALUE $viewFieldInternalTerm;")
val openCode =
s"""
|$viewFieldTerm = ($viewTypeTerm) $STORE_TERM.$stateStoreCall;
|$viewFieldInternalTerm = $BINARY_RAW_VALUE.fromObject($viewFieldTerm);
""".stripMargin
ctx.addReusableOpenStatement(openCode)
// only cleanup dataview term, do not need to cleanup backup
val cleanupCode = if (hasNamespace) {
s"""
|$viewFieldTerm.setCurrentNamespace($NAMESPACE_TERM);
|$viewFieldTerm.clear();
""".stripMargin
} else {
s"""
|$viewFieldTerm.clear();
""".stripMargin
}
ctx.addReusableCleanupStatement(cleanupCode)
// generate backup dataview codes
if (enableBackupDataView) {
val backupViewTerm = createDataViewBackupTerm(spec)
val backupViewInternalTerm = createDataViewBackupRawValueTerm(spec)
// create backup dataview
ctx.addReusableMember(s"private $viewTypeTerm $backupViewTerm;")
ctx.addReusableMember(s"private $BINARY_RAW_VALUE $backupViewInternalTerm;")
val backupOpenCode =
s"""
|$backupViewTerm = ($viewTypeTerm) $STORE_TERM.$stateStoreCall;
|$backupViewInternalTerm = $BINARY_RAW_VALUE.fromObject($backupViewTerm);
""".stripMargin
ctx.addReusableOpenStatement(backupOpenCode)
}
}
}
private def addReusableDataViewSerializer(
ctx: CodeGeneratorContext,
legacySerializer: Optional[TypeSerializer[_]],
dataType: () => DataType)
: String = {
toScala(legacySerializer) match {
case Some(serializer) =>
ctx.addReusableObject(serializer, "serializer")
case None =>
ctx.addReusableExternalSerializer(dataType())
}
}
}