All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.inference.Predictor Maven / Gradle / Ivy

/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.inference;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.inference.streaming.StreamingBlock;
import ai.djl.inference.streaming.StreamingTranslator;
import ai.djl.inference.streaming.StreamingTranslator.StreamOutput;
import ai.djl.metric.Dimension;
import ai.djl.metric.Metrics;
import ai.djl.metric.Unit;
import ai.djl.ndarray.LazyNDArray;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Block;
import ai.djl.training.ParameterStore;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * The {@code Predictor} interface provides a session for model inference.
 *
 * 

You can use a {@code Predictor}, with a specified {@link Translator}, to perform inference on * a {@link Model}. The following is example code that uses {@code Predictor}: * *

 * Model model = Model.load(modelDir, modelName);
 *
 * // User must implement Translator interface, read {@link Translator} for detail.
 * Translator<String, String> translator = new MyTranslator();
 *
 * try (Predictor<String, String> predictor = model.newPredictor(translator)) {
 *   String result = predictor.predict("What's up");
 * }
 * 
* *

See the tutorials on: * *

* * @param the input type * @param the output type * @see Model * @see Translator * @see The guide on * memory management * @see The * guide on running multi-threaded inference * @see The * guide on inference performance optimization */ public class Predictor implements AutoCloseable { private static final Logger logger = LoggerFactory.getLogger(Predictor.class); protected Translator translator; protected long timestamp; protected boolean prepared; protected Model model; protected NDManager manager; protected Metrics metrics; protected Block block; protected ParameterStore parameterStore; protected Dimension dimension; /** * Creates a new instance of {@code BasePredictor} with the given {@link Model} and {@link * Translator}. * * @param model the model on which the predictions are based * @param translator the translator to be used * @param device the device for prediction * @param copy whether to copy the parameters to the parameter store. If the device changes, it * will copy regardless */ public Predictor(Model model, Translator translator, Device device, boolean copy) { if (!device.equals(model.getNDManager().getDevice())) { // Always copy during device changes copy = true; } this.model = model; this.manager = model.getNDManager().newSubManager(device); this.manager.setName("predictor"); this.translator = translator; block = model.getBlock(); parameterStore = new ParameterStore(manager, copy); dimension = new Dimension("Model", model.getProperty("metric_dimension", "model")); } /** * Predicts an item for inference. * * @param input the input * @return the output object defined by the user * @throws TranslateException if an error occurs during prediction */ public O predict(I input) throws TranslateException { return batchPredict(Collections.singletonList(input)).get(0); } /** * Predicts an item for inference. * * @param ctx the context for the {@code Predictor}. * @param ndList the input {@code NDList} * @return the output {@code NDList} * @throws TranslateException if an error occurs during prediction */ protected NDList predictInternal(TranslatorContext ctx, NDList ndList) throws TranslateException { logger.trace("Predictor input data: {}", ndList); return block.forward(parameterStore, ndList, false); } /** * Predicts a batch for inference. * * @param inputs a list of inputs * @return a list of output objects defined by the user * @throws TranslateException if an error occurs during prediction */ @SuppressWarnings({"PMD.AvoidRethrowingException", "PMD.IdenticalCatchBranches", "unchecked"}) public List batchPredict(List inputs) throws TranslateException { try (PredictorContext context = new PredictorContext()) { if (!prepared) { translator.prepare(context); prepared = true; } if (translator.getBatchifier() == null) { List ret = new ArrayList<>(inputs.size()); for (I input : inputs) { timestamp = System.nanoTime(); long begin = timestamp; NDList ndList = translator.processInput(context, input); preprocessEnd(ndList, 1); NDList result = predictInternal(context, ndList); predictEnd(result, 1); ret.add(translator.processOutput(context, result)); postProcessEnd(begin, 1); } return ret; } int batchSize = inputs.size(); timestamp = System.nanoTime(); long begin = timestamp; NDList ndList = translator.batchProcessInput(context, inputs); preprocessEnd(ndList, batchSize); NDList result = predictInternal(context, ndList); predictEnd(result, batchSize); List ret = translator.batchProcessOutput(context, result); postProcessEnd(begin, batchSize); return ret; } catch (TranslateException e) { throw e; } catch (Exception e) { throw new TranslateException(e); } } /** * Predicts an item for inference. * * @param input the input * @return the output object defined by the user * @throws TranslateException if an error occurs during prediction */ @SuppressWarnings({"PMD.AvoidRethrowingException", "PMD.IdenticalCatchBranches"}) public StreamOutput streamingPredict(I input) throws TranslateException { String streamingSupported = streamingSupportError(); if (streamingSupported != null) { throw new IllegalStateException(streamingSupported); } StreamingBlock streamingBlock = (StreamingBlock) block; StreamingTranslator streamingTranslator = (StreamingTranslator) translator; try { PredictorContext context = new PredictorContext(); if (!prepared) { translator.prepare(context); prepared = true; } Batchifier batchifier = translator.getBatchifier(); if (batchifier == null) { NDList ndList = translator.processInput(context, input); return streamingTranslator.processStreamOutput( context, streamingBlock .forwardStream(parameterStore, ndList, false) .onClose(context::close)); } // For the batched case, need to create singleton batch and unbatchify singleton NDList inputBatch = processInputs(context, Collections.singletonList(input)); return streamingTranslator.processStreamOutput( context, streamingBlock .forwardStream(parameterStore, inputBatch, false) .map( result -> { NDList[] unbatched = translator.getBatchifier().unbatchify(result); if (unbatched.length != 1) { throw new IllegalStateException( "Unexpected number of outputs from model"); } return unbatched[0]; }) .onClose(context::close)); } catch (TranslateException e) { throw e; } catch (Exception e) { throw new TranslateException(e); } } /** * Returns true if streaming is supported by the predictor, block, and translator. * * @return true if streaming is supported by the predictor, block, and translator */ public boolean supportsStreaming() { return streamingSupportError() == null; } private String streamingSupportError() { if (!(block instanceof StreamingBlock)) { return "streamingPredict() can only be called with a StreamingBlock"; } if (!(translator instanceof StreamingTranslator)) { return "streamingPredict() can only be called with a StreamingTranslator"; } return null; } /** * Attaches a Metrics param to use for benchmark. * * @param metrics the Metrics class */ public void setMetrics(Metrics metrics) { this.metrics = metrics; } private void waitToRead(NDList list) { for (NDArray array : list) { if (array instanceof LazyNDArray) { ((LazyNDArray) array).waitToRead(); } } } @SuppressWarnings("PMD.SignatureDeclareThrowsException") private NDList processInputs(TranslatorContext ctx, List inputs) throws Exception { int batchSize = inputs.size(); NDList[] preprocessed = new NDList[batchSize]; for (int i = 0; i < batchSize; ++i) { preprocessed[i] = translator.processInput(ctx, inputs.get(i)); } return translator.getBatchifier().batchify(preprocessed); } private void preprocessEnd(NDList list, int batchSize) { if (metrics != null) { waitToRead(list); long tmp = System.nanoTime(); long duration = (tmp - timestamp) / 1000 / batchSize; timestamp = tmp; metrics.addMetric("Preprocess", duration, Unit.MICROSECONDS, dimension); } } private void predictEnd(NDList list, int batchSize) { if (metrics != null) { waitToRead(list); long tmp = System.nanoTime(); long duration = (tmp - timestamp) / 1000 / batchSize; timestamp = tmp; metrics.addMetric("Inference", duration, Unit.MICROSECONDS, dimension); } } private void postProcessEnd(long begin, int batchSize) { if (metrics != null) { long tmp = System.nanoTime(); long duration = (tmp - timestamp) / 1000 / batchSize; timestamp = tmp; metrics.addMetric("Postprocess", duration, Unit.MICROSECONDS, dimension); long prediction = (tmp - begin) / 1000; metrics.addMetric("Prediction", prediction, Unit.MICROSECONDS, dimension); } } /** {@inheritDoc} */ @Override public void close() { manager.close(); } /** {@inheritDoc} */ @SuppressWarnings("deprecation") @Override protected void finalize() throws Throwable { if (manager.isOpen()) { if (logger.isDebugEnabled()) { logger.warn("Predictor for {} was not closed explicitly.", model.getName()); } close(); } super.finalize(); } protected class PredictorContext implements TranslatorContext { private NDManager ctxManager; private Map attachments; /** Constructs a new {@code PredictorContext} instance. */ public PredictorContext() { ctxManager = manager.newSubManager(); ctxManager.setName("predictor ctx"); attachments = new ConcurrentHashMap<>(); } /** {@inheritDoc} */ @Override public Model getModel() { return model; } /** {@inheritDoc} */ @Override public NDManager getNDManager() { return ctxManager; } /** {@inheritDoc} */ @Override public NDManager getPredictorManager() { return manager; } /** {@inheritDoc} */ @Override public Block getBlock() { return block; } /** {@inheritDoc} */ @Override public Metrics getMetrics() { return metrics; } /** {@inheritDoc} */ @Override public void close() { ctxManager.close(); } /** {@inheritDoc} */ @Override public Object getAttachment(String key) { return attachments.get(key); } /** {@inheritDoc} */ @Override public void setAttachment(String key, Object value) { attachments.put(key, value); } } }