All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.repository.zoo.Criteria Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.repository.zoo;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.Model;
import ai.djl.nn.Block;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.JsonUtils;
import ai.djl.util.Progress;
import java.net.MalformedURLException;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;

/**
 * The {@code Criteria} class contains search criteria to look up a {@link ZooModel}.
 *
 * 

Criteria follows Builder pattern. See {@link Builder} for detail. In DJL's builder convention, * the methods start with {@code set} are required fields, and {@code opt} for optional fields. * *

Examples * *

 * Criteria<Image, Classifications> criteria = Criteria.builder()
 *         .setTypes(Image.class, Classifications.class) // defines input and output data type
 *         .optTranslator(ImageClassificationTranslator.builder().setSynsetArtifactName("synset.txt").build())
 *         .optModelUrls("file:///var/models/my_resnet50") // search models in specified path
 *         .optModelName("resnet50") // specify model file prefix
 *         .build();
 * 
* *

See Model loading for * more detail. * * @param the model input type * @param the model output type */ public class Criteria { private Application application; private Class inputClass; private Class outputClass; private String engine; private Device device; private String groupId; private String artifactId; private ModelZoo modelZoo; private Map filters; private Map arguments; private Map options; private TranslatorFactory factory; private Block block; private String modelName; private Progress progress; Criteria(Builder builder) { this.application = builder.application; this.inputClass = builder.inputClass; this.outputClass = builder.outputClass; this.engine = builder.engine; this.device = builder.device; this.groupId = builder.groupId; this.artifactId = builder.artifactId; this.modelZoo = builder.modelZoo; this.filters = builder.filters; this.arguments = builder.arguments; this.options = builder.options; this.factory = builder.factory; this.block = builder.block; this.modelName = builder.modelName; this.progress = builder.progress; } /** * Returns the application of the model. * * @return the application of the model */ public Application getApplication() { return application; } /** * Returns the input data type. * * @return the input data type */ public Class getInputClass() { return inputClass; } /** * Returns the output data type. * * @return the output data type */ public Class getOutputClass() { return outputClass; } /** * Returns the engine name. * * @return the engine name */ public String getEngine() { return engine; } /** * Returns the {@link Device} of the model to be loaded on. * * @return the {@link Device} of the model to be loaded on */ public Device getDevice() { return device; } /** * Returns the groupId of the {@link ModelZoo} to be searched. * * @return the groupId of the {@link ModelZoo} to be searched */ public String getGroupId() { return groupId; } /** * Returns the artifactId of the {@link ModelLoader} to be searched. * * @return the artifactIds of the {@link ModelLoader} to be searched */ public String getArtifactId() { return artifactId; } /** * Returns the {@link ModelZoo} to be searched. * * @return the {@link ModelZoo} to be searched */ public ModelZoo getModelZoo() { return modelZoo; } /** * Returns the search filters that must match the properties of the model. * * @return the search filters that must match the properties of the model. */ public Map getFilters() { return filters; } /** * Returns the override configurations of the model loading arguments. * * @return the override configurations of the model loading arguments */ public Map getArguments() { return arguments; } /** * Returns the model loading options. * * @return the model loading options */ public Map getOptions() { return options; } /** * Returns the optional {@link TranslatorFactory} to be used for {@link ZooModel}. * * @return the optional {@link TranslatorFactory} to be used for {@link ZooModel} */ public TranslatorFactory getTranslatorFactory() { return factory; } /** * Returns the optional {@link Block} to be used for {@link ZooModel}. * * @return the optional {@link Block} to be used for {@link ZooModel} */ public Block getBlock() { return block; } /** * Returns the optional model name to be used for {@link ZooModel}. * * @return the optional model name to be used for {@link ZooModel} */ public String getModelName() { return modelName; } /** * Returns the optional {@link Progress} for the model loading. * * @return the optional {@link Progress} for the model loading */ public Progress getProgress() { return progress; } /** {@inheritDoc} */ @Override public String toString() { StringBuilder sb = new StringBuilder(128); sb.append("Criteria:\n"); if (application != null) { sb.append("\tApplication: ").append(application).append('\n'); } sb.append("\tInput: ").append(inputClass).append('\n'); sb.append("\tOutput: ").append(outputClass).append('\n'); if (engine != null) { sb.append("\tEngine: ").append(engine).append('\n'); } if (modelZoo != null) { sb.append("\tModelZoo: ").append(modelZoo.getGroupId()).append('\n'); } if (groupId != null) { sb.append("\tGroupID: ").append(groupId).append('\n'); } if (artifactId != null) { sb.append("\tArtifactId: ").append(artifactId).append('\n'); } if (filters != null) { sb.append("\tFilter: ").append(JsonUtils.GSON.toJson(filters)).append('\n'); } if (arguments != null) { sb.append("\tArguments: ").append(JsonUtils.GSON.toJson(arguments)).append('\n'); } if (options != null) { sb.append("\tOptions: ").append(JsonUtils.GSON.toJson(options)).append('\n'); } if (factory == null) { sb.append("\tNo translator supplied\n"); } return sb.toString(); } /** * Creates a builder to build a {@code Criteria}. * *

The methods start with {@code set} are required fields, and {@code opt} for optional * fields. * * @return a new builder */ public static Builder builder() { return new Builder<>(); } /** A Builder to construct a {@code Criteria}. */ public static final class Builder { Application application; Class inputClass; Class outputClass; String engine; Device device; String groupId; String artifactId; ModelZoo modelZoo; Map filters; Map arguments; Map options; TranslatorFactory factory; Block block; String modelName; Progress progress; Builder() { application = Application.UNDEFINED; } private Builder(Class inputClass, Class outputClass, Builder parent) { this.inputClass = inputClass; this.outputClass = outputClass; application = parent.application; engine = parent.engine; device = parent.device; groupId = parent.groupId; filters = parent.filters; arguments = parent.arguments; options = parent.options; block = parent.block; modelName = parent.modelName; progress = parent.progress; } /** * Creates a new @{code Builder} class with the specified input and output data type. * * @param

the input data type * @param the output data type * @param inputClass the input class * @param outputClass the output class * @return a new @{code Builder} class with the specified input and output data type */ public Builder setTypes(Class

inputClass, Class outputClass) { return new Builder<>(inputClass, outputClass, this); } /** * Sets the model application for this criteria. * * @param application the model application * @return this {@code Builder} */ public Builder optApplication(Application application) { this.application = application; return this; } /** * Sets the engine name for this criteria. * * @param engine the engine name * @return this {@code Builder} */ public Builder optEngine(String engine) { this.engine = engine; return this; } /** * Sets the {@link Device} for this criteria. * * @param device the {@link Device} for the criteria * @return this {@code Builder} */ public Builder optDevice(Device device) { this.device = device; return this; } /** * Sets optional groupId of the {@link ModelZoo} for this criteria. * * @param groupId the groupId of the {@link ModelZoo} * @return this {@code Builder} */ public Builder optGroupId(String groupId) { this.groupId = groupId; return this; } /** * Sets optional artifactId of the {@link ModelLoader} for this criteria. * * @param artifactId the artifactId of the {@link ModelLoader} * @return this {@code Builder} */ public Builder optArtifactId(String artifactId) { if (artifactId.contains(":")) { String[] tokens = artifactId.split(":"); groupId = tokens[0]; this.artifactId = tokens[1]; } else { this.artifactId = artifactId; } return this; } /** * Sets optional model urls of the {@link ModelLoader} for this criteria. * * @param modelUrls the comma delimited url string * @return this {@code Builder} */ public Builder optModelUrls(String modelUrls) { this.modelZoo = new DefaultModelZoo(modelUrls); return this; } /** * Sets the optional model path of the {@link ModelLoader} for this criteria. * * @param modelPath the path to the model folder/files * @return this {@code Builder} * @throws MalformedURLException wrong path format */ public Builder optModelPath(Path modelPath) throws MalformedURLException { this.modelZoo = new DefaultModelZoo(modelPath.toUri().toURL().toString()); return this; } /** * Sets optional {@link ModelZoo} of the {@link ModelLoader} for this criteria. * * @param modelZoo ModelZoo} of the {@link ModelLoader} for this criteria * @return this {@code Builder} */ public Builder optModelZoo(ModelZoo modelZoo) { this.modelZoo = modelZoo; return this; } /** * Sets the extra search filters for this criteria. * * @param filters the extra search filters * @return this {@code Builder} */ public Builder optFilters(Map filters) { this.filters = filters; return this; } /** * Sets an extra search filter for this criteria. * * @param key the search key * @param value the search value * @return this {@code Builder} */ public Builder optFilter(String key, String value) { if (filters == null) { filters = new HashMap<>(); } filters.put(key, value); return this; } /** * Sets an optional model {@link Block} for this criteria. * * @param block optional model {@link Block} for this criteria * @return this {@code Builder} */ public Builder optBlock(Block block) { this.block = block; return this; } /** * Sets an optional model name for this criteria. * * @param modelName optional model name for this criteria * @return this {@code Builder} */ public Builder optModelName(String modelName) { this.modelName = modelName; return this; } /** * Sets an extra model loading argument for this criteria. * * @param arguments optional model loading arguments * @return this {@code Builder} */ public Builder optArguments(Map arguments) { this.arguments = arguments; return this; } /** * Sets the optional model loading argument for this criteria. * * @param key the model loading argument key * @param value the model loading argument value * @return this {@code Builder} */ public Builder optArgument(String key, Object value) { if (arguments == null) { arguments = new HashMap<>(); } arguments.put(key, value); return this; } /** * Sets the model loading options for this criteria. * * @param options the model loading options * @return this {@code Builder} */ public Builder optOptions(Map options) { this.options = options; return this; } /** * Sets the optional model loading option for this criteria. * * @param key the model loading option key * @param value the model loading option value * @return this {@code Builder} */ public Builder optOption(String key, String value) { if (options == null) { options = new HashMap<>(); } options.put(key, value); return this; } /** * Sets the optional {@link Translator} to override default {@code Translator}. * * @param translator the override {@code Translator} * @return this {@code Builder} */ public Builder optTranslator(Translator translator) { this.factory = new TranslatorFactorImpl<>(translator); return this; } /** * Sets the optional {@link TranslatorFactory} to override default {@code Translator}. * * @param factory the override {@code TranslatorFactory} * @return this {@code Builder} */ public Builder optTranslatorFactory(TranslatorFactory factory) { this.factory = factory; return this; } /** * Set the optional {@link Progress}. * * @param progress the {@code Progress} * @return this {@code Builder} */ public Builder optProgress(Progress progress) { this.progress = progress; return this; } /** * Builds a {@link Criteria} instance. * * @return the {@link Criteria} instance */ public Criteria build() { return new Criteria<>(this); } } private static final class TranslatorFactorImpl implements TranslatorFactory { private Translator translator; public TranslatorFactorImpl(Translator translator) { this.translator = translator; } /** {@inheritDoc} */ @Override public Translator newInstance(Model model, Map arguments) { return translator; } } }