All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.repository.zoo.Criteria Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.repository.zoo;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.engine.Engine;
import ai.djl.translate.Translator;
import ai.djl.util.Progress;
import java.util.HashMap;
import java.util.Map;

/**
 * The {@code Criteria} class contains search criteria to look up a {@link ZooModel}.
 *
 * @param  the model input type
 * @param  the model output type
 */
public class Criteria {

    private Application application;
    private Class inputClass;
    private Class outputClass;
    private String engine;
    private Device device;
    private String groupId;
    private String artifactId;
    private Map filters;
    private Map arguments;
    private Map options;
    private Translator translator;
    private Progress progress;

    Criteria(Builder builder) {
        this.application = builder.application;
        this.inputClass = builder.inputClass;
        this.outputClass = builder.outputClass;
        this.engine = builder.engine;
        this.device = builder.device;
        this.groupId = builder.groupId;
        this.artifactId = builder.artifactId;
        this.filters = builder.filters;
        this.arguments = builder.arguments;
        this.options = builder.options;
        this.translator = builder.translator;
        this.progress = builder.progress;
    }

    /**
     * Returns the application of the model.
     *
     * @return the application of the model
     */
    public Application getApplication() {
        return application;
    }

    /**
     * Returns the input data type.
     *
     * @return the input data type
     */
    public Class getInputClass() {
        return inputClass;
    }

    /**
     * Returns the output data type.
     *
     * @return the output data type
     */
    public Class getOutputClass() {
        return outputClass;
    }

    /**
     * Returns the engine name.
     *
     * @return the engine name
     */
    public String getEngine() {
        return engine;
    }

    /**
     * Returns the {@link Device} of the model to be loaded on.
     *
     * @return the {@link Device} of the model to be loaded on
     */
    public Device getDevice() {
        return device;
    }

    /**
     * Returns the groupId of the {@link ModelZoo} to be searched.
     *
     * @return the groupId of the {@link ModelZoo} to be searched
     */
    public String getGroupId() {
        return groupId;
    }

    /**
     * Returns the artifactId of the {@link ModelLoader} to be searched.
     *
     * @return the artifactIds of the {@link ModelLoader} to be searched
     */
    public String getArtifactId() {
        return artifactId;
    }

    /**
     * Returns the search filters that must match the properties of the model.
     *
     * @return the search filters that must match the properties of the model.
     */
    public Map getFilters() {
        return filters;
    }

    /**
     * Returns the override configurations of the model loading arguments.
     *
     * @return the override configurations of the model loading arguments
     */
    public Map getArguments() {
        return arguments;
    }

    /**
     * Returns the model loading options.
     *
     * @return the model loading options
     */
    public Map getOptions() {
        return options;
    }

    /**
     * Returns the optional {@link Translator} to be used for {@link ZooModel}.
     *
     * @return the optional {@link Translator} to be used for {@link ZooModel}
     */
    public Translator getTranslator() {
        return translator;
    }

    /**
     * Returns the optional {@link Progress} for the model loading.
     *
     * @return the optional {@link Progress} for the model loading
     */
    public Progress getProgress() {
        return progress;
    }

    /**
     * Creates a builder to build a {@code Criteria}.
     *
     * @return a new builder
     */
    public static Builder builder() {
        return new Builder<>();
    }

    /** A Builder to construct a {@code Criteria}. */
    public static final class Builder {

        Application application;
        Class inputClass;
        Class outputClass;
        String engine;
        Device device;
        String groupId;
        String artifactId;
        Map filters;
        Map arguments;
        Map options;
        Translator translator;
        Progress progress;

        Builder() {
            engine = Engine.getInstance().getEngineName();
        }

        private Builder(Class inputClass, Class outputClass, Builder parent) {
            this.inputClass = inputClass;
            this.outputClass = outputClass;
            application = parent.application;
            engine = parent.engine;
            device = parent.device;
            groupId = parent.groupId;
            filters = parent.filters;
            arguments = parent.arguments;
            options = parent.options;
            progress = parent.progress;
        }

        /**
         * Creates a new @{code Builder} class with the specified input and output data type.
         *
         * @param 

the input data type * @param the output data type * @param inputClass the input class * @param outputClass the output class * @return a new @{code Builder} class with the specified input and output data type */ public Builder setTypes(Class

inputClass, Class outputClass) { return new Builder<>(inputClass, outputClass, this); } /** * Sets the model application for this criteria. * * @param application the model application * @return this {@code Builder} */ public Builder optApplication(Application application) { this.application = application; return this; } /** * Sets the engine name for this criteria. * * @param engine the engine name * @return this {@code Builder} */ public Builder optEngine(String engine) { this.engine = engine; return this; } /** * Sets the {@link Device} for this criteria. * * @param device the {@link Device} for the criteria * @return this {@code Builder} */ public Builder optDevice(Device device) { this.device = device; return this; } /** * Sets optional groupId of the {@link ModelZoo} for this criteria. * * @param groupId the groupId of the {@link ModelZoo} * @return this {@code Builder} */ public Builder optGroupId(String groupId) { this.groupId = groupId; return this; } /** * Sets optional artifactId of the {@link ModelLoader} for this criteria. * * @param artifactId the artifactId of the {@link ModelLoader} * @return this {@code Builder} */ public Builder optArtifactId(String artifactId) { if (artifactId.contains(":")) { String[] tokens = artifactId.split(":"); groupId = tokens[0]; this.artifactId = tokens[1]; } else { this.artifactId = artifactId; } return this; } /** * Sets the extra search filters for this criteria. * * @param filters the extra search filters * @return this {@code Builder} */ public Builder optFilters(Map filters) { this.filters = filters; return this; } /** * Sets an extra search filter for this criteria. * * @param key the search key * @param value the search value * @return this {@code Builder} */ public Builder optFilter(String key, String value) { if (filters == null) { filters = new HashMap<>(); } filters.put(key, value); return this; } /** * Sets an extra model loading argument for this criteria. * * @param arguments optional model loading arguments * @return this {@code Builder} */ public Builder optArguments(Map arguments) { this.arguments = arguments; return this; } /** * Sets the optional model loading argument for this criteria. * * @param key the model loading argument key * @param value the model loading argument value * @return this {@code Builder} */ public Builder optArgument(String key, Object value) { if (arguments == null) { arguments = new HashMap<>(); } arguments.put(key, value); return this; } /** * Sets the model loading options for this criteria. * * @param options the model loading options * @return this {@code Builder} */ public Builder optOptions(Map options) { this.options = options; return this; } /** * Sets the optional model loading option for this criteria. * * @param key the model loading option key * @param value the model loading option value * @return this {@code Builder} */ public Builder optOption(String key, Object value) { if (options == null) { options = new HashMap<>(); } options.put(key, value); return this; } /** * Sets the optional {@link Translator} to override default {@code Translator}. * * @param translator the override {@code Translator} * @return this {@code Builder} */ public Builder optTranslator(Translator translator) { this.translator = translator; return this; } /** * Set the optional {@link Progress}. * * @param progress the {@code Progress} * @return this {@code Builder} */ public Builder optProgress(Progress progress) { this.progress = progress; return this; } /** * Builds a {@link Criteria} instance. * * @return the {@link Criteria} instance */ public Criteria build() { return new Criteria<>(this); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy