All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.repository.zoo.ZooModel Maven / Gradle / Ivy

/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.repository.zoo;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.translate.Translator;
import ai.djl.util.PairList;

import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.file.Path;
import java.util.Map;
import java.util.function.Function;

/**
 * A {@code ZooModel} is a {@link Model} loaded from a model zoo and includes a default {@link
 * Translator}.
 *
 * @param  the model input type
 * @param  the model output type
 */
public class ZooModel implements Model {

    private Model model;
    private Translator translator;

    /**
     * Constructs a {@code ZooModel} given the model and translator.
     *
     * @param model the model to wrap
     * @param translator the translator
     */
    public ZooModel(Model model, Translator translator) {
        this.model = model;
        this.translator = translator;
    }

    /** {@inheritDoc} */
    @Override
    public void load(Path modelPath, String prefix, Map options) {
        throw new IllegalArgumentException("ZooModel should not be re-loaded.");
    }

    /** {@inheritDoc} */
    @Override
    public void load(InputStream modelStream, Map options) throws IOException {
        throw new IllegalArgumentException("ZooModel should not be re-loaded.");
    }

    /**
     * Returns the wrapped model.
     *
     * @return the wrapped model
     */
    public Model getWrappedModel() {
        return model;
    }

    /** {@inheritDoc} */
    @Override
    public void save(Path modelPath, String modelName) throws IOException {
        model.save(modelPath, modelName);
    }

    /** {@inheritDoc} */
    @Override
    public Path getModelPath() {
        return model.getModelPath();
    }

    /** {@inheritDoc} */
    @Override
    public Block getBlock() {
        return model.getBlock();
    }

    /** {@inheritDoc} */
    @Override
    public void setBlock(Block block) {
        model.setBlock(block);
    }

    /** {@inheritDoc} */
    @Override
    public String getName() {
        return model.getName();
    }

    /** {@inheritDoc} */
    @Override
    public String getProperty(String key) {
        return model.getProperty(key);
    }

    /** {@inheritDoc} */
    @Override
    public void setProperty(String key, String value) {
        model.setProperty(key, value);
    }

    /** {@inheritDoc} */
    @Override
    public Map getProperties() {
        return model.getProperties();
    }

    /** {@inheritDoc} */
    @Override
    public Trainer newTrainer(TrainingConfig trainingConfig) {
        return model.newTrainer(trainingConfig);
    }

    /**
     * Creates a new Predictor based on the model with the default translator.
     *
     * @return an instance of {@code Predictor}
     */
    public Predictor newPredictor() {
        return newPredictor(translator);
    }

    /**
     * Creates a new Predictor based on the model with the default translator and a specified
     * device.
     *
     * @param device the device to use for prediction
     * @return an instance of {@code Predictor}
     */
    public Predictor newPredictor(Device device) {
        return model.newPredictor(translator, device);
    }

    /** {@inheritDoc} */
    @Override
    public  Predictor newPredictor(Translator translator, Device device) {
        return model.newPredictor(translator, device);
    }

    /**
     * Returns the default translator.
     *
     * @return the default translator
     */
    public Translator getTranslator() {
        return translator;
    }

    /** {@inheritDoc} */
    @Override
    public PairList describeInput() {
        return model.describeInput();
    }

    /** {@inheritDoc} */
    @Override
    public PairList describeOutput() {
        return model.describeOutput();
    }

    /** {@inheritDoc} */
    @Override
    public String[] getArtifactNames() {
        return model.getArtifactNames();
    }

    /** {@inheritDoc} */
    @Override
    public  T getArtifact(String name, Function function) throws IOException {
        return model.getArtifact(name, function);
    }

    /** {@inheritDoc} */
    @Override
    public URL getArtifact(String name) throws IOException {
        return model.getArtifact(name);
    }

    /** {@inheritDoc} */
    @Override
    public InputStream getArtifactAsStream(String name) throws IOException {
        return model.getArtifactAsStream(name);
    }

    /** {@inheritDoc} */
    @Override
    public NDManager getNDManager() {
        return model.getNDManager();
    }

    /** {@inheritDoc} */
    @Override
    public void setDataType(DataType dataType) {
        model.setDataType(dataType);
    }

    /** {@inheritDoc} */
    @Override
    public DataType getDataType() {
        return model.getDataType();
    }

    /** {@inheritDoc} */
    @Override
    public void cast(DataType dataType) {
        model.cast(dataType);
    }

    /** {@inheritDoc} */
    @Override
    public void close() {
        model.close();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy