Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
io.trino.operator.scalar.annotations.ParametricScalarImplementation Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.scalar.annotations;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.primitives.Primitives;
import io.trino.metadata.FunctionBinding;
import io.trino.operator.ParametricImplementation;
import io.trino.operator.annotations.FunctionsParserHelper;
import io.trino.operator.annotations.ImplementationDependency;
import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction;
import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction.ScalarImplementationChoice;
import io.trino.operator.scalar.SpecializedSqlScalarFunction;
import io.trino.spi.block.Block;
import io.trino.spi.block.ValueBlock;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.function.BlockIndex;
import io.trino.spi.function.BlockPosition;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionDependencies;
import io.trino.spi.function.FunctionNullability;
import io.trino.spi.function.InOut;
import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention;
import io.trino.spi.function.InvocationConvention.InvocationReturnConvention;
import io.trino.spi.function.IsNull;
import io.trino.spi.function.Signature;
import io.trino.spi.function.SqlNullable;
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.type.FunctionType;
import java.lang.annotation.Annotation;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.ImmutableSortedSet.toImmutableSortedSet;
import static io.trino.operator.ParametricFunctionHelpers.bindDependencies;
import static io.trino.operator.annotations.FunctionsParserHelper.containsImplementationDependencyAnnotation;
import static io.trino.operator.annotations.FunctionsParserHelper.containsLegacyNullable;
import static io.trino.operator.annotations.FunctionsParserHelper.createTypeVariableConstraints;
import static io.trino.operator.annotations.FunctionsParserHelper.getDeclaredSpecializedTypeParameters;
import static io.trino.operator.annotations.FunctionsParserHelper.parseLiteralParameters;
import static io.trino.operator.annotations.FunctionsParserHelper.parseLongVariableConstraints;
import static io.trino.operator.annotations.ImplementationDependency.Factory.createDependency;
import static io.trino.operator.annotations.ImplementationDependency.checkTypeParameters;
import static io.trino.operator.annotations.ImplementationDependency.getImplementationDependencyAnnotation;
import static io.trino.operator.annotations.ImplementationDependency.validateImplementationDependencyAnnotation;
import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FUNCTION;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.VALUE_BLOCK_POSITION_NOT_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
import static io.trino.sql.analyzer.TypeSignatureTranslator.parseTypeSignature;
import static io.trino.util.Failures.checkCondition;
import static io.trino.util.Reflection.constructorMethodHandle;
import static io.trino.util.Reflection.methodHandle;
import static java.lang.String.CASE_INSENSITIVE_ORDER;
import static java.lang.String.format;
import static java.lang.invoke.MethodHandles.permuteArguments;
import static java.lang.reflect.Modifier.isStatic;
import static java.util.Objects.requireNonNull;
public class ParametricScalarImplementation
implements ParametricImplementation
{
private final Signature signature;
private final List>> argumentNativeContainerTypes; // argument native container type is Optional.empty() for function type
private final Map> specializedTypeParameters;
private final Class> returnNativeContainerType;
private final List choices;
private final FunctionNullability functionNullability;
private ParametricScalarImplementation(
Signature signature,
List>> argumentNativeContainerTypes,
Map> specializedTypeParameters,
List choices,
Class> returnContainerType)
{
this.signature = requireNonNull(signature, "signature is null");
this.argumentNativeContainerTypes = ImmutableList.copyOf(requireNonNull(argumentNativeContainerTypes, "argumentNativeContainerTypes is null"));
this.specializedTypeParameters = ImmutableMap.copyOf(requireNonNull(specializedTypeParameters, "specializedTypeParameters is null"));
this.choices = requireNonNull(choices, "choices is null");
checkArgument(!choices.isEmpty(), "choices is empty");
this.returnNativeContainerType = requireNonNull(returnContainerType, "returnContainerType is null");
for (Class> specializedJavaType : specializedTypeParameters.values()) {
checkArgument(!Primitives.isWrapperType(specializedJavaType), "specializedTypeParameter must not contain boxed primitive types");
}
ParametricScalarImplementationChoice defaultChoice = choices.get(0);
boolean hasBlockPositionArgument = defaultChoice.getArgumentConventions().stream()
.noneMatch(argumentConvention -> BLOCK_POSITION == argumentConvention || BLOCK_POSITION_NOT_NULL == argumentConvention);
checkArgument(hasBlockPositionArgument, "default choice can not use the block and position calling convention: %s", signature);
boolean returnNullability = defaultChoice.getReturnConvention().isNullable();
checkArgument(choices.stream().allMatch(choice -> choice.getReturnConvention().isNullable() == returnNullability), "all choices must have the same nullable flag: %s", signature);
List argumentNullability = defaultChoice.getArgumentConventions().stream()
.map(InvocationArgumentConvention::isNullable)
.collect(toImmutableList());
functionNullability = new FunctionNullability(returnNullability, argumentNullability);
checkArgument(
choices.stream().allMatch(choice -> matches(argumentNullability, choice.getArgumentConventions())),
"all choices must have the same nullable parameter flags: %s",
signature);
}
@Override
public FunctionNullability getFunctionNullability()
{
return functionNullability;
}
public Optional specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies)
{
List implementationChoices = new ArrayList<>();
for (Map.Entry> entry : specializedTypeParameters.entrySet()) {
if (!entry.getValue().isAssignableFrom(functionBinding.getTypeVariable(entry.getKey()).getJavaType())) {
return Optional.empty();
}
}
BoundSignature boundSignature = functionBinding.getBoundSignature();
if (returnNativeContainerType != Object.class && returnNativeContainerType != boundSignature.getReturnType().getJavaType()) {
return Optional.empty();
}
for (int i = 0; i < boundSignature.getArgumentTypes().size(); i++) {
if (boundSignature.getArgumentTypes().get(i) instanceof FunctionType) {
if (argumentNativeContainerTypes.get(i).isPresent()) {
return Optional.empty();
}
}
else {
if (argumentNativeContainerTypes.get(i).isEmpty()) {
return Optional.empty();
}
Class> argumentType = boundSignature.getArgumentTypes().get(i).getJavaType();
Class> argumentNativeContainerType = argumentNativeContainerTypes.get(i).get();
if (argumentNativeContainerType != Object.class && argumentNativeContainerType != argumentType) {
return Optional.empty();
}
}
}
for (ParametricScalarImplementationChoice choice : choices) {
MethodHandle boundMethodHandle = bindDependencies(choice.getMethodHandle(), choice.getDependencies(), functionBinding, functionDependencies);
Optional boundConstructor = choice.getConstructor().map(constructor -> {
MethodHandle result = bindDependencies(constructor, choice.getConstructorDependencies(), functionBinding, functionDependencies);
checkCondition(
result.type().parameterList().isEmpty(),
FUNCTION_IMPLEMENTATION_ERROR,
"All parameters of a constructor in a function definition class must be Dependencies. Signature: %s",
boundSignature);
return result;
});
implementationChoices.add(new ScalarImplementationChoice(
choice.getReturnConvention(),
choice.getArgumentConventions(),
choice.getLambdaInterfaces(),
boundMethodHandle.asType(javaMethodType(choice, boundSignature)),
boundConstructor));
}
return Optional.of(new ChoicesSpecializedSqlScalarFunction(boundSignature, implementationChoices));
}
@Override
public boolean hasSpecializedTypeParameters()
{
return !specializedTypeParameters.isEmpty();
}
@Override
public Signature getSignature()
{
return signature;
}
@VisibleForTesting
public List getChoices()
{
return choices;
}
private static MethodType javaMethodType(ParametricScalarImplementationChoice choice, BoundSignature signature)
{
// This method accomplishes two purposes:
// * Assert that the method signature is as expected.
// This catches errors that would otherwise surface during bytecode generation and class loading.
// * Adapt the method signature when necessary (for example, when the parameter type or return type is declared as Object).
ImmutableList.Builder> methodHandleParameterTypes = ImmutableList.builder();
if (choice.getConstructor().isPresent()) {
methodHandleParameterTypes.add(Object.class);
}
if (choice.hasConnectorSession()) {
methodHandleParameterTypes.add(ConnectorSession.class);
}
List argumentConventions = choice.getArgumentConventions();
int lambdaArgumentIndex = 0;
for (int i = 0; i < argumentConventions.size(); i++) {
InvocationArgumentConvention argumentConvention = argumentConventions.get(i);
Type signatureType = signature.getArgumentTypes().get(i);
switch (argumentConvention) {
case NEVER_NULL:
methodHandleParameterTypes.add(signatureType.getJavaType());
break;
case NULL_FLAG:
methodHandleParameterTypes.add(signatureType.getJavaType());
methodHandleParameterTypes.add(boolean.class);
break;
case BOXED_NULLABLE:
methodHandleParameterTypes.add(Primitives.wrap(signatureType.getJavaType()));
break;
case BLOCK_POSITION_NOT_NULL:
case BLOCK_POSITION:
methodHandleParameterTypes.add(Block.class);
methodHandleParameterTypes.add(int.class);
break;
case VALUE_BLOCK_POSITION:
case VALUE_BLOCK_POSITION_NOT_NULL:
methodHandleParameterTypes.add(ValueBlock.class);
methodHandleParameterTypes.add(int.class);
break;
case IN_OUT:
methodHandleParameterTypes.add(InOut.class);
break;
case FUNCTION:
methodHandleParameterTypes.add(choice.getLambdaInterfaces().get(lambdaArgumentIndex));
lambdaArgumentIndex++;
break;
default:
throw new UnsupportedOperationException("unknown argument convention: " + argumentConvention);
}
}
Class> methodHandleReturnType = signature.getReturnType().getJavaType();
if (choice.getReturnConvention().isNullable()) {
methodHandleReturnType = Primitives.wrap(methodHandleReturnType);
}
return MethodType.methodType(methodHandleReturnType, methodHandleParameterTypes.build());
}
private static boolean matches(List argumentNullability, List argumentConventions)
{
if (argumentNullability.size() != argumentConventions.size()) {
return false;
}
for (int i = 0; i < argumentNullability.size(); i++) {
boolean expectedNullable = argumentNullability.get(i);
InvocationArgumentConvention argumentConvention = argumentConventions.get(i);
if (argumentConvention == FUNCTION) {
// functions are never null
if (expectedNullable) {
return false;
}
}
else if (expectedNullable != argumentConvention.isNullable()) {
return false;
}
}
return true;
}
public static final class Builder
{
private final Signature signature;
private final List>> argumentNativeContainerTypes; // argument native container type is Optional.empty() for function type
private final Map> specializedTypeParameters;
private final Class> returnNativeContainerType;
private final List choices;
public Builder(
Signature signature,
List>> argumentNativeContainerTypes,
Map> specializedTypeParameters,
Class> returnNativeContainerType)
{
this.signature = requireNonNull(signature, "signature is null");
this.argumentNativeContainerTypes = ImmutableList.copyOf(requireNonNull(argumentNativeContainerTypes, "argumentNativeContainerTypes is null"));
this.specializedTypeParameters = ImmutableMap.copyOf(requireNonNull(specializedTypeParameters, "specializedTypeParameters is null"));
this.choices = new ArrayList<>();
this.returnNativeContainerType = requireNonNull(returnNativeContainerType, "returnNativeContainerType is null");
}
void addChoice(ParametricScalarImplementationChoice choice)
{
this.choices.add(choice);
}
public ParametricScalarImplementation build()
{
choices.sort(ParametricScalarImplementationChoice::compareTo);
return new ParametricScalarImplementation(signature, argumentNativeContainerTypes, specializedTypeParameters, choices, returnNativeContainerType);
}
}
public static final class ParametricScalarImplementationChoice
implements Comparable
{
private final InvocationReturnConvention returnConvention;
private final List argumentConventions;
private final List> lambdaInterfaces;
private final MethodHandle methodHandle;
private final Optional constructor;
private final List dependencies;
private final List constructorDependencies;
private final int numberOfBlockPositionArguments;
private final boolean hasConnectorSession;
private ParametricScalarImplementationChoice(
InvocationReturnConvention returnConvention,
boolean hasConnectorSession,
List argumentConventions,
List> lambdaInterfaces,
MethodHandle methodHandle,
Optional constructor,
List dependencies,
List constructorDependencies)
{
this.returnConvention = requireNonNull(returnConvention, "returnConvention is null");
this.hasConnectorSession = hasConnectorSession;
this.argumentConventions = ImmutableList.copyOf(requireNonNull(argumentConventions, "argumentConventions is null"));
this.lambdaInterfaces = ImmutableList.copyOf(requireNonNull(lambdaInterfaces, "lambdaInterfaces is null"));
this.methodHandle = requireNonNull(methodHandle, "methodHandle is null");
this.constructor = requireNonNull(constructor, "constructor is null");
this.dependencies = ImmutableList.copyOf(requireNonNull(dependencies, "dependencies is null"));
this.constructorDependencies = ImmutableList.copyOf(requireNonNull(constructorDependencies, "constructorDependencies is null"));
this.numberOfBlockPositionArguments = (int) argumentConventions.stream()
.filter(argumentConvention -> BLOCK_POSITION == argumentConvention || BLOCK_POSITION_NOT_NULL == argumentConvention)
.count();
}
public InvocationReturnConvention getReturnConvention()
{
return returnConvention;
}
public boolean hasConnectorSession()
{
return hasConnectorSession;
}
public MethodHandle getMethodHandle()
{
return methodHandle;
}
@VisibleForTesting
public List getDependencies()
{
return dependencies;
}
public List getArgumentConventions()
{
return argumentConventions;
}
public List> getLambdaInterfaces()
{
return lambdaInterfaces;
}
public boolean checkDependencies()
{
for (int i = 1; i < getDependencies().size(); i++) {
if (!getDependencies().get(i).equals(getDependencies().get(0))) {
return false;
}
}
return true;
}
@VisibleForTesting
public List getConstructorDependencies()
{
return constructorDependencies;
}
public Optional getConstructor()
{
return constructor;
}
@Override
public int compareTo(ParametricScalarImplementationChoice choice)
{
if (choice.numberOfBlockPositionArguments < this.numberOfBlockPositionArguments) {
return 1;
}
return -1;
}
}
public static final class SpecializedSignature
{
private final Signature signature;
private final List>> argumentNativeContainerTypes;
private final Map> specializedTypeParameters;
private final Class> returnNativeContainerType;
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SpecializedSignature that = (SpecializedSignature) o;
return Objects.equals(signature, that.signature) &&
Objects.equals(argumentNativeContainerTypes, that.argumentNativeContainerTypes) &&
Objects.equals(specializedTypeParameters, that.specializedTypeParameters) &&
Objects.equals(returnNativeContainerType, that.returnNativeContainerType);
}
@Override
public int hashCode()
{
return Objects.hash(signature, argumentNativeContainerTypes, specializedTypeParameters, returnNativeContainerType);
}
private SpecializedSignature(
Signature signature,
List>> argumentNativeContainerTypes,
Map> specializedTypeParameters,
Class> returnNativeContainerType)
{
this.signature = signature;
this.argumentNativeContainerTypes = argumentNativeContainerTypes;
this.specializedTypeParameters = specializedTypeParameters;
this.returnNativeContainerType = returnNativeContainerType;
}
}
public static final class Parser
{
private final Signature signature;
private final List argumentConventions = new ArrayList<>();
private final List> lambdaInterfaces = new ArrayList<>();
private final List>> argumentNativeContainerTypes = new ArrayList<>();
private final MethodHandle methodHandle;
private final Set typeParameters;
private final Set literalParameters;
private final Set typeParameterNames;
private final Map> specializedTypeParameters;
private final Class> returnNativeContainerType;
private boolean hasConnectorSession;
private final ParametricScalarImplementationChoice choice;
Parser(Method method, Optional> constructor)
{
Signature.Builder signatureBuilder = Signature.builder();
boolean nullable = method.getAnnotation(SqlNullable.class) != null;
checkArgument(nullable || !containsLegacyNullable(method.getAnnotations()), "Method [%s] is annotated with @Nullable but not @SqlNullable", method);
typeParameters = ImmutableSet.copyOf(method.getAnnotationsByType(TypeParameter.class));
literalParameters = parseLiteralParameters(method);
typeParameterNames = typeParameters.stream()
.map(TypeParameter::value)
.collect(toImmutableSortedSet(CASE_INSENSITIVE_ORDER));
SqlType returnType = method.getAnnotation(SqlType.class);
checkArgument(returnType != null, "Method [%s] is missing @SqlType annotation", method);
signatureBuilder.returnType(parseTypeSignature(returnType.value(), literalParameters));
Class> actualReturnType = method.getReturnType();
this.returnNativeContainerType = Primitives.unwrap(actualReturnType);
if (Primitives.isWrapperType(actualReturnType)) {
checkArgument(nullable, "Method [%s] has wrapper return type %s but is missing @SqlNullable", method, actualReturnType.getSimpleName());
}
else if (actualReturnType.isPrimitive()) {
checkArgument(!nullable, "Method [%s] annotated with @SqlNullable has primitive return type %s", method, actualReturnType.getSimpleName());
}
parseLongVariableConstraints(method, signatureBuilder);
this.specializedTypeParameters = getDeclaredSpecializedTypeParameters(method, typeParameters);
for (TypeParameter typeParameter : typeParameters) {
checkArgument(
typeParameter.value().matches("[A-Z][A-Z0-9]*"),
"Expected type parameter to only contain A-Z and 0-9 (starting with A-Z), but got %s on method [%s]", typeParameter.value(), method);
}
inferSpecialization(method, actualReturnType, returnType.value());
List dependencies = new ArrayList<>();
parseArguments(method, signatureBuilder, dependencies);
List constructorDependencies = new ArrayList<>();
Optional constructorMethodHandle = getConstructor(method, constructor, constructorDependencies);
this.methodHandle = getMethodHandle(method, dependencies);
this.choice = new ParametricScalarImplementationChoice(
nullable ? NULLABLE_RETURN : FAIL_ON_NULL,
hasConnectorSession,
argumentConventions,
lambdaInterfaces,
methodHandle,
constructorMethodHandle,
dependencies,
constructorDependencies);
createTypeVariableConstraints(typeParameters, dependencies)
.forEach(signatureBuilder::typeVariableConstraint);
signature = signatureBuilder.build();
}
private void parseArguments(Method method, Signature.Builder signatureBuilder, List dependencies)
{
boolean encounteredNonDependencyAnnotation = false;
int parameterIndex = 0;
while (parameterIndex < method.getParameterCount()) {
Parameter parameter = method.getParameters()[parameterIndex];
Class> parameterType = parameter.getType();
// Skip injected parameters
if (parameterType == ConnectorSession.class) {
checkCondition(!hasConnectorSession, FUNCTION_IMPLEMENTATION_ERROR, "Method [%s] has more than 1 ConnectorSession in the parameter list", method);
hasConnectorSession = true;
parameterIndex++;
continue;
}
Optional implementationDependency = getImplementationDependencyAnnotation(parameter);
if (implementationDependency.isPresent()) {
checkCondition(!encounteredNonDependencyAnnotation, FUNCTION_IMPLEMENTATION_ERROR, "Method [%s] has parameters annotated with Dependency annotations that appears after other parameters", method);
// check if only declared typeParameters and literalParameters are used
validateImplementationDependencyAnnotation(method, implementationDependency.get(), typeParameterNames, literalParameters);
dependencies.add(createDependency(implementationDependency.get(), literalParameters, parameterType));
parameterIndex++;
}
else {
encounteredNonDependencyAnnotation = true;
Annotation[] annotations = parameter.getAnnotations();
checkArgument(Stream.of(annotations).noneMatch(IsNull.class::isInstance), "Method [%s] has @IsNull parameter that does not follow a @SqlType parameter", method);
SqlType type = Stream.of(annotations)
.filter(SqlType.class::isInstance)
.map(SqlType.class::cast)
.findFirst()
.orElseThrow(() -> new IllegalArgumentException(format("Method [%s] is missing @SqlType annotation for parameter", method)));
TypeSignature typeSignature = parseTypeSignature(type.value(), literalParameters);
signatureBuilder.argumentType(typeSignature);
if (typeSignature.getBase().equals(FunctionType.NAME)) {
// function type
checkCondition(parameterType.isAnnotationPresent(FunctionalInterface.class), FUNCTION_IMPLEMENTATION_ERROR, "argument %s is marked as lambda but the function interface class is not annotated: %s", parameterIndex, methodHandle);
argumentConventions.add(FUNCTION);
lambdaInterfaces.add(parameterType);
argumentNativeContainerTypes.add(Optional.empty());
parameterIndex++;
}
else {
// value type
InvocationArgumentConvention argumentConvention;
boolean nullable = Stream.of(annotations).anyMatch(SqlNullable.class::isInstance);
if (Stream.of(annotations).anyMatch(BlockPosition.class::isInstance)) {
verify(method.getParameterCount() > (parameterIndex + 1));
if (parameterType == Block.class) {
argumentConvention = nullable ? BLOCK_POSITION : BLOCK_POSITION_NOT_NULL;
}
else {
verify(ValueBlock.class.isAssignableFrom(parameterType));
argumentConvention = nullable ? VALUE_BLOCK_POSITION : VALUE_BLOCK_POSITION_NOT_NULL;
}
Annotation[] parameterAnnotations = method.getParameterAnnotations()[parameterIndex + 1];
verify(Stream.of(parameterAnnotations).anyMatch(BlockIndex.class::isInstance));
}
else if (nullable) {
checkCondition(!parameterType.isPrimitive(), FUNCTION_IMPLEMENTATION_ERROR, "Method [%s] has parameter with primitive type %s annotated with @SqlNullable", method, parameterType.getSimpleName());
argumentConvention = BOXED_NULLABLE;
}
else if (parameterType.equals(InOut.class)) {
argumentConvention = IN_OUT;
}
else {
// USE_NULL_FLAG or RETURN_NULL_ON_NULL
checkCondition(parameterType == Void.class || !Primitives.isWrapperType(parameterType), FUNCTION_IMPLEMENTATION_ERROR, "A parameter with USE_NULL_FLAG or RETURN_NULL_ON_NULL convention must not use wrapper type. Found in method [%s]", method);
boolean useNullFlag = false;
if (method.getParameterCount() > (parameterIndex + 1)) {
Annotation[] parameterAnnotations = method.getParameterAnnotations()[parameterIndex + 1];
if (Stream.of(parameterAnnotations).anyMatch(IsNull.class::isInstance)) {
Class> isNullType = method.getParameterTypes()[parameterIndex + 1];
checkArgument(Stream.of(parameterAnnotations).filter(FunctionsParserHelper::isTrinoAnnotation).allMatch(IsNull.class::isInstance), "Method [%s] has @IsNull parameter that has other annotations", method);
checkArgument(isNullType == boolean.class, "Method [%s] has non-boolean parameter with @IsNull", method);
checkArgument((parameterType == Void.class) || !Primitives.isWrapperType(parameterType), "Method [%s] uses @IsNull following a parameter with boxed primitive type: %s", method, parameterType.getSimpleName());
useNullFlag = true;
}
}
if (useNullFlag) {
argumentConvention = NULL_FLAG;
}
else {
argumentConvention = NEVER_NULL;
}
}
if (argumentConvention == BLOCK_POSITION || argumentConvention == BLOCK_POSITION_NOT_NULL || argumentConvention == VALUE_BLOCK_POSITION || argumentConvention == VALUE_BLOCK_POSITION_NOT_NULL) {
argumentNativeContainerTypes.add(Optional.of(type.nativeContainerType()));
}
else {
inferSpecialization(method, parameterType, type.value());
checkCondition(type.nativeContainerType().equals(Object.class), FUNCTION_IMPLEMENTATION_ERROR, "@SqlType can only contain an explicitly specified nativeContainerType when using @BlockPosition");
argumentNativeContainerTypes.add(Optional.of(Primitives.unwrap(parameterType)));
}
argumentConventions.add(argumentConvention);
parameterIndex += argumentConvention.getParameterCount();
}
}
}
}
private void inferSpecialization(Method method, Class> parameterType, String typeParameterName)
{
if (typeParameterNames.contains(typeParameterName) && parameterType != Object.class) {
// Infer specialization on this type parameter.
// We don't do this for Object because it could match any type.
Class> specialization = specializedTypeParameters.get(typeParameterName);
Class> nativeParameterType = Primitives.unwrap(parameterType);
checkArgument(specialization == null || specialization.equals(nativeParameterType), "Method [%s] type %s has conflicting specializations %s and %s", method, typeParameterName, specialization, nativeParameterType);
specializedTypeParameters.put(typeParameterName, nativeParameterType);
}
}
// Find matching constructor, if this is an instance method, and populate constructorDependencies
private Optional getConstructor(Method method, Optional> optionalConstructor, List constructorDependencies)
{
if (isStatic(method.getModifiers())) {
return Optional.empty();
}
checkArgument(optionalConstructor.isPresent(), "Method [%s] is an instance method. It must be in a class annotated with @ScalarFunction or @ScalarOperator, and the class is required to have a public constructor.", method);
Constructor> constructor = optionalConstructor.get();
Set constructorTypeParameters = Stream.of(constructor.getAnnotationsByType(TypeParameter.class))
.collect(toImmutableSet());
checkArgument(constructorTypeParameters.containsAll(typeParameters), "Method [%s] is an instance method and requires a public constructor containing all type parameters: %s", method, typeParameters);
for (int i = 0; i < constructor.getParameterCount(); i++) {
Annotation[] annotations = constructor.getParameterAnnotations()[i];
checkArgument(containsImplementationDependencyAnnotation(annotations), "Constructors may only have meta parameters [%s]", constructor);
checkArgument(annotations.length == 1, "Meta parameters may only have a single annotation [%s]", constructor);
Annotation annotation = annotations[0];
if (annotation instanceof TypeParameter) {
checkTypeParameters(parseTypeSignature(((TypeParameter) annotation).value(), ImmutableSet.of()), typeParameterNames, method);
}
constructorDependencies.add(createDependency(annotation, literalParameters, constructor.getParameterTypes()[i]));
}
MethodHandle result = constructorMethodHandle(FUNCTION_IMPLEMENTATION_ERROR, constructor);
// Change type of return value to Object to make sure callers won't have classloader issues
return Optional.of(result.asType(result.type().changeReturnType(Object.class)));
}
private static MethodHandle getMethodHandle(Method method, List dependencies)
{
MethodHandle methodHandle = methodHandle(FUNCTION_IMPLEMENTATION_ERROR, method);
if (!isStatic(method.getModifiers())) {
// Change type of "this" argument to Object to make sure callers won't have classloader issues
methodHandle = methodHandle.asType(methodHandle.type().changeParameterType(0, Object.class));
// Re-arrange the parameters, so that the "this" parameter is after the meta parameters
int[] permutedIndices = new int[methodHandle.type().parameterCount()];
permutedIndices[0] = dependencies.size();
MethodType newType = methodHandle.type().changeParameterType(dependencies.size(), methodHandle.type().parameterType(0));
for (int i = 0; i < dependencies.size(); i++) {
permutedIndices[i + 1] = i;
newType = newType.changeParameterType(i, methodHandle.type().parameterType(i + 1));
}
for (int i = dependencies.size() + 1; i < permutedIndices.length; i++) {
permutedIndices[i] = i;
}
methodHandle = permuteArguments(methodHandle, newType, permutedIndices);
}
return methodHandle;
}
public List>> getArgumentNativeContainerTypes()
{
return argumentNativeContainerTypes;
}
public Map> getSpecializedTypeParameters()
{
return specializedTypeParameters;
}
public Class> getReturnNativeContainerType()
{
return returnNativeContainerType;
}
public ParametricScalarImplementationChoice getChoice()
{
return choice;
}
public SpecializedSignature getSpecializedSignature()
{
return new SpecializedSignature(
getSignature(),
argumentNativeContainerTypes,
specializedTypeParameters,
returnNativeContainerType);
}
public Signature getSignature()
{
return signature;
}
}
}