All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.client.python.PythonFunctionFactory Maven / Gradle / Ivy

There is a newer version: 2.0-preview1
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.client.python;

import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.python.util.PythonDependencyUtils;
import org.apache.flink.table.api.TableConfig;
import org.apache.flink.table.functions.python.PythonFunction;
import org.apache.flink.util.FileUtils;
import org.apache.flink.util.concurrent.ExecutorThreadFactory;

import org.apache.flink.shaded.guava31.com.google.common.cache.CacheBuilder;
import org.apache.flink.shaded.guava31.com.google.common.cache.CacheLoader;
import org.apache.flink.shaded.guava31.com.google.common.cache.LoadingCache;
import org.apache.flink.shaded.guava31.com.google.common.cache.RemovalListener;

import py4j.GatewayServer;

import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import static org.apache.flink.client.python.PythonEnvUtils.CHECK_INTERVAL;
import static org.apache.flink.client.python.PythonEnvUtils.PythonProcessShutdownHook;
import static org.apache.flink.client.python.PythonEnvUtils.TIMEOUT_MILLIS;
import static org.apache.flink.client.python.PythonEnvUtils.getGatewayServer;
import static org.apache.flink.client.python.PythonEnvUtils.launchPy4jPythonClient;
import static org.apache.flink.client.python.PythonEnvUtils.maxConcurrentPythonFunctionFactories;
import static org.apache.flink.client.python.PythonEnvUtils.shutdownPythonProcess;
import static org.apache.flink.client.python.PythonEnvUtils.startGatewayServer;

/** The factory which creates the PythonFunction objects from given module name and object name. */
public interface PythonFunctionFactory {

    ScheduledExecutorService CACHE_CLEANUP_EXECUTOR_SERVICE =
            Executors.newSingleThreadScheduledExecutor(
                    new ExecutorThreadFactory("PythonFunctionFactory"));

    AtomicReference CACHE_CLEANUP_EXECUTOR_SERVICE_STARTED = new AtomicReference<>(false);

    LoadingCache PYTHON_FUNCTION_FACTORY_CACHE =
            CacheBuilder.newBuilder()
                    .expireAfterAccess(1, TimeUnit.MINUTES)
                    .maximumSize(maxConcurrentPythonFunctionFactories)
                    .removalListener(
                            (RemovalListener)
                                    removalNotification -> {
                                        if (removalNotification.getValue() instanceof Closeable) {
                                            try {
                                                ((Closeable) removalNotification.getValue())
                                                        .close();
                                            } catch (IOException ignore) {
                                            }
                                        }
                                    })
                    .build(
                            new CacheLoader() {
                                @Override
                                public PythonFunctionFactory load(CacheKey cacheKey) {
                                    try {
                                        return createPythonFunctionFactory(cacheKey.config);
                                    } catch (Throwable t) {
                                        throw new RuntimeException(t);
                                    }
                                }
                            });

    /**
     * Returns PythonFunction according to moduleName and objectName.
     *
     * @param moduleName The module name of the Python UDF.
     * @param objectName The function name / class name of the Python UDF.
     * @return The PythonFunction object which represents the Python UDF.
     */
    PythonFunction getPythonFunction(String moduleName, String objectName);

    /**
     * Returns PythonFunction according to the fully qualified name of the Python UDF i.e
     * ${moduleName}.${functionName} or ${moduleName}.${className}.
     *
     * @param fullyQualifiedName The fully qualified name of the Python UDF.
     * @param config The configuration of python dependencies.
     * @param classLoader The classloader which is used to identify different jobs.
     * @return The PythonFunction object which represents the Python UDF.
     */
    static PythonFunction getPythonFunction(
            String fullyQualifiedName, ReadableConfig config, ClassLoader classLoader)
            throws ExecutionException {
        int splitIndex = fullyQualifiedName.lastIndexOf(".");
        if (splitIndex <= 0) {
            throw new IllegalArgumentException(
                    String.format("The fully qualified name is invalid: '%s'", fullyQualifiedName));
        }
        String moduleName = fullyQualifiedName.substring(0, splitIndex);
        String objectName = fullyQualifiedName.substring(splitIndex + 1);

        Configuration mergedConfig =
                new Configuration(
                        ExecutionEnvironment.getExecutionEnvironment().getConfiguration());
        if (config instanceof TableConfig) {
            PythonDependencyUtils.merge(mergedConfig, ((TableConfig) config).getConfiguration());
        } else {
            PythonDependencyUtils.merge(mergedConfig, (Configuration) config);
        }
        PythonFunctionFactory pythonFunctionFactory =
                PYTHON_FUNCTION_FACTORY_CACHE.get(CacheKey.of(mergedConfig, classLoader));
        ensureCacheCleanupExecutorServiceStarted();
        return pythonFunctionFactory.getPythonFunction(moduleName, objectName);
    }

    static void ensureCacheCleanupExecutorServiceStarted() {
        if (CACHE_CLEANUP_EXECUTOR_SERVICE_STARTED.compareAndSet(false, true)) {
            CACHE_CLEANUP_EXECUTOR_SERVICE.scheduleAtFixedRate(
                    PYTHON_FUNCTION_FACTORY_CACHE::cleanUp, 1, 1, TimeUnit.MINUTES);
        }
    }

    static PythonFunctionFactory createPythonFunctionFactory(ReadableConfig config)
            throws ExecutionException, InterruptedException, IOException {
        Map entryPoint;
        PythonProcessShutdownHook shutdownHook = null;
        if (getGatewayServer() == null) {
            GatewayServer gatewayServer = null;
            Process pythonProcess = null;
            String tmpDir = null;
            try {
                gatewayServer = startGatewayServer();
                List commands = new ArrayList<>();
                commands.add("-m");
                commands.add("pyflink.pyflink_callback_server");
                tmpDir =
                        System.getProperty("java.io.tmpdir")
                                + File.separator
                                + "pyflink"
                                + File.separator
                                + UUID.randomUUID();
                pythonProcess =
                        launchPy4jPythonClient(
                                gatewayServer, config, commands, null, tmpDir, false);
                entryPoint = (Map) gatewayServer.getGateway().getEntryPoint();
                int i = 0;
                while (!entryPoint.containsKey("PythonFunctionFactory")) {
                    if (!pythonProcess.isAlive()) {
                        throw new RuntimeException("Python callback server start failed!");
                    }
                    try {
                        Thread.sleep(CHECK_INTERVAL);
                    } catch (InterruptedException e) {
                        throw new RuntimeException(
                                "Interrupted while waiting for the python process to start.", e);
                    }
                    i++;
                    if (i > TIMEOUT_MILLIS / CHECK_INTERVAL) {
                        throw new RuntimeException("Python callback server start failed!");
                    }
                }
            } catch (Throwable e) {
                try {
                    if (gatewayServer != null) {
                        gatewayServer.shutdown();
                    }
                } catch (Throwable e2) {
                    // ignore, do not swallow the origin exception.
                }
                try {
                    if (pythonProcess != null) {
                        shutdownPythonProcess(pythonProcess, TIMEOUT_MILLIS);
                    }
                } catch (Throwable e3) {
                    // ignore, do not swallow the origin exception.
                }
                if (tmpDir != null) {
                    FileUtils.deleteDirectoryQuietly(new File(tmpDir));
                }

                throw e;
            }
            shutdownHook = new PythonProcessShutdownHook(pythonProcess, gatewayServer, tmpDir);
            Runtime.getRuntime().addShutdownHook(shutdownHook);
        } else {
            entryPoint = (Map) getGatewayServer().getGateway().getEntryPoint();
        }

        return new PythonFunctionFactoryImpl(
                (PythonFunctionFactory) entryPoint.get("PythonFunctionFactory"), shutdownHook);
    }

    /** The cache key. It only considers the classloader. */
    class CacheKey {
        private final ReadableConfig config;
        private final ClassLoader classLoader;

        CacheKey(ReadableConfig config, ClassLoader classLoader) {
            this.config = config;
            this.classLoader = classLoader;
        }

        public static CacheKey of(ReadableConfig config, ClassLoader classLoader) {
            return new CacheKey(config, classLoader);
        }

        @Override
        public boolean equals(Object other) {
            return other instanceof CacheKey && this.classLoader == ((CacheKey) other).classLoader;
        }

        @Override
        public int hashCode() {
            return Objects.hashCode(classLoader);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy