Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
io.trino.operator.aggregation.AccumulatorCompiler Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.aggregation;
import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.DynamicClassLoader;
import io.airlift.bytecode.FieldDefinition;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.ForLoop;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.trino.operator.window.InternalWindowIndex;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.ColumnarRow;
import io.trino.spi.block.RowBlockBuilder;
import io.trino.spi.block.RowValueBuilder;
import io.trino.spi.block.ValueBlock;
import io.trino.spi.function.AccumulatorState;
import io.trino.spi.function.AccumulatorStateFactory;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.function.AggregationImplementation;
import io.trino.spi.function.AggregationImplementation.AccumulatorStateDescriptor;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionNullability;
import io.trino.spi.function.GroupedAccumulatorState;
import io.trino.spi.function.WindowIndex;
import io.trino.sql.gen.Binding;
import io.trino.sql.gen.CallSiteBinder;
import io.trino.sql.gen.CompilerOperations;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.airlift.bytecode.Access.FINAL;
import static io.airlift.bytecode.Access.PRIVATE;
import static io.airlift.bytecode.Access.PUBLIC;
import static io.airlift.bytecode.Access.a;
import static io.airlift.bytecode.Parameter.arg;
import static io.airlift.bytecode.ParameterizedType.type;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantLong;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantString;
import static io.airlift.bytecode.expression.BytecodeExpressions.invokeDynamic;
import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic;
import static io.airlift.bytecode.expression.BytecodeExpressions.newInstance;
import static io.trino.operator.aggregation.AggregationLoopBuilder.buildLoop;
import static io.trino.operator.aggregation.AggregationMaskCompiler.generateAggregationMaskBuilder;
import static io.trino.sql.gen.Bootstrap.BOOTSTRAP_METHOD;
import static io.trino.sql.gen.BytecodeUtils.invoke;
import static io.trino.sql.gen.BytecodeUtils.loadConstant;
import static io.trino.sql.gen.LambdaMetafactoryGenerator.generateMetafactory;
import static io.trino.util.CompilerUtils.defineClass;
import static io.trino.util.CompilerUtils.makeClassName;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
public final class AccumulatorCompiler
{
private AccumulatorCompiler() {}
public static AccumulatorFactory generateAccumulatorFactory(
BoundSignature boundSignature,
AggregationImplementation implementation,
FunctionNullability functionNullability,
boolean specializedLoops)
{
// change types used in Aggregation methods to types used in the core Trino engine to simplify code generation
implementation = normalizeAggregationMethods(implementation);
DynamicClassLoader classLoader = new DynamicClassLoader(AccumulatorCompiler.class.getClassLoader());
List argumentNullable = functionNullability.getArgumentNullable()
.subList(0, functionNullability.getArgumentNullable().size() - implementation.getLambdaInterfaces().size());
Constructor extends GroupedAccumulator> groupedAccumulatorConstructor = generateAccumulatorClass(
boundSignature,
GroupedAccumulator.class,
implementation,
argumentNullable,
classLoader,
specializedLoops);
Constructor extends Accumulator> accumulatorConstructor = generateAccumulatorClass(
boundSignature,
Accumulator.class,
implementation,
argumentNullable,
classLoader,
specializedLoops);
List nonNullArguments = new ArrayList<>();
for (int argumentIndex = 0; argumentIndex < argumentNullable.size(); argumentIndex++) {
if (!argumentNullable.get(argumentIndex)) {
nonNullArguments.add(argumentIndex);
}
}
Constructor extends AggregationMaskBuilder> maskBuilderConstructor = generateAggregationMaskBuilder(nonNullArguments.stream().mapToInt(Integer::intValue).toArray());
return new CompiledAccumulatorFactory(
accumulatorConstructor,
groupedAccumulatorConstructor,
implementation.getLambdaInterfaces(),
maskBuilderConstructor);
}
private static Constructor extends T> generateAccumulatorClass(
BoundSignature boundSignature,
Class accumulatorInterface,
AggregationImplementation implementation,
List argumentNullable,
DynamicClassLoader classLoader,
boolean specializedLoops)
{
boolean grouped = accumulatorInterface == GroupedAccumulator.class;
ClassDefinition definition = new ClassDefinition(
a(PUBLIC, FINAL),
makeClassName(boundSignature.getName().getFunctionName() + accumulatorInterface.getSimpleName()),
type(Object.class),
type(accumulatorInterface));
CallSiteBinder callSiteBinder = new CallSiteBinder();
List> stateDescriptors = implementation.getAccumulatorStateDescriptors();
List stateFieldAndDescriptors = new ArrayList<>();
for (int i = 0; i < stateDescriptors.size(); i++) {
stateFieldAndDescriptors.add(new StateFieldAndDescriptor(
stateDescriptors.get(i),
definition.declareField(a(PRIVATE, FINAL), "stateSerializer_" + i, AccumulatorStateSerializer.class),
definition.declareField(a(PRIVATE, FINAL), "stateFactory_" + i, AccumulatorStateFactory.class),
definition.declareField(a(PRIVATE, FINAL), "state_" + i, grouped ? GroupedAccumulatorState.class : AccumulatorState.class)));
}
List stateFields = stateFieldAndDescriptors.stream()
.map(StateFieldAndDescriptor::getStateField)
.collect(toImmutableList());
int lambdaCount = implementation.getLambdaInterfaces().size();
List lambdaProviderFields = new ArrayList<>(lambdaCount);
for (int i = 0; i < lambdaCount; i++) {
lambdaProviderFields.add(definition.declareField(a(PRIVATE, FINAL), "lambdaProvider_" + i, Supplier.class));
}
// Generate constructors
generateConstructor(
definition,
stateFieldAndDescriptors,
lambdaProviderFields,
callSiteBinder,
grouped);
generateCopyConstructor(
definition,
stateFieldAndDescriptors,
lambdaProviderFields);
// Generate methods
generateCopy(definition, Accumulator.class);
generateAddInput(
definition,
specializedLoops,
stateFields,
argumentNullable,
lambdaProviderFields,
implementation.getInputFunction(),
callSiteBinder,
grouped);
generateGetEstimatedSize(definition, stateFields);
if (grouped) {
generateSetGroupCount(definition, stateFields);
}
generateAddIntermediateAsCombine(
definition,
stateFieldAndDescriptors,
lambdaProviderFields,
implementation.getCombineFunction(),
callSiteBinder,
grouped);
if (grouped) {
generateGroupedEvaluateIntermediate(definition, stateFieldAndDescriptors, true);
}
else {
generateEvaluateIntermediate(definition, stateFieldAndDescriptors, true);
}
if (grouped) {
generateGroupedEvaluateFinal(definition, stateFields, implementation.getOutputFunction(), callSiteBinder);
}
else {
generateEvaluateFinal(definition, stateFields, implementation.getOutputFunction(), callSiteBinder);
}
if (grouped) {
generatePrepareFinal(definition);
}
Class extends T> accumulatorClass = defineClass(definition, accumulatorInterface, callSiteBinder.getBindings(), classLoader);
try {
return accumulatorClass.getConstructor(List.class);
}
catch (NoSuchMethodException e) {
throw new RuntimeException(e);
}
}
public static Constructor extends WindowAccumulator> generateWindowAccumulatorClass(
BoundSignature boundSignature,
AggregationImplementation implementation,
FunctionNullability functionNullability)
{
// change types used in Aggregation methods to types used in the core Trino engine to simplify code generation
implementation = normalizeAggregationMethods(implementation);
DynamicClassLoader classLoader = new DynamicClassLoader(AccumulatorCompiler.class.getClassLoader());
List argumentNullable = functionNullability.getArgumentNullable()
.subList(0, functionNullability.getArgumentNullable().size() - implementation.getLambdaInterfaces().size());
ClassDefinition definition = new ClassDefinition(
a(PUBLIC, FINAL),
makeClassName(boundSignature.getName().getFunctionName() + WindowAccumulator.class.getSimpleName()),
type(Object.class),
type(WindowAccumulator.class));
CallSiteBinder callSiteBinder = new CallSiteBinder();
List> stateDescriptors = implementation.getAccumulatorStateDescriptors();
List stateFieldAndDescriptors = new ArrayList<>();
for (int i = 0; i < stateDescriptors.size(); i++) {
stateFieldAndDescriptors.add(new StateFieldAndDescriptor(
stateDescriptors.get(i),
definition.declareField(a(PRIVATE, FINAL), "stateSerializer_" + i, AccumulatorStateSerializer.class),
definition.declareField(a(PRIVATE, FINAL), "stateFactory_" + i, AccumulatorStateFactory.class),
definition.declareField(a(PRIVATE, FINAL), "state_" + i, AccumulatorState.class)));
}
List stateFields = stateFieldAndDescriptors.stream()
.map(StateFieldAndDescriptor::getStateField)
.collect(toImmutableList());
int lambdaCount = implementation.getLambdaInterfaces().size();
List lambdaProviderFields = new ArrayList<>(lambdaCount);
for (int i = 0; i < lambdaCount; i++) {
lambdaProviderFields.add(definition.declareField(a(PRIVATE, FINAL), "lambdaProvider_" + i, Supplier.class));
}
// Generate constructor
generateWindowAccumulatorConstructor(
definition,
stateFieldAndDescriptors,
lambdaProviderFields,
callSiteBinder);
generateCopyConstructor(
definition,
stateFieldAndDescriptors,
lambdaProviderFields);
// Generate methods
generateCopy(definition, WindowAccumulator.class);
generateAddOrRemoveInputWindowIndex(
definition,
stateFields,
argumentNullable,
lambdaProviderFields,
implementation.getInputFunction(),
"addInput",
callSiteBinder);
implementation.getRemoveInputFunction().ifPresent(
removeInputFunction -> generateAddOrRemoveInputWindowIndex(
definition,
stateFields,
argumentNullable,
lambdaProviderFields,
removeInputFunction,
"removeInput",
callSiteBinder));
generateEvaluateFinal(definition, stateFields, implementation.getOutputFunction(), callSiteBinder);
generateGetEstimatedSize(definition, stateFields);
Class extends WindowAccumulator> windowAccumulatorClass = defineClass(definition, WindowAccumulator.class, callSiteBinder.getBindings(), classLoader);
try {
return windowAccumulatorClass.getConstructor(List.class);
}
catch (NoSuchMethodException e) {
throw new RuntimeException(e);
}
}
private static void generateWindowAccumulatorConstructor(
ClassDefinition definition,
List stateFieldAndDescriptors,
List lambdaProviderFields,
CallSiteBinder callSiteBinder)
{
Parameter lambdaProviders = arg("lambdaProviders", type(List.class, Supplier.class));
MethodDefinition method = definition.declareConstructor(
a(PUBLIC),
lambdaProviders);
BytecodeBlock body = method.getBody();
Variable thisVariable = method.getThis();
body.comment("super();")
.append(thisVariable)
.invokeConstructor(Object.class);
initializeStateFields(method, stateFieldAndDescriptors, callSiteBinder, false);
initializeLambdaProviderFields(method, lambdaProviderFields, lambdaProviders);
body.ret();
}
private static void generateGetEstimatedSize(ClassDefinition definition, List stateFields)
{
MethodDefinition method = definition.declareMethod(a(PUBLIC), "getEstimatedSize", type(long.class));
Variable estimatedSize = method.getScope().declareVariable(long.class, "estimatedSize");
method.getBody().append(estimatedSize.set(constantLong(0L)));
for (FieldDefinition stateField : stateFields) {
method.getBody()
.append(estimatedSize.set(
BytecodeExpressions.add(
estimatedSize,
method.getThis().getField(stateField).invoke("getEstimatedSize", long.class))));
}
method.getBody().append(estimatedSize.ret());
}
private static void generateSetGroupCount(ClassDefinition definition, List stateFields)
{
Parameter groupCount = arg("groupCount", long.class);
MethodDefinition method = definition.declareMethod(a(PUBLIC), "setGroupCount", type(void.class), groupCount);
BytecodeBlock body = method.getBody();
for (FieldDefinition stateField : stateFields) {
BytecodeExpression state = method.getScope().getThis().getField(stateField);
body.append(state.invoke("ensureCapacity", void.class, groupCount));
}
body.ret();
}
private static void generateAddInput(
ClassDefinition definition,
boolean specializedLoops,
List stateField,
List argumentNullable,
List lambdaProviderFields,
MethodHandle inputFunction,
CallSiteBinder callSiteBinder,
boolean grouped)
{
ImmutableList.Builder parameters = ImmutableList.builder();
if (grouped) {
parameters.add(arg("groupIds", int[].class));
}
Parameter arguments = arg("arguments", Page.class);
parameters.add(arguments);
Parameter mask = arg("mask", AggregationMask.class);
parameters.add(mask);
MethodDefinition method = definition.declareMethod(a(PUBLIC), "addInput", type(void.class), parameters.build());
Scope scope = method.getScope();
BytecodeBlock body = method.getBody();
List parameterVariables = new ArrayList<>();
for (int i = 0; i < argumentNullable.size(); i++) {
parameterVariables.add(scope.declareVariable(Block.class, "block" + i));
}
// Get all parameter blocks
for (int i = 0; i < parameterVariables.size(); i++) {
body.comment("%s = arguments.getBlock(%d);", parameterVariables.get(i).getName(), i)
.append(parameterVariables.get(i).set(arguments.invoke("getBlock", Block.class, constantInt(i))));
}
BytecodeBlock block = generateInputForLoop(
specializedLoops,
stateField,
inputFunction,
scope,
parameterVariables,
lambdaProviderFields,
mask,
callSiteBinder,
grouped);
body.append(block);
body.ret();
}
private static void generateAddOrRemoveInputWindowIndex(
ClassDefinition definition,
List stateField,
List argumentNullable,
List lambdaProviderFields,
MethodHandle inputFunction,
String generatedFunctionName,
CallSiteBinder callSiteBinder)
{
// TODO: implement masking based on maskChannel field once Window Functions support DISTINCT arguments to the functions.
Parameter index = arg("index", WindowIndex.class);
Parameter startPosition = arg("startPosition", int.class);
Parameter endPosition = arg("endPosition", int.class);
MethodDefinition method = definition.declareMethod(
a(PUBLIC),
generatedFunctionName,
type(void.class),
ImmutableList.of(index, startPosition, endPosition));
Scope scope = method.getScope();
BytecodeBlock body = method.getBody();
Variable position = scope.declareVariable(int.class, "position");
// input parameters
Variable inputBlockPosition = scope.declareVariable(int.class, "inputBlockPosition");
List inputBlockVariables = new ArrayList<>();
for (int i = 0; i < argumentNullable.size(); i++) {
inputBlockVariables.add(scope.declareVariable(Block.class, "inputBlock" + i));
}
Binding binding = callSiteBinder.bind(inputFunction);
BytecodeBlock invokeInputFunction = new BytecodeBlock();
// WindowIndex is built on PagesIndex, which simply wraps Blocks
// and currently does not understand ValueBlocks.
// Until PagesIndex is updated to understand ValueBlocks, the
// input function parameters must be directly unwrapped to ValueBlocks.
invokeInputFunction.append(inputBlockPosition.set(index.cast(InternalWindowIndex.class).invoke("getRawBlockPosition", int.class, position)));
for (int i = 0; i < inputBlockVariables.size(); i++) {
invokeInputFunction.append(inputBlockVariables.get(i).set(index.cast(InternalWindowIndex.class).invoke("getRawBlock", Block.class, constantInt(i), position)));
}
invokeInputFunction.append(invokeDynamic(
BOOTSTRAP_METHOD,
ImmutableList.of(binding.getBindingId()),
generatedFunctionName,
binding.getType(),
getInvokeFunctionOnWindowIndexParameters(
scope.getThis(),
stateField,
inputBlockPosition,
inputBlockVariables,
lambdaProviderFields)));
body.append(new ForLoop()
.initialize(position.set(startPosition))
.condition(BytecodeExpressions.lessThanOrEqual(position, endPosition))
.update(position.increment())
.body(new IfStatement()
.condition(anyParametersAreNull(argumentNullable, index, position))
.ifFalse(invokeInputFunction)))
.ret();
}
private static BytecodeExpression anyParametersAreNull(
List argumentNullable,
Variable index,
Variable position)
{
BytecodeExpression isNull = constantFalse();
for (int inputChannel = 0; inputChannel < argumentNullable.size(); inputChannel++) {
if (!argumentNullable.get(inputChannel)) {
isNull = BytecodeExpressions.or(isNull, index.invoke("isNull", boolean.class, constantInt(inputChannel), position));
}
}
return isNull;
}
private static List getInvokeFunctionOnWindowIndexParameters(
Variable thisVariable,
List stateField,
Variable inputBlockPosition,
List inputBlockVariables,
List lambdaProviderFields)
{
List expressions = new ArrayList<>();
// state parameters
for (FieldDefinition field : stateField) {
expressions.add(thisVariable.getField(field));
}
// input parameters
for (Variable blockVariable : inputBlockVariables) {
expressions.add(blockVariable.invoke("getUnderlyingValueBlock", ValueBlock.class));
expressions.add(blockVariable.invoke("getUnderlyingValuePosition", int.class, inputBlockPosition));
}
// lambda parameters
for (FieldDefinition lambdaProviderField : lambdaProviderFields) {
expressions.add(thisVariable.getField(lambdaProviderField)
.invoke("get", Object.class));
}
return expressions;
}
private static BytecodeBlock generateInputForLoop(
boolean specializedLoops,
List stateField,
MethodHandle inputFunction,
Scope scope,
List parameterVariables,
List lambdaProviderFields,
Variable mask,
CallSiteBinder callSiteBinder,
boolean grouped)
{
if (specializedLoops) {
BytecodeBlock newBlock = new BytecodeBlock();
Variable thisVariable = scope.getThis();
MethodHandle mainLoop = buildLoop(inputFunction, stateField.size(), parameterVariables.size(), grouped);
ImmutableList.Builder parameters = ImmutableList.builder();
parameters.add(mask);
if (grouped) {
parameters.add(scope.getVariable("groupIds"));
}
for (FieldDefinition fieldDefinition : stateField) {
parameters.add(thisVariable.getField(fieldDefinition));
}
parameters.addAll(parameterVariables);
for (FieldDefinition lambdaProviderField : lambdaProviderFields) {
parameters.add(scope.getThis().getField(lambdaProviderField)
.invoke("get", Object.class));
}
newBlock.append(invoke(callSiteBinder.bind(mainLoop), "mainLoop", parameters.build()));
return newBlock;
}
// For-loop over rows
Variable positionVariable = scope.declareVariable(int.class, "position");
Variable rowsVariable = scope.declareVariable(int.class, "rows");
Variable selectedPositionsArrayVariable = scope.declareVariable(int[].class, "selectedPositionsArray");
Variable selectedPositionVariable = scope.declareVariable(int.class, "selectedPosition");
BytecodeBlock block = new BytecodeBlock()
.initializeVariable(rowsVariable)
.initializeVariable(selectedPositionVariable)
.initializeVariable(positionVariable);
ForLoop selectAllLoop = new ForLoop()
.initialize(new BytecodeBlock()
.append(rowsVariable.set(mask.invoke("getPositionCount", int.class)))
.append(positionVariable.set(constantInt(0))))
.condition(BytecodeExpressions.lessThan(positionVariable, rowsVariable))
.update(new BytecodeBlock().incrementVariable(positionVariable, (byte) 1))
.body(generateInvokeInputFunction(
scope,
stateField,
positionVariable,
parameterVariables,
lambdaProviderFields,
inputFunction,
callSiteBinder,
grouped));
ForLoop selectedPositionsLoop = new ForLoop()
.initialize(new BytecodeBlock()
.append(rowsVariable.set(mask.invoke("getSelectedPositionCount", int.class)))
.append(selectedPositionsArrayVariable.set(mask.invoke("getSelectedPositions", int[].class)))
.append(positionVariable.set(constantInt(0))))
.condition(BytecodeExpressions.lessThan(positionVariable, rowsVariable))
.update(new BytecodeBlock().incrementVariable(positionVariable, (byte) 1))
.body(new BytecodeBlock()
.append(selectedPositionVariable.set(selectedPositionsArrayVariable.getElement(positionVariable)))
.append(generateInvokeInputFunction(
scope,
stateField,
selectedPositionVariable,
parameterVariables,
lambdaProviderFields,
inputFunction,
callSiteBinder,
grouped)));
block.append(new IfStatement()
.condition(mask.invoke("isSelectAll", boolean.class))
.ifTrue(selectAllLoop)
.ifFalse(selectedPositionsLoop));
return block;
}
private static BytecodeBlock generateInvokeInputFunction(
Scope scope,
List stateField,
Variable position,
List parameterVariables,
List lambdaProviderFields,
MethodHandle inputFunction,
CallSiteBinder callSiteBinder,
boolean grouped)
{
BytecodeBlock block = new BytecodeBlock();
if (grouped) {
generateSetGroupIdFromGroupIds(scope, stateField, block, position);
}
block.comment("Call input function with unpacked Block arguments");
List parameters = new ArrayList<>();
// state parameters
for (FieldDefinition field : stateField) {
parameters.add(scope.getThis().getField(field));
}
// input parameters
for (Variable variable : parameterVariables) {
parameters.add(variable.invoke("getUnderlyingValueBlock", ValueBlock.class));
parameters.add(variable.invoke("getUnderlyingValuePosition", int.class, position));
}
// lambda parameters
for (FieldDefinition lambdaProviderField : lambdaProviderFields) {
parameters.add(scope.getThis().getField(lambdaProviderField)
.invoke("get", Object.class));
}
block.append(invoke(callSiteBinder.bind(inputFunction), "input", parameters));
return block;
}
private static void generateAddIntermediateAsCombine(
ClassDefinition definition,
List stateFieldAndDescriptors,
List lambdaProviderFields,
Optional combineFunction,
CallSiteBinder callSiteBinder,
boolean grouped)
{
MethodDefinition method = declareAddIntermediate(definition, grouped);
if (combineFunction.isEmpty()) {
method.getBody()
.append(newInstance(UnsupportedOperationException.class, constantString("Aggregation is not decomposable")))
.throwObject();
return;
}
Scope scope = method.getScope();
BytecodeBlock body = method.getBody();
Variable thisVariable = method.getThis();
int stateCount = stateFieldAndDescriptors.size();
List scratchStates = new ArrayList<>();
for (int i = 0; i < stateCount; i++) {
Class> scratchStateClass = AccumulatorState.class;
scratchStates.add(scope.declareVariable(scratchStateClass, "scratchState_" + i));
}
List block;
if (stateCount == 1) {
block = ImmutableList.of(scope.getVariable("block"));
}
else {
// ColumnarRow is used to get the column blocks represents each state, this allows to
// 1. handle single state and multiple states in a unified way
// 2. avoid the cost of constructing SingleRowBlock for each group
Variable columnarRow = scope.declareVariable(ColumnarRow.class, "columnarRow");
body.append(columnarRow.set(
invokeStatic(ColumnarRow.class, "toColumnarRow", ColumnarRow.class, scope.getVariable("block"))));
block = new ArrayList<>();
for (int i = 0; i < stateCount; i++) {
Variable columnBlock = scope.declareVariable(Block.class, "columnBlock_" + i);
body.append(columnBlock.set(
columnarRow.invoke("getField", Block.class, constantInt(i))));
block.add(columnBlock);
}
}
Variable position = scope.declareVariable(int.class, "position");
for (int i = 0; i < stateCount; i++) {
FieldDefinition stateFactoryField = stateFieldAndDescriptors.get(i).getStateFactoryField();
body.comment(format("scratchState_%s = stateFactory[%s].createSingleState();", i, i))
.append(thisVariable.getField(stateFactoryField))
.invokeInterface(AccumulatorStateFactory.class, "createSingleState", AccumulatorState.class)
.checkCast(scratchStates.get(i).getType())
.putVariable(scratchStates.get(i));
}
List stateFields = stateFieldAndDescriptors.stream()
.map(StateFieldAndDescriptor::getStateField)
.collect(toImmutableList());
BytecodeBlock loopBody = new BytecodeBlock();
loopBody.comment("combine(state_0, state_1, ... scratchState_0, scratchState_1, ... lambda_0, lambda_1, ...)");
for (FieldDefinition stateField : stateFields) {
if (grouped) {
Variable groupIds = scope.getVariable("groupIds");
loopBody.append(thisVariable.getField(stateField).invoke("setGroupId", void.class, groupIds.getElement(position).cast(long.class)));
}
loopBody.append(thisVariable.getField(stateField));
}
for (int i = 0; i < stateCount; i++) {
FieldDefinition stateSerializerField = stateFieldAndDescriptors.get(i).getStateSerializerField();
loopBody.append(thisVariable.getField(stateSerializerField).invoke("deserialize", void.class, block.get(i), position, scratchStates.get(i).cast(AccumulatorState.class)));
loopBody.append(scratchStates.get(i));
}
for (FieldDefinition lambdaProviderField : lambdaProviderFields) {
loopBody.append(scope.getThis().getField(lambdaProviderField)
.invoke("get", Object.class));
}
loopBody.append(invoke(callSiteBinder.bind(combineFunction.get()), "combine"));
body.append(generateBlockNonNullPositionForLoop(scope, position, loopBody))
.ret();
}
private static void generateSetGroupIdFromGroupIds(Scope scope, List stateFields, BytecodeBlock block, Variable position)
{
Variable groupIds = scope.getVariable("groupIds");
for (FieldDefinition stateField : stateFields) {
BytecodeExpression state = scope.getThis().getField(stateField);
block.append(state.invoke("setGroupId", void.class, groupIds.getElement(position).cast(long.class)));
}
}
private static MethodDefinition declareAddIntermediate(ClassDefinition definition, boolean grouped)
{
ImmutableList.Builder parameters = ImmutableList.builder();
if (grouped) {
parameters.add(arg("groupIds", int[].class));
}
parameters.add(arg("block", Block.class));
return definition.declareMethod(
a(PUBLIC),
"addIntermediate",
type(void.class),
parameters.build());
}
// Generates a for-loop with a local variable named "position" defined, with the current position in the block,
// loopBody will only be executed for non-null positions in the Block
private static BytecodeBlock generateBlockNonNullPositionForLoop(Scope scope, Variable positionVariable, BytecodeBlock loopBody)
{
Variable rowsVariable = scope.declareVariable(int.class, "rows");
Variable blockVariable = scope.getVariable("block");
BytecodeBlock block = new BytecodeBlock()
.append(blockVariable)
.invokeInterface(Block.class, "getPositionCount", int.class)
.putVariable(rowsVariable);
IfStatement ifStatement = new IfStatement("if(!block.isNull(position))")
.condition(new BytecodeBlock()
.append(blockVariable)
.append(positionVariable)
.invokeInterface(Block.class, "isNull", boolean.class, int.class))
.ifFalse(loopBody);
block.append(new ForLoop()
.initialize(positionVariable.set(constantInt(0)))
.condition(new BytecodeBlock()
.append(positionVariable)
.append(rowsVariable)
.invokeStatic(CompilerOperations.class, "lessThan", boolean.class, int.class, int.class))
.update(new BytecodeBlock().incrementVariable(positionVariable, (byte) 1))
.body(ifStatement));
return block;
}
private static void generateGroupedEvaluateIntermediate(ClassDefinition definition, List stateFieldAndDescriptors, boolean decomposable)
{
Parameter groupId = arg("groupId", int.class);
Parameter out = arg("out", BlockBuilder.class);
MethodDefinition method = definition.declareMethod(a(PUBLIC), "evaluateIntermediate", type(void.class), groupId, out);
if (!decomposable) {
method.getBody()
.append(newInstance(UnsupportedOperationException.class, constantString("Aggregation is not decomposable")))
.throwObject();
return;
}
Variable thisVariable = method.getThis();
BytecodeBlock body = method.getBody();
if (stateFieldAndDescriptors.size() == 1) {
BytecodeExpression stateSerializer = thisVariable.getField(getOnlyElement(stateFieldAndDescriptors).getStateSerializerField());
BytecodeExpression state = thisVariable.getField(getOnlyElement(stateFieldAndDescriptors).getStateField());
body.append(state.invoke("setGroupId", void.class, groupId.cast(long.class)))
.append(stateSerializer.invoke("serialize", void.class, state.cast(AccumulatorState.class), out))
.ret();
}
else {
for (StateFieldAndDescriptor stateFieldAndDescriptor : stateFieldAndDescriptors) {
BytecodeExpression state = thisVariable.getField(stateFieldAndDescriptor.getStateField());
body.append(state.invoke("setGroupId", void.class, groupId.cast(long.class)));
}
generateSerializeState(definition, stateFieldAndDescriptors, out, thisVariable, body);
body.ret();
}
}
private static void generateEvaluateIntermediate(ClassDefinition definition, List stateFieldAndDescriptors, boolean decomposable)
{
Parameter out = arg("out", BlockBuilder.class);
MethodDefinition method = definition.declareMethod(
a(PUBLIC),
"evaluateIntermediate",
type(void.class),
out);
if (!decomposable) {
method.getBody()
.append(newInstance(UnsupportedOperationException.class, constantString("Aggregation is not decomposable")))
.throwObject();
return;
}
Variable thisVariable = method.getThis();
BytecodeBlock body = method.getBody();
if (stateFieldAndDescriptors.size() == 1) {
BytecodeExpression stateSerializer = thisVariable.getField(getOnlyElement(stateFieldAndDescriptors).getStateSerializerField());
BytecodeExpression state = thisVariable.getField(getOnlyElement(stateFieldAndDescriptors).getStateField());
body.append(stateSerializer.invoke("serialize", void.class, state.cast(AccumulatorState.class), out))
.ret();
}
else {
generateSerializeState(definition, stateFieldAndDescriptors, out, thisVariable, body);
body.ret();
}
}
private static void generateSerializeState(ClassDefinition definition, List stateFieldAndDescriptors, Parameter out, Variable thisVariable, BytecodeBlock body)
{
MethodDefinition serializeState = generateSerializeStateMethod(definition, stateFieldAndDescriptors);
BytecodeExpression rowEntryBuilder = generateMetafactory(RowValueBuilder.class, serializeState, ImmutableList.of(thisVariable));
body.append(out.cast(RowBlockBuilder.class).invoke("buildEntry", void.class, rowEntryBuilder));
}
private static MethodDefinition generateSerializeStateMethod(ClassDefinition definition, List stateFieldAndDescriptors)
{
Parameter fieldBuilders = arg("fieldBuilders", type(List.class, BlockBuilder.class));
MethodDefinition method = definition.declareMethod(a(PRIVATE), "serializeState", type(void.class), fieldBuilders);
Variable thisVariable = method.getThis();
BytecodeBlock body = method.getBody();
for (int i = 0; i < stateFieldAndDescriptors.size(); i++) {
StateFieldAndDescriptor stateFieldAndDescriptor = stateFieldAndDescriptors.get(i);
BytecodeExpression stateSerializer = thisVariable.getField(stateFieldAndDescriptor.getStateSerializerField());
BytecodeExpression state = thisVariable.getField(stateFieldAndDescriptor.getStateField());
BytecodeExpression fieldBuilder = fieldBuilders.invoke("get", Object.class, constantInt(i)).cast(BlockBuilder.class);
body.append(stateSerializer.invoke("serialize", void.class, state.cast(AccumulatorState.class), fieldBuilder));
}
body.ret();
return method;
}
private static void generateGroupedEvaluateFinal(
ClassDefinition definition,
List stateFields,
MethodHandle outputFunction,
CallSiteBinder callSiteBinder)
{
Parameter groupId = arg("groupId", int.class);
Parameter out = arg("out", BlockBuilder.class);
MethodDefinition method = definition.declareMethod(a(PUBLIC), "evaluateFinal", type(void.class), groupId, out);
BytecodeBlock body = method.getBody();
Variable thisVariable = method.getThis();
List states = new ArrayList<>();
for (FieldDefinition stateField : stateFields) {
BytecodeExpression state = thisVariable.getField(stateField);
body.append(state.invoke("setGroupId", void.class, groupId.cast(long.class)));
states.add(state);
}
body.comment("output(state_0, state_1, ..., out)");
states.forEach(body::append);
body.append(out);
body.append(invoke(callSiteBinder.bind(outputFunction), "output"));
body.ret();
}
private static void generateEvaluateFinal(
ClassDefinition definition,
List stateFields,
MethodHandle outputFunction,
CallSiteBinder callSiteBinder)
{
Parameter out = arg("out", BlockBuilder.class);
MethodDefinition method = definition.declareMethod(
a(PUBLIC),
"evaluateFinal",
type(void.class),
out);
BytecodeBlock body = method.getBody();
Variable thisVariable = method.getThis();
List states = new ArrayList<>();
for (FieldDefinition stateField : stateFields) {
BytecodeExpression state = thisVariable.getField(stateField);
states.add(state);
}
body.comment("output(state_0, state_1, ..., out)");
states.forEach(body::append);
body.append(out);
body.append(invoke(callSiteBinder.bind(outputFunction), "output"));
body.ret();
}
private static void generatePrepareFinal(ClassDefinition definition)
{
MethodDefinition method = definition.declareMethod(
a(PUBLIC),
"prepareFinal",
type(void.class));
method.getBody().ret();
}
private static void generateConstructor(
ClassDefinition definition,
List stateFieldAndDescriptors,
List lambdaProviderFields,
CallSiteBinder callSiteBinder,
boolean grouped)
{
Parameter lambdaProviders = arg("lambdaProviders", type(List.class, Supplier.class));
MethodDefinition method = definition.declareConstructor(
a(PUBLIC),
lambdaProviders);
BytecodeBlock body = method.getBody();
Variable thisVariable = method.getThis();
body.comment("super();")
.append(thisVariable)
.invokeConstructor(Object.class);
body.append(generateRequireNotNull(lambdaProviders));
initializeStateFields(method, stateFieldAndDescriptors, callSiteBinder, grouped);
initializeLambdaProviderFields(method, lambdaProviderFields, lambdaProviders);
body.ret();
}
private static void initializeStateFields(
MethodDefinition method,
List stateFieldAndDescriptors,
CallSiteBinder callSiteBinder,
boolean grouped)
{
BytecodeBlock body = method.getBody();
Variable thisVariable = method.getThis();
for (StateFieldAndDescriptor fieldAndDescriptor : stateFieldAndDescriptors) {
AccumulatorStateDescriptor> accumulatorStateDescriptor = fieldAndDescriptor.getAccumulatorStateDescriptor();
body.append(thisVariable.setField(
fieldAndDescriptor.getStateSerializerField(),
loadConstant(callSiteBinder, accumulatorStateDescriptor.getSerializer(), AccumulatorStateSerializer.class)));
body.append(generateRequireNotNull(thisVariable, fieldAndDescriptor.getStateSerializerField()));
body.append(thisVariable.setField(
fieldAndDescriptor.getStateFactoryField(),
loadConstant(callSiteBinder, accumulatorStateDescriptor.getFactory(), AccumulatorStateFactory.class)));
body.append(generateRequireNotNull(thisVariable, fieldAndDescriptor.getStateFactoryField()));
// create the state object
FieldDefinition stateField = fieldAndDescriptor.getStateField();
BytecodeExpression stateFactory = thisVariable.getField(fieldAndDescriptor.getStateFactoryField());
BytecodeExpression createStateInstance = stateFactory.invoke(grouped ? "createGroupedState" : "createSingleState", AccumulatorState.class);
body.append(thisVariable.setField(stateField, createStateInstance.cast(stateField.getType())));
body.append(generateRequireNotNull(thisVariable, stateField));
}
}
private static void initializeLambdaProviderFields(MethodDefinition method, List lambdaProviderFields, Parameter lambdaProviders)
{
BytecodeBlock body = method.getBody();
Variable thisVariable = method.getThis();
for (int i = 0; i < lambdaProviderFields.size(); i++) {
body.append(thisVariable.setField(
lambdaProviderFields.get(i),
lambdaProviders.invoke("get", Object.class, constantInt(i))
.cast(Supplier.class)));
body.append(generateRequireNotNull(thisVariable, lambdaProviderFields.get(i)));
}
}
private static void generateCopyConstructor(
ClassDefinition definition,
List stateFieldAndDescriptors,
List lambdaProviderFields)
{
Parameter source = arg("source", definition.getType());
MethodDefinition method = definition.declareConstructor(
a(PUBLIC),
source);
BytecodeBlock body = method.getBody();
Variable thisVariable = method.getThis();
body.comment("super();")
.append(thisVariable)
.invokeConstructor(Object.class);
body.append(generateRequireNotNull(source));
for (StateFieldAndDescriptor descriptor : stateFieldAndDescriptors) {
FieldDefinition stateSerializerField = descriptor.getStateSerializerField();
body.append(thisVariable.setField(stateSerializerField, source.getField(stateSerializerField)));
body.append(generateRequireNotNull(thisVariable, stateSerializerField));
FieldDefinition stateFactoryField = descriptor.getStateFactoryField();
body.append(thisVariable.setField(stateFactoryField, source.getField(stateFactoryField)));
body.append(generateRequireNotNull(thisVariable, stateFactoryField));
FieldDefinition stateField = descriptor.getStateField();
body.append(thisVariable.setField(stateField, source.getField(stateField).invoke("copy", AccumulatorState.class).cast(stateField.getType())));
body.append(generateRequireNotNull(thisVariable, stateField));
}
for (FieldDefinition lambdaProviderField : lambdaProviderFields) {
body.append(thisVariable.setField(lambdaProviderField, source.getField(lambdaProviderField)));
body.append(generateRequireNotNull(thisVariable, lambdaProviderField));
}
body.ret();
}
private static void generateCopy(ClassDefinition definition, Class> returnType)
{
MethodDefinition copy = definition.declareMethod(a(PUBLIC), "copy", type(returnType));
copy.getBody()
.append(newInstance(definition.getType(), copy.getScope().getThis()).ret());
}
private static BytecodeExpression generateRequireNotNull(Variable variable)
{
return generateRequireNotNull(variable, variable.getName() + " is null");
}
private static BytecodeExpression generateRequireNotNull(Variable variable, FieldDefinition field)
{
return generateRequireNotNull(variable.getField(field), field.getName() + " is null");
}
private static BytecodeExpression generateRequireNotNull(BytecodeExpression expression, String message)
{
return invokeStatic(Objects.class, "requireNonNull", Object.class, expression.cast(Object.class), constantString(message))
.cast(expression.getType());
}
private static AggregationImplementation normalizeAggregationMethods(AggregationImplementation implementation)
{
// change aggregations state variables to simply AccumulatorState to avoid any class loader issues in generated code
int lambdaParameterCount = implementation.getLambdaInterfaces().size();
AggregationImplementation.Builder builder = AggregationImplementation.builder();
builder.inputFunction(normalizeParameters(implementation.getInputFunction(), lambdaParameterCount));
implementation.getRemoveInputFunction()
.map(removeFunction -> normalizeParameters(removeFunction, lambdaParameterCount))
.ifPresent(builder::removeInputFunction);
implementation.getCombineFunction()
.map(combineFunction -> normalizeParameters(combineFunction, lambdaParameterCount))
.ifPresent(builder::combineFunction);
builder.outputFunction(normalizeParameters(implementation.getOutputFunction(), 0));
builder.accumulatorStateDescriptors(implementation.getAccumulatorStateDescriptors());
builder.lambdaInterfaces(implementation.getLambdaInterfaces());
return builder.build();
}
private static MethodHandle normalizeParameters(MethodHandle function, int lambdaParameterCount)
{
Class>[] parameterTypes = function.type().parameterArray();
for (int i = 0; i < parameterTypes.length; i++) {
Class> parameterType = parameterTypes[i];
if (AccumulatorState.class.isAssignableFrom(parameterType)) {
parameterTypes[i] = AccumulatorState.class;
}
else if (ValueBlock.class.isAssignableFrom(parameterType)) {
parameterTypes[i] = ValueBlock.class;
}
}
for (int i = parameterTypes.length - lambdaParameterCount; i < parameterTypes.length; i++) {
parameterTypes[i] = Object.class;
}
MethodType newType = MethodType.methodType(function.type().returnType(), parameterTypes);
return MethodHandles.explicitCastArguments(function, newType);
}
private static class StateFieldAndDescriptor
{
private final AccumulatorStateDescriptor> accumulatorStateDescriptor;
private final FieldDefinition stateSerializerField;
private final FieldDefinition stateFactoryField;
private final FieldDefinition stateField;
private StateFieldAndDescriptor(AccumulatorStateDescriptor> accumulatorStateDescriptor, FieldDefinition stateSerializerField, FieldDefinition stateFactoryField, FieldDefinition stateField)
{
this.accumulatorStateDescriptor = accumulatorStateDescriptor;
this.stateSerializerField = requireNonNull(stateSerializerField, "stateSerializerField is null");
this.stateFactoryField = requireNonNull(stateFactoryField, "stateFactoryField is null");
this.stateField = requireNonNull(stateField, "stateField is null");
}
public AccumulatorStateDescriptor> getAccumulatorStateDescriptor()
{
return accumulatorStateDescriptor;
}
private FieldDefinition getStateSerializerField()
{
return stateSerializerField;
}
private FieldDefinition getStateFactoryField()
{
return stateFactoryField;
}
private FieldDefinition getStateField()
{
return stateField;
}
}
}