All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.repository.zoo.BaseModelLoader Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.repository.zoo;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.nn.Block;
import ai.djl.nn.BlockFactory;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.translate.DefaultTranslatorFactory;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.Pair;
import ai.djl.util.Progress;
import ai.djl.util.Utils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.Reader;
import java.lang.reflect.Type;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.stream.Collectors;

/** Shared code for the {@link ModelLoader} implementations. */
public class BaseModelLoader implements ModelLoader {

    private static final Logger logger = LoggerFactory.getLogger(BaseModelLoader.class);

    protected MRL mrl;
    protected TranslatorFactory defaultFactory;

    /**
     * Constructs a {@link ModelLoader} given the repository, mrl, and version.
     *
     * @param mrl the mrl of the model to load
     */
    public BaseModelLoader(MRL mrl) {
        this.mrl = mrl;
        defaultFactory = new DefaultTranslatorFactory();
    }

    /** {@inheritDoc} */
    @Override
    public String getGroupId() {
        return mrl.getGroupId();
    }

    /** {@inheritDoc} */
    @Override
    public String getArtifactId() {
        return mrl.getArtifactId();
    }

    /** {@inheritDoc} */
    @Override
    public Application getApplication() {
        return mrl.getApplication();
    }

    /** {@inheritDoc} */
    @Override
    public  ZooModel loadModel(Criteria criteria)
            throws IOException, ModelNotFoundException, MalformedModelException {
        Artifact artifact = mrl.match(criteria.getFilters());
        if (artifact == null) {
            throw new ModelNotFoundException("No matching filter found");
        }

        Progress progress = criteria.getProgress();
        Map arguments = artifact.getArguments(criteria.getArguments());
        Map options = artifact.getOptions(criteria.getOptions());

        try {
            TranslatorFactory factory = getTranslatorFactory(criteria, arguments);
            Class input = criteria.getInputClass();
            Class output = criteria.getOutputClass();
            if (factory == null || !factory.isSupported(input, output)) {
                factory = defaultFactory;
                if (!factory.isSupported(input, output)) {
                    throw new ModelNotFoundException(getFactoryLookupErrorMessage(factory));
                }
            }

            mrl.prepare(artifact, progress);
            if (progress != null) {
                progress.reset("Loading", 2);
                progress.update(1);
            }

            Path modelPath = mrl.getRepository().getResourceDirectory(artifact);
            Path modelDir = Files.isRegularFile(modelPath) ? modelPath.getParent() : modelPath;
            if (modelDir == null) {
                throw new AssertionError("Directory should not be null.");
            }
            modelDir = Utils.getNestedModelDir(modelDir);

            loadServingProperties(modelDir, arguments, options);
            Application application = criteria.getApplication();
            if (application != Application.UNDEFINED) {
                arguments.put("application", application.getPath());
            }
            String engine = criteria.getEngine();
            if (engine == null) {
                // get engine from serving.properties
                engine = (String) arguments.get("engine");
            }

            // Check if the engine is specified in Criteria, use it if it is.
            // Otherwise check the modelzoo supported engine and grab a random engine in the list.
            // Otherwise if none of them is specified or model zoo is null, go to default engine.
            if (engine == null) {
                ModelZoo modelZoo = ModelZoo.getModelZoo(mrl.getGroupId());
                if (modelZoo != null) {
                    String defaultEngine = Engine.getDefaultEngineName();
                    for (String supportedEngine : modelZoo.getSupportedEngines()) {
                        if (supportedEngine.equals(defaultEngine)) {
                            engine = supportedEngine;
                            break;
                        } else if (Engine.hasEngine(supportedEngine)) {
                            engine = supportedEngine;
                        }
                    }
                    if (engine == null) {
                        throw new ModelNotFoundException(
                                "No supported engine available for model zoo: "
                                        + modelZoo.getGroupId());
                    }
                }
            }
            if (engine != null && !Engine.hasEngine(engine)) {
                throw new ModelNotFoundException(engine + " is not supported");
            }

            String modelName = criteria.getModelName();
            if (modelName == null) {
                modelName = options.get("modelName");
                if (modelName == null) {
                    modelName = artifact.getName();
                }
            }

            Model model =
                    createModel(
                            modelDir,
                            modelName,
                            criteria.getDevice(),
                            criteria.getBlock(),
                            arguments,
                            engine);
            model.load(modelPath, null, options);
            Translator translator = factory.newInstance(input, output, model, arguments);
            return new ZooModel<>(model, translator);
        } catch (TranslateException e) {
            throw new ModelNotFoundException("No matching translator found", e);
        } finally {
            if (progress != null) {
                progress.end();
            }
        }
    }

    /** {@inheritDoc} */
    @Override
    public  boolean isDownloaded(Criteria criteria)
            throws IOException, ModelNotFoundException {
        Artifact artifact = mrl.match(criteria.getFilters());
        if (artifact == null) {
            throw new ModelNotFoundException("No matching filter found");
        }
        return mrl.isPrepared(artifact);
    }

    /** {@inheritDoc} */
    @Override
    public  void downloadModel(Criteria criteria, Progress progress)
            throws IOException, ModelNotFoundException {
        Artifact artifact = mrl.match(criteria.getFilters());
        if (artifact == null) {
            throw new ModelNotFoundException("No matching filter found");
        }
        mrl.prepare(artifact, progress);
    }

    /** {@inheritDoc} */
    @Override
    public List listModels() throws IOException {
        List list = mrl.listArtifacts();
        String version = mrl.getVersion();
        return list.stream()
                .filter(a -> version == null || version.equals(a.getVersion()))
                .collect(Collectors.toList());
    }

    protected Model createModel(
            Path modelPath,
            String name,
            Device device,
            Block block,
            Map arguments,
            String engine)
            throws IOException {
        Model model = Model.newInstance(name, device, engine);
        if (block == null) {
            Object bf = arguments.get("blockFactory");
            if (bf instanceof BlockFactory) {
                block = ((BlockFactory) bf).newBlock(model, modelPath, arguments);
            } else {
                String className = (String) bf;
                BlockFactory factory =
                        ClassLoaderUtils.findImplementation(
                                modelPath, BlockFactory.class, className);
                if (factory != null) {
                    block = factory.newBlock(model, modelPath, arguments);
                } else if (className != null) {
                    throw new IllegalArgumentException("Failed to load BlockFactory: " + className);
                }
            }
        }
        if (block != null) {
            model.setBlock(block);
        }
        for (Map.Entry entry : arguments.entrySet()) {
            model.setProperty(entry.getKey(), entry.getValue().toString());
        }
        return model;
    }

    /** {@inheritDoc} */
    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append(mrl.getGroupId())
                .append(':')
                .append(mrl.getArtifactId())
                .append(' ')
                .append(getApplication())
                .append(" [\n");
        try {
            for (Artifact artifact : listModels()) {
                sb.append('\t').append(artifact).append('\n');
            }
        } catch (IOException e) {
            sb.append("\tFailed load metadata.");
        }
        sb.append(']');
        return sb.toString();
    }

    protected TranslatorFactory getTranslatorFactory(
            Criteria criteria, Map arguments) {
        TranslatorFactory factory = criteria.getTranslatorFactory();
        if (factory != null) {
            return factory;
        }

        String factoryClass = (String) arguments.get("translatorFactory");
        if (factoryClass != null) {
            ClassLoader cl = ClassLoaderUtils.getContextClassLoader();
            factory = ClassLoaderUtils.initClass(cl, TranslatorFactory.class, factoryClass);
            if (factory == null) {
                logger.warn("Failed to load translatorFactory: {}", factoryClass);
            }
        }
        return factory;
    }

    private String getFactoryLookupErrorMessage(TranslatorFactory factory) {
        StringBuilder sb = new StringBuilder(200);
        sb.append(
                "No matching default translator found. The valid input and output classes are: \n");
        for (Pair io : factory.getSupportedTypes()) {
            sb.append("\t(")
                    .append(io.getKey().getTypeName())
                    .append(", ")
                    .append(io.getValue().getTypeName())
                    .append(")\n");
        }
        return sb.toString();
    }

    private void loadServingProperties(
            Path modelDir, Map arguments, Map options)
            throws IOException {
        Path manifestFile = modelDir.resolve("serving.properties");
        if (Files.isRegularFile(manifestFile)) {
            Properties prop = new Properties();
            try (Reader reader = Files.newBufferedReader(manifestFile)) {
                prop.load(reader);
            }
            for (String key : prop.stringPropertyNames()) {
                if (key.startsWith("option.")) {
                    options.putIfAbsent(key.substring(7), prop.getProperty(key));
                } else {
                    arguments.putIfAbsent(key, prop.getProperty(key));
                }
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy