All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.prestosql.operator.aggregation.ParametricAggregation Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.prestosql.operator.aggregation;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.DynamicClassLoader;
import io.prestosql.metadata.BoundSignature;
import io.prestosql.metadata.FunctionBinding;
import io.prestosql.metadata.FunctionDependencies;
import io.prestosql.metadata.FunctionDependencyDeclaration;
import io.prestosql.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder;
import io.prestosql.metadata.FunctionMetadata;
import io.prestosql.metadata.Signature;
import io.prestosql.metadata.SqlAggregationFunction;
import io.prestosql.operator.ParametricImplementationsGroup;
import io.prestosql.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata;
import io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType;
import io.prestosql.operator.aggregation.state.StateCompiler;
import io.prestosql.operator.annotations.ImplementationDependency;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.function.AccumulatorStateFactory;
import io.prestosql.spi.function.AccumulatorStateSerializer;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeSignature;

import java.lang.invoke.MethodHandle;
import java.util.Collection;
import java.util.List;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.prestosql.metadata.FunctionKind.AGGREGATE;
import static io.prestosql.operator.ParametricFunctionHelpers.bindDependencies;
import static io.prestosql.operator.aggregation.AggregationUtils.generateAggregationName;
import static io.prestosql.operator.aggregation.state.StateCompiler.generateStateSerializer;
import static io.prestosql.operator.aggregation.state.StateCompiler.getSerializedType;
import static io.prestosql.spi.StandardErrorCode.AMBIGUOUS_FUNCTION_CALL;
import static io.prestosql.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public class ParametricAggregation
        extends SqlAggregationFunction
{
    private final ParametricImplementationsGroup implementations;

    public ParametricAggregation(
            Signature signature,
            AggregationHeader details,
            ParametricImplementationsGroup implementations,
            boolean deprecated)
    {
        super(
                new FunctionMetadata(
                        signature,
                        true,
                        implementations.getArgumentDefinitions(),
                        details.isHidden(),
                        true,
                        details.getDescription().orElse(""),
                        AGGREGATE,
                        deprecated),
                details.isDecomposable(),
                details.isOrderSensitive());
        requireNonNull(details, "details is null");
        checkArgument(implementations.isNullable(), "currently aggregates are required to be nullable");
        this.implementations = requireNonNull(implementations, "implementations is null");
    }

    @Override
    public FunctionDependencyDeclaration getFunctionDependencies()
    {
        FunctionDependencyDeclarationBuilder builder = FunctionDependencyDeclaration.builder();
        declareDependencies(builder, implementations.getExactImplementations().values());
        declareDependencies(builder, implementations.getSpecializedImplementations());
        declareDependencies(builder, implementations.getGenericImplementations());
        return builder.build();
    }

    private static void declareDependencies(FunctionDependencyDeclarationBuilder builder, Collection implementations)
    {
        for (AggregationImplementation implementation : implementations) {
            for (ImplementationDependency dependency : implementation.getInputDependencies()) {
                dependency.declareDependencies(builder);
            }
            for (ImplementationDependency dependency : implementation.getCombineDependencies()) {
                dependency.declareDependencies(builder);
            }
            for (ImplementationDependency dependency : implementation.getOutputDependencies()) {
                dependency.declareDependencies(builder);
            }
        }
    }

    @Override
    public List getIntermediateTypes(FunctionBinding functionBinding)
    {
        // Find implementation matching arguments
        AggregationImplementation concreteImplementation = findMatchingImplementation(functionBinding.getBoundSignature());

        // Use state compiler to extract intermediate types
        Type serializedType = getSerializedType(concreteImplementation.getStateClass());
        return ImmutableList.of(serializedType.getTypeSignature());
    }

    @Override
    public InternalAggregationFunction specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies)
    {
        // Bind variables
        Signature signature = getFunctionMetadata().getSignature();

        // Find implementation matching arguments
        AggregationImplementation concreteImplementation = findMatchingImplementation(functionBinding.getBoundSignature());

        // Build argument and return Types from signatures
        List inputTypes = functionBinding.getBoundSignature().getArgumentTypes();
        Type outputType = functionBinding.getBoundSignature().getReturnType();

        // Create classloader for additional aggregation dependencies
        Class definitionClass = concreteImplementation.getDefinitionClass();
        DynamicClassLoader classLoader = new DynamicClassLoader(definitionClass.getClassLoader(), getClass().getClassLoader());

        // Build state factory and serializer
        Class stateClass = concreteImplementation.getStateClass();
        AccumulatorStateSerializer stateSerializer = generateStateSerializer(stateClass, classLoader);
        AccumulatorStateFactory stateFactory = StateCompiler.generateStateFactory(stateClass, classLoader);

        // Bind provided dependencies to aggregation method handlers
        MethodHandle inputHandle = bindDependencies(concreteImplementation.getInputFunction(), concreteImplementation.getInputDependencies(), functionBinding, functionDependencies);
        Optional removeInputHandle = concreteImplementation.getRemoveInputFunction().map(
                removeInputFunction -> bindDependencies(removeInputFunction, concreteImplementation.getRemoveInputDependencies(), functionBinding, functionDependencies));
        MethodHandle combineHandle = bindDependencies(concreteImplementation.getCombineFunction(), concreteImplementation.getCombineDependencies(), functionBinding, functionDependencies);
        MethodHandle outputHandle = bindDependencies(concreteImplementation.getOutputFunction(), concreteImplementation.getOutputDependencies(), functionBinding, functionDependencies);

        // Build metadata of input parameters
        List parametersMetadata = buildParameterMetadata(concreteImplementation.getInputParameterMetadataTypes(), inputTypes);

        // Generate Aggregation name
        String aggregationName = generateAggregationName(signature.getName(), outputType.getTypeSignature(), signaturesFromTypes(inputTypes));

        // Collect all collected data in Metadata
        AggregationMetadata aggregationMetadata = new AggregationMetadata(
                aggregationName,
                parametersMetadata,
                inputHandle,
                removeInputHandle,
                combineHandle,
                outputHandle,
                ImmutableList.of(new AccumulatorStateDescriptor(
                        stateClass,
                        stateSerializer,
                        stateFactory)),
                outputType);

        // Create specialized InternalAggregregationFunction for Presto
        return new InternalAggregationFunction(
                signature.getName(),
                inputTypes,
                ImmutableList.of(stateSerializer.getSerializedType()),
                outputType,
                new LazyAccumulatorFactoryBinder(aggregationMetadata, classLoader));
    }

    @VisibleForTesting
    public ParametricImplementationsGroup getImplementations()
    {
        return implementations;
    }

    private AggregationImplementation findMatchingImplementation(BoundSignature boundSignature)
    {
        Signature signature = boundSignature.toSignature();
        Optional foundImplementation = Optional.empty();
        if (implementations.getExactImplementations().containsKey(signature)) {
            foundImplementation = Optional.of(implementations.getExactImplementations().get(signature));
        }
        else {
            for (AggregationImplementation candidate : implementations.getGenericImplementations()) {
                if (candidate.areTypesAssignable(boundSignature)) {
                    if (foundImplementation.isPresent()) {
                        throw new PrestoException(AMBIGUOUS_FUNCTION_CALL, format("Ambiguous function call (%s) for %s", boundSignature, getFunctionMetadata().getSignature()));
                    }
                    foundImplementation = Optional.of(candidate);
                }
            }
        }

        if (foundImplementation.isEmpty()) {
            throw new PrestoException(FUNCTION_IMPLEMENTATION_MISSING, format("Unsupported type parameters (%s) for %s", boundSignature, getFunctionMetadata().getSignature()));
        }
        return foundImplementation.get();
    }

    private static List signaturesFromTypes(List types)
    {
        return types
                .stream()
                .map(Type::getTypeSignature)
                .collect(toImmutableList());
    }

    private static List buildParameterMetadata(List parameterMetadataTypes, List inputTypes)
    {
        ImmutableList.Builder builder = ImmutableList.builder();
        int inputId = 0;

        for (ParameterType parameterMetadataType : parameterMetadataTypes) {
            switch (parameterMetadataType) {
                case STATE:
                case BLOCK_INDEX:
                    builder.add(new ParameterMetadata(parameterMetadataType));
                    break;
                case INPUT_CHANNEL:
                case BLOCK_INPUT_CHANNEL:
                case NULLABLE_BLOCK_INPUT_CHANNEL:
                    builder.add(new ParameterMetadata(parameterMetadataType, inputTypes.get(inputId++)));
                    break;
            }
        }

        return builder.build();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy