All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.spi.function.AggregationImplementation Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.spi.function;

import io.trino.spi.Experimental;

import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

import static java.util.Objects.requireNonNull;

@Experimental(eta = "2022-10-31")
public class AggregationImplementation
{
    private final MethodHandle inputFunction;
    private final Optional removeInputFunction;
    private final Optional combineFunction;
    private final MethodHandle outputFunction;
    private final List> accumulatorStateDescriptors;
    private final List> lambdaInterfaces;

    private AggregationImplementation(
            MethodHandle inputFunction,
            Optional removeInputFunction,
            Optional combineFunction,
            MethodHandle outputFunction,
            List> accumulatorStateDescriptors,
            List> lambdaInterfaces)
    {
        this.inputFunction = requireNonNull(inputFunction, "inputFunction is null");
        this.removeInputFunction = requireNonNull(removeInputFunction, "removeInputFunction is null");
        this.combineFunction = requireNonNull(combineFunction, "combineFunction is null");
        this.outputFunction = requireNonNull(outputFunction, "outputFunction is null");
        this.accumulatorStateDescriptors = requireNonNull(accumulatorStateDescriptors, "accumulatorStateDescriptors is null");
        this.lambdaInterfaces = List.copyOf(requireNonNull(lambdaInterfaces, "lambdaInterfaces is null"));
    }

    public MethodHandle getInputFunction()
    {
        return inputFunction;
    }

    public Optional getRemoveInputFunction()
    {
        return removeInputFunction;
    }

    public Optional getCombineFunction()
    {
        return combineFunction;
    }

    public MethodHandle getOutputFunction()
    {
        return outputFunction;
    }

    public List> getAccumulatorStateDescriptors()
    {
        return accumulatorStateDescriptors;
    }

    public List> getLambdaInterfaces()
    {
        return lambdaInterfaces;
    }

    public static class AccumulatorStateDescriptor
    {
        private final Class stateInterface;
        private final AccumulatorStateSerializer serializer;
        private final AccumulatorStateFactory factory;

        private AccumulatorStateDescriptor(Class stateInterface, AccumulatorStateSerializer serializer, AccumulatorStateFactory factory)
        {
            this.stateInterface = requireNonNull(stateInterface, "stateInterface is null");
            this.serializer = requireNonNull(serializer, "serializer is null");
            this.factory = requireNonNull(factory, "factory is null");
        }

        // this is only used to verify method interfaces
        public Class getStateInterface()
        {
            return stateInterface;
        }

        public AccumulatorStateSerializer getSerializer()
        {
            return serializer;
        }

        public AccumulatorStateFactory getFactory()
        {
            return factory;
        }

        public static  Builder builder(Class stateInterface)
        {
            return new Builder<>(stateInterface);
        }

        public static class Builder
        {
            private final Class stateInterface;
            private AccumulatorStateSerializer serializer;
            private AccumulatorStateFactory factory;

            private Builder(Class stateInterface)
            {
                this.stateInterface = requireNonNull(stateInterface, "stateInterface is null");
            }

            public Builder serializer(AccumulatorStateSerializer serializer)
            {
                this.serializer = serializer;
                return this;
            }

            public Builder factory(AccumulatorStateFactory factory)
            {
                this.factory = factory;
                return this;
            }

            public AccumulatorStateDescriptor build()
            {
                return new AccumulatorStateDescriptor<>(stateInterface, serializer, factory);
            }
        }
    }

    public static Builder builder()
    {
        return new Builder();
    }

    public static class Builder
    {
        private MethodHandle inputFunction;
        private Optional removeInputFunction = Optional.empty();
        private Optional combineFunction = Optional.empty();
        private MethodHandle outputFunction;
        private List> accumulatorStateDescriptors = new ArrayList<>();
        private List> lambdaInterfaces = List.of();

        private Builder() {}

        public Builder inputFunction(MethodHandle inputFunction)
        {
            this.inputFunction = requireNonNull(inputFunction, "inputFunction is null");
            return this;
        }

        public Builder removeInputFunction(MethodHandle removeInputFunction)
        {
            this.removeInputFunction = Optional.of(requireNonNull(removeInputFunction, "removeInputFunction is null"));
            return this;
        }

        public Builder combineFunction(MethodHandle combineFunction)
        {
            this.combineFunction = Optional.of(requireNonNull(combineFunction, "combineFunction is null"));
            return this;
        }

        public Builder outputFunction(MethodHandle outputFunction)
        {
            this.outputFunction = requireNonNull(outputFunction, "outputFunction is null");
            return this;
        }

        public  Builder accumulatorStateDescriptor(Class stateInterface, AccumulatorStateSerializer serializer, AccumulatorStateFactory factory)
        {
            this.accumulatorStateDescriptors.add(AccumulatorStateDescriptor.builder(stateInterface)
                    .serializer(serializer)
                    .factory(factory)
                    .build());
            return this;
        }

        public Builder accumulatorStateDescriptors(List> accumulatorStateDescriptors)
        {
            requireNonNull(accumulatorStateDescriptors, "accumulatorStateDescriptors is null");

            this.accumulatorStateDescriptors = new ArrayList<>();
            this.accumulatorStateDescriptors.addAll(accumulatorStateDescriptors);
            return this;
        }

        public Builder lambdaInterfaces(Class... lambdaInterfaces)
        {
            return lambdaInterfaces(List.of(lambdaInterfaces));
        }

        public Builder lambdaInterfaces(List> lambdaInterfaces)
        {
            this.lambdaInterfaces = List.copyOf(requireNonNull(lambdaInterfaces, "lambdaInterfaces is null"));
            return this;
        }

        public AggregationImplementation build()
        {
            return new AggregationImplementation(
                    inputFunction,
                    removeInputFunction,
                    combineFunction,
                    outputFunction,
                    accumulatorStateDescriptors,
                    lambdaInterfaces);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy