All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.translate.ServingTranslatorFactory Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.translate;

import ai.djl.Application;
import ai.djl.Model;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.Pair;
import ai.djl.util.Utils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.Constructor;
import java.lang.reflect.Type;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Collections;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

/** A {@link TranslatorFactory} that creates a generic {@link Translator}. */
public class ServingTranslatorFactory implements TranslatorFactory {

    private static final Logger logger = LoggerFactory.getLogger(ServingTranslatorFactory.class);

    /** {@inheritDoc} */
    @Override
    public Set> getSupportedTypes() {
        return Collections.singleton(new Pair<>(Input.class, Output.class));
    }

    /** {@inheritDoc} */
    @Override
    @SuppressWarnings("unchecked")
    public  Translator newInstance(
            Class input, Class output, Model model, Map arguments)
            throws TranslateException {
        if (!isSupported(input, output)) {
            throw new IllegalArgumentException("Unsupported input/output types.");
        }

        Path modelDir = model.getModelPath();
        String factoryClass = ArgumentsUtil.stringValue(arguments, "translatorFactory");
        if (factoryClass != null) {
            Translator translator =
                    getServingTranslator(factoryClass, model, arguments);
            if (translator != null) {
                return (Translator) translator;
            }
            throw new TranslateException("Failed to load translatorFactory: " + factoryClass);
        }

        String className = (String) arguments.get("translator");
        Path libPath = modelDir.resolve("libs");
        if (!Files.isDirectory(libPath)) {
            libPath = modelDir.resolve("lib");
            if (!Files.isDirectory(libPath) && className == null) {
                return (Translator) loadDefaultTranslator(model, arguments);
            }
        }
        ServingTranslator servingTranslator = findTranslator(libPath, className);
        if (servingTranslator != null) {
            servingTranslator.setArguments(arguments);
            logger.info("Using translator: {}", servingTranslator.getClass().getName());
            return (Translator) servingTranslator;
        } else if (className != null) {
            throw new TranslateException("Failed to load translator: " + className);
        }

        return (Translator) loadDefaultTranslator(model, arguments);
    }

    private ServingTranslator findTranslator(Path path, String className) {
        Path classesDir = path.resolve("classes");
        ClassLoaderUtils.compileJavaClass(classesDir);
        return ClassLoaderUtils.findImplementation(path, ServingTranslator.class, className);
    }

    private TranslatorFactory loadTranslatorFactory(String className) {
        try {
            Class clazz = Class.forName(className);
            Class subclass = clazz.asSubclass(TranslatorFactory.class);
            Constructor constructor = subclass.getConstructor();
            return constructor.newInstance();
        } catch (Throwable e) {
            logger.trace("Not able to load TranslatorFactory: {}", className, e);
        }
        return null;
    }

    private Translator loadDefaultTranslator(Model model, Map arguments)
            throws TranslateException {
        String factoryClass = detectTranslatorFactory(arguments);
        Translator translator = getServingTranslator(factoryClass, model, arguments);
        if (translator != null) {
            return translator;
        }

        NoopServingTranslatorFactory factory = new NoopServingTranslatorFactory();
        return factory.newInstance(Input.class, Output.class, null, arguments);
    }

    private String detectTranslatorFactory(Map arguments) {
        Application application;
        String app = ArgumentsUtil.stringValue(arguments, "application");
        if (app != null) {
            application = Application.of(app);
        } else {
            String task = Utils.getEnvOrSystemProperty("HF_TASK");
            task = ArgumentsUtil.stringValue(arguments, "task", task);
            if (task != null) {
                task = task.replace("-", "_").toLowerCase(Locale.ROOT);
                application = Application.of(task);
            } else {
                application = Application.UNDEFINED;
            }
        }
        if (application == Application.CV.IMAGE_CLASSIFICATION) {
            return "ai.djl.modality.cv.translator.ImageClassificationTranslatorFactory";
        } else if (application == Application.NLP.FILL_MASK) {
            return "ai.djl.huggingface.translator.FillMaskTranslatorFactory";
        } else if (application == Application.NLP.QUESTION_ANSWER) {
            return "ai.djl.huggingface.translator.QuestionAnsweringTranslatorFactory";
        } else if (application == Application.NLP.TEXT_CLASSIFICATION) {
            return "ai.djl.huggingface.translator.TextClassificationTranslatorFactory";
        } else if (application == Application.NLP.TEXT_EMBEDDING) {
            return "ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory";
        } else if (application == Application.NLP.TOKEN_CLASSIFICATION) {
            return "ai.djl.huggingface.translator.TokenClassificationTranslatorFactory";
        }
        return null;
    }

    private Translator getServingTranslator(
            String factoryClass, Model model, Map arguments) throws TranslateException {
        TranslatorFactory factory = loadTranslatorFactory(factoryClass);
        if (factory != null && factory.isSupported(Input.class, Output.class)) {
            logger.info("Using TranslatorFactory: {}", factoryClass);
            return factory.newInstance(Input.class, Output.class, model, arguments);
        }
        return null;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy