All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.inference.streaming.StreamingTranslator Maven / Gradle / Ivy

/*
 * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.inference.streaming;

import ai.djl.ndarray.NDList;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;

import java.util.stream.Stream;

/**
 * An expansion of the {@link Translator} with postProcessing for the {@link StreamingBlock} (used
 * by {@link ai.djl.inference.Predictor#streamingPredict(Object)}.
 *
 * @param  the input type
 * @param  the output type
 */
public interface StreamingTranslator extends Translator {

    /**
     * Processes the output NDList to the corresponding output object.
     *
     * @param ctx the toolkit used for post-processing
     * @param list the output NDList after inference, usually immutable in engines like
     *     PyTorch. @see Issue 1774
     * @return the output object of expected type
     * @throws Exception if an error occurs during processing output
     */
    @SuppressWarnings("PMD.SignatureDeclareThrowsException")
    StreamOutput processStreamOutput(TranslatorContext ctx, Stream list)
            throws Exception;

    /**
     * Returns what kind of {@link StreamOutput} this {@link StreamingTranslator} supports.
     *
     * @return what kind of {@link StreamOutput} this {@link StreamingTranslator} supports
     */
    Support getSupport();

    /**
     * Returns whether the {@link StreamingTranslator} supports iterative output.
     *
     * @return whether the {@link StreamingTranslator} supports iterative output
     * @see StreamOutput#getIterativeOutput()
     */
    default boolean supportsIterative() {
        return getSupport().iterative();
    }

    /**
     * Returns whether the {@link StreamingTranslator} supports iterative output.
     *
     * @return whether the {@link StreamingTranslator} supports iterative output
     * @see StreamOutput#getAsyncOutput()
     */
    default boolean supportsAsync() {
        return getSupport().async();
    }

    /**
     * A {@link StreamOutput} represents a streamable output type (either iterative or
     * asynchronous).
     *
     * 

There are two modes for the {@link StreamOutput}. When using it, you must choose one of * the modes and can only access it once. Any other usage including trying both modes or trying * one mode twice will result in an {@link IllegalStateException}. * *

The first mode is the iterative mode which can be used by calling {@link * #getIterativeOutput()}, it returns a result that has an internal iterate method. When calling * the iterating method, it will compute an additional part of the output. * *

The second mode is asynchronous mode. Here, you can produce a mutable output object by * calling {@link #getAsyncOutput()}. Then, calling {@link #computeAsyncOutput()} will * synchronously compute the results and deposit them into the prepared output. This method * works best with manual threading where the worker can return the template result to another * thread and then continue to compute it. * * @param the output type */ abstract class StreamOutput { private O output; private boolean computed; /** * Returns a template object to be used with the async output. * *

This should only be an empty data structure until {@link #computeAsyncOutput()} is * called. * * @return a template object to be used with the async output */ public final O getAsyncOutput() { if (output != null) { throw new IllegalStateException("The StreamOutput can only be gotten once"); } if (computed) { throw new IllegalStateException( "Attempted to getAsyncOutput, but has already called getIterativeOutput." + " Only one kind of output can be used."); } output = buildAsyncOutput(); return output; } /** * Performs the internal implementation of {@link #getAsyncOutput()}. * * @return the output for {@link #getAsyncOutput()}. */ protected abstract O buildAsyncOutput(); /** * Computes the actual value and stores it in the object returned earlier by {@link * #getAsyncOutput()}. */ public final void computeAsyncOutput() { if (output == null) { throw new IllegalStateException( "Before calling computeAsyncOutput, you must first getAsyncOutput"); } if (computed) { throw new IllegalStateException("Attempted to computeAsyncOutput multiple times."); } computed = true; computeAsyncOutputInternal(output); } /** * Performs the internal implementation of {@link #computeAsyncOutput()}. * * @param output the output object returned by the earlier call to {@link * #getAsyncOutput()}. */ protected abstract void computeAsyncOutputInternal(O output); /** * Returns an iterative streamable output. * * @return an iterative streamable output */ public final O getIterativeOutput() { if (output != null) { throw new IllegalStateException( "Can't call getIterativeOutput after already using getAsyncOutput."); } if (computed) { throw new IllegalStateException( "Attempted to getIterativeOutput multiple times. getIterativeOutput can" + " only be called once"); } return getIterativeOutputInternal(); } /** * Performs the internal implementation of {@link #getIterativeOutput()}. * * @return the output for {@link #getIterativeOutput()} */ public abstract O getIterativeOutputInternal(); } /** What types of {@link StreamOutput}s are supported by a {@link StreamingTranslator}. */ enum Support { /** Supports {@link #iterative()} but not {@link #async()}. */ ITERATIVE(true, false), /** Supports {@link #async()} but not {@link #iterative()}. */ ASYNC(false, true), /** Supports both {@link #iterative()} and {@link #async()}. */ BOTH(true, true); private boolean iterative; private boolean async; Support(boolean iterative, boolean async) { this.iterative = iterative; this.async = async; } /** * Returns whether the {@link StreamingTranslator} supports iterative output. * * @return whether the {@link StreamingTranslator} supports iterative output * @see StreamOutput#getIterativeOutput() */ public boolean iterative() { return iterative; } /** * Returns whether the {@link StreamingTranslator} supports iterative output. * * @return whether the {@link StreamingTranslator} supports iterative output * @see StreamOutput#getAsyncOutput() */ public boolean async() { return async; } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy