All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.translate.ExpansionTranslatorFactory Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.translate;

import ai.djl.Model;
import ai.djl.util.Pair;

import java.lang.reflect.Type;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;

/**
 * A {@link TranslatorFactory} based on a {@link Translator} and it's {@link TranslatorOptions}.
 *
 * @param  the input type for the base translator
 * @param  the output type for the base translator
 */
@SuppressWarnings({"PMD.GenericsNaming", "InterfaceTypeParameterName"})
public abstract class ExpansionTranslatorFactory implements TranslatorFactory {

    /** {@inheritDoc} */
    @Override
    public Set> getSupportedTypes() {
        Set> results = new HashSet<>();
        results.addAll(getExpansions().keySet());

        Set preProcessorTypes = new HashSet<>();
        preProcessorTypes.addAll(getPreprocessorExpansions().keySet());
        preProcessorTypes.add(getBaseInputType());

        Set postProcessorTypes = new HashSet<>();
        postProcessorTypes.addAll(getPostprocessorExpansions().keySet());
        postProcessorTypes.add(getBaseOutputType());

        for (Type i : preProcessorTypes) {
            for (Type o : postProcessorTypes) {
                results.add(new Pair<>(i, o));
            }
        }
        return results;
    }

    /** {@inheritDoc} */
    @Override
    public  Translator newInstance(
            Class input, Class output, Model model, Map arguments) {
        Translator baseTranslator = buildBaseTranslator(model, arguments);
        return newInstance(input, output, baseTranslator);
    }

    /**
     * Returns a new instance of the {@link Translator} class.
     *
     * @param  the input data type
     * @param  the output data type
     * @param input the input class
     * @param output the output class
     * @param translator the base translator to expand from
     * @return a new instance of the {@code Translator} class
     */
    @SuppressWarnings("unchecked")
     Translator newInstance(
            Class input, Class output, Translator translator) {

        if (input.equals(getBaseInputType()) && output.equals(getBaseOutputType())) {
            return (Translator) translator;
        }

        TranslatorExpansion expansion =
                getExpansions().get(new Pair<>(input, output));
        if (expansion != null) {
            return (Translator) expansion.apply(translator);
        }

        // Note that regular expansions take precedence over pre-processor+post-processor expansions
        PreProcessor preProcessor = null;
        if (input.equals(getBaseInputType())) {
            preProcessor = (PreProcessor) translator;
        } else {
            Function, PreProcessor> expander =
                    getPreprocessorExpansions().get(input);
            if (expander != null) {
                preProcessor = (PreProcessor) expander.apply(translator);
            }
        }

        PostProcessor postProcessor = null;
        if (output.equals(getBaseOutputType())) {
            postProcessor = (PostProcessor) translator;
        } else {
            Function, PostProcessor> expander =
                    getPostprocessorExpansions().get(output);
            if (expander != null) {
                postProcessor = (PostProcessor) expander.apply(translator);
            }
        }

        if (preProcessor != null && postProcessor != null) {
            return new BasicTranslator<>(preProcessor, postProcessor, translator.getBatchifier());
        }

        throw new IllegalArgumentException("Unsupported expansion input/output types.");
    }

    /**
     * Creates a set of {@link TranslatorOptions} based on the expansions of a given translator.
     *
     * @param translator the translator to expand
     * @return the {@link TranslatorOptions}
     */
    public ExpandedTranslatorOptions withTranslator(Translator translator) {
        return new ExpandedTranslatorOptions(translator);
    }

    /**
     * Builds the base translator that can be expanded.
     *
     * @param model the {@link Model} that uses the {@link Translator}
     * @param arguments the configurations for a new {@code Translator} instance
     * @return a base translator that can be expanded to form the factory options
     */
    protected abstract Translator buildBaseTranslator(
            Model model, Map arguments);

    /**
     * Returns the input type for the base translator.
     *
     * @return the input type for the base translator
     */
    public abstract Class getBaseInputType();

    /**
     * Returns the output type for the base translator.
     *
     * @return the output type for the base translator
     */
    public abstract Class getBaseOutputType();

    /**
     * Returns the possible expansions of this factory.
     *
     * @return the possible expansions of this factory
     */
    protected Map, TranslatorExpansion> getExpansions() {
        return Collections.emptyMap();
    }

    /**
     * Returns the possible expansions of this factory.
     *
     * @return the possible expansions of this factory
     */
    protected Map, PreProcessor>>
            getPreprocessorExpansions() {
        return Collections.singletonMap(getBaseInputType(), p -> p);
    }

    /**
     * Returns the possible expansions of this factory.
     *
     * @return the possible expansions of this factory
     */
    protected Map, PostProcessor>>
            getPostprocessorExpansions() {
        return Collections.singletonMap(getBaseOutputType(), p -> p);
    }

    /** Represents {@link TranslatorOptions} by applying expansions to a base {@link Translator}. */
    final class ExpandedTranslatorOptions implements TranslatorOptions {

        private Translator translator;

        private ExpandedTranslatorOptions(Translator translator) {
            this.translator = translator;
        }

        /** {@inheritDoc} */
        @Override
        public Set> getOptions() {
            return getSupportedTypes();
        }

        /** {@inheritDoc} */
        @Override
        public  Translator option(Class input, Class output) {
            return newInstance(input, output, translator);
        }
    }

    /**
     * A function from a base translator to an expanded translator.
     *
     * @param  the base translator input type
     * @param  the base translator output type
     */
    @FunctionalInterface
    @SuppressWarnings({"PMD.GenericsNaming", "InterfaceTypeParameterName"})
    public interface TranslatorExpansion
            extends Function, Translator> {}
}