All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.prestosql.operator.aggregation.AggregationImplementation Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.prestosql.operator.aggregation;

import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import io.prestosql.metadata.BoundSignature;
import io.prestosql.metadata.FunctionArgumentDefinition;
import io.prestosql.metadata.LongVariableConstraint;
import io.prestosql.metadata.Signature;
import io.prestosql.metadata.TypeVariableConstraint;
import io.prestosql.operator.ParametricImplementation;
import io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType;
import io.prestosql.operator.annotations.FunctionsParserHelper;
import io.prestosql.operator.annotations.ImplementationDependency;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.connector.ConnectorSession;
import io.prestosql.spi.function.AggregationState;
import io.prestosql.spi.function.BlockIndex;
import io.prestosql.spi.function.BlockPosition;
import io.prestosql.spi.function.OutputFunction;
import io.prestosql.spi.function.SqlType;
import io.prestosql.spi.function.TypeParameter;
import io.prestosql.spi.type.TypeSignature;
import io.prestosql.util.Reflection;

import java.lang.annotation.Annotation;
import java.lang.invoke.MethodHandle;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX;
import static io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.NULLABLE_BLOCK_INPUT_CHANNEL;
import static io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.inputChannelParameterType;
import static io.prestosql.operator.annotations.FunctionsParserHelper.containsAnnotation;
import static io.prestosql.operator.annotations.FunctionsParserHelper.createTypeVariableConstraints;
import static io.prestosql.operator.annotations.FunctionsParserHelper.parseLiteralParameters;
import static io.prestosql.operator.annotations.ImplementationDependency.Factory.createDependency;
import static io.prestosql.operator.annotations.ImplementationDependency.getImplementationDependencyAnnotation;
import static io.prestosql.operator.annotations.ImplementationDependency.isImplementationDependencyAnnotation;
import static io.prestosql.operator.annotations.ImplementationDependency.validateImplementationDependencyAnnotation;
import static io.prestosql.sql.analyzer.TypeSignatureTranslator.parseTypeSignature;
import static io.prestosql.util.Reflection.methodHandle;
import static java.util.Objects.requireNonNull;

public class AggregationImplementation
        implements ParametricImplementation
{
    public static class AggregateNativeContainerType
    {
        private final Class javaType;
        private final boolean isBlockPosition;

        public AggregateNativeContainerType(Class javaType, boolean isBlockPosition)
        {
            this.javaType = javaType;
            this.isBlockPosition = isBlockPosition;
        }

        public Class getJavaType()
        {
            return javaType;
        }

        public boolean isBlockPosition()
        {
            return isBlockPosition;
        }
    }

    private final Signature signature;

    private final Class definitionClass;
    private final Class stateClass;
    private final MethodHandle inputFunction;
    private final Optional removeInputFunction;
    private final MethodHandle outputFunction;
    private final MethodHandle combineFunction;
    private final List argumentNativeContainerTypes;
    private final List inputDependencies;
    private final List removeInputDependencies;
    private final List combineDependencies;
    private final List outputDependencies;
    private final List inputParameterMetadataTypes;
    private final ImmutableList argumentDefinitions;

    public AggregationImplementation(
            Signature signature,
            Class definitionClass,
            Class stateClass,
            MethodHandle inputFunction,
            Optional removeInputFunction,
            MethodHandle outputFunction,
            MethodHandle combineFunction,
            List argumentNativeContainerTypes,
            List inputDependencies,
            List removeInputDependencies,
            List combineDependencies,
            List outputDependencies,
            List inputParameterMetadataTypes)
    {
        this.signature = requireNonNull(signature, "signature cannot be null");
        this.definitionClass = requireNonNull(definitionClass, "definition class cannot be null");
        this.stateClass = requireNonNull(stateClass, "stateClass cannot be null");
        this.inputFunction = requireNonNull(inputFunction, "inputFunction cannot be null");
        this.removeInputFunction = requireNonNull(removeInputFunction, "removeInputFunction cannot be null");
        this.outputFunction = requireNonNull(outputFunction, "outputFunction cannot be null");
        this.combineFunction = requireNonNull(combineFunction, "combineFunction cannot be null");
        this.argumentNativeContainerTypes = requireNonNull(argumentNativeContainerTypes, "argumentNativeContainerTypes cannot be null");
        this.inputDependencies = requireNonNull(inputDependencies, "inputDependencies cannot be null");
        this.removeInputDependencies = requireNonNull(removeInputDependencies, "removeInputDependencies cannot be null");
        this.outputDependencies = requireNonNull(outputDependencies, "outputDependencies cannot be null");
        this.combineDependencies = requireNonNull(combineDependencies, "combineDependencies cannot be null");
        this.inputParameterMetadataTypes = requireNonNull(inputParameterMetadataTypes, "inputParameterMetadataTypes cannot be null");
        this.argumentDefinitions = inputParameterMetadataTypes.stream()
                .filter(parameterType -> parameterType != BLOCK_INDEX && parameterType != STATE)
                .map(NULLABLE_BLOCK_INPUT_CHANNEL::equals)
                .map(FunctionArgumentDefinition::new)
                .collect(toImmutableList());
    }

    @Override
    public Signature getSignature()
    {
        return signature;
    }

    @Override
    public boolean hasSpecializedTypeParameters()
    {
        return false;
    }

    @Override
    public final boolean isNullable()
    {
        // for now all aggregation functions are considered nullable
        return true;
    }

    @Override
    public List getArgumentDefinitions()
    {
        return argumentDefinitions;
    }

    public Class getDefinitionClass()
    {
        return definitionClass;
    }

    public Class getStateClass()
    {
        return stateClass;
    }

    public MethodHandle getInputFunction()
    {
        return inputFunction;
    }

    public Optional getRemoveInputFunction()
    {
        return removeInputFunction;
    }

    public MethodHandle getOutputFunction()
    {
        return outputFunction;
    }

    public MethodHandle getCombineFunction()
    {
        return combineFunction;
    }

    public List getInputDependencies()
    {
        return inputDependencies;
    }

    public List getRemoveInputDependencies()
    {
        return removeInputDependencies;
    }

    public List getOutputDependencies()
    {
        return outputDependencies;
    }

    public List getCombineDependencies()
    {
        return combineDependencies;
    }

    public List getInputParameterMetadataTypes()
    {
        return inputParameterMetadataTypes;
    }

    public boolean areTypesAssignable(BoundSignature boundSignature)
    {
        checkState(argumentNativeContainerTypes.size() == boundSignature.getArgumentTypes().size(), "Number of argument assigned to AggregationImplementation is different than number parsed from annotations.");

        // TODO specialized functions variants support is missing here
        for (int i = 0; i < boundSignature.getArgumentTypes().size(); i++) {
            Class argumentType = boundSignature.getArgumentTypes().get(i).getJavaType();
            Class methodDeclaredType = argumentNativeContainerTypes.get(i).getJavaType();
            boolean isCurrentBlockPosition = argumentNativeContainerTypes.get(i).isBlockPosition();

            if (isCurrentBlockPosition && methodDeclaredType.isAssignableFrom(Block.class)) {
                continue;
            }
            if (!isCurrentBlockPosition && methodDeclaredType.isAssignableFrom(argumentType)) {
                continue;
            }
            return false;
        }

        return true;
    }

    public static final class Parser
    {
        private final Class aggregationDefinition;
        private final Class stateClass;
        private final MethodHandle inputHandle;
        private final Optional removeInputHandle;
        private final MethodHandle outputHandle;
        private final MethodHandle combineHandle;
        private final List argumentNativeContainerTypes;
        private final List inputDependencies;
        private final List removeInputDependencies;
        private final List combineDependencies;
        private final List outputDependencies;
        private final List parameterMetadataTypes;

        private final List longVariableConstraints;
        private final List typeVariableConstraints;
        private final List inputTypes;
        private final TypeSignature returnType;

        private final AggregationHeader header;
        private final Set literalParameters;
        private final List typeParameters;

        private Parser(
                Class aggregationDefinition,
                AggregationHeader header,
                Class stateClass,
                Method inputFunction,
                Optional removeInputFunction,
                Method outputFunction,
                Method combineFunction)
        {
            // rewrite data passed directly
            this.aggregationDefinition = aggregationDefinition;
            this.header = header;
            this.stateClass = stateClass;

            // parse declared literal and type parameters
            // it is required to declare all literal and type parameters in input function
            literalParameters = parseLiteralParameters(inputFunction);
            typeParameters = Arrays.asList(inputFunction.getAnnotationsByType(TypeParameter.class));

            // parse dependencies
            inputDependencies = parseImplementationDependencies(inputFunction);
            removeInputDependencies = removeInputFunction.map(this::parseImplementationDependencies).orElse(ImmutableList.of());
            outputDependencies = parseImplementationDependencies(outputFunction);
            combineDependencies = parseImplementationDependencies(combineFunction);

            // parse metadata types
            parameterMetadataTypes = parseParameterMetadataTypes(inputFunction);

            // parse constraints
            longVariableConstraints = FunctionsParserHelper.parseLongVariableConstraints(inputFunction);
            List allDependencies =
                    Stream.of(
                            inputDependencies.stream(),
                            removeInputDependencies.stream(),
                            outputDependencies.stream(),
                            combineDependencies.stream())
                            .reduce(Stream::concat)
                            .orElseGet(Stream::empty)
                            .collect(toImmutableList());
            typeVariableConstraints = createTypeVariableConstraints(typeParameters, allDependencies);

            // parse native types of arguments
            argumentNativeContainerTypes = parseSignatureArgumentsTypes(inputFunction);

            // determine TypeSignatures of function declaration
            inputTypes = getInputTypesSignatures(inputFunction);
            returnType = parseTypeSignature(outputFunction.getAnnotation(OutputFunction.class).value(), literalParameters);

            inputHandle = methodHandle(inputFunction);
            removeInputHandle = removeInputFunction.map(Reflection::methodHandle);
            combineHandle = methodHandle(combineFunction);
            outputHandle = methodHandle(outputFunction);
        }

        private AggregationImplementation get()
        {
            Signature signature = new Signature(
                    header.getName(),
                    typeVariableConstraints,
                    longVariableConstraints,
                    returnType,
                    inputTypes,
                    false);

            return new AggregationImplementation(signature,
                    aggregationDefinition,
                    stateClass,
                    inputHandle,
                    removeInputHandle,
                    outputHandle,
                    combineHandle,
                    argumentNativeContainerTypes,
                    inputDependencies,
                    removeInputDependencies,
                    combineDependencies,
                    outputDependencies,
                    parameterMetadataTypes);
        }

        public static AggregationImplementation parseImplementation(
                Class aggregationDefinition,
                AggregationHeader header,
                Class stateClass,
                Method inputFunction,
                Optional removeInputFunction,
                Method outputFunction,
                Method combineFunction)
        {
            return new Parser(aggregationDefinition, header, stateClass, inputFunction, removeInputFunction, outputFunction, combineFunction).get();
        }

        private static List parseParameterMetadataTypes(Method method)
        {
            ImmutableList.Builder builder = ImmutableList.builder();

            Annotation[][] annotations = method.getParameterAnnotations();
            String methodName = method.getDeclaringClass() + "." + method.getName();

            checkArgument(method.getParameterCount() > 0, "At least @AggregationState argument is required for each of aggregation functions.");

            int i = 0;
            if (annotations[0].length == 0) {
                // Backward compatibility - first argument without annotations is interpreted as State argument
                builder.add(STATE);
                i++;
            }

            for (; i < annotations.length; ++i) {
                Annotation baseTypeAnnotation = baseTypeAnnotation(annotations[i], methodName);
                if (isImplementationDependencyAnnotation(baseTypeAnnotation)) {
                    // Implementation dependencies are bound in specializing phase.
                    // For that reason there are omitted in parameter metadata, as they
                    // are no longer visible while processing aggregations.
                }
                else if (baseTypeAnnotation instanceof AggregationState) {
                    builder.add(STATE);
                }
                else if (baseTypeAnnotation instanceof SqlType) {
                    boolean isParameterBlock = isParameterBlock(annotations[i]);
                    boolean isParameterNullable = isParameterNullable(annotations[i]);
                    builder.add(inputChannelParameterType(isParameterNullable, isParameterBlock, methodName));
                }
                else if (baseTypeAnnotation instanceof BlockIndex) {
                    builder.add(BLOCK_INDEX);
                }
                else {
                    throw new VerifyException("Unhandled annotation: " + baseTypeAnnotation);
                }
            }
            return builder.build();
        }

        private static Annotation baseTypeAnnotation(Annotation[] annotations, String methodName)
        {
            List baseTypes = Arrays.asList(annotations).stream()
                    .filter(annotation -> isAggregationMetaAnnotation(annotation) || annotation instanceof SqlType)
                    .collect(toImmutableList());

            checkArgument(baseTypes.size() == 1, "Parameter of %s must have exactly one of @SqlType, @BlockIndex", methodName);

            boolean nullable = isParameterNullable(annotations);
            boolean isBlock = isParameterBlock(annotations);

            Annotation annotation = baseTypes.get(0);
            checkArgument((!isBlock && !nullable) || (annotation instanceof SqlType),
                    "%s contains a parameter with @BlockPosition and/or @NullablePosition that is not @SqlType", methodName);

            return annotation;
        }

        public static List parseSignatureArgumentsTypes(Method inputFunction)
        {
            ImmutableList.Builder builder = ImmutableList.builder();

            for (int i = 0; i < inputFunction.getParameterCount(); i++) {
                Class parameterType = inputFunction.getParameterTypes()[i];
                Annotation[] annotations = inputFunction.getParameterAnnotations()[i];

                // Skip injected parameters
                if (parameterType == ConnectorSession.class) {
                    continue;
                }

                if (containsAnnotation(annotations, Parser::isAggregationMetaAnnotation)) {
                    continue;
                }

                builder.add(new AggregateNativeContainerType(inputFunction.getParameterTypes()[i], isParameterBlock(annotations)));
            }

            return builder.build();
        }

        public List parseImplementationDependencies(Method inputFunction)
        {
            ImmutableList.Builder builder = ImmutableList.builder();

            for (Parameter parameter : inputFunction.getParameters()) {
                Class parameterType = parameter.getType();

                // Skip injected parameters
                if (parameterType == ConnectorSession.class) {
                    continue;
                }

                getImplementationDependencyAnnotation(parameter).ifPresent(annotation -> {
                    // check if only declared typeParameters and literalParameters are used
                    validateImplementationDependencyAnnotation(
                            inputFunction,
                            annotation,
                            typeParameters.stream()
                                    .map(TypeParameter::value)
                                    .collect(toImmutableSet()),
                            literalParameters);
                    builder.add(createDependency(annotation, literalParameters, parameter.getType()));
                });
            }
            return builder.build();
        }

        public static boolean isParameterNullable(Annotation[] annotations)
        {
            return containsAnnotation(annotations, annotation -> annotation instanceof NullablePosition);
        }

        public static boolean isParameterBlock(Annotation[] annotations)
        {
            return containsAnnotation(annotations, annotation -> annotation instanceof BlockPosition);
        }

        public List getInputTypesSignatures(Method inputFunction)
        {
            ImmutableList.Builder builder = ImmutableList.builder();

            Annotation[][] parameterAnnotations = inputFunction.getParameterAnnotations();
            for (Annotation[] annotations : parameterAnnotations) {
                for (Annotation annotation : annotations) {
                    if (annotation instanceof SqlType) {
                        String typeName = ((SqlType) annotation).value();
                        builder.add(parseTypeSignature(typeName, literalParameters));
                    }
                }
            }

            return builder.build();
        }

        public static Class findAggregationStateParamType(Method inputFunction)
        {
            return inputFunction.getParameterTypes()[findAggregationStateParamId(inputFunction)];
        }

        public static int findAggregationStateParamId(Method method)
        {
            return findAggregationStateParamId(method, 0);
        }

        public static int findAggregationStateParamId(Method method, int id)
        {
            int currentParamId = 0;
            int found = 0;
            for (Annotation[] annotations : method.getParameterAnnotations()) {
                for (Annotation annotation : annotations) {
                    if (annotation instanceof AggregationState) {
                        if (found++ == id) {
                            return currentParamId;
                        }
                    }
                }
                currentParamId++;
            }

            // backward compatibility @AggregationState annotation didn't exists before
            // some third party aggregates may assume that State will be id-th parameter
            return id;
        }

        private static boolean isAggregationMetaAnnotation(Annotation annotation)
        {
            return annotation instanceof BlockIndex || annotation instanceof AggregationState || isImplementationDependencyAnnotation(annotation);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy