All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.inference.Predictor Maven / Gradle / Ivy

/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.inference;

import ai.djl.Model;
import ai.djl.metric.Metrics;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import java.util.List;

/**
 * The {@code Predictor} interface provides a session for model inference.
 *
 * 

You can use a {@code Predictor}, with a specified {@link Translator}, to perform inference on * a {@link Model}. The following is example code that uses {@code Predictor}: * *

 * Model model = Model.load(modelDir, modelName);
 *
 * // User must implement Translator interface, read {@link Translator} for detail.
 * Translator<String, String> translator = new MyTranslator();
 *
 * try (Predictor<String, String> predictor = model.newPredictor(translator)) {
 *   String result = predictor.predict("What's up");
 * }
 * 
* *

See the tutorials on: * *

* *

For information about running multi-threaded inference, see here. * * @param the input type * @param the output type * @see Model * @see Translator */ public interface Predictor extends AutoCloseable { /** * Predicts an item for inference. * * @param input the input * @return the output object defined by the user * @throws TranslateException if an error occurs during prediction */ O predict(I input) throws TranslateException; /** * Predicts a batch for inference. * * @param inputs a list of inputs * @return a list of output objects defined by the user * @throws TranslateException if an error occurs during prediction */ List batchPredict(List inputs) throws TranslateException; /** * Attaches a Metrics param to use for benchmark. * * @param metrics the Metrics class */ void setMetrics(Metrics metrics); /** {@inheritDoc} */ @Override void close(); }