ai.djl.modality.cv.translator.ObjectDetectionTranslator Maven / Gradle / Ivy
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.translator;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.ndarray.NDArray;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.TranslatorContext;
import java.util.List;
import java.util.Map;
/**
* A {@link BaseImageTranslator} that post-process the {@link NDArray} into {@link DetectedObjects}
* with boundaries.
*/
public abstract class ObjectDetectionTranslator extends BaseImageTranslator {
protected float threshold;
private SynsetLoader synsetLoader;
protected List classes;
protected double imageWidth;
protected double imageHeight;
protected boolean applyRatio;
/**
* Creates the {@link ObjectDetectionTranslator} from the given builder.
*
* @param builder the builder for the translator
*/
protected ObjectDetectionTranslator(ObjectDetectionBuilder> builder) {
super(builder);
this.threshold = builder.threshold;
this.synsetLoader = builder.synsetLoader;
this.imageWidth = builder.imageWidth;
this.imageHeight = builder.imageHeight;
this.applyRatio = builder.applyRatio;
}
/** {@inheritDoc} */
@Override
public void prepare(TranslatorContext ctx) throws Exception {
if (classes == null) {
classes = synsetLoader.load(ctx.getModel());
}
}
/** The base builder for the object detection translator. */
@SuppressWarnings("rawtypes")
public abstract static class ObjectDetectionBuilder
extends ClassificationBuilder {
protected float threshold = 0.2f;
protected double imageWidth;
protected double imageHeight;
protected boolean applyRatio;
/**
* Sets the threshold for prediction accuracy.
*
* Predictions below the threshold will be dropped.
*
* @param threshold the threshold for the prediction accuracy
* @return this builder
*/
public T optThreshold(float threshold) {
this.threshold = threshold;
return self();
}
/**
* Sets the optional rescale size.
*
* @param imageWidth the width to rescale images to
* @param imageHeight the height to rescale images to
* @return this builder
*/
public T optRescaleSize(double imageWidth, double imageHeight) {
this.imageWidth = imageWidth;
this.imageHeight = imageHeight;
return self();
}
/**
* Determine Whether to divide output object width/height on the inference result. Default
* false.
*
*
DetectedObject value should always bring a ratio based on the width/height instead of
* actual width/height. Most of the model will produce ratio as the inference output. This
* function is aimed to cover those who produce the pixel value. Make this to true to divide
* the width/height in postprocessing in order to get ratio in detectedObjects.
*
* @param value whether to apply ratio
* @return this builder
*/
public T optApplyRatio(boolean value) {
this.applyRatio = value;
return self();
}
/**
* Get resized image width.
*
* @return image width
*/
public double getImageWidth() {
return imageWidth;
}
/**
* Get resized image height.
*
* @return image height
*/
public double getImageHeight() {
return imageHeight;
}
/** {@inheritDoc} */
@Override
protected void configPostProcess(Map arguments) {
super.configPostProcess(arguments);
if (ArgumentsUtil.booleanValue(arguments, "rescale")) {
optRescaleSize(width, height);
}
optApplyRatio(ArgumentsUtil.booleanValue(arguments, "optApplyRatio"));
threshold = ArgumentsUtil.floatValue(arguments, "threshold", 0.2f);
}
}
}