All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.modality.cv.zoo.SimplePoseModelLoader Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.modality.cv.zoo;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.translator.SimplePoseTranslator;
import ai.djl.modality.cv.translator.wrapper.FileTranslatorFactory;
import ai.djl.modality.cv.translator.wrapper.InputStreamTranslatorFactory;
import ai.djl.modality.cv.translator.wrapper.UrlTranslatorFactory;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.BaseModelLoader;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import ai.djl.util.Progress;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.file.Path;
import java.util.Map;

/**
 * The translator for Simple Pose models.
 *
 * 

The model was trained on Gluon and loaded in DJL in MXNet Symbol Block. See Simple Pose. */ public class SimplePoseModelLoader extends BaseModelLoader { private static final Application APPLICATION = Application.CV.POSE_ESTIMATION; /** * Creates the Model loader from the given repository. * * @param repository the repository to load the model from * @param groupId the group id of the model * @param artifactId the artifact id of the model * @param version the version number of the model * @param modelZoo the modelZoo type that is being used to get supported engine types */ public SimplePoseModelLoader( Repository repository, String groupId, String artifactId, String version, ModelZoo modelZoo) { super(repository, MRL.model(APPLICATION, groupId, artifactId), version, modelZoo); FactoryImpl factory = new FactoryImpl(); factories.put(new Pair<>(Image.class, Joints.class), factory); factories.put(new Pair<>(Path.class, Joints.class), new FileTranslatorFactory<>(factory)); factories.put(new Pair<>(URL.class, Joints.class), new UrlTranslatorFactory<>(factory)); factories.put( new Pair<>(InputStream.class, Joints.class), new InputStreamTranslatorFactory<>(factory)); } /** * Loads the model. * * @return the loaded model * @throws IOException for various exceptions loading data from the repository * @throws ModelNotFoundException if no model with the specified criteria is found * @throws MalformedModelException if the model data is malformed */ public ZooModel loadModel() throws MalformedModelException, ModelNotFoundException, IOException { return loadModel(null, null, null); } /** * Loads the model. * * @param progress the progress tracker to update while loading the model * @return the loaded model * @throws IOException for various exceptions loading data from the repository * @throws ModelNotFoundException if no model with the specified criteria is found * @throws MalformedModelException if the model data is malformed */ public ZooModel loadModel(Progress progress) throws MalformedModelException, ModelNotFoundException, IOException { return loadModel(null, null, progress); } /** * Loads the model with the given search filters. * * @param filters the search filters to match against the loaded model * @param device the device the loaded model should use * @param progress the progress tracker to update while loading the model * @return the loaded model * @throws IOException for various exceptions loading data from the repository * @throws ModelNotFoundException if no model with the specified criteria is found * @throws MalformedModelException if the model data is malformed */ public ZooModel loadModel( Map filters, Device device, Progress progress) throws IOException, ModelNotFoundException, MalformedModelException { Criteria criteria = Criteria.builder() .setTypes(Image.class, Joints.class) .optModelZoo(modelZoo) .optGroupId(resource.getMrl().getGroupId()) .optArtifactId(resource.getMrl().getArtifactId()) .optFilters(filters) .optDevice(device) .optProgress(progress) .build(); return loadModel(criteria); } private static final class FactoryImpl implements TranslatorFactory { /** {@inheritDoc} */ @Override public Translator newInstance(Model model, Map arguments) { // default size: 192 return SimplePoseTranslator.builder(arguments).build(); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy