All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.functions.UserDefinedFunctionHelper Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.functions;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.ExecutionConfig.ClosureCleanerLevel;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.ClosureCleaner;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.PipelineOptions;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.catalog.CatalogFunction;
import org.apache.flink.table.functions.SpecializedFunction.SpecializedContext;
import org.apache.flink.table.functions.python.utils.PythonFunctionUtils;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.extraction.ExtractionUtils;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.util.InstantiationUtil;

import javax.annotation.Nullable;

import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

import static org.apache.flink.api.java.typeutils.TypeExtractionUtils.getAllDeclaredMethods;
import static org.apache.flink.util.Preconditions.checkState;

/**
 * Utility for dealing with subclasses of {@link UserDefinedFunction}. The purpose of this class is
 * to keep the user-facing APIs clean and offer methods/constants from here.
 *
 * 

It contains methods for instantiating, validating and extracting types during function * registration in a catalog. */ @Internal public final class UserDefinedFunctionHelper { // method names of code generated UDFs public static final String SCALAR_EVAL = "eval"; public static final String TABLE_EVAL = "eval"; public static final String AGGREGATE_ACCUMULATE = "accumulate"; public static final String AGGREGATE_RETRACT = "retract"; public static final String AGGREGATE_MERGE = "merge"; public static final String TABLE_AGGREGATE_ACCUMULATE = "accumulate"; public static final String TABLE_AGGREGATE_RETRACT = "retract"; public static final String TABLE_AGGREGATE_MERGE = "merge"; public static final String TABLE_AGGREGATE_EMIT = "emitValue"; public static final String TABLE_AGGREGATE_EMIT_RETRACT = "emitUpdateWithRetract"; public static final String ASYNC_TABLE_EVAL = "eval"; /** * Tries to infer the TypeInformation of an AggregateFunction's accumulator type. * * @param aggregateFunction The AggregateFunction for which the accumulator type is inferred. * @return The inferred accumulator type of the AggregateFunction. */ public static TypeInformation getReturnTypeOfAggregateFunction( ImperativeAggregateFunction aggregateFunction) { return getReturnTypeOfAggregateFunction(aggregateFunction, null); } /** * Tries to infer the TypeInformation of an AggregateFunction's accumulator type. * * @param aggregateFunction The AggregateFunction for which the accumulator type is inferred. * @param scalaType The implicitly inferred type of the accumulator type. * @return The inferred accumulator type of the AggregateFunction. */ public static TypeInformation getReturnTypeOfAggregateFunction( ImperativeAggregateFunction aggregateFunction, TypeInformation scalaType) { TypeInformation userProvidedType = aggregateFunction.getResultType(); if (userProvidedType != null) { return userProvidedType; } else if (scalaType != null) { return scalaType; } else { return TypeExtractor.createTypeInfo( aggregateFunction, ImperativeAggregateFunction.class, aggregateFunction.getClass(), 0); } } /** * Tries to infer the TypeInformation of an AggregateFunction's accumulator type. * * @param aggregateFunction The AggregateFunction for which the accumulator type is inferred. * @return The inferred accumulator type of the AggregateFunction. */ public static TypeInformation getAccumulatorTypeOfAggregateFunction( ImperativeAggregateFunction aggregateFunction) { return getAccumulatorTypeOfAggregateFunction(aggregateFunction, null); } /** * Tries to infer the TypeInformation of an AggregateFunction's accumulator type. * * @param aggregateFunction The AggregateFunction for which the accumulator type is inferred. * @param scalaType The implicitly inferred type of the accumulator type. * @return The inferred accumulator type of the AggregateFunction. */ public static TypeInformation getAccumulatorTypeOfAggregateFunction( ImperativeAggregateFunction aggregateFunction, TypeInformation scalaType) { TypeInformation userProvidedType = aggregateFunction.getAccumulatorType(); if (userProvidedType != null) { return userProvidedType; } else if (scalaType != null) { return scalaType; } else { return TypeExtractor.createTypeInfo( aggregateFunction, ImperativeAggregateFunction.class, aggregateFunction.getClass(), 1); } } /** * Tries to infer the TypeInformation of an AggregateFunction's accumulator type. * * @param tableFunction The TableFunction for which the accumulator type is inferred. * @return The inferred accumulator type of the AggregateFunction. */ public static TypeInformation getReturnTypeOfTableFunction( TableFunction tableFunction) { return getReturnTypeOfTableFunction(tableFunction, null); } /** * Tries to infer the TypeInformation of an AggregateFunction's accumulator type. * * @param tableFunction The TableFunction for which the accumulator type is inferred. * @param scalaType The implicitly inferred type of the accumulator type. * @return The inferred accumulator type of the AggregateFunction. */ public static TypeInformation getReturnTypeOfTableFunction( TableFunction tableFunction, TypeInformation scalaType) { TypeInformation userProvidedType = tableFunction.getResultType(); if (userProvidedType != null) { return userProvidedType; } else if (scalaType != null) { return scalaType; } else { return TypeExtractor.createTypeInfo( tableFunction, TableFunction.class, tableFunction.getClass(), 0); } } /** * Instantiates a {@link UserDefinedFunction} from a {@link CatalogFunction}. * *

Requires access to {@link ReadableConfig} if Python functions should be supported. */ @SuppressWarnings({"unchecked", "rawtypes"}) public static UserDefinedFunction instantiateFunction( ClassLoader classLoader, @Nullable ReadableConfig config, String name, CatalogFunction catalogFunction) { try { switch (catalogFunction.getFunctionLanguage()) { case PYTHON: if (config == null) { throw new IllegalStateException( "Python functions are not supported at this location."); } return (UserDefinedFunction) PythonFunctionUtils.getPythonFunction( catalogFunction.getClassName(), config, classLoader); case JAVA: case SCALA: final Class functionClass = classLoader.loadClass(catalogFunction.getClassName()); return UserDefinedFunctionHelper.instantiateFunction((Class) functionClass); default: throw new IllegalArgumentException( "Unknown function language: " + catalogFunction.getFunctionLanguage()); } } catch (Exception e) { throw new ValidationException( String.format("Cannot instantiate user-defined function '%s'.", name), e); } } /** * Instantiates a {@link UserDefinedFunction} assuming a JVM function with default constructor. */ public static UserDefinedFunction instantiateFunction( Class functionClass) { validateClass(functionClass, true); try { return functionClass.newInstance(); } catch (Exception e) { throw new ValidationException( String.format( "Cannot instantiate user-defined function class '%s'.", functionClass.getName()), e); } } /** Prepares a {@link UserDefinedFunction} instance for usage in the API. */ public static void prepareInstance(ReadableConfig config, UserDefinedFunction function) { validateClass(function.getClass(), false); cleanFunction(config, function); } /** * Validates a {@link UserDefinedFunction} class for usage in the API. * *

Note: This is an initial validation to indicate common errors early. The concrete * signature validation happens in the code generation when the actual {@link DataType}s for * arguments and result are known. */ public static void validateClass(Class functionClass) { validateClass(functionClass, true); } /** * Validates a {@link UserDefinedFunction} class for usage in the runtime. * *

Note: This is for the final validation when actual {@link DataType}s for arguments and * result are known. */ public static void validateClassForRuntime( Class functionClass, String methodName, Class[] argumentClasses, Class outputClass, String functionName) { final List methods = ExtractionUtils.collectMethods(functionClass, methodName); // verifies regular JVM calling semantics final boolean isMatching = methods.stream() .anyMatch( method -> ExtractionUtils.isInvokable(method, argumentClasses) && ExtractionUtils.isAssignable( outputClass, method.getReturnType(), true)); if (!isMatching) { throw new ValidationException( String.format( "Could not find an implementation method '%s' in class '%s' for function '%s' that " + "matches the following signature:\n%s", methodName, functionClass.getName(), functionName, ExtractionUtils.createMethodSignatureString( methodName, argumentClasses, outputClass))); } } /** * Creates the runtime implementation of a {@link FunctionDefinition} as an instance of {@link * UserDefinedFunction}. * * @see SpecializedFunction */ public static UserDefinedFunction createSpecializedFunction( String functionName, FunctionDefinition definition, CallContext callContext, ClassLoader builtInClassLoader, @Nullable ReadableConfig configuration) { if (definition instanceof SpecializedFunction) { final SpecializedFunction specialized = (SpecializedFunction) definition; final SpecializedContext specializedContext = new SpecializedContext() { @Override public CallContext getCallContext() { return callContext; } @Override public ReadableConfig getConfiguration() { if (configuration == null) { throw new TableException( "Access to configuration is currently not supported for all kinds of calls."); } return configuration; } @Override public ClassLoader getBuiltInClassLoader() { return builtInClassLoader; } }; final UserDefinedFunction udf = specialized.specialize(specializedContext); checkState( udf.getKind() == definition.getKind(), "Function kind must not change during specialization."); return udf; } else if (definition instanceof UserDefinedFunction) { return (UserDefinedFunction) definition; } else { throw new TableException( String.format( "Could not find a runtime implementation for function definition '%s'.", functionName)); } } /** Validates a {@link UserDefinedFunction} class for usage in the API. */ private static void validateClass( Class functionClass, boolean requiresDefaultConstructor) { if (TableFunction.class.isAssignableFrom(functionClass)) { validateNotSingleton(functionClass); } validateInstantiation(functionClass, requiresDefaultConstructor); validateImplementationMethods(functionClass); } /** * Check whether this is a Scala object. Using Scala objects can lead to concurrency issues, * e.g., due to a shared collector. */ private static void validateNotSingleton(Class clazz) { if (Arrays.stream(clazz.getFields()).anyMatch(f -> f.getName().equals("MODULE$"))) { throw new ValidationException( String.format( "Function implemented by class %s is a Scala object. This is forbidden because of concurrency" + " problems when using them.", clazz.getName())); } } /** * Validates the implementation methods such as {@link #SCALAR_EVAL} or {@link * #AGGREGATE_ACCUMULATE} depending on the {@link UserDefinedFunction} subclass. * *

This method must be kept in sync with the code generation requirements and the individual * docs of each function. */ private static void validateImplementationMethods( Class functionClass) { if (ScalarFunction.class.isAssignableFrom(functionClass)) { validateImplementationMethod(functionClass, false, false, SCALAR_EVAL); } else if (TableFunction.class.isAssignableFrom(functionClass)) { validateImplementationMethod(functionClass, true, false, TABLE_EVAL); } else if (AsyncTableFunction.class.isAssignableFrom(functionClass)) { validateImplementationMethod(functionClass, true, false, ASYNC_TABLE_EVAL); } else if (AggregateFunction.class.isAssignableFrom(functionClass)) { validateImplementationMethod(functionClass, true, false, AGGREGATE_ACCUMULATE); validateImplementationMethod(functionClass, true, true, AGGREGATE_RETRACT); validateImplementationMethod(functionClass, true, true, AGGREGATE_MERGE); } else if (TableAggregateFunction.class.isAssignableFrom(functionClass)) { validateImplementationMethod(functionClass, true, false, TABLE_AGGREGATE_ACCUMULATE); validateImplementationMethod(functionClass, true, true, TABLE_AGGREGATE_RETRACT); validateImplementationMethod(functionClass, true, true, TABLE_AGGREGATE_MERGE); validateImplementationMethod( functionClass, true, false, TABLE_AGGREGATE_EMIT, TABLE_AGGREGATE_EMIT_RETRACT); } } /** Validates an implementation method such as {@code eval()} or {@code accumulate()}. */ private static void validateImplementationMethod( Class clazz, boolean rejectStatic, boolean isOptional, String... methodNameOptions) { final Set nameSet = new HashSet<>(Arrays.asList(methodNameOptions)); final List methods = getAllDeclaredMethods(clazz); boolean found = false; for (Method method : methods) { if (!nameSet.contains(method.getName())) { continue; } found = true; final int modifier = method.getModifiers(); if (!Modifier.isPublic(modifier)) { throw new ValidationException( String.format( "Method '%s' of function class '%s' is not public.", method.getName(), clazz.getName())); } if (Modifier.isAbstract(modifier)) { throw new ValidationException( String.format( "Method '%s' of function class '%s' must not be abstract.", method.getName(), clazz.getName())); } if (rejectStatic && Modifier.isStatic(modifier)) { throw new ValidationException( String.format( "Method '%s' of function class '%s' must not be static.", method.getName(), clazz.getName())); } } if (!found && !isOptional) { throw new ValidationException( String.format( "Function class '%s' does not implement a method named %s.", clazz.getName(), nameSet.stream() .map(s -> "'" + s + "'") .collect(Collectors.joining(" or ")))); } } /** Checks if a user-defined function can be easily instantiated. */ private static void validateInstantiation(Class clazz, boolean requiresDefaultConstructor) { if (!InstantiationUtil.isPublic(clazz)) { throw new ValidationException( String.format("Function class '%s' is not public.", clazz.getName())); } else if (!InstantiationUtil.isProperClass(clazz)) { throw new ValidationException( String.format( "Function class '%s' is not a proper class. It is either abstract, an interface, or a primitive type.", clazz.getName())); } else if (requiresDefaultConstructor && !InstantiationUtil.hasPublicNullaryConstructor(clazz)) { throw new ValidationException( String.format( "Function class '%s' must have a public default constructor.", clazz.getName())); } } /** * Modifies a function instance by removing any reference to outer classes. This enables * non-static inner function classes. */ private static void cleanFunction(ReadableConfig config, UserDefinedFunction function) { final ClosureCleanerLevel level = config.get(PipelineOptions.CLOSURE_CLEANER_LEVEL); try { ClosureCleaner.clean(function, level, true); } catch (Throwable t) { throw new ValidationException( String.format( "Function class '%s' is not serializable. Make sure that the class is self-contained " + "(i.e. no references to outer classes) and all inner fields are serializable as well.", function.getClass()), t); } } private UserDefinedFunctionHelper() { // no instantiation } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy