All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonCalc Maven / Gradle / Ivy

Go to download

There is a newer version: 1.13.6
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.planner.plan.nodes.exec.common;

import org.apache.flink.api.dag.Transformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.functions.python.PythonFunctionInfo;
import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
import org.apache.flink.table.planner.plan.nodes.exec.SingleTransformationTranslator;
import org.apache.flink.table.planner.plan.nodes.exec.utils.CommonPythonUtil;
import org.apache.flink.table.planner.plan.utils.PythonUtil;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;

import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;

import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.stream.Collectors;

/** Base class for exec Python Calc. */
public abstract class CommonExecPythonCalc extends ExecNodeBase
        implements SingleTransformationTranslator {

    private static final String PYTHON_SCALAR_FUNCTION_OPERATOR_NAME =
            "org.apache.flink.table.runtime.operators.python.scalar."
                    + "RowDataPythonScalarFunctionOperator";

    private static final String ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME =
            "org.apache.flink.table.runtime.operators.python.scalar.arrow."
                    + "RowDataArrowPythonScalarFunctionOperator";

    private final RexProgram calcProgram;

    public CommonExecPythonCalc(
            RexProgram calcProgram,
            InputProperty inputProperty,
            RowType outputType,
            String description) {
        super(Collections.singletonList(inputProperty), outputType, description);
        this.calcProgram = calcProgram;
    }

    @SuppressWarnings("unchecked")
    @Override
    protected Transformation translateToPlanInternal(PlannerBase planner) {
        final ExecEdge inputEdge = getInputEdges().get(0);
        final Transformation inputTransform =
                (Transformation) inputEdge.translateToPlan(planner);
        final Configuration config =
                CommonPythonUtil.getMergedConfig(planner.getExecEnv(), planner.getTableConfig());
        OneInputTransformation ret =
                createPythonOneInputTransformation(
                        inputTransform, calcProgram, getDescription(), config);
        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(config)) {
            ret.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
        }
        return ret;
    }

    private OneInputTransformation createPythonOneInputTransformation(
            Transformation inputTransform,
            RexProgram calcProgram,
            String name,
            Configuration config) {
        List pythonRexCalls =
                calcProgram.getProjectList().stream()
                        .map(calcProgram::expandLocalRef)
                        .filter(x -> x instanceof RexCall)
                        .map(x -> (RexCall) x)
                        .collect(Collectors.toList());

        List forwardedFields =
                calcProgram.getProjectList().stream()
                        .map(calcProgram::expandLocalRef)
                        .filter(x -> x instanceof RexInputRef)
                        .map(x -> ((RexInputRef) x).getIndex())
                        .collect(Collectors.toList());

        Tuple2 extractResult =
                extractPythonScalarFunctionInfos(pythonRexCalls);
        int[] pythonUdfInputOffsets = extractResult.f0;
        PythonFunctionInfo[] pythonFunctionInfos = extractResult.f1;
        LogicalType[] inputLogicalTypes =
                ((InternalTypeInfo) inputTransform.getOutputType()).toRowFieldTypes();
        InternalTypeInfo pythonOperatorInputTypeInfo =
                (InternalTypeInfo) inputTransform.getOutputType();

        List forwardedFieldsLogicalTypes =
                forwardedFields.stream()
                        .map(i -> inputLogicalTypes[i])
                        .collect(Collectors.toList());
        List pythonCallLogicalTypes =
                pythonRexCalls.stream()
                        .map(node -> FlinkTypeFactory.toLogicalType(node.getType()))
                        .collect(Collectors.toList());
        List fieldsLogicalTypes = new ArrayList<>();
        fieldsLogicalTypes.addAll(forwardedFieldsLogicalTypes);
        fieldsLogicalTypes.addAll(pythonCallLogicalTypes);
        InternalTypeInfo pythonOperatorResultTyeInfo =
                InternalTypeInfo.ofFields(fieldsLogicalTypes.toArray(new LogicalType[0]));
        OneInputStreamOperator pythonOperator =
                getPythonScalarFunctionOperator(
                        config,
                        pythonOperatorInputTypeInfo,
                        pythonOperatorResultTyeInfo,
                        pythonUdfInputOffsets,
                        pythonFunctionInfos,
                        forwardedFields.stream().mapToInt(x -> x).toArray(),
                        calcProgram.getExprList().stream()
                                .anyMatch(
                                        x ->
                                                PythonUtil.containsPythonCall(
                                                        x, PythonFunctionKind.PANDAS)));

        return new OneInputTransformation<>(
                inputTransform,
                name,
                pythonOperator,
                pythonOperatorResultTyeInfo,
                inputTransform.getParallelism());
    }

    private Tuple2 extractPythonScalarFunctionInfos(
            List rexCalls) {
        LinkedHashMap inputNodes = new LinkedHashMap<>();
        PythonFunctionInfo[] pythonFunctionInfos =
                rexCalls.stream()
                        .map(x -> CommonPythonUtil.createPythonFunctionInfo(x, inputNodes))
                        .collect(Collectors.toList())
                        .toArray(new PythonFunctionInfo[rexCalls.size()]);

        int[] udfInputOffsets =
                inputNodes.keySet().stream()
                        .map(
                                x -> {
                                    if (x instanceof RexInputRef) {
                                        return ((RexInputRef) x).getIndex();
                                    } else if (x instanceof RexFieldAccess) {
                                        return ((RexFieldAccess) x).getField().getIndex();
                                    }
                                    return null;
                                })
                        .mapToInt(i -> i)
                        .toArray();
        return Tuple2.of(udfInputOffsets, pythonFunctionInfos);
    }

    @SuppressWarnings("unchecked")
    private OneInputStreamOperator getPythonScalarFunctionOperator(
            Configuration config,
            InternalTypeInfo inputRowTypeInfo,
            InternalTypeInfo outputRowTypeInfo,
            int[] udfInputOffsets,
            PythonFunctionInfo[] pythonFunctionInfos,
            int[] forwardedFields,
            boolean isArrow) {
        Class clazz;
        if (isArrow) {
            clazz = CommonPythonUtil.loadClass(ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME);
        } else {
            clazz = CommonPythonUtil.loadClass(PYTHON_SCALAR_FUNCTION_OPERATOR_NAME);
        }

        try {
            Constructor ctor =
                    clazz.getConstructor(
                            Configuration.class,
                            PythonFunctionInfo[].class,
                            RowType.class,
                            RowType.class,
                            int[].class,
                            int[].class);
            return (OneInputStreamOperator)
                    ctor.newInstance(
                            config,
                            pythonFunctionInfos,
                            inputRowTypeInfo.toRowType(),
                            outputRowTypeInfo.toRowType(),
                            udfInputOffsets,
                            forwardedFields);
        } catch (Exception e) {
            throw new TableException("Python Scalar Function Operator constructed failed.", e);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy