
io.prestosql.operator.aggregation.ParametricAggregation Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.operator.aggregation;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.DynamicClassLoader;
import io.prestosql.metadata.BoundSignature;
import io.prestosql.metadata.FunctionBinding;
import io.prestosql.metadata.FunctionDependencies;
import io.prestosql.metadata.FunctionDependencyDeclaration;
import io.prestosql.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder;
import io.prestosql.metadata.FunctionMetadata;
import io.prestosql.metadata.Signature;
import io.prestosql.metadata.SqlAggregationFunction;
import io.prestosql.operator.ParametricImplementationsGroup;
import io.prestosql.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata;
import io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType;
import io.prestosql.operator.aggregation.state.StateCompiler;
import io.prestosql.operator.annotations.ImplementationDependency;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.function.AccumulatorStateFactory;
import io.prestosql.spi.function.AccumulatorStateSerializer;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeSignature;
import java.lang.invoke.MethodHandle;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.prestosql.metadata.FunctionKind.AGGREGATE;
import static io.prestosql.operator.ParametricFunctionHelpers.bindDependencies;
import static io.prestosql.operator.aggregation.AggregationUtils.generateAggregationName;
import static io.prestosql.operator.aggregation.state.StateCompiler.generateStateSerializer;
import static io.prestosql.operator.aggregation.state.StateCompiler.getSerializedType;
import static io.prestosql.spi.StandardErrorCode.AMBIGUOUS_FUNCTION_CALL;
import static io.prestosql.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
public class ParametricAggregation
extends SqlAggregationFunction
{
private final ParametricImplementationsGroup implementations;
public ParametricAggregation(
Signature signature,
AggregationHeader details,
ParametricImplementationsGroup implementations,
boolean deprecated)
{
super(
new FunctionMetadata(
signature,
true,
implementations.getArgumentDefinitions(),
details.isHidden(),
true,
details.getDescription().orElse(""),
AGGREGATE,
deprecated),
details.isDecomposable(),
details.isOrderSensitive());
requireNonNull(details, "details is null");
checkArgument(implementations.isNullable(), "currently aggregates are required to be nullable");
this.implementations = requireNonNull(implementations, "implementations is null");
}
@Override
public FunctionDependencyDeclaration getFunctionDependencies()
{
FunctionDependencyDeclarationBuilder builder = FunctionDependencyDeclaration.builder();
declareDependencies(builder, implementations.getExactImplementations().values());
declareDependencies(builder, implementations.getSpecializedImplementations());
declareDependencies(builder, implementations.getGenericImplementations());
return builder.build();
}
private static void declareDependencies(FunctionDependencyDeclarationBuilder builder, Collection implementations)
{
for (AggregationImplementation implementation : implementations) {
for (ImplementationDependency dependency : implementation.getInputDependencies()) {
dependency.declareDependencies(builder);
}
for (ImplementationDependency dependency : implementation.getCombineDependencies()) {
dependency.declareDependencies(builder);
}
for (ImplementationDependency dependency : implementation.getOutputDependencies()) {
dependency.declareDependencies(builder);
}
}
}
@Override
public List getIntermediateTypes(FunctionBinding functionBinding)
{
// Find implementation matching arguments
AggregationImplementation concreteImplementation = findMatchingImplementation(functionBinding.getBoundSignature());
// Use state compiler to extract intermediate types
Type serializedType = getSerializedType(concreteImplementation.getStateClass());
return ImmutableList.of(serializedType.getTypeSignature());
}
@Override
public InternalAggregationFunction specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies)
{
// Bind variables
Signature signature = getFunctionMetadata().getSignature();
// Find implementation matching arguments
AggregationImplementation concreteImplementation = findMatchingImplementation(functionBinding.getBoundSignature());
// Build argument and return Types from signatures
List inputTypes = functionBinding.getBoundSignature().getArgumentTypes();
Type outputType = functionBinding.getBoundSignature().getReturnType();
// Create classloader for additional aggregation dependencies
Class> definitionClass = concreteImplementation.getDefinitionClass();
DynamicClassLoader classLoader = new DynamicClassLoader(definitionClass.getClassLoader(), getClass().getClassLoader());
// Build state factory and serializer
Class> stateClass = concreteImplementation.getStateClass();
AccumulatorStateSerializer> stateSerializer = generateStateSerializer(stateClass, classLoader);
AccumulatorStateFactory> stateFactory = StateCompiler.generateStateFactory(stateClass, classLoader);
// Bind provided dependencies to aggregation method handlers
MethodHandle inputHandle = bindDependencies(concreteImplementation.getInputFunction(), concreteImplementation.getInputDependencies(), functionBinding, functionDependencies);
Optional removeInputHandle = concreteImplementation.getRemoveInputFunction().map(
removeInputFunction -> bindDependencies(removeInputFunction, concreteImplementation.getRemoveInputDependencies(), functionBinding, functionDependencies));
MethodHandle combineHandle = bindDependencies(concreteImplementation.getCombineFunction(), concreteImplementation.getCombineDependencies(), functionBinding, functionDependencies);
MethodHandle outputHandle = bindDependencies(concreteImplementation.getOutputFunction(), concreteImplementation.getOutputDependencies(), functionBinding, functionDependencies);
// Build metadata of input parameters
List parametersMetadata = buildParameterMetadata(concreteImplementation.getInputParameterMetadataTypes(), inputTypes);
// Generate Aggregation name
String aggregationName = generateAggregationName(signature.getName(), outputType.getTypeSignature(), signaturesFromTypes(inputTypes));
// Collect all collected data in Metadata
AggregationMetadata aggregationMetadata = new AggregationMetadata(
aggregationName,
parametersMetadata,
inputHandle,
removeInputHandle,
combineHandle,
outputHandle,
ImmutableList.of(new AccumulatorStateDescriptor(
stateClass,
stateSerializer,
stateFactory)),
outputType);
// Create specialized InternalAggregregationFunction for Presto
return new InternalAggregationFunction(
signature.getName(),
inputTypes,
ImmutableList.of(stateSerializer.getSerializedType()),
outputType,
new LazyAccumulatorFactoryBinder(aggregationMetadata, classLoader));
}
@VisibleForTesting
public ParametricImplementationsGroup getImplementations()
{
return implementations;
}
private AggregationImplementation findMatchingImplementation(BoundSignature boundSignature)
{
Signature signature = boundSignature.toSignature();
Optional foundImplementation = Optional.empty();
if (implementations.getExactImplementations().containsKey(signature)) {
foundImplementation = Optional.of(implementations.getExactImplementations().get(signature));
}
else {
for (AggregationImplementation candidate : implementations.getGenericImplementations()) {
if (candidate.areTypesAssignable(boundSignature)) {
if (foundImplementation.isPresent()) {
throw new PrestoException(AMBIGUOUS_FUNCTION_CALL, format("Ambiguous function call (%s) for %s", boundSignature, getFunctionMetadata().getSignature()));
}
foundImplementation = Optional.of(candidate);
}
}
}
if (foundImplementation.isEmpty()) {
throw new PrestoException(FUNCTION_IMPLEMENTATION_MISSING, format("Unsupported type parameters (%s) for %s", boundSignature, getFunctionMetadata().getSignature()));
}
return foundImplementation.get();
}
private static List signaturesFromTypes(List types)
{
return types
.stream()
.map(Type::getTypeSignature)
.collect(toImmutableList());
}
private static List buildParameterMetadata(List parameterMetadataTypes, List inputTypes)
{
ImmutableList.Builder builder = ImmutableList.builder();
int inputId = 0;
for (ParameterType parameterMetadataType : parameterMetadataTypes) {
switch (parameterMetadataType) {
case STATE:
case BLOCK_INDEX:
builder.add(new ParameterMetadata(parameterMetadataType));
break;
case INPUT_CHANNEL:
case BLOCK_INPUT_CHANNEL:
case NULLABLE_BLOCK_INPUT_CHANNEL:
builder.add(new ParameterMetadata(parameterMetadataType, inputTypes.get(inputId++)));
break;
}
}
return builder.build();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy