All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.operator.scalar.annotations.ScalarFromAnnotationsParser Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.operator.scalar.annotations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.metadata.SqlScalarFunction;
import io.trino.operator.ParametricImplementationsGroup;
import io.trino.operator.annotations.FunctionsParserHelper;
import io.trino.operator.scalar.ParametricScalar;
import io.trino.operator.scalar.ScalarHeader;
import io.trino.operator.scalar.annotations.ParametricScalarImplementation.SpecializedSignature;
import io.trino.spi.function.ScalarFunction;
import io.trino.spi.function.ScalarOperator;
import io.trino.spi.function.Signature;
import io.trino.spi.function.SqlType;

import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.operator.scalar.annotations.OperatorValidator.validateOperator;
import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR;
import static io.trino.util.Failures.checkCondition;
import static java.util.Objects.requireNonNull;

public final class ScalarFromAnnotationsParser
{
    private ScalarFromAnnotationsParser() {}

    public static List parseFunctionDefinition(Class clazz)
    {
        ImmutableList.Builder builder = ImmutableList.builder();
        boolean deprecated = clazz.getAnnotationsByType(Deprecated.class).length > 0;
        for (ScalarHeaderAndMethods scalar : findScalarsInFunctionDefinitionClass(clazz)) {
            builder.add(parseParametricScalar(scalar, FunctionsParserHelper.findConstructor(clazz), deprecated));
        }
        return builder.build();
    }

    public static List parseFunctionDefinitions(Class clazz)
    {
        ImmutableList.Builder builder = ImmutableList.builder();
        for (ScalarHeaderAndMethods methods : findScalarsInFunctionSetClass(clazz)) {
            boolean deprecated = methods.methods().iterator().next().getAnnotationsByType(Deprecated.class).length > 0;
            // Non-static function only makes sense in classes annotated with @ScalarFunction or @ScalarOperator.
            builder.add(parseParametricScalar(methods, FunctionsParserHelper.findConstructor(clazz), deprecated));
        }
        return builder.build();
    }

    private static List findScalarsInFunctionDefinitionClass(Class annotated)
    {
        ImmutableList.Builder builder = ImmutableList.builder();
        List classHeaders = ScalarHeader.fromAnnotatedElement(annotated);
        checkArgument(!classHeaders.isEmpty(), "Class [%s] that defines function must be annotated with @ScalarFunction or @ScalarOperator", annotated.getName());

        for (ScalarHeader header : classHeaders) {
            Set methods = FunctionsParserHelper.findPublicMethodsWithAnnotation(annotated, SqlType.class, ScalarFunction.class, ScalarOperator.class);
            checkCondition(!methods.isEmpty(), FUNCTION_IMPLEMENTATION_ERROR, "Parametric class [%s] does not have any annotated methods", annotated.getName());
            for (Method method : methods) {
                checkArgument(method.getAnnotation(ScalarFunction.class) == null, "Parametric class method [%s] is annotated with @ScalarFunction", method);
                checkArgument(method.getAnnotation(ScalarOperator.class) == null, "Parametric class method [%s] is annotated with @ScalarOperator", method);
            }
            builder.add(new ScalarHeaderAndMethods(header, methods));
        }

        return builder.build();
    }

    private static List findScalarsInFunctionSetClass(Class annotated)
    {
        ImmutableList.Builder builder = ImmutableList.builder();
        for (Method method : FunctionsParserHelper.findPublicMethodsWithAnnotation(annotated, SqlType.class, ScalarFunction.class, ScalarOperator.class)) {
            checkCondition((method.getAnnotation(ScalarFunction.class) != null) || (method.getAnnotation(ScalarOperator.class) != null),
                    FUNCTION_IMPLEMENTATION_ERROR, "Method [%s] annotated with @SqlType is missing @ScalarFunction or @ScalarOperator", method);
            for (ScalarHeader header : ScalarHeader.fromAnnotatedElement(method)) {
                builder.add(new ScalarHeaderAndMethods(header, ImmutableSet.of(method)));
            }
        }
        List methods = builder.build();
        checkArgument(!methods.isEmpty(), "Class [%s] does not have any methods annotated with @ScalarFunction or @ScalarOperator", annotated.getName());
        return methods;
    }

    private static SqlScalarFunction parseParametricScalar(ScalarHeaderAndMethods scalar, Optional> constructor, boolean deprecated)
    {
        Map signatures = new HashMap<>();
        for (Method method : scalar.methods()) {
            ParametricScalarImplementation.Parser implementation = new ParametricScalarImplementation.Parser(method, constructor);
            if (!signatures.containsKey(implementation.getSpecializedSignature())) {
                ParametricScalarImplementation.Builder builder = new ParametricScalarImplementation.Builder(
                        implementation.getSignature(),
                        implementation.getArgumentNativeContainerTypes(),
                        implementation.getSpecializedTypeParameters(),
                        implementation.getReturnNativeContainerType());
                signatures.put(implementation.getSpecializedSignature(), builder);
                builder.addChoice(implementation.getChoice());
            }
            else {
                ParametricScalarImplementation.Builder builder = signatures.get(implementation.getSpecializedSignature());
                builder.addChoice(implementation.getChoice());
            }
        }

        ParametricImplementationsGroup.Builder implementationsBuilder = ParametricImplementationsGroup.builder();
        for (ParametricScalarImplementation.Builder implementation : signatures.values()) {
            implementationsBuilder.addImplementation(implementation.build());
        }
        ParametricImplementationsGroup implementations = implementationsBuilder.build();
        Signature scalarSignature = implementations.getSignature();

        scalar.header().getOperatorType().ifPresent(operatorType ->
                validateOperator(operatorType, scalarSignature.getReturnType(), scalarSignature.getArgumentTypes()));

        return new ParametricScalar(scalarSignature, scalar.header(), implementations, deprecated);
    }

    private record ScalarHeaderAndMethods(ScalarHeader header, Set methods)
    {
        private ScalarHeaderAndMethods
        {
            requireNonNull(header);
            requireNonNull(methods);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy