ai.djl.repository.zoo.ModelZoo Maven / Gradle / Ivy
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.repository.zoo;
import ai.djl.Application;
import ai.djl.MalformedModelException;
import ai.djl.repository.Artifact;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.TreeMap;
/** An interface represents a collection of models. */
public interface ModelZoo {
/**
* Returns the global unique identifier of the {@code ModelZoo}.
*
* We recommend to use reverse DNS name as your model zoo group ID to make sure it's not
* conflict with other ModelZoos.
*
* @return the global unique identifier of the {@code ModelZoo}
*/
String getGroupId();
/**
* Lists the available model families in the ModelZoo.
*
* @return the list of all available model families
*/
default List> getModelLoaders() {
List> list = new ArrayList<>();
try {
Field[] fields = getClass().getDeclaredFields();
for (Field field : fields) {
if (ModelLoader.class.isAssignableFrom(field.getType())) {
list.add((ModelLoader, ?>) field.get(null));
}
}
} catch (ReflectiveOperationException e) {
// ignore
}
return list;
}
/**
* Returns the {@link ModelLoader} based on the model name.
*
* @param name the name of the model
* @return the {@link ModelLoader} of the model
*/
default ModelLoader, ?> getModelLoader(String name) {
for (ModelLoader, ?> loader : getModelLoaders()) {
if (name.equals(loader.getArtifactId())) {
return loader;
}
}
return null;
}
/**
* Returns all supported engine names.
*
* @return all supported engine names
*/
Set getSupportedEngines();
/**
* Gets the {@link ModelLoader} based on the model name.
*
* @param criteria the name of the model
* @param the input data type for preprocessing
* @param the output data type after postprocessing
* @return the model that matches the criteria
* @throws IOException for various exceptions loading data from the repository
* @throws ModelNotFoundException if no model with the specified criteria is found
* @throws MalformedModelException if the model data is malformed
*/
static ZooModel loadModel(Criteria criteria)
throws IOException, ModelNotFoundException, MalformedModelException {
String groupId = criteria.getGroupId();
ServiceLoader providers = ServiceLoader.load(ZooProvider.class);
for (ZooProvider provider : providers) {
ModelZoo zoo = provider.getModelZoo();
if (zoo == null) {
continue;
}
if (groupId != null && !zoo.getGroupId().equals(groupId)) {
// filter out ModelZoo by groupId
continue;
}
Set supportedEngine = zoo.getSupportedEngines();
if (!supportedEngine.contains(criteria.getEngine())) {
continue;
}
Application application = criteria.getApplication();
String artifactId = criteria.getArtifactId();
for (ModelLoader, ?> loader : zoo.getModelLoaders()) {
if (artifactId != null && !artifactId.equals(loader.getArtifactId())) {
// filter out by model loader artifactId
continue;
}
Application app = loader.getApplication();
if (application != null
&& app != Application.UNDEFINED
&& !app.equals(application)) {
// filter out ModelLoader by application
continue;
}
try {
return loader.loadModel(criteria);
} catch (ModelNotFoundException e) {
// ignore
}
}
}
throw new ModelNotFoundException(
"No matching model with specified Input/Output type found.");
}
/**
* Returns the available {@link Application} and their model artifact metadata.
*
* @return the available {@link Application} and their model artifact metadata
* @throws IOException if failed to download to repository metadata
* @throws ModelNotFoundException if failed to parse repository metadata
*/
static Map> listModels()
throws IOException, ModelNotFoundException {
@SuppressWarnings("PMD.UseConcurrentHashMap")
Map> models =
new TreeMap<>(Comparator.comparing(Application::getPath));
ServiceLoader providers = ServiceLoader.load(ZooProvider.class);
for (ZooProvider provider : providers) {
ModelZoo zoo = provider.getModelZoo();
if (zoo == null) {
continue;
}
List> list = zoo.getModelLoaders();
for (ModelLoader, ?> loader : list) {
Application app = loader.getApplication();
final List artifacts = loader.listModels();
models.compute(
app,
(key, val) -> {
if (val == null) {
val = new ArrayList<>();
}
val.addAll(artifacts);
return val;
});
}
}
return models;
}
}