ai.djl.translate.ExpansionTranslatorFactory Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.translate;
import ai.djl.Model;
import ai.djl.util.Pair;
import java.lang.reflect.Type;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
/**
* A {@link TranslatorFactory} based on a {@link Translator} and it's {@link TranslatorOptions}.
*
* @param the input type for the base translator
* @param the output type for the base translator
*/
@SuppressWarnings({"PMD.GenericsNaming", "InterfaceTypeParameterName"})
public abstract class ExpansionTranslatorFactory implements TranslatorFactory {
/** {@inheritDoc} */
@Override
public Set> getSupportedTypes() {
Set> results = new HashSet<>();
results.addAll(getExpansions().keySet());
Set preProcessorTypes = new HashSet<>();
preProcessorTypes.addAll(getPreprocessorExpansions().keySet());
preProcessorTypes.add(getBaseInputType());
Set postProcessorTypes = new HashSet<>();
postProcessorTypes.addAll(getPostprocessorExpansions().keySet());
postProcessorTypes.add(getBaseOutputType());
for (Type i : preProcessorTypes) {
for (Type o : postProcessorTypes) {
results.add(new Pair<>(i, o));
}
}
return results;
}
/** {@inheritDoc} */
@Override
public Translator newInstance(
Class input, Class output, Model model, Map arguments) {
Translator baseTranslator = buildBaseTranslator(model, arguments);
return newInstance(input, output, baseTranslator);
}
/**
* Returns a new instance of the {@link Translator} class.
*
* @param the input data type
* @param the output data type
* @param input the input class
* @param output the output class
* @param translator the base translator to expand from
* @return a new instance of the {@code Translator} class
*/
@SuppressWarnings("unchecked")
Translator newInstance(
Class input, Class output, Translator translator) {
if (input.equals(getBaseInputType()) && output.equals(getBaseOutputType())) {
return (Translator) translator;
}
TranslatorExpansion expansion =
getExpansions().get(new Pair<>(input, output));
if (expansion != null) {
return (Translator) expansion.apply(translator);
}
// Note that regular expansions take precedence over pre-processor+post-processor expansions
PreProcessor preProcessor = null;
if (input.equals(getBaseInputType())) {
preProcessor = (PreProcessor) translator;
} else {
Function, PreProcessor>> expander =
getPreprocessorExpansions().get(input);
if (expander != null) {
preProcessor = (PreProcessor) expander.apply(translator);
}
}
PostProcessor postProcessor = null;
if (output.equals(getBaseOutputType())) {
postProcessor = (PostProcessor) translator;
} else {
Function, PostProcessor>> expander =
getPostprocessorExpansions().get(output);
if (expander != null) {
postProcessor = (PostProcessor) expander.apply(translator);
}
}
if (preProcessor != null && postProcessor != null) {
return new BasicTranslator<>(preProcessor, postProcessor, translator.getBatchifier());
}
throw new IllegalArgumentException("Unsupported expansion input/output types.");
}
/**
* Creates a set of {@link TranslatorOptions} based on the expansions of a given translator.
*
* @param translator the translator to expand
* @return the {@link TranslatorOptions}
*/
public ExpandedTranslatorOptions withTranslator(Translator translator) {
return new ExpandedTranslatorOptions(translator);
}
/**
* Builds the base translator that can be expanded.
*
* @param model the {@link Model} that uses the {@link Translator}
* @param arguments the configurations for a new {@code Translator} instance
* @return a base translator that can be expanded to form the factory options
*/
protected abstract Translator buildBaseTranslator(
Model model, Map arguments);
/**
* Returns the input type for the base translator.
*
* @return the input type for the base translator
*/
public abstract Class getBaseInputType();
/**
* Returns the output type for the base translator.
*
* @return the output type for the base translator
*/
public abstract Class getBaseOutputType();
/**
* Returns the possible expansions of this factory.
*
* @return the possible expansions of this factory
*/
protected Map, TranslatorExpansion> getExpansions() {
return Collections.emptyMap();
}
/**
* Returns the possible expansions of this factory.
*
* @return the possible expansions of this factory
*/
protected Map, PreProcessor>>>
getPreprocessorExpansions() {
return Collections.singletonMap(getBaseInputType(), p -> p);
}
/**
* Returns the possible expansions of this factory.
*
* @return the possible expansions of this factory
*/
protected Map, PostProcessor>>>
getPostprocessorExpansions() {
return Collections.singletonMap(getBaseOutputType(), p -> p);
}
/** Represents {@link TranslatorOptions} by applying expansions to a base {@link Translator}. */
final class ExpandedTranslatorOptions implements TranslatorOptions {
private Translator translator;
private ExpandedTranslatorOptions(Translator translator) {
this.translator = translator;
}
/** {@inheritDoc} */
@Override
public Set> getOptions() {
return getSupportedTypes();
}
/** {@inheritDoc} */
@Override
public Translator option(Class input, Class output) {
return newInstance(input, output, translator);
}
}
/**
* A function from a base translator to an expanded translator.
*
* @param the base translator input type
* @param the base translator output type
*/
@FunctionalInterface
@SuppressWarnings({"PMD.GenericsNaming", "InterfaceTypeParameterName"})
public interface TranslatorExpansion
extends Function, Translator, ?>> {}
}