All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.operator.scalar;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention;
import io.trino.spi.function.InvocationConvention.InvocationReturnConvention;
import io.trino.spi.function.ScalarFunctionAdapter;
import io.trino.spi.function.ScalarFunctionImplementation;

import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_FOUND;
import static java.lang.String.format;
import static java.util.Comparator.comparingInt;
import static java.util.Objects.requireNonNull;

public final class ChoicesSpecializedSqlScalarFunction
        implements SpecializedSqlScalarFunction
{
    private final BoundSignature boundSignature;
    private final List choices;

    public ChoicesSpecializedSqlScalarFunction(
            BoundSignature boundSignature,
            InvocationReturnConvention returnConvention,
            List argumentConventions,
            MethodHandle methodHandle)
    {
        this(boundSignature, returnConvention, argumentConventions, ImmutableList.of(), methodHandle, Optional.empty());
    }

    public ChoicesSpecializedSqlScalarFunction(
            BoundSignature boundSignature,
            InvocationReturnConvention returnConvention,
            List argumentConventions,
            MethodHandle methodHandle,
            Optional instanceFactory)
    {
        this(boundSignature, returnConvention, argumentConventions, ImmutableList.of(), methodHandle, instanceFactory);
    }

    public ChoicesSpecializedSqlScalarFunction(
            BoundSignature boundSignature,
            InvocationReturnConvention returnConvention,
            List argumentConventions,
            List> lambdaInterfaces,
            MethodHandle methodHandle,
            Optional instanceFactory)
    {
        this(boundSignature, ImmutableList.of(new ScalarImplementationChoice(returnConvention, argumentConventions, lambdaInterfaces, methodHandle, instanceFactory)));
    }

    /**
     * Creates a ScalarFunctionImplementation consisting of one or more choices.
     * 

* All choices must have the same SQL signature, and are equivalent in what they do. * The first choice is the default choice, which is the one used for legacy access methods. * The default choice must be usable under any context. (e.g. it must not use BLOCK_POSITION convention.) * * @param choices the list of choices, ordered from generic to specific */ public ChoicesSpecializedSqlScalarFunction(BoundSignature boundSignature, List choices) { this.boundSignature = boundSignature; checkArgument(!choices.isEmpty(), "choices is an empty list"); this.choices = ImmutableList.copyOf(choices); } @VisibleForTesting public List getChoices() { return choices; } @Override public ScalarFunctionImplementation getScalarFunctionImplementation(InvocationConvention invocationConvention) { List choices = new ArrayList<>(); for (ScalarImplementationChoice choice : this.choices) { InvocationConvention callingConvention = choice.getInvocationConvention(); if (ScalarFunctionAdapter.canAdapt(callingConvention, invocationConvention)) { choices.add(choice); } } if (choices.isEmpty()) { throw new TrinoException(FUNCTION_NOT_FOUND, format("Function implementation for (%s) cannot be adapted to convention (%s)", boundSignature, invocationConvention)); } ScalarImplementationChoice bestChoice = Collections.max(choices, comparingInt(ScalarImplementationChoice::getScore)); MethodHandle methodHandle = ScalarFunctionAdapter.adapt( bestChoice.getMethodHandle(), boundSignature.getReturnType(), boundSignature.getArgumentTypes(), bestChoice.getInvocationConvention(), invocationConvention); ScalarFunctionImplementation.Builder builder = ScalarFunctionImplementation.builder() .methodHandle(methodHandle); bestChoice.getInstanceFactory().ifPresent(builder::instanceFactory); builder.lambdaInterfaces(bestChoice.getLambdaInterfaces()); return builder.build(); } public static class ScalarImplementationChoice { private final MethodHandle methodHandle; private final Optional instanceFactory; private final InvocationConvention invocationConvention; private final List> lambdaInterfaces; private final int score; public ScalarImplementationChoice( InvocationReturnConvention returnConvention, List argumentConventions, List> lambdaInterfaces, MethodHandle methodHandle, Optional instanceFactory) { this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); this.instanceFactory = requireNonNull(instanceFactory, "instanceFactory is null"); if (instanceFactory.isPresent()) { Class instanceType = instanceFactory.get().type().returnType(); checkArgument(instanceFactory.get().type().parameterList().isEmpty(), "instanceFactory should have no parameter"); checkArgument(instanceType.equals(methodHandle.type().parameterType(0)), "methodHandle is not an instance method"); } List> parameterList = methodHandle.type().parameterList(); boolean hasSession = false; if (parameterList.contains(ConnectorSession.class)) { checkArgument(parameterList.stream().filter(ConnectorSession.class::equals).count() == 1, "function implementation should have exactly one ConnectorSession parameter"); if (instanceFactory.isEmpty()) { checkArgument(parameterList.get(0) == ConnectorSession.class, "ConnectorSession must be the first argument when instanceFactory is not present"); } else { checkArgument(parameterList.get(1) == ConnectorSession.class, "ConnectorSession must be the second argument when instanceFactory is present"); } hasSession = true; } this.lambdaInterfaces = ImmutableList.copyOf(requireNonNull(lambdaInterfaces, "lambdaInterfaces is null")); invocationConvention = new InvocationConvention( argumentConventions, returnConvention, hasSession, instanceFactory.isPresent()); checkArgument(lambdaInterfaces.size() <= argumentConventions.size()); score = computeScore(invocationConvention); } public MethodHandle getMethodHandle() { return methodHandle; } public Optional getInstanceFactory() { return instanceFactory; } public List> getLambdaInterfaces() { return lambdaInterfaces; } public InvocationConvention getInvocationConvention() { return invocationConvention; } public int getScore() { return score; } private static int computeScore(InvocationConvention callingConvention) { int score = 0; for (InvocationArgumentConvention argument : callingConvention.getArgumentConventions()) { switch (argument) { case NULL_FLAG: score += 1; break; case BLOCK_POSITION_NOT_NULL: case BLOCK_POSITION: score += 1000; break; case VALUE_BLOCK_POSITION_NOT_NULL: case VALUE_BLOCK_POSITION: score += 2000; break; case IN_OUT: score += 10_000; break; default: break; } } return score; } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy