ai.djl.modality.cv.translator.ObjectDetectionTranslator Maven / Gradle / Ivy
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.translator;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.ndarray.NDArray;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.TranslatorContext;
import java.util.List;
import java.util.Map;
/**
* A {@link BaseImageTranslator} that post-process the {@link NDArray} into {@link DetectedObjects}
* with boundaries.
*/
public abstract class ObjectDetectionTranslator extends BaseImageTranslator {
protected float threshold;
private SynsetLoader synsetLoader;
protected List classes;
protected boolean applyRatio;
protected boolean removePadding;
/**
* Creates the {@link ObjectDetectionTranslator} from the given builder.
*
* @param builder the builder for the translator
*/
protected ObjectDetectionTranslator(ObjectDetectionBuilder> builder) {
super(builder);
this.threshold = builder.threshold;
this.synsetLoader = builder.synsetLoader;
this.applyRatio = builder.applyRatio;
this.removePadding = builder.removePadding;
}
/** {@inheritDoc} */
@Override
public void prepare(TranslatorContext ctx) throws Exception {
if (classes == null) {
classes = synsetLoader.load(ctx.getModel());
}
}
/** The base builder for the object detection translator. */
@SuppressWarnings("rawtypes")
public abstract static class ObjectDetectionBuilder
extends ClassificationBuilder {
protected float threshold = 0.2f;
protected boolean applyRatio;
protected boolean removePadding;
/**
* Sets the threshold for prediction accuracy.
*
* Predictions below the threshold will be dropped.
*
* @param threshold the threshold for the prediction accuracy
* @return this builder
*/
public T optThreshold(float threshold) {
this.threshold = threshold;
return self();
}
/**
* Determine Whether to divide output object width/height on the inference result. Default
* false.
*
*
DetectedObject value should always bring a ratio based on the width/height instead of
* actual width/height. Most of the model will produce ratio as the inference output. This
* function is aimed to cover those who produce the pixel value. Make this to true to divide
* the width/height in postprocessing in order to get ratio in detectedObjects.
*
* @param value whether to apply ratio
* @return this builder
*/
public T optApplyRatio(boolean value) {
this.applyRatio = value;
return self();
}
/** {@inheritDoc} */
@Override
protected void configPostProcess(Map arguments) {
super.configPostProcess(arguments);
if (ArgumentsUtil.booleanValue(arguments, "optApplyRatio")
|| ArgumentsUtil.booleanValue(arguments, "applyRatio")) {
optApplyRatio(true);
}
threshold = ArgumentsUtil.floatValue(arguments, "threshold", 0.2f);
String centerFit = ArgumentsUtil.stringValue(arguments, "centerFit", "false");
removePadding = "true".equals(centerFit);
}
}
}