All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.python.util.PythonConfigUtil Maven / Gradle / Ivy

There is a newer version: 2.0-preview1
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.python.util;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.cache.DistributedCache;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.python.PythonConfig;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.graph.StreamGraphGenerator;
import org.apache.flink.streaming.api.graph.TransformationTranslator;
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.api.operators.python.AbstractPythonFunctionOperator;
import org.apache.flink.streaming.api.operators.python.DataStreamPythonFunctionOperator;
import org.apache.flink.streaming.api.operators.python.process.AbstractExternalOneInputPythonFunctionOperator;
import org.apache.flink.streaming.api.transformations.AbstractMultipleInputTransformation;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.streaming.api.transformations.PartitionTransformation;
import org.apache.flink.streaming.api.transformations.SideOutputTransformation;
import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
import org.apache.flink.streaming.api.transformations.python.DelegateOperatorTransformation;
import org.apache.flink.streaming.api.transformations.python.PythonBroadcastStateTransformation;
import org.apache.flink.streaming.api.transformations.python.PythonKeyedBroadcastStateTransformation;
import org.apache.flink.streaming.api.utils.ByteArrayWrapper;
import org.apache.flink.streaming.api.utils.ByteArrayWrapperSerializer;
import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
import org.apache.flink.streaming.runtime.translators.python.PythonBroadcastStateTransformationTranslator;
import org.apache.flink.streaming.runtime.translators.python.PythonKeyedBroadcastStateTransformationTranslator;
import org.apache.flink.util.OutputTag;

import org.apache.flink.shaded.guava30.com.google.common.collect.Iterables;
import org.apache.flink.shaded.guava30.com.google.common.collect.Queues;
import org.apache.flink.shaded.guava30.com.google.common.collect.Sets;

import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;

/** A Util class to handle the configurations of Python jobs. */
public class PythonConfigUtil {

    public static final String KEYED_STREAM_VALUE_OPERATOR_NAME = "_keyed_stream_values_operator";
    public static final String STREAM_KEY_BY_MAP_OPERATOR_NAME = "_stream_key_by_map_operator";
    public static final String STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME =
            "_partition_custom_map_operator";

    public static Configuration getEnvironmentConfig(StreamExecutionEnvironment env) {
        return (Configuration) env.getConfiguration();
    }

    public static void configPythonOperator(StreamExecutionEnvironment env) throws Exception {
        final Configuration config =
                extractPythonConfiguration(env.getCachedFiles(), env.getConfiguration());

        for (Transformation transformation : env.getTransformations()) {
            alignTransformation(transformation);

            if (isPythonOperator(transformation)) {
                // declare the use case of managed memory
                transformation.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);

                AbstractPythonFunctionOperator pythonFunctionOperator =
                        getPythonOperator(transformation);
                if (pythonFunctionOperator != null) {
                    pythonFunctionOperator.getConfiguration().addAll(config);
                }
            }
        }

        processSideOutput(env.getTransformations());
        registerPythonBroadcastTransformationTranslator();
    }

    /** Extract the configurations which is used in the Python operators. */
    public static Configuration extractPythonConfiguration(
            List> cachedFiles,
            ReadableConfig config) {
        final Configuration pythonDependencyConfig =
                PythonDependencyUtils.configurePythonDependencies(cachedFiles, config);
        final PythonConfig pythonConfig = new PythonConfig(config, pythonDependencyConfig);
        return pythonConfig.toConfiguration();
    }

    /**
     * Process {@link SideOutputTransformation}s, set the {@link OutputTag}s into the Python
     * corresponding operator to make it aware of the {@link OutputTag}s.
     */
    private static void processSideOutput(List> transformations) {
        final Set> visitedTransforms = Sets.newIdentityHashSet();
        final Queue> queue = Queues.newArrayDeque(transformations);

        while (!queue.isEmpty()) {
            Transformation transform = queue.poll();
            visitedTransforms.add(transform);

            if (transform instanceof SideOutputTransformation) {
                final SideOutputTransformation sideTransform =
                        (SideOutputTransformation) transform;
                final Transformation upTransform =
                        Iterables.getOnlyElement(sideTransform.getInputs());
                if (PythonConfigUtil.isPythonDataStreamOperator(upTransform)) {
                    final DataStreamPythonFunctionOperator upOperator =
                            (DataStreamPythonFunctionOperator)
                                    ((SimpleOperatorFactory) getOperatorFactory(upTransform))
                                            .getOperator();
                    upOperator.addSideOutputTags(
                            Collections.singletonList(sideTransform.getOutputTag()));
                }
            }

            for (Transformation upTransform : transform.getInputs()) {
                if (!visitedTransforms.contains(upTransform)) {
                    queue.add(upTransform);
                }
            }
        }
    }

    public static StreamOperatorFactory getOperatorFactory(Transformation transform) {
        if (transform instanceof OneInputTransformation) {
            return ((OneInputTransformation) transform).getOperatorFactory();
        } else if (transform instanceof TwoInputTransformation) {
            return ((TwoInputTransformation) transform).getOperatorFactory();
        } else if (transform instanceof AbstractMultipleInputTransformation) {
            return ((AbstractMultipleInputTransformation) transform).getOperatorFactory();
        } else if (transform instanceof DelegateOperatorTransformation) {
            return ((DelegateOperatorTransformation) transform).getOperatorFactory();
        } else {
            return null;
        }
    }

    /**
     * Configure the {@link AbstractExternalOneInputPythonFunctionOperator} to be chained with the
     * upstream/downstream operator by setting their parallelism, slot sharing group, co-location
     * group to be the same, and applying a {@link ForwardPartitioner}. 1. operator with name
     * "_keyed_stream_values_operator" should align with its downstream operator. 2. operator with
     * name "_stream_key_by_map_operator" should align with its upstream operator.
     */
    private static void alignTransformation(Transformation transformation)
            throws NoSuchFieldException, IllegalAccessException {
        String transformName = transformation.getName();
        if (transformation.getInputs().isEmpty()) {
            return;
        }
        Transformation inputTransformation = transformation.getInputs().get(0);
        String inputTransformName = inputTransformation.getName();
        if (inputTransformName.equals(KEYED_STREAM_VALUE_OPERATOR_NAME)) {
            chainTransformation(inputTransformation, transformation);
            configForwardPartitioner(inputTransformation, transformation);
        }
        if (transformName.equals(STREAM_KEY_BY_MAP_OPERATOR_NAME)
                || transformName.equals(STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME)) {
            chainTransformation(transformation, inputTransformation);
            configForwardPartitioner(inputTransformation, transformation);
        }
    }

    private static void chainTransformation(
            Transformation firstTransformation, Transformation secondTransformation) {
        secondTransformation
                .getSlotSharingGroup()
                .ifPresent(firstTransformation::setSlotSharingGroup);
        firstTransformation.setCoLocationGroupKey(secondTransformation.getCoLocationGroupKey());
        firstTransformation.setParallelism(secondTransformation.getParallelism(), false);
    }

    private static void configForwardPartitioner(
            Transformation upTransformation, Transformation transformation)
            throws IllegalAccessException, NoSuchFieldException {
        // set ForwardPartitioner
        PartitionTransformation partitionTransform =
                new PartitionTransformation<>(upTransformation, new ForwardPartitioner<>());
        Field inputTransformationField = transformation.getClass().getDeclaredField("input");
        inputTransformationField.setAccessible(true);
        inputTransformationField.set(transformation, partitionTransform);
    }

    private static AbstractPythonFunctionOperator getPythonOperator(
            Transformation transformation) {
        StreamOperatorFactory operatorFactory = null;
        if (transformation instanceof OneInputTransformation) {
            operatorFactory = ((OneInputTransformation) transformation).getOperatorFactory();
        } else if (transformation instanceof TwoInputTransformation) {
            operatorFactory =
                    ((TwoInputTransformation) transformation).getOperatorFactory();
        } else if (transformation instanceof AbstractMultipleInputTransformation) {
            operatorFactory =
                    ((AbstractMultipleInputTransformation) transformation).getOperatorFactory();
        } else if (transformation instanceof DelegateOperatorTransformation) {
            operatorFactory =
                    ((DelegateOperatorTransformation) transformation).getOperatorFactory();
        }

        if (operatorFactory instanceof SimpleOperatorFactory
                && ((SimpleOperatorFactory) operatorFactory).getOperator()
                        instanceof AbstractPythonFunctionOperator) {
            return (AbstractPythonFunctionOperator)
                    ((SimpleOperatorFactory) operatorFactory).getOperator();
        }

        return null;
    }

    private static boolean isPythonOperator(Transformation transform) {
        if (transform instanceof OneInputTransformation) {
            return isPythonOperator(
                    ((OneInputTransformation) transform).getOperatorFactory());
        } else if (transform instanceof TwoInputTransformation) {
            return isPythonOperator(
                    ((TwoInputTransformation) transform).getOperatorFactory());
        } else if (transform instanceof AbstractMultipleInputTransformation) {
            return isPythonOperator(
                    ((AbstractMultipleInputTransformation) transform).getOperatorFactory());
        } else if (transform instanceof PythonBroadcastStateTransformation
                || transform instanceof PythonKeyedBroadcastStateTransformation) {
            return true;
        } else {
            return false;
        }
    }

    private static boolean isPythonOperator(StreamOperatorFactory streamOperatorFactory) {
        if (streamOperatorFactory instanceof SimpleOperatorFactory) {
            return ((SimpleOperatorFactory) streamOperatorFactory).getOperator()
                    instanceof AbstractPythonFunctionOperator;
        } else {
            return false;
        }
    }

    public static boolean isPythonDataStreamOperator(Transformation transform) {
        if (transform instanceof OneInputTransformation) {
            return isPythonDataStreamOperator(
                    ((OneInputTransformation) transform).getOperatorFactory());
        } else if (transform instanceof TwoInputTransformation) {
            return isPythonDataStreamOperator(
                    ((TwoInputTransformation) transform).getOperatorFactory());
        } else if (transform instanceof PythonBroadcastStateTransformation
                || transform instanceof PythonKeyedBroadcastStateTransformation) {
            return true;
        } else {
            return false;
        }
    }

    private static boolean isPythonDataStreamOperator(
            StreamOperatorFactory streamOperatorFactory) {
        if (streamOperatorFactory instanceof SimpleOperatorFactory) {
            return ((SimpleOperatorFactory) streamOperatorFactory).getOperator()
                    instanceof DataStreamPythonFunctionOperator;
        } else {
            return false;
        }
    }

    public static void setPartitionCustomOperatorNumPartitions(
            List> transformations) {
        // Update the numPartitions of PartitionCustomOperator after aligned all operators.
        final Set> alreadyTransformed = Sets.newIdentityHashSet();
        final Queue> toTransformQueue = Queues.newArrayDeque(transformations);
        while (!toTransformQueue.isEmpty()) {
            final Transformation transformation = toTransformQueue.poll();
            if (!alreadyTransformed.contains(transformation)
                    && !(transformation instanceof PartitionTransformation)) {
                alreadyTransformed.add(transformation);

                getNonPartitionTransformationInput(transformation)
                        .ifPresent(
                                input -> {
                                    AbstractPythonFunctionOperator pythonFunctionOperator =
                                            getPythonOperator(input);
                                    if (pythonFunctionOperator
                                            instanceof DataStreamPythonFunctionOperator) {
                                        DataStreamPythonFunctionOperator
                                                pythonDataStreamFunctionOperator =
                                                        (DataStreamPythonFunctionOperator)
                                                                pythonFunctionOperator;
                                        pythonDataStreamFunctionOperator.setNumPartitions(
                                                transformation.getParallelism());
                                    }
                                });

                toTransformQueue.addAll(transformation.getInputs());
            }
        }
    }

    private static Optional> getNonPartitionTransformationInput(
            Transformation transformation) {
        if (transformation.getInputs().size() != 1) {
            return Optional.empty();
        }

        final Transformation inputTransformation = transformation.getInputs().get(0);
        if (inputTransformation instanceof PartitionTransformation) {
            return getNonPartitionTransformationInput(inputTransformation);
        } else {
            return Optional.of(inputTransformation);
        }
    }

    public static List>
            convertStateNamesToStateDescriptors(String[] names) {
        List> descriptors =
                new ArrayList<>(names.length);
        TypeSerializer byteArraySerializer =
                PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO.createSerializer(
                        new ExecutionConfig());
        for (String name : names) {
            descriptors.add(
                    new MapStateDescriptor<>(
                            name, ByteArrayWrapperSerializer.INSTANCE, byteArraySerializer));
        }
        return descriptors;
    }

    @SuppressWarnings("rawtypes,unchecked")
    public static void registerPythonBroadcastTransformationTranslator() throws Exception {
        final Field translatorMapField =
                StreamGraphGenerator.class.getDeclaredField("translatorMap");
        translatorMapField.setAccessible(true);
        final Map, TransformationTranslator> translatorMap =
                (Map, TransformationTranslator>)
                        translatorMapField.get(null);
        final Field underlyingMapField = translatorMap.getClass().getDeclaredField("m");
        underlyingMapField.setAccessible(true);
        final Map, TransformationTranslator> underlyingMap =
                (Map, TransformationTranslator>)
                        underlyingMapField.get(translatorMap);

        underlyingMap.put(
                PythonBroadcastStateTransformation.class,
                new PythonBroadcastStateTransformationTranslator<>());
        underlyingMap.put(
                PythonKeyedBroadcastStateTransformation.class,
                new PythonKeyedBroadcastStateTransformationTranslator<>());
    }

    @SuppressWarnings("rawtypes")
    public static SingleOutputStreamOperator createSingleOutputStreamOperator(
            StreamExecutionEnvironment env, Transformation transformation) throws Exception {
        Constructor constructor =
                SingleOutputStreamOperator.class.getDeclaredConstructor(
                        StreamExecutionEnvironment.class, Transformation.class);
        constructor.setAccessible(true);
        return constructor.newInstance(env, transformation);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy