All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.translate.ServingTranslatorFactory Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.translate;

import ai.djl.Application;
import ai.djl.Model;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.modality.cv.translator.SingleShotDetectionTranslator;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.util.JsonSerializable;
import ai.djl.util.JsonUtils;
import ai.djl.util.PairList;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.Reader;
import java.lang.reflect.Constructor;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Enumeration;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.stream.Collectors;
import javax.tools.JavaCompiler;
import javax.tools.ToolProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** A {@link TranslatorFactory} that creates an generic {@link Translator}. */
public class ServingTranslatorFactory implements TranslatorFactory {

    private static final Logger logger = LoggerFactory.getLogger(ServingTranslatorFactory.class);

    /** {@inheritDoc} */
    @Override
    public Translator newInstance(Model model, Map arguments)
            throws TranslateException {
        Map merged = new ConcurrentHashMap<>(arguments);
        Path modelDir = model.getModelPath();
        String className = null;
        Path manifestFile = modelDir.resolve("serving.properties");
        if (Files.isRegularFile(manifestFile)) {
            Properties prop = new Properties();
            try (Reader reader = Files.newBufferedReader(manifestFile)) {
                prop.load(reader);
            } catch (IOException e) {
                throw new TranslateException("Failed to load serving.properties file", e);
            }
            for (String key : prop.stringPropertyNames()) {
                merged.putIfAbsent(key, prop.getProperty(key));
            }
            className = prop.getProperty("translator");
        }

        Path libPath = modelDir.resolve("libs");
        if (!Files.isDirectory(libPath)) {
            libPath = modelDir.resolve("lib");
            if (!Files.isDirectory(libPath)) {
                return loadDefaultTranslator(merged);
            }
        }
        ServingTranslator translator = findTranslator(libPath, className);
        if (translator != null) {
            translator.setArguments(merged);
            return translator;
        }
        return loadDefaultTranslator(merged);
    }

    private ServingTranslator findTranslator(Path path, String className) {
        try {
            Path classesDir = path.resolve("classes");
            compileJavaClass(classesDir);

            List jarFiles =
                    Files.list(path)
                            .filter(p -> p.toString().endsWith(".jar"))
                            .collect(Collectors.toList());
            List urls = new ArrayList<>(jarFiles.size() + 1);
            urls.add(classesDir.toUri().toURL());
            for (Path p : jarFiles) {
                urls.add(p.toUri().toURL());
            }

            ClassLoader parentCl = Thread.currentThread().getContextClassLoader();
            ClassLoader cl = new URLClassLoader(urls.toArray(new URL[0]), parentCl);
            if (className != null && !className.isEmpty()) {
                return initTranslator(cl, className);
            }

            ServingTranslator translator = scanDirectory(cl, classesDir);
            if (translator != null) {
                return translator;
            }

            for (Path p : jarFiles) {
                translator = scanJarFile(cl, p);
                if (translator != null) {
                    return translator;
                }
            }
        } catch (IOException e) {
            logger.debug("Failed to find Translator", e);
        }
        return null;
    }

    private ServingTranslator scanDirectory(ClassLoader cl, Path dir) throws IOException {
        if (!Files.isDirectory(dir)) {
            logger.debug("Directory not exists: {}", dir);
            return null;
        }
        Collection files =
                Files.walk(dir)
                        .filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class"))
                        .collect(Collectors.toList());
        for (Path file : files) {
            Path p = dir.relativize(file);
            String className = p.toString();
            className = className.substring(0, className.lastIndexOf('.'));
            className = className.replace(File.separatorChar, '.');
            ServingTranslator translator = initTranslator(cl, className);
            if (translator != null) {
                return translator;
            }
        }
        return null;
    }

    private ServingTranslator scanJarFile(ClassLoader cl, Path path) throws IOException {
        try (JarFile jarFile = new JarFile(path.toFile())) {
            Enumeration en = jarFile.entries();
            while (en.hasMoreElements()) {
                JarEntry entry = en.nextElement();
                String fileName = entry.getName();
                if (fileName.endsWith(".class")) {
                    fileName = fileName.substring(0, fileName.lastIndexOf('.'));
                    fileName = fileName.replace('/', '.');
                    ServingTranslator translator = initTranslator(cl, fileName);
                    if (translator != null) {
                        return translator;
                    }
                }
            }
        }
        return null;
    }

    private ServingTranslator initTranslator(ClassLoader cl, String className) {
        try {
            Class clazz = Class.forName(className, true, cl);
            Class subclass = clazz.asSubclass(ServingTranslator.class);
            Constructor constructor = subclass.getConstructor();
            return constructor.newInstance();
        } catch (Throwable e) {
            logger.trace("Not able to load ModelServerTranslator", e);
        }
        return null;
    }

    private Translator loadDefaultTranslator(Map arguments) {
        String appName = (String) arguments.get("application");
        if (appName != null) {
            Application application = Application.of(appName);
            if (application == Application.CV.IMAGE_CLASSIFICATION) {
                return getImageClassificationTranslator(arguments);
            } else if (application == Application.CV.OBJECT_DETECTION) {
                // TODO: check model name
                return getSsdTranslator(arguments);
            }
        }
        return new RawTranslator();
    }

    private Translator getImageClassificationTranslator(
            Map arguments) {
        return new ImageServingTranslator(ImageClassificationTranslator.builder(arguments).build());
    }

    private Translator getSsdTranslator(Map arguments) {
        return new ImageServingTranslator(SingleShotDetectionTranslator.builder(arguments).build());
    }

    private void compileJavaClass(Path dir) {
        try {
            if (!Files.isDirectory(dir)) {
                logger.debug("Directory not exists: {}", dir);
                return;
            }
            String[] files =
                    Files.walk(dir)
                            .filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".java"))
                            .map(p -> p.toAbsolutePath().toString())
                            .toArray(String[]::new);
            JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
            if (files.length > 0) {
                compiler.run(null, null, null, files);
            }
        } catch (Throwable e) {
            logger.warn("Failed to compile bundled java file", e);
        }
    }

    private static final class ImageServingTranslator implements Translator {

        private Translator translator;
        private ImageFactory factory;

        public ImageServingTranslator(Translator translator) {
            this.translator = translator;
            factory = ImageFactory.getInstance();
        }

        /** {@inheritDoc} */
        @Override
        public Batchifier getBatchifier() {
            return translator.getBatchifier();
        }

        /** {@inheritDoc} */
        @Override
        public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
            Input input = (Input) ctx.getAttachment("input");
            Output output = new Output(input.getRequestId(), 200, "OK");
            Object obj = translator.processOutput(ctx, list);
            if (obj instanceof JsonSerializable) {
                output.setContent(((JsonSerializable) obj).toJson() + '\n');
            } else {
                output.setContent(JsonUtils.GSON_PRETTY.toJson(obj) + '\n');
            }
            return output;
        }

        /** {@inheritDoc} */
        @Override
        public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
            ctx.setAttachment("input", input);
            PairList inputs = input.getContent();
            byte[] data = inputs.get("data");
            if (data == null) {
                data = inputs.get("body");
            }
            if (data == null) {
                data = input.getContent().valueAt(0);
            }
            Image image = factory.fromInputStream(new ByteArrayInputStream(data));
            return translator.processInput(ctx, image);
        }

        /** {@inheritDoc} */
        @Override
        public void prepare(NDManager manager, Model model) throws IOException {
            translator.prepare(manager, model);
        }
    }

    private static final class RawTranslator implements Translator {

        /** {@inheritDoc} */
        @Override
        public Batchifier getBatchifier() {
            return null;
        }

        /** {@inheritDoc} */
        @Override
        public NDList processInput(TranslatorContext ctx, Input input) {
            ctx.setAttachment("input", input);
            PairList inputs = input.getContent();
            byte[] data = inputs.get("data");
            if (data == null) {
                data = inputs.get("body");
            }
            if (data == null) {
                data = input.getContent().valueAt(0);
            }
            NDManager manager = ctx.getNDManager();
            return NDList.decode(manager, data);
        }

        /** {@inheritDoc} */
        @Override
        public Output processOutput(TranslatorContext ctx, NDList list) {
            Input input = (Input) ctx.getAttachment("input");
            Output output = new Output(input.getRequestId(), 200, "OK");
            output.setContent(list.encode());
            return output;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy