All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.opensearch.ml.common.MLCommonsClassLoader Maven / Gradle / Ivy

The newest version!
/*
 * Copyright OpenSearch Contributors
 * SPDX-License-Identifier: Apache-2.0
 */

package org.opensearch.ml.common;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.ml.common.annotation.Connector;
import org.opensearch.ml.common.annotation.ExecuteInput;
import org.opensearch.ml.common.annotation.ExecuteOutput;
import org.opensearch.ml.common.annotation.InputDataSet;
import org.opensearch.ml.common.annotation.MLAlgoOutput;
import org.opensearch.ml.common.annotation.MLAlgoParameter;
import org.opensearch.ml.common.annotation.MLInput;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLOutputType;
import org.reflections.Reflections;

import lombok.extern.log4j.Log4j2;

@Log4j2
@SuppressWarnings("removal")
public class MLCommonsClassLoader {

    private static Map, Class> parameterClassMap = new HashMap<>();
    private static Map, Class> executeInputClassMap = new HashMap<>();
    private static Map, Class> executeOutputClassMap = new HashMap<>();
    private static Map, Class> mlInputClassMap = new HashMap<>();
    private static Map> connectorClassMap = new HashMap<>();

    static {
        try {
            AccessController.doPrivileged((PrivilegedExceptionAction) () -> {
                loadClassMapping();
                return null;
            });
        } catch (PrivilegedActionException e) {
            throw new RuntimeException("Can't load class mapping in ML commons", e);
        }
    }

    public static void loadClassMapping() {
        ClassLoader originalClassLoader = Thread.currentThread().getContextClassLoader();
        try {
            Thread.currentThread().setContextClassLoader(MLCommonsClassLoader.class.getClassLoader());
            loadMLAlgoParameterClassMapping();
            loadMLOutputClassMapping();
            loadMLInputDataSetClassMapping();
            loadExecuteInputClassMapping();
            loadExecuteOutputClassMapping();
            loadMLInputClassMapping();
            loadConnectorClassMapping();
        } finally {
            Thread.currentThread().setContextClassLoader(originalClassLoader);
        }
    }

    private static void loadConnectorClassMapping() {
        Reflections reflections = new Reflections("org.opensearch.ml.common.connector");
        Set> classes = reflections.getTypesAnnotatedWith(Connector.class);
        for (Class clazz : classes) {
            Connector connector = clazz.getAnnotation(Connector.class);
            if (connector != null) {
                String name = connector.value();
                if (name != null && name.length() > 0) {
                    connectorClassMap.put(name, clazz);
                }
            }
        }
    }

    /**
     * Load ML algorithm parameter and ML output class.
     */
    private static void loadMLAlgoParameterClassMapping() {
        Reflections reflections = new Reflections("org.opensearch.ml.common.input.parameter");

        Set> classes = reflections.getTypesAnnotatedWith(MLAlgoParameter.class);
        // Load ML algorithm parameter class
        for (Class clazz : classes) {
            MLAlgoParameter mlAlgoParameter = clazz.getAnnotation(MLAlgoParameter.class);
            if (mlAlgoParameter != null) {
                FunctionName[] algorithms = mlAlgoParameter.algorithms();
                if (algorithms != null && algorithms.length > 0) {
                    for (FunctionName name : algorithms) {
                        parameterClassMap.put(name, clazz);
                    }
                }
            }
        }

        // Load ML output class
        classes = reflections.getTypesAnnotatedWith(MLAlgoOutput.class);
        for (Class clazz : classes) {
            MLAlgoOutput mlAlgoOutput = clazz.getAnnotation(MLAlgoOutput.class);
            MLOutputType mlOutputType = mlAlgoOutput.value();
            if (mlOutputType != null) {
                parameterClassMap.put(mlOutputType, clazz);
            }
        }
    }

    /**
     * Load ML algorithm parameter and ML output class.
     */
    private static void loadMLOutputClassMapping() {
        Reflections reflections = new Reflections("org.opensearch.ml.common.output");

        Set> classes = reflections.getTypesAnnotatedWith(MLAlgoOutput.class);
        for (Class clazz : classes) {
            MLAlgoOutput mlAlgoOutput = clazz.getAnnotation(MLAlgoOutput.class);
            if (mlAlgoOutput != null) {
                MLOutputType mlOutputType = mlAlgoOutput.value();
                if (mlOutputType != null) {
                    parameterClassMap.put(mlOutputType, clazz);
                }
            }
        }
    }

    /**
     * Load ML input data set class
     */
    private static void loadMLInputDataSetClassMapping() {
        Reflections reflections = new Reflections("org.opensearch.ml.common.dataset");
        Set> classes = reflections.getTypesAnnotatedWith(InputDataSet.class);
        for (Class clazz : classes) {
            InputDataSet inputDataSet = clazz.getAnnotation(InputDataSet.class);
            if (inputDataSet != null) {
                MLInputDataType value = inputDataSet.value();
                if (value != null) {
                    parameterClassMap.put(value, clazz);
                }
            }
        }
    }

    /**
     * Load execute input output class.
     */
    private static void loadExecuteInputClassMapping() {
        Reflections reflections = new Reflections("org.opensearch.ml.common.input.execute");
        Set> classes = reflections.getTypesAnnotatedWith(ExecuteInput.class);
        for (Class clazz : classes) {
            ExecuteInput executeInput = clazz.getAnnotation(ExecuteInput.class);
            if (executeInput != null) {
                FunctionName[] algorithms = executeInput.algorithms();
                if (algorithms != null && algorithms.length > 0) {
                    for (FunctionName name : algorithms) {
                        executeInputClassMap.put(name, clazz);
                    }
                }
            }
        }
    }

    /**
     * Load execute input output class.
     */
    private static void loadExecuteOutputClassMapping() {
        Reflections reflections = new Reflections("org.opensearch.ml.common.output.execute");
        Set> classes = reflections.getTypesAnnotatedWith(ExecuteOutput.class);
        for (Class clazz : classes) {
            ExecuteOutput executeOutput = clazz.getAnnotation(ExecuteOutput.class);
            if (executeOutput != null) {
                FunctionName[] algorithms = executeOutput.algorithms();
                if (algorithms != null && algorithms.length > 0) {
                    for (FunctionName name : algorithms) {
                        executeOutputClassMap.put(name, clazz);
                    }
                }
            }
        }
    }

    private static void loadMLInputClassMapping() {
        Reflections reflections = new Reflections("org.opensearch.ml.common.input");
        Set> classes = reflections.getTypesAnnotatedWith(MLInput.class);
        for (Class clazz : classes) {
            MLInput mlInput = clazz.getAnnotation(MLInput.class);
            if (mlInput != null) {
                FunctionName[] algorithms = mlInput.functionNames();
                if (algorithms != null && algorithms.length > 0) {
                    for (FunctionName name : algorithms) {
                        mlInputClassMap.put(name, clazz);
                    }
                }
            }
        }
    }

    @SuppressWarnings("unchecked")
    public static , S, I extends Object> S initMLInstance(T type, I in, Class constructorParamClass) {
        return init(parameterClassMap, type, in, constructorParamClass);
    }

    @SuppressWarnings("unchecked")
    public static , S, I extends Object> S initExecuteInputInstance(T type, I in, Class constructorParamClass) {
        try {
            return init(executeInputClassMap, type, in, constructorParamClass);
        } catch (Exception e) {
            return init(mlInputClassMap, type, in, constructorParamClass);
        }
    }

    @SuppressWarnings("unchecked")
    public static , S, I extends Object> S initExecuteOutputInstance(T type, I in, Class constructorParamClass) {
        try {
            return init(executeOutputClassMap, type, in, constructorParamClass);
        } catch (Exception e) {
            if (in instanceof StreamInput) {
                try {
                    return (S) MLOutput.fromStream((StreamInput) in);
                } catch (IOException ex) {
                    throw new RuntimeException(ex);
                }
            }
            throw e;
        }
    }

    @SuppressWarnings("unchecked")
    private static  S init(Map> map, T type, I in, Class constructorParamClass) {
        Class clazz = map.get(type);
        if (clazz == null) {
            throw new IllegalArgumentException("Can't find class for type " + type);
        }
        try {
            Constructor constructor = clazz.getConstructor(constructorParamClass);
            return (S) constructor.newInstance(in);
        } catch (Exception e) {
            Throwable cause = e.getCause();
            if (cause instanceof MLException || cause instanceof IllegalArgumentException) {
                throw (RuntimeException) cause;
            } else {
                log.error("Failed to init instance for type " + type, e);
                return null;
            }
        }
    }

    public static boolean canInitMLInput(FunctionName functionName) {
        return mlInputClassMap.containsKey(functionName);
    }

    public static  S initConnector(String name, Object[] initArgs, Class... constructorParameterTypes) {
        return init(connectorClassMap, name, initArgs, constructorParameterTypes);
    }

    @SuppressWarnings("unchecked")
    public static , S> S initMLInput(T type, Object[] initArgs, Class... constructorParameterTypes) {
        return init(mlInputClassMap, type, initArgs, constructorParameterTypes);
    }

    private static  S init(Map> map, T type, Object[] initArgs, Class... constructorParameterTypes) {
        Class clazz = map.get(type);
        if (clazz == null) {
            throw new IllegalArgumentException("Can't find class for type " + type);
        }
        try {
            Constructor constructor = clazz.getConstructor(constructorParameterTypes);
            return (S) constructor.newInstance(initArgs);
        } catch (Exception e) {
            Throwable cause = e.getCause();
            if (cause instanceof MLException) {
                throw (MLException) cause;
            } else if (cause instanceof IllegalArgumentException) {
                throw (IllegalArgumentException) cause;
            } else {
                log.error("Failed to init instance for type " + type, e);
                return null;
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy