org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonCalc Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of flink-table-planner-blink_2.11 Show documentation
Show all versions of flink-table-planner-blink_2.11 Show documentation
This module bridges Table/SQL API and runtime. It contains
all resources that are required during pre-flight and runtime
phase. The content of this module is work-in-progress. It will
replace flink-table-planner once it is stable. See FLINK-11439
and FLIP-32 for more details.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.planner.plan.nodes.exec.common;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.functions.python.PythonFunctionInfo;
import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
import org.apache.flink.table.planner.plan.nodes.exec.SingleTransformationTranslator;
import org.apache.flink.table.planner.plan.nodes.exec.utils.CommonPythonUtil;
import org.apache.flink.table.planner.plan.utils.PythonUtil;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.stream.Collectors;
/** Base class for exec Python Calc. */
public abstract class CommonExecPythonCalc extends ExecNodeBase
implements SingleTransformationTranslator {
private static final String PYTHON_SCALAR_FUNCTION_OPERATOR_NAME =
"org.apache.flink.table.runtime.operators.python.scalar."
+ "RowDataPythonScalarFunctionOperator";
private static final String ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME =
"org.apache.flink.table.runtime.operators.python.scalar.arrow."
+ "RowDataArrowPythonScalarFunctionOperator";
private final RexProgram calcProgram;
public CommonExecPythonCalc(
RexProgram calcProgram,
InputProperty inputProperty,
RowType outputType,
String description) {
super(Collections.singletonList(inputProperty), outputType, description);
this.calcProgram = calcProgram;
}
@SuppressWarnings("unchecked")
@Override
protected Transformation translateToPlanInternal(PlannerBase planner) {
final ExecEdge inputEdge = getInputEdges().get(0);
final Transformation inputTransform =
(Transformation) inputEdge.translateToPlan(planner);
final Configuration config =
CommonPythonUtil.getMergedConfig(planner.getExecEnv(), planner.getTableConfig());
OneInputTransformation ret =
createPythonOneInputTransformation(
inputTransform, calcProgram, getDescription(), config);
if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(config)) {
ret.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
}
return ret;
}
private OneInputTransformation createPythonOneInputTransformation(
Transformation inputTransform,
RexProgram calcProgram,
String name,
Configuration config) {
List pythonRexCalls =
calcProgram.getProjectList().stream()
.map(calcProgram::expandLocalRef)
.filter(x -> x instanceof RexCall)
.map(x -> (RexCall) x)
.collect(Collectors.toList());
List forwardedFields =
calcProgram.getProjectList().stream()
.map(calcProgram::expandLocalRef)
.filter(x -> x instanceof RexInputRef)
.map(x -> ((RexInputRef) x).getIndex())
.collect(Collectors.toList());
Tuple2 extractResult =
extractPythonScalarFunctionInfos(pythonRexCalls);
int[] pythonUdfInputOffsets = extractResult.f0;
PythonFunctionInfo[] pythonFunctionInfos = extractResult.f1;
LogicalType[] inputLogicalTypes =
((InternalTypeInfo) inputTransform.getOutputType()).toRowFieldTypes();
InternalTypeInfo pythonOperatorInputTypeInfo =
(InternalTypeInfo) inputTransform.getOutputType();
List forwardedFieldsLogicalTypes =
forwardedFields.stream()
.map(i -> inputLogicalTypes[i])
.collect(Collectors.toList());
List pythonCallLogicalTypes =
pythonRexCalls.stream()
.map(node -> FlinkTypeFactory.toLogicalType(node.getType()))
.collect(Collectors.toList());
List fieldsLogicalTypes = new ArrayList<>();
fieldsLogicalTypes.addAll(forwardedFieldsLogicalTypes);
fieldsLogicalTypes.addAll(pythonCallLogicalTypes);
InternalTypeInfo pythonOperatorResultTyeInfo =
InternalTypeInfo.ofFields(fieldsLogicalTypes.toArray(new LogicalType[0]));
OneInputStreamOperator pythonOperator =
getPythonScalarFunctionOperator(
config,
pythonOperatorInputTypeInfo,
pythonOperatorResultTyeInfo,
pythonUdfInputOffsets,
pythonFunctionInfos,
forwardedFields.stream().mapToInt(x -> x).toArray(),
calcProgram.getExprList().stream()
.anyMatch(
x ->
PythonUtil.containsPythonCall(
x, PythonFunctionKind.PANDAS)));
return new OneInputTransformation<>(
inputTransform,
name,
pythonOperator,
pythonOperatorResultTyeInfo,
inputTransform.getParallelism());
}
private Tuple2 extractPythonScalarFunctionInfos(
List rexCalls) {
LinkedHashMap inputNodes = new LinkedHashMap<>();
PythonFunctionInfo[] pythonFunctionInfos =
rexCalls.stream()
.map(x -> CommonPythonUtil.createPythonFunctionInfo(x, inputNodes))
.collect(Collectors.toList())
.toArray(new PythonFunctionInfo[rexCalls.size()]);
int[] udfInputOffsets =
inputNodes.keySet().stream()
.map(
x -> {
if (x instanceof RexInputRef) {
return ((RexInputRef) x).getIndex();
} else if (x instanceof RexFieldAccess) {
return ((RexFieldAccess) x).getField().getIndex();
}
return null;
})
.mapToInt(i -> i)
.toArray();
return Tuple2.of(udfInputOffsets, pythonFunctionInfos);
}
@SuppressWarnings("unchecked")
private OneInputStreamOperator getPythonScalarFunctionOperator(
Configuration config,
InternalTypeInfo inputRowTypeInfo,
InternalTypeInfo outputRowTypeInfo,
int[] udfInputOffsets,
PythonFunctionInfo[] pythonFunctionInfos,
int[] forwardedFields,
boolean isArrow) {
Class clazz;
if (isArrow) {
clazz = CommonPythonUtil.loadClass(ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME);
} else {
clazz = CommonPythonUtil.loadClass(PYTHON_SCALAR_FUNCTION_OPERATOR_NAME);
}
try {
Constructor ctor =
clazz.getConstructor(
Configuration.class,
PythonFunctionInfo[].class,
RowType.class,
RowType.class,
int[].class,
int[].class);
return (OneInputStreamOperator)
ctor.newInstance(
config,
pythonFunctionInfos,
inputRowTypeInfo.toRowType(),
outputRowTypeInfo.toRowType(),
udfInputOffsets,
forwardedFields);
} catch (Exception e) {
throw new TableException("Python Scalar Function Operator constructed failed.", e);
}
}
}