All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.operator.aggregation.ParametricAggregationImplementation Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.operator.aggregation;

import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import io.trino.operator.ParametricImplementation;
import io.trino.operator.aggregation.AggregationFromAnnotationsParser.AccumulatorStateDetails;
import io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind;
import io.trino.operator.annotations.ImplementationDependency;
import io.trino.spi.block.ValueBlock;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.BlockIndex;
import io.trino.spi.function.BlockPosition;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionNullability;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.Signature;
import io.trino.spi.function.SqlNullable;
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.function.WindowAccumulator;
import io.trino.spi.type.TypeSignature;
import io.trino.util.Reflection;

import java.lang.annotation.Annotation;
import java.lang.invoke.MethodHandle;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INDEX;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INPUT_CHANNEL;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.NULLABLE_BLOCK_INPUT_CHANNEL;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE;
import static io.trino.operator.annotations.FunctionsParserHelper.containsAnnotation;
import static io.trino.operator.annotations.FunctionsParserHelper.createTypeVariableConstraints;
import static io.trino.operator.annotations.FunctionsParserHelper.parseLiteralParameters;
import static io.trino.operator.annotations.FunctionsParserHelper.parseLongVariableConstraints;
import static io.trino.operator.annotations.ImplementationDependency.Factory.createDependency;
import static io.trino.operator.annotations.ImplementationDependency.getImplementationDependencyAnnotation;
import static io.trino.operator.annotations.ImplementationDependency.isImplementationDependencyAnnotation;
import static io.trino.operator.annotations.ImplementationDependency.validateImplementationDependencyAnnotation;
import static io.trino.sql.analyzer.TypeSignatureTranslator.parseTypeSignature;
import static io.trino.util.Reflection.methodHandle;
import static java.util.Objects.requireNonNull;

public class ParametricAggregationImplementation
        implements ParametricImplementation
{
    public record AggregateNativeContainerType(Class javaType, boolean isBlockPosition) {}

    private final Signature signature;

    private final Class definitionClass;
    private final MethodHandle inputFunction;
    private final MethodHandle outputFunction;
    private final Optional combineFunction;
    private final Optional> windowAccumulator;
    private final List argumentNativeContainerTypes;
    private final List inputDependencies;
    private final List combineDependencies;
    private final List outputDependencies;
    private final List inputParameterKinds;
    private final FunctionNullability functionNullability;

    private ParametricAggregationImplementation(
            Signature signature,
            Class definitionClass,
            MethodHandle inputFunction,
            MethodHandle outputFunction,
            Optional combineFunction,
            Optional> windowAccumulator,
            List argumentNativeContainerTypes,
            List inputDependencies,
            List combineDependencies,
            List outputDependencies,
            List inputParameterKinds)
    {
        this.signature = requireNonNull(signature, "signature cannot be null");
        this.definitionClass = requireNonNull(definitionClass, "definition class cannot be null");
        this.inputFunction = requireNonNull(inputFunction, "inputFunction cannot be null");
        this.outputFunction = requireNonNull(outputFunction, "outputFunction cannot be null");
        this.combineFunction = requireNonNull(combineFunction, "combineFunction cannot be null");
        this.windowAccumulator = requireNonNull(windowAccumulator, "windowAccumulator cannot be null");
        this.argumentNativeContainerTypes = requireNonNull(argumentNativeContainerTypes, "argumentNativeContainerTypes cannot be null");
        this.inputDependencies = requireNonNull(inputDependencies, "inputDependencies cannot be null");
        this.outputDependencies = requireNonNull(outputDependencies, "outputDependencies cannot be null");
        this.combineDependencies = requireNonNull(combineDependencies, "combineDependencies cannot be null");
        this.inputParameterKinds = requireNonNull(inputParameterKinds, "inputParameterKinds cannot be null");
        this.functionNullability = new FunctionNullability(
                true,
                inputParameterKinds.stream()
                        .filter(parameterType -> parameterType != BLOCK_INDEX && parameterType != STATE)
                        .map(NULLABLE_BLOCK_INPUT_CHANNEL::equals)
                        .collect(toImmutableList()));
    }

    @Override
    public Signature getSignature()
    {
        return signature;
    }

    @Override
    public boolean hasSpecializedTypeParameters()
    {
        return false;
    }

    @Override
    public FunctionNullability getFunctionNullability()
    {
        return functionNullability;
    }

    public Class getDefinitionClass()
    {
        return definitionClass;
    }

    public MethodHandle getInputFunction()
    {
        return inputFunction;
    }

    public MethodHandle getOutputFunction()
    {
        return outputFunction;
    }

    public Optional getCombineFunction()
    {
        return combineFunction;
    }

    public Optional> getWindowAccumulator()
    {
        return windowAccumulator;
    }

    public List getInputDependencies()
    {
        return inputDependencies;
    }

    public List getOutputDependencies()
    {
        return outputDependencies;
    }

    public List getCombineDependencies()
    {
        return combineDependencies;
    }

    public List getInputParameterKinds()
    {
        return inputParameterKinds;
    }

    public boolean areTypesAssignable(BoundSignature boundSignature)
    {
        checkState(argumentNativeContainerTypes.size() == boundSignature.getArgumentTypes().size(), "Number of argument assigned to AggregationImplementation is different than number parsed from annotations.");

        // TODO specialized functions variants support is missing here
        for (int i = 0; i < boundSignature.getArgumentTypes().size(); i++) {
            Class argumentType = boundSignature.getArgumentTypes().get(i).getJavaType();
            Class methodDeclaredType = argumentNativeContainerTypes.get(i).javaType();
            boolean isCurrentBlockPosition = argumentNativeContainerTypes.get(i).isBlockPosition();

            // block and position works for any type, but if block is annotated with SqlType nativeContainerType, then only types with the
            // specified container type match
            if (isCurrentBlockPosition && ValueBlock.class.isAssignableFrom(methodDeclaredType)) {
                continue;
            }
            if (methodDeclaredType.isAssignableFrom(argumentType)) {
                continue;
            }
            return false;
        }

        return true;
    }

    public static final class Parser
    {
        private final Class aggregationDefinition;
        private final MethodHandle inputHandle;
        private final MethodHandle outputHandle;
        private final Optional combineHandle;
        private final Optional> windowAccumulator;

        private final List argumentNativeContainerTypes;
        private final List inputDependencies;
        private final List combineDependencies;
        private final List outputDependencies;
        private final List inputParameterKinds;

        private final Signature.Builder signatureBuilder = Signature.builder();

        private final Set literalParameters;
        private final List typeParameters;

        private Parser(
                Class aggregationDefinition,
                List> stateDetails,
                Method inputFunction,
                Method outputFunction,
                Optional combineFunction,
                Optional> windowAccumulator)
        {
            // rewrite data passed directly
            this.aggregationDefinition = aggregationDefinition;

            // parse declared literal and type parameters
            // it is required to declare all literal and type parameters in input function
            literalParameters = parseLiteralParameters(inputFunction);
            typeParameters = Arrays.asList(inputFunction.getAnnotationsByType(TypeParameter.class));

            // parse dependencies
            inputDependencies = parseImplementationDependencies(inputFunction);
            outputDependencies = parseImplementationDependencies(outputFunction);
            combineDependencies = combineFunction.map(this::parseImplementationDependencies).orElse(ImmutableList.of());

            // parse input parameters
            inputParameterKinds = parseInputParameterKinds(inputFunction);

            // parse constraints
            parseLongVariableConstraints(inputFunction, signatureBuilder);
            List allDependencies =
                    Stream.of(
                            stateDetails.stream().map(AccumulatorStateDetails::getDependencies).flatMap(Collection::stream),
                            inputDependencies.stream(),
                            outputDependencies.stream(),
                            combineDependencies.stream())
                            .reduce(Stream::concat)
                            .orElseGet(Stream::empty)
                            .collect(toImmutableList());
            createTypeVariableConstraints(typeParameters, allDependencies)
                    .forEach(signatureBuilder::typeVariableConstraint);

            // parse native types of arguments
            argumentNativeContainerTypes = parseSignatureArgumentsTypes(inputFunction);

            // determine TypeSignatures of function declaration
            signatureBuilder.argumentTypes(getInputTypesSignatures(inputFunction));
            signatureBuilder.returnType(parseTypeSignature(outputFunction.getAnnotation(OutputFunction.class).value(), literalParameters));

            inputHandle = methodHandle(inputFunction);
            combineHandle = combineFunction.map(Reflection::methodHandle);
            outputHandle = methodHandle(outputFunction);

            this.windowAccumulator = windowAccumulator;
        }

        private ParametricAggregationImplementation get()
        {
            return new ParametricAggregationImplementation(
                    signatureBuilder.build(),
                    aggregationDefinition,
                    inputHandle,
                    outputHandle,
                    combineHandle,
                    windowAccumulator,
                    argumentNativeContainerTypes,
                    inputDependencies,
                    combineDependencies,
                    outputDependencies,
                    inputParameterKinds);
        }

        public static ParametricAggregationImplementation parseImplementation(
                Class aggregationDefinition,
                List> stateDetails,
                Method inputFunction,
                Method outputFunction,
                Optional combineFunction,
                Optional> windowAccumulator)
        {
            return new Parser(aggregationDefinition, stateDetails, inputFunction, outputFunction, combineFunction, windowAccumulator).get();
        }

        private static List parseInputParameterKinds(Method method)
        {
            ImmutableList.Builder builder = ImmutableList.builder();

            Annotation[][] annotations = method.getParameterAnnotations();
            String methodName = method.getDeclaringClass() + "." + method.getName();

            checkArgument(method.getParameterCount() > 0, "At least @AggregationState argument is required for each of aggregation functions.");

            int i = 0;
            if (annotations[0].length == 0) {
                // Backward compatibility - first argument without annotations is interpreted as State argument
                builder.add(STATE);
                i++;
            }

            for (; i < annotations.length; ++i) {
                Annotation baseTypeAnnotation = baseTypeAnnotation(annotations[i], methodName);
                if (isImplementationDependencyAnnotation(baseTypeAnnotation)) {
                    // Implementation dependencies are bound in specializing phase.
                    // For that reason there are omitted in parameter metadata, as they
                    // are no longer visible while processing aggregations.
                }
                else if (baseTypeAnnotation instanceof AggregationState) {
                    builder.add(STATE);
                }
                else if (baseTypeAnnotation instanceof SqlType) {
                    boolean isParameterBlock = isParameterBlock(annotations[i]);
                    boolean isParameterNullable = isParameterNullable(annotations[i]);
                    builder.add(getInputParameterKind(isParameterNullable, isParameterBlock, methodName));
                }
                else if (baseTypeAnnotation instanceof BlockIndex) {
                    builder.add(BLOCK_INDEX);
                }
                else {
                    throw new VerifyException("Unhandled annotation: " + baseTypeAnnotation);
                }
            }
            return builder.build();
        }

        static AggregationParameterKind getInputParameterKind(boolean isNullable, boolean isBlock, String methodName)
        {
            if (isBlock) {
                if (isNullable) {
                    return NULLABLE_BLOCK_INPUT_CHANNEL;
                }
                return BLOCK_INPUT_CHANNEL;
            }
            if (isNullable) {
                throw new IllegalArgumentException(methodName + " contains a parameter with @NullablePosition that is not @BlockPosition");
            }
            return INPUT_CHANNEL;
        }

        private static Annotation baseTypeAnnotation(Annotation[] annotations, String methodName)
        {
            List baseTypes = Arrays.stream(annotations)
                    .filter(annotation -> isAggregationMetaAnnotation(annotation) || annotation instanceof SqlType)
                    .collect(toImmutableList());

            checkArgument(baseTypes.size() == 1, "Parameter of %s must have exactly one of @SqlType, @BlockIndex", methodName);

            boolean nullable = isParameterNullable(annotations);
            boolean isBlock = isParameterBlock(annotations);

            Annotation annotation = baseTypes.getFirst();
            checkArgument((!isBlock && !nullable) || (annotation instanceof SqlType),
                    "%s contains a parameter with @BlockPosition and/or @NullablePosition that is not @SqlType", methodName);

            return annotation;
        }

        public static List parseSignatureArgumentsTypes(Method inputFunction)
        {
            ImmutableList.Builder builder = ImmutableList.builder();

            for (int i = 0; i < inputFunction.getParameterCount(); i++) {
                Class parameterType = inputFunction.getParameterTypes()[i];
                Annotation[] annotations = inputFunction.getParameterAnnotations()[i];

                // Skip injected parameters
                if (parameterType == ConnectorSession.class) {
                    continue;
                }

                if (containsAnnotation(annotations, Parser::isAggregationMetaAnnotation)) {
                    continue;
                }

                Optional> nativeContainerType = Arrays.stream(annotations)
                        .filter(SqlType.class::isInstance)
                        .map(SqlType.class::cast)
                        .findFirst()
                        .map(SqlType::nativeContainerType);
                // Note: this cannot be done as a chain due to strange generic type mismatches
                if (nativeContainerType.isPresent() && !nativeContainerType.get().equals(Object.class)) {
                    parameterType = nativeContainerType.get();
                }
                builder.add(new AggregateNativeContainerType(parameterType, isParameterBlock(annotations)));
            }

            return builder.build();
        }

        public List parseImplementationDependencies(Method inputFunction)
        {
            ImmutableList.Builder builder = ImmutableList.builder();

            for (Parameter parameter : inputFunction.getParameters()) {
                Class parameterType = parameter.getType();

                // Skip injected parameters
                if (parameterType == ConnectorSession.class) {
                    continue;
                }

                getImplementationDependencyAnnotation(parameter).ifPresent(annotation -> {
                    // check if only declared typeParameters and literalParameters are used
                    validateImplementationDependencyAnnotation(
                            inputFunction,
                            annotation,
                            typeParameters.stream()
                                    .map(TypeParameter::value)
                                    .collect(toImmutableSet()),
                            literalParameters);
                    builder.add(createDependency(annotation, literalParameters, parameter.getType()));
                });
            }
            return builder.build();
        }

        public static boolean isParameterNullable(Annotation[] annotations)
        {
            return containsAnnotation(annotations, annotation -> annotation instanceof SqlNullable);
        }

        public static boolean isParameterBlock(Annotation[] annotations)
        {
            return containsAnnotation(annotations, annotation -> annotation instanceof BlockPosition);
        }

        public List getInputTypesSignatures(Method inputFunction)
        {
            ImmutableList.Builder builder = ImmutableList.builder();

            Annotation[][] parameterAnnotations = inputFunction.getParameterAnnotations();
            for (Annotation[] annotations : parameterAnnotations) {
                for (Annotation annotation : annotations) {
                    if (annotation instanceof SqlType) {
                        String typeName = ((SqlType) annotation).value();
                        builder.add(parseTypeSignature(typeName, literalParameters));
                    }
                }
            }

            return builder.build();
        }

        private static boolean isAggregationMetaAnnotation(Annotation annotation)
        {
            return annotation instanceof BlockIndex || annotation instanceof AggregationState || isImplementationDependencyAnnotation(annotation);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy