ai.djl.repository.zoo.Criteria Maven / Gradle / Ivy
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.repository.zoo;
import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.nn.Block;
import ai.djl.translate.DefaultTranslatorFactory;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.JsonUtils;
import ai.djl.util.Progress;
import com.google.gson.Gson;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.net.MalformedURLException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* The {@code Criteria} class contains search criteria to look up a {@link ZooModel}.
*
* Criteria follows Builder pattern. See {@link Builder} for detail. In DJL's builder convention,
* the methods start with {@code set} are required fields, and {@code opt} for optional fields.
*
*
Examples
*
*
* Criteria<Image, Classifications> criteria = Criteria.builder()
* .setTypes(Image.class, Classifications.class) // defines input and output data type
* .optTranslator(ImageClassificationTranslator.builder().setSynsetArtifactName("synset.txt").build())
* .optModelUrls("file:///var/models/my_resnet50") // search models in specified path
* .optModelName("resnet50") // specify model file prefix
* .build();
*
*
* See Model loading
* for more detail.
*
* @param the model input type
* @param the model output type
*/
public class Criteria {
private Application application;
private Class inputClass;
private Class outputClass;
private String engine;
private Device device;
private String groupId;
private String artifactId;
private ModelZoo modelZoo;
private Map filters;
private Map arguments;
private Map options;
private TranslatorFactory factory;
private Block block;
private String modelName;
private Progress progress;
private List resolvedLoaders;
Criteria(Builder builder) {
this.application = builder.application;
this.inputClass = builder.inputClass;
this.outputClass = builder.outputClass;
this.engine = builder.engine;
this.device = builder.device;
this.groupId = builder.groupId;
this.artifactId = builder.artifactId;
this.modelZoo = builder.modelZoo;
this.filters = builder.filters;
this.arguments = builder.arguments;
this.options = builder.options;
this.factory = builder.factory;
this.block = builder.block;
this.modelName = builder.modelName;
this.progress = builder.progress;
}
/**
* Returns {@code true} of the model artifacts has been downloaded.
*
* @return {@code true} of the model artifacts has been downloaded
* @throws IOException for various exceptions loading data from the repository
* @throws ModelNotFoundException if no model with the specified criteria is found
*/
public boolean isDownloaded() throws IOException, ModelNotFoundException {
if (resolvedLoaders == null) {
resolvedLoaders = resolveModelLoaders();
}
for (ModelLoader loader : resolvedLoaders) {
if (!loader.isDownloaded(this)) {
return false;
}
}
return true;
}
/**
* Downloads the model artifacts that matches this criteria.
*
* @throws IOException for various exceptions loading data from the repository
* @throws ModelNotFoundException if no model with the specified criteria is found
*/
public void downloadModel() throws ModelNotFoundException, IOException {
if (!isDownloaded()) {
for (ModelLoader loader : resolvedLoaders) {
loader.downloadModel(this, progress);
}
}
}
/**
* Loads the {@link ZooModel} that matches this criteria.
*
* @return the model that matches the criteria
* @throws IOException for various exceptions loading data from the repository
* @throws ModelNotFoundException if no model with the specified criteria is found
* @throws MalformedModelException if the model data is malformed
*/
public ZooModel loadModel()
throws IOException, ModelNotFoundException, MalformedModelException {
if (resolvedLoaders == null) {
resolvedLoaders = resolveModelLoaders();
}
Logger logger = LoggerFactory.getLogger(ModelZoo.class);
Exception lastException = null;
for (ModelLoader loader : resolvedLoaders) {
try {
return loader.loadModel(this);
} catch (ModelNotFoundException e) {
lastException = e;
logger.trace("", e);
logger.debug(
"{} for ModelLoader: {}:{}",
e.getMessage(),
loader.getGroupId(),
loader.getArtifactId());
}
}
throw new ModelNotFoundException(
"No model with the specified URI or the matching Input/Output type is found.",
lastException);
}
/**
* Returns the application of the model.
*
* @return the application of the model
*/
public Application getApplication() {
return application;
}
/**
* Returns the input data type.
*
* @return the input data type
*/
public Class getInputClass() {
return inputClass;
}
/**
* Returns the output data type.
*
* @return the output data type
*/
public Class getOutputClass() {
return outputClass;
}
/**
* Returns the engine name.
*
* @return the engine name
*/
public String getEngine() {
return engine;
}
/**
* Returns the {@link Device} of the model to be loaded on.
*
* @return the {@link Device} of the model to be loaded on
*/
public Device getDevice() {
return device;
}
/**
* Returns the groupId of the {@link ModelZoo} to be searched.
*
* @return the groupId of the {@link ModelZoo} to be searched
*/
public String getGroupId() {
return groupId;
}
/**
* Returns the artifactId of the {@link ModelLoader} to be searched.
*
* @return the artifactIds of the {@link ModelLoader} to be searched
*/
public String getArtifactId() {
return artifactId;
}
/**
* Returns the {@link ModelZoo} to be searched.
*
* @return the {@link ModelZoo} to be searched
*/
public ModelZoo getModelZoo() {
return modelZoo;
}
/**
* Returns the search filters that must match the properties of the model.
*
* @return the search filters that must match the properties of the model.
*/
public Map getFilters() {
return filters;
}
/**
* Returns the override configurations of the model loading arguments.
*
* @return the override configurations of the model loading arguments
*/
public Map getArguments() {
return arguments;
}
/**
* Returns the model loading options.
*
* @return the model loading options
*/
public Map getOptions() {
return options;
}
/**
* Returns the optional {@link TranslatorFactory} to be used for {@link ZooModel}.
*
* @return the optional {@link TranslatorFactory} to be used for {@link ZooModel}
*/
public TranslatorFactory getTranslatorFactory() {
return factory;
}
/**
* Returns the optional {@link Block} to be used for {@link ZooModel}.
*
* @return the optional {@link Block} to be used for {@link ZooModel}
*/
public Block getBlock() {
return block;
}
/**
* Returns the optional model name to be used for {@link ZooModel}.
*
* @return the optional model name to be used for {@link ZooModel}
*/
public String getModelName() {
return modelName;
}
/**
* Returns the optional {@link Progress} for the model loading.
*
* @return the optional {@link Progress} for the model loading
*/
public Progress getProgress() {
return progress;
}
/** {@inheritDoc} */
@Override
public String toString() {
StringBuilder sb = new StringBuilder(128);
sb.append("Criteria:\n");
if (application != null) {
sb.append("\tApplication: ").append(application).append('\n');
}
sb.append("\tInput: ").append(inputClass);
sb.append("\n\tOutput: ").append(outputClass).append('\n');
if (engine != null) {
sb.append("\tEngine: ").append(engine).append('\n');
}
if (modelZoo != null) {
sb.append("\tModelZoo: ").append(modelZoo.getGroupId()).append('\n');
}
if (groupId != null) {
sb.append("\tGroupID: ").append(groupId).append('\n');
}
if (artifactId != null) {
sb.append("\tArtifactId: ").append(artifactId).append('\n');
}
if (filters != null) {
sb.append("\tFilter: ").append(JsonUtils.GSON.toJson(filters)).append('\n');
}
if (arguments != null) {
Gson gson = JsonUtils.builder().excludeFieldsWithoutExposeAnnotation().create();
sb.append("\tArguments: ").append(gson.toJson(arguments)).append('\n');
}
if (options != null) {
sb.append("\tOptions: ").append(JsonUtils.GSON.toJson(options)).append('\n');
}
if (factory == null) {
sb.append("\tNo translator supplied\n");
}
return sb.toString();
}
/**
* Creates a new {@link Builder} which starts with the values of this {@link Criteria}.
*
* @return a new {@link Builder}
*/
public Builder toBuilder() {
return Criteria.builder()
.setTypes(inputClass, outputClass)
.optApplication(application)
.optEngine(engine)
.optDevice(device)
.optGroupId(groupId)
.optArtifactId(artifactId)
.optModelZoo(modelZoo)
.optFilters(filters)
.optArguments(arguments)
.optOptions(options)
.optTranslatorFactory(factory)
.optBlock(block)
.optModelName(modelName)
.optProgress(progress);
}
/**
* Creates a builder to build a {@code Criteria}.
*
* The methods start with {@code set} are required fields, and {@code opt} for optional
* fields.
*
* @return a new builder
*/
public static Builder, ?> builder() {
return new Builder<>();
}
private List resolveModelLoaders() throws ModelNotFoundException {
if (inputClass == null || outputClass == null) {
throw new IllegalArgumentException("inputClass and outputClass are required.");
}
Logger logger = LoggerFactory.getLogger(ModelZoo.class);
logger.debug("Loading model with {}", this);
List list = new ArrayList<>();
if (modelZoo != null) {
logger.debug("Searching model in specified model zoo: {}", modelZoo.getGroupId());
if (groupId != null && !modelZoo.getGroupId().equals(groupId)) {
throw new ModelNotFoundException(
"groupId conflict with ModelZoo criteria."
+ modelZoo.getGroupId()
+ " v.s. "
+ groupId);
}
Set supportedEngine = modelZoo.getSupportedEngines();
if (engine != null && !supportedEngine.contains(engine)) {
throw new ModelNotFoundException(
"ModelZoo doesn't support specified engine: " + engine);
}
list.add(modelZoo);
} else {
for (ModelZoo zoo : ModelZoo.listModelZoo()) {
if (groupId != null && !zoo.getGroupId().equals(groupId)) {
// filter out ModelZoo by groupId
logger.debug("Ignore ModelZoo {} by groupId: {}", zoo.getGroupId(), groupId);
continue;
}
Set supportedEngine = zoo.getSupportedEngines();
if (engine != null && !supportedEngine.contains(engine)) {
logger.debug("Ignore ModelZoo {} by engine: {}", zoo.getGroupId(), engine);
continue;
}
list.add(zoo);
}
}
List loaders = new ArrayList<>();
for (ModelZoo zoo : list) {
String loaderGroupId = zoo.getGroupId();
for (ModelLoader loader : zoo.getModelLoaders()) {
Application app = loader.getApplication();
String loaderArtifactId = loader.getArtifactId();
logger.debug("Checking ModelLoader: {}", loader);
if (artifactId != null && !artifactId.equals(loaderArtifactId)) {
// filter out by model loader artifactId
logger.debug(
"artifactId mismatch for ModelLoader: {}:{}",
loaderGroupId,
loaderArtifactId);
continue;
}
if (application != Application.UNDEFINED
&& app != Application.UNDEFINED
&& !app.matches(application)) {
// filter out ModelLoader by application
logger.debug(
"application mismatch for ModelLoader: {}:{}",
loaderGroupId,
loaderArtifactId);
continue;
}
loaders.add(loader);
}
}
if (loaders.isEmpty()) {
throw new ModelNotFoundException("No model matching the criteria is found.");
}
return loaders;
}
/** A Builder to construct a {@code Criteria}. */
public static final class Builder {
Application application;
Class inputClass;
Class outputClass;
String engine;
Device device;
String groupId;
String artifactId;
ModelZoo modelZoo;
Map filters;
Map arguments;
Map options;
TranslatorFactory factory;
Block block;
String modelName;
Progress progress;
Translator translator;
Builder() {
application = Application.UNDEFINED;
}
@SuppressWarnings("unchecked")
private Builder(Class inputClass, Class outputClass, Builder, ?> parent) {
this.inputClass = inputClass;
this.outputClass = outputClass;
application = parent.application;
engine = parent.engine;
device = parent.device;
groupId = parent.groupId;
artifactId = parent.artifactId;
modelZoo = parent.modelZoo;
filters = parent.filters;
arguments = parent.arguments;
options = parent.options;
factory = parent.factory;
block = parent.block;
modelName = parent.modelName;
progress = parent.progress;
translator = (Translator) parent.translator;
}
/**
* Creates a new @{code Builder} class with the specified input and output data type.
*
* @param the input data type
* @param the output data type
* @param inputClass the input class
* @param outputClass the output class
* @return a new @{code Builder} class with the specified input and output data type
*/
public Builder
setTypes(Class
inputClass, Class outputClass) {
return new Builder<>(inputClass, outputClass, this);
}
/**
* Sets the model application for this criteria.
*
* @param application the model application
* @return this {@code Builder}
*/
public Builder optApplication(Application application) {
this.application = application;
return this;
}
/**
* Sets the engine name for this criteria.
*
* @param engine the engine name
* @return this {@code Builder}
*/
public Builder optEngine(String engine) {
this.engine = engine;
return this;
}
/**
* Sets the {@link Device} for this criteria.
*
* @param device the {@link Device} for the criteria
* @return this {@code Builder}
*/
public Builder optDevice(Device device) {
this.device = device;
return this;
}
/**
* Sets optional groupId of the {@link ModelZoo} for this criteria.
*
* @param groupId the groupId of the {@link ModelZoo}
* @return this {@code Builder}
*/
public Builder optGroupId(String groupId) {
this.groupId = groupId;
return this;
}
/**
* Sets optional artifactId of the {@link ModelLoader} for this criteria.
*
* @param artifactId the artifactId of the {@link ModelLoader}
* @return this {@code Builder}
*/
public Builder optArtifactId(String artifactId) {
if (artifactId != null && artifactId.contains(":")) {
String[] tokens = artifactId.split(":", -1);
groupId = tokens[0].isEmpty() ? null : tokens[0];
this.artifactId = tokens[1].isEmpty() ? null : tokens[1];
} else {
this.artifactId = artifactId;
}
return this;
}
/**
* Sets optional model urls of the {@link ModelLoader} for this criteria.
*
* @param modelUrls the comma delimited url string
* @return this {@code Builder}
*/
public Builder optModelUrls(String modelUrls) {
if (modelUrls != null) {
this.modelZoo = new DefaultModelZoo(modelUrls);
}
return this;
}
/**
* Sets the optional model path of the {@link ModelLoader} for this criteria.
*
* @param modelPath the path to the model folder/files
* @return this {@code Builder}
*/
public Builder optModelPath(Path modelPath) {
if (modelPath != null) {
try {
this.modelZoo = new DefaultModelZoo(modelPath.toUri().toURL().toString());
} catch (MalformedURLException e) {
throw new AssertionError("Invalid model path: " + modelPath, e);
}
}
return this;
}
/**
* Sets optional {@link ModelZoo} of the {@link ModelLoader} for this criteria.
*
* @param modelZoo ModelZoo} of the {@link ModelLoader} for this criteria
* @return this {@code Builder}
*/
public Builder optModelZoo(ModelZoo modelZoo) {
this.modelZoo = modelZoo;
return this;
}
/**
* Sets the extra search filters for this criteria.
*
* @param filters the extra search filters
* @return this {@code Builder}
*/
public Builder optFilters(Map filters) {
this.filters = filters;
return this;
}
/**
* Sets an extra search filter for this criteria.
*
* @param key the search key
* @param value the search value
* @return this {@code Builder}
*/
public Builder optFilter(String key, String value) {
if (filters == null) {
filters = new HashMap<>();
}
filters.put(key, value);
return this;
}
/**
* Sets an optional model {@link Block} for this criteria.
*
* @param block optional model {@link Block} for this criteria
* @return this {@code Builder}
*/
public Builder optBlock(Block block) {
this.block = block;
return this;
}
/**
* Sets an optional model name for this criteria.
*
* @param modelName optional model name for this criteria
* @return this {@code Builder}
*/
public Builder optModelName(String modelName) {
this.modelName = modelName;
return this;
}
/**
* Sets an extra model loading argument for this criteria.
*
* @param arguments optional model loading arguments
* @return this {@code Builder}
*/
public Builder optArguments(Map arguments) {
this.arguments = arguments;
return this;
}
/**
* Sets the optional model loading argument for this criteria.
*
* @param key the model loading argument key
* @param value the model loading argument value
* @return this {@code Builder}
*/
public Builder optArgument(String key, Object value) {
if (arguments == null) {
arguments = new HashMap<>();
}
arguments.put(key, value);
return this;
}
/**
* Sets the model loading options for this criteria.
*
* @param options the model loading options
* @return this {@code Builder}
*/
public Builder optOptions(Map options) {
this.options = options;
return this;
}
/**
* Sets the optional model loading option for this criteria.
*
* @param key the model loading option key
* @param value the model loading option value
* @return this {@code Builder}
*/
public Builder optOption(String key, String value) {
if (options == null) {
options = new HashMap<>();
}
options.put(key, value);
return this;
}
/**
* Sets the optional {@link Translator} to override default {@code Translator}.
*
* @param translator the override {@code Translator}
* @return this {@code Builder}
*/
public Builder optTranslator(Translator translator) {
this.factory = null;
this.translator = translator;
return this;
}
/**
* Sets the optional {@link TranslatorFactory} to override default {@code Translator}.
*
* @param factory the override {@code TranslatorFactory}
* @return this {@code Builder}
*/
public Builder optTranslatorFactory(TranslatorFactory factory) {
this.translator = null;
this.factory = factory;
return this;
}
/**
* Set the optional {@link Progress}.
*
* @param progress the {@code Progress}
* @return this {@code Builder}
*/
public Builder optProgress(Progress progress) {
this.progress = progress;
return this;
}
/**
* Builds a {@link Criteria} instance.
*
* @return the {@link Criteria} instance
*/
public Criteria build() {
if (factory == null && translator != null) {
DefaultTranslatorFactory f = new DefaultTranslatorFactory();
f.registerTranslator(inputClass, outputClass, translator);
factory = f;
}
return new Criteria<>(this);
}
}
}