ai.djl.repository.zoo.Criteria Maven / Gradle / Ivy
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.repository.zoo;
import ai.djl.Application;
import ai.djl.Device;
import ai.djl.Model;
import ai.djl.nn.Block;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.JsonUtils;
import ai.djl.util.Progress;
import java.net.MalformedURLException;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
/**
* The {@code Criteria} class contains search criteria to look up a {@link ZooModel}.
*
* Criteria follows Builder pattern. See {@link Builder} for detail. In DJL's builder convention,
* the methods start with {@code set} are required fields, and {@code opt} for optional fields.
*
*
Examples
*
*
* Criteria<Image, Classifications> criteria = Criteria.builder()
* .setTypes(Image.class, Classifications.class) // defines input and output data type
* .optTranslator(ImageClassificationTranslator.builder().setSynsetArtifactName("synset.txt").build())
* .optModelUrls("file:///var/models/my_resnet50") // search models in specified path
* .optModelName("resnet50") // specify model file prefix
* .build();
*
*
* See Model loading for
* more detail.
*
* @param the model input type
* @param the model output type
*/
public class Criteria {
private Application application;
private Class inputClass;
private Class outputClass;
private String engine;
private Device device;
private String groupId;
private String artifactId;
private ModelZoo modelZoo;
private Map filters;
private Map arguments;
private Map options;
private TranslatorFactory factory;
private Block block;
private String modelName;
private Progress progress;
Criteria(Builder builder) {
this.application = builder.application;
this.inputClass = builder.inputClass;
this.outputClass = builder.outputClass;
this.engine = builder.engine;
this.device = builder.device;
this.groupId = builder.groupId;
this.artifactId = builder.artifactId;
this.modelZoo = builder.modelZoo;
this.filters = builder.filters;
this.arguments = builder.arguments;
this.options = builder.options;
this.factory = builder.factory;
this.block = builder.block;
this.modelName = builder.modelName;
this.progress = builder.progress;
}
/**
* Returns the application of the model.
*
* @return the application of the model
*/
public Application getApplication() {
return application;
}
/**
* Returns the input data type.
*
* @return the input data type
*/
public Class getInputClass() {
return inputClass;
}
/**
* Returns the output data type.
*
* @return the output data type
*/
public Class getOutputClass() {
return outputClass;
}
/**
* Returns the engine name.
*
* @return the engine name
*/
public String getEngine() {
return engine;
}
/**
* Returns the {@link Device} of the model to be loaded on.
*
* @return the {@link Device} of the model to be loaded on
*/
public Device getDevice() {
return device;
}
/**
* Returns the groupId of the {@link ModelZoo} to be searched.
*
* @return the groupId of the {@link ModelZoo} to be searched
*/
public String getGroupId() {
return groupId;
}
/**
* Returns the artifactId of the {@link ModelLoader} to be searched.
*
* @return the artifactIds of the {@link ModelLoader} to be searched
*/
public String getArtifactId() {
return artifactId;
}
/**
* Returns the {@link ModelZoo} to be searched.
*
* @return the {@link ModelZoo} to be searched
*/
public ModelZoo getModelZoo() {
return modelZoo;
}
/**
* Returns the search filters that must match the properties of the model.
*
* @return the search filters that must match the properties of the model.
*/
public Map getFilters() {
return filters;
}
/**
* Returns the override configurations of the model loading arguments.
*
* @return the override configurations of the model loading arguments
*/
public Map getArguments() {
return arguments;
}
/**
* Returns the model loading options.
*
* @return the model loading options
*/
public Map getOptions() {
return options;
}
/**
* Returns the optional {@link TranslatorFactory} to be used for {@link ZooModel}.
*
* @return the optional {@link TranslatorFactory} to be used for {@link ZooModel}
*/
public TranslatorFactory getTranslatorFactory() {
return factory;
}
/**
* Returns the optional {@link Block} to be used for {@link ZooModel}.
*
* @return the optional {@link Block} to be used for {@link ZooModel}
*/
public Block getBlock() {
return block;
}
/**
* Returns the optional model name to be used for {@link ZooModel}.
*
* @return the optional model name to be used for {@link ZooModel}
*/
public String getModelName() {
return modelName;
}
/**
* Returns the optional {@link Progress} for the model loading.
*
* @return the optional {@link Progress} for the model loading
*/
public Progress getProgress() {
return progress;
}
/** {@inheritDoc} */
@Override
public String toString() {
StringBuilder sb = new StringBuilder(128);
sb.append("Criteria:\n");
if (application != null) {
sb.append("\tApplication: ").append(application).append('\n');
}
sb.append("\tInput: ").append(inputClass).append('\n');
sb.append("\tOutput: ").append(outputClass).append('\n');
if (engine != null) {
sb.append("\tEngine: ").append(engine).append('\n');
}
if (modelZoo != null) {
sb.append("\tModelZoo: ").append(modelZoo.getGroupId()).append('\n');
}
if (groupId != null) {
sb.append("\tGroupID: ").append(groupId).append('\n');
}
if (artifactId != null) {
sb.append("\tArtifactId: ").append(artifactId).append('\n');
}
if (filters != null) {
sb.append("\tFilter: ").append(JsonUtils.GSON.toJson(filters)).append('\n');
}
if (arguments != null) {
sb.append("\tArguments: ").append(JsonUtils.GSON.toJson(arguments)).append('\n');
}
if (options != null) {
sb.append("\tOptions: ").append(JsonUtils.GSON.toJson(options)).append('\n');
}
if (factory == null) {
sb.append("\tNo translator supplied\n");
}
return sb.toString();
}
/**
* Creates a builder to build a {@code Criteria}.
*
* The methods start with {@code set} are required fields, and {@code opt} for optional
* fields.
*
* @return a new builder
*/
public static Builder, ?> builder() {
return new Builder<>();
}
/** A Builder to construct a {@code Criteria}. */
public static final class Builder {
Application application;
Class inputClass;
Class outputClass;
String engine;
Device device;
String groupId;
String artifactId;
ModelZoo modelZoo;
Map filters;
Map arguments;
Map options;
TranslatorFactory factory;
Block block;
String modelName;
Progress progress;
Builder() {
application = Application.UNDEFINED;
}
private Builder(Class inputClass, Class outputClass, Builder, ?> parent) {
this.inputClass = inputClass;
this.outputClass = outputClass;
application = parent.application;
engine = parent.engine;
device = parent.device;
groupId = parent.groupId;
filters = parent.filters;
arguments = parent.arguments;
options = parent.options;
block = parent.block;
modelName = parent.modelName;
progress = parent.progress;
}
/**
* Creates a new @{code Builder} class with the specified input and output data type.
*
* @param the input data type
* @param the output data type
* @param inputClass the input class
* @param outputClass the output class
* @return a new @{code Builder} class with the specified input and output data type
*/
public Builder
setTypes(Class
inputClass, Class outputClass) {
return new Builder<>(inputClass, outputClass, this);
}
/**
* Sets the model application for this criteria.
*
* @param application the model application
* @return this {@code Builder}
*/
public Builder optApplication(Application application) {
this.application = application;
return this;
}
/**
* Sets the engine name for this criteria.
*
* @param engine the engine name
* @return this {@code Builder}
*/
public Builder optEngine(String engine) {
this.engine = engine;
return this;
}
/**
* Sets the {@link Device} for this criteria.
*
* @param device the {@link Device} for the criteria
* @return this {@code Builder}
*/
public Builder optDevice(Device device) {
this.device = device;
return this;
}
/**
* Sets optional groupId of the {@link ModelZoo} for this criteria.
*
* @param groupId the groupId of the {@link ModelZoo}
* @return this {@code Builder}
*/
public Builder optGroupId(String groupId) {
this.groupId = groupId;
return this;
}
/**
* Sets optional artifactId of the {@link ModelLoader} for this criteria.
*
* @param artifactId the artifactId of the {@link ModelLoader}
* @return this {@code Builder}
*/
public Builder optArtifactId(String artifactId) {
if (artifactId.contains(":")) {
String[] tokens = artifactId.split(":");
groupId = tokens[0];
this.artifactId = tokens[1];
} else {
this.artifactId = artifactId;
}
return this;
}
/**
* Sets optional model urls of the {@link ModelLoader} for this criteria.
*
* @param modelUrls the comma delimited url string
* @return this {@code Builder}
*/
public Builder optModelUrls(String modelUrls) {
this.modelZoo = new DefaultModelZoo(modelUrls);
return this;
}
/**
* Sets the optional model path of the {@link ModelLoader} for this criteria.
*
* @param modelPath the path to the model folder/files
* @return this {@code Builder}
* @throws MalformedURLException wrong path format
*/
public Builder optModelPath(Path modelPath) throws MalformedURLException {
this.modelZoo = new DefaultModelZoo(modelPath.toUri().toURL().toString());
return this;
}
/**
* Sets optional {@link ModelZoo} of the {@link ModelLoader} for this criteria.
*
* @param modelZoo ModelZoo} of the {@link ModelLoader} for this criteria
* @return this {@code Builder}
*/
public Builder optModelZoo(ModelZoo modelZoo) {
this.modelZoo = modelZoo;
return this;
}
/**
* Sets the extra search filters for this criteria.
*
* @param filters the extra search filters
* @return this {@code Builder}
*/
public Builder optFilters(Map filters) {
this.filters = filters;
return this;
}
/**
* Sets an extra search filter for this criteria.
*
* @param key the search key
* @param value the search value
* @return this {@code Builder}
*/
public Builder optFilter(String key, String value) {
if (filters == null) {
filters = new HashMap<>();
}
filters.put(key, value);
return this;
}
/**
* Sets an optional model {@link Block} for this criteria.
*
* @param block optional model {@link Block} for this criteria
* @return this {@code Builder}
*/
public Builder optBlock(Block block) {
this.block = block;
return this;
}
/**
* Sets an optional model name for this criteria.
*
* @param modelName optional model name for this criteria
* @return this {@code Builder}
*/
public Builder optModelName(String modelName) {
this.modelName = modelName;
return this;
}
/**
* Sets an extra model loading argument for this criteria.
*
* @param arguments optional model loading arguments
* @return this {@code Builder}
*/
public Builder optArguments(Map arguments) {
this.arguments = arguments;
return this;
}
/**
* Sets the optional model loading argument for this criteria.
*
* @param key the model loading argument key
* @param value the model loading argument value
* @return this {@code Builder}
*/
public Builder optArgument(String key, Object value) {
if (arguments == null) {
arguments = new HashMap<>();
}
arguments.put(key, value);
return this;
}
/**
* Sets the model loading options for this criteria.
*
* @param options the model loading options
* @return this {@code Builder}
*/
public Builder optOptions(Map options) {
this.options = options;
return this;
}
/**
* Sets the optional model loading option for this criteria.
*
* @param key the model loading option key
* @param value the model loading option value
* @return this {@code Builder}
*/
public Builder optOption(String key, String value) {
if (options == null) {
options = new HashMap<>();
}
options.put(key, value);
return this;
}
/**
* Sets the optional {@link Translator} to override default {@code Translator}.
*
* @param translator the override {@code Translator}
* @return this {@code Builder}
*/
public Builder optTranslator(Translator translator) {
this.factory = new TranslatorFactorImpl<>(translator);
return this;
}
/**
* Sets the optional {@link TranslatorFactory} to override default {@code Translator}.
*
* @param factory the override {@code TranslatorFactory}
* @return this {@code Builder}
*/
public Builder optTranslatorFactory(TranslatorFactory factory) {
this.factory = factory;
return this;
}
/**
* Set the optional {@link Progress}.
*
* @param progress the {@code Progress}
* @return this {@code Builder}
*/
public Builder optProgress(Progress progress) {
this.progress = progress;
return this;
}
/**
* Builds a {@link Criteria} instance.
*
* @return the {@link Criteria} instance
*/
public Criteria build() {
return new Criteria<>(this);
}
}
private static final class TranslatorFactorImpl implements TranslatorFactory {
private Translator translator;
public TranslatorFactorImpl(Translator translator) {
this.translator = translator;
}
/** {@inheritDoc} */
@Override
public Translator newInstance(Model model, Map arguments) {
return translator;
}
}
}