All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.facebook.presto.jdbc.internal.spi.function.aggregation.AggregationMetadata Maven / Gradle / Ivy

The newest version!
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.jdbc.internal.spi.function.aggregation;

import com.facebook.presto.jdbc.internal.common.block.Block;
import com.facebook.presto.jdbc.internal.common.block.BlockBuilder;
import com.facebook.presto.jdbc.internal.common.type.Type;
import com.facebook.presto.jdbc.internal.spi.function.AccumulatorStateFactory;
import com.facebook.presto.jdbc.internal.spi.function.AccumulatorStateSerializer;
import com.facebook.presto.jdbc.internal.io.airlift.slice.Slice;

import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import static com.facebook.presto.jdbc.internal.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX;
import static com.facebook.presto.jdbc.internal.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL;
import static com.facebook.presto.jdbc.internal.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL;
import static com.facebook.presto.jdbc.internal.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.NULLABLE_BLOCK_INPUT_CHANNEL;
import static com.facebook.presto.jdbc.internal.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static com.facebook.presto.jdbc.internal.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.inputChannelParameterType;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public class AggregationMetadata
{
    public static final Set> SUPPORTED_PARAMETER_TYPES = new HashSet<>(Arrays.asList(Block.class, long.class, double.class, boolean.class, Slice.class));

    private final String name;
    private final List valueInputMetadata;
    private final List lambdaInterfaces;
    private final MethodHandle inputFunction;
    private final MethodHandle combineFunction;
    private final MethodHandle outputFunction;
    private final List accumulatorStateDescriptors;
    private final Type outputType;

    public AggregationMetadata(
            String name,
            List valueInputMetadata,
            MethodHandle inputFunction,
            MethodHandle combineFunction,
            MethodHandle outputFunction,
            List accumulatorStateDescriptors,
            Type outputType)
    {
        this(
                name,
                valueInputMetadata,
                inputFunction,
                combineFunction,
                outputFunction,
                accumulatorStateDescriptors,
                outputType,
                Collections.emptyList());
    }

    public AggregationMetadata(
            String name,
            List valueInputMetadata,
            MethodHandle inputFunction,
            MethodHandle combineFunction,
            MethodHandle outputFunction,
            List accumulatorStateDescriptors,
            Type outputType,
            List lambdaInterfaces)
    {
        this.outputType = requireNonNull(outputType);
        this.valueInputMetadata = Collections.unmodifiableList(new ArrayList<>(requireNonNull(valueInputMetadata, "valueInputMetadata is null")));
        this.name = requireNonNull(name, "name is null");
        this.inputFunction = requireNonNull(inputFunction, "inputFunction is null");
        this.combineFunction = requireNonNull(combineFunction, "combineFunction is null");
        this.outputFunction = requireNonNull(outputFunction, "outputFunction is null");
        this.accumulatorStateDescriptors = requireNonNull(accumulatorStateDescriptors, "accumulatorStateDescriptors is null");
        this.lambdaInterfaces = Collections.unmodifiableList(new ArrayList<>(requireNonNull(lambdaInterfaces, "lambdaInterfaces is null")));

        verifyInputFunctionSignature(inputFunction, valueInputMetadata, lambdaInterfaces, accumulatorStateDescriptors);
        verifyCombineFunction(combineFunction, lambdaInterfaces, accumulatorStateDescriptors);
        verifyExactOutputFunction(outputFunction, accumulatorStateDescriptors);
    }

    public Type getOutputType()
    {
        return outputType;
    }

    public List getValueInputMetadata()
    {
        return valueInputMetadata;
    }

    public List getLambdaInterfaces()
    {
        return lambdaInterfaces;
    }

    public String getName()
    {
        return name;
    }

    public MethodHandle getInputFunction()
    {
        return inputFunction;
    }

    public MethodHandle getCombineFunction()
    {
        return combineFunction;
    }

    public MethodHandle getOutputFunction()
    {
        return outputFunction;
    }

    public List getAccumulatorStateDescriptors()
    {
        return accumulatorStateDescriptors;
    }

    private static void verifyInputFunctionSignature(MethodHandle method, List dataChannelMetadata, List lambdaInterfaces, List stateDescriptors)
    {
        Class[] parameters = method.type().parameterArray();
        checkArgument(parameters.length > 0, "Aggregation input function must have at least one parameter");
        checkArgument(
                parameters.length == dataChannelMetadata.size() + lambdaInterfaces.size(),
                format("Number of parameters in input function (%d) must be the total number of data channels and lambda channels (%d)",
                        parameters.length,
                        dataChannelMetadata.size() + lambdaInterfaces.size()));
        checkArgument(dataChannelMetadata.stream().filter(m -> m.getParameterType() == STATE).count() == stateDescriptors.size(), "Number of state parameter in input function must be the same as size of stateDescriptors");
        checkArgument(dataChannelMetadata.get(0).getParameterType() == STATE, "First parameter must be state");

        // verify data channels
        int stateIndex = 0;
        for (int i = 0; i < dataChannelMetadata.size(); i++) {
            ParameterMetadata metadata = dataChannelMetadata.get(i);
            switch (metadata.getParameterType()) {
                case STATE:
                    checkArgument(stateDescriptors.get(stateIndex).getStateInterface() == parameters[i], format("State argument must be of type %s, found %s", stateDescriptors.get(stateIndex).getStateInterface(), parameters[i]));
                    stateIndex++;
                    break;
                case BLOCK_INPUT_CHANNEL:
                case NULLABLE_BLOCK_INPUT_CHANNEL:
                    checkArgument(parameters[i] == Block.class, "Parameter must be Block if it has @BlockPosition");
                    break;
                case INPUT_CHANNEL:
                    checkArgument(SUPPORTED_PARAMETER_TYPES.contains(parameters[i]), "Unsupported type: %s", parameters[i].getSimpleName());
                    verifyMethodParameterType(method, i, metadata.getSqlType().getJavaType(), metadata.getSqlType().getDisplayName());
                    break;
                case BLOCK_INDEX:
                    checkArgument(parameters[i] == int.class, "Block index parameter must be an int");
                    break;
                default:
                    throw new IllegalArgumentException("Unsupported parameter: " + metadata.getParameterType());
            }
        }
        checkArgument(stateIndex == stateDescriptors.size(), format("Input function only has %d states, expected: %d.", stateIndex, stateDescriptors.size()));

        // verify lambda channels
        for (int i = 0; i < lambdaInterfaces.size(); i++) {
            verifyMethodParameterType(method, i + dataChannelMetadata.size(), lambdaInterfaces.get(i), "function");
        }
    }

    private static void verifyCombineFunction(MethodHandle method, List lambdaInterfaces, List stateDescriptors)
    {
        Class[] parameterTypes = method.type().parameterArray();
        checkArgument(
                parameterTypes.length == stateDescriptors.size() * 2 + lambdaInterfaces.size(),
                "Number of arguments for combine function must be 2 times the size of states plus number of lambda channels.");

        for (int i = 0; i < stateDescriptors.size() * 2; i++) {
            checkArgument(
                    parameterTypes[i].equals(stateDescriptors.get(i % stateDescriptors.size()).getStateInterface()),
                    format("Type for Parameter index %d is unexpected. Arguments for combine function must appear in the order of state1, state2, ..., otherState1, otherState2, ...", i));
        }

        for (int i = 0; i < lambdaInterfaces.size(); i++) {
            verifyMethodParameterType(method, i + stateDescriptors.size() * 2, lambdaInterfaces.get(i), "function");
        }
    }

    private static void verifyExactOutputFunction(MethodHandle method, List stateDescriptors)
    {
        if (method == null) {
            return;
        }
        Class[] parameterTypes = method.type().parameterArray();
        checkArgument(parameterTypes.length == stateDescriptors.size() + 1, "Number of arguments for output function must be exactly one plus than number of states.");
        for (int i = 0; i < stateDescriptors.size(); i++) {
            checkArgument(parameterTypes[i].equals(stateDescriptors.get(i).getStateInterface()), format("Type for Parameter index %d is unexpected", i));
        }
        checkArgument(Arrays.stream(parameterTypes).filter(type -> type.equals(BlockBuilder.class)).count() == 1, "Output function must take exactly one BlockBuilder parameter");
    }

    private static void verifyMethodParameterType(MethodHandle method, int index, Class javaType, String sqlTypeDisplayName)
    {
        checkArgument(method.type().parameterArray()[index] == javaType,
                "Expected method %s parameter %s type to be %s (%s)",
                method,
                index,
                javaType.getName(),
                sqlTypeDisplayName);
    }

    public static int countInputChannels(List metadatas)
    {
        int parameters = 0;
        for (ParameterMetadata metadata : metadatas) {
            if (metadata.getParameterType() == INPUT_CHANNEL ||
                    metadata.getParameterType() == BLOCK_INPUT_CHANNEL ||
                    metadata.getParameterType() == NULLABLE_BLOCK_INPUT_CHANNEL) {
                parameters++;
            }
        }

        return parameters;
    }

    public static class ParameterMetadata
    {
        private final ParameterType parameterType;
        private final Type sqlType;

        public ParameterMetadata(ParameterType parameterType)
        {
            this(parameterType, null);
        }

        public ParameterMetadata(ParameterType parameterType, Type sqlType)
        {
            checkArgument((sqlType == null) == (parameterType == BLOCK_INDEX || parameterType == STATE),
                    "sqlType must be provided only for input channels");
            this.parameterType = parameterType;
            this.sqlType = sqlType;
        }

        public static ParameterMetadata fromSqlType(Type sqlType, boolean isBlock, boolean isNullable, String methodName)
        {
            return new ParameterMetadata(inputChannelParameterType(isNullable, isBlock, methodName), sqlType);
        }

        public static ParameterMetadata forBlockIndexParameter()
        {
            return new ParameterMetadata(BLOCK_INDEX);
        }

        public static ParameterMetadata forStateParameter()
        {
            return new ParameterMetadata(STATE);
        }

        public ParameterType getParameterType()
        {
            return parameterType;
        }

        public Type getSqlType()
        {
            return sqlType;
        }

        public enum ParameterType
        {
            INPUT_CHANNEL,
            BLOCK_INPUT_CHANNEL,
            NULLABLE_BLOCK_INPUT_CHANNEL,
            BLOCK_INDEX,
            STATE;

            public static ParameterType inputChannelParameterType(boolean isNullable, boolean isBlock, String methodName)
            {
                if (isBlock) {
                    if (isNullable) {
                        return NULLABLE_BLOCK_INPUT_CHANNEL;
                    }
                    else {
                        return BLOCK_INPUT_CHANNEL;
                    }
                }
                else {
                    if (isNullable) {
                        throw new IllegalArgumentException(methodName + " contains a parameter with @NullablePosition that is not @BlockPosition");
                    }
                    else {
                        return INPUT_CHANNEL;
                    }
                }
            }
        }
    }

    public static class AccumulatorStateDescriptor
    {
        private final Class stateInterface;
        private final AccumulatorStateSerializer serializer;
        private final AccumulatorStateFactory factory;

        public AccumulatorStateDescriptor(Class stateInterface, AccumulatorStateSerializer serializer, AccumulatorStateFactory factory)
        {
            this.stateInterface = requireNonNull(stateInterface, "stateInterface is null");
            this.serializer = requireNonNull(serializer, "serializer is null");
            this.factory = requireNonNull(factory, "factory is null");
        }

        public Class getStateInterface()
        {
            return stateInterface;
        }

        public AccumulatorStateSerializer getSerializer()
        {
            return serializer;
        }

        public AccumulatorStateFactory getFactory()
        {
            return factory;
        }
    }

    private static void checkArgument(boolean condition, String message, Object... messageArgs)
    {
        if (!condition) {
            throw new IllegalArgumentException(format(message, messageArgs));
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy