ai.djl.modality.cv.translator.ObjectDetectionTranslator Maven / Gradle / Ivy
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.translator;
import ai.djl.Model;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import java.io.IOException;
import java.util.List;
import java.util.Map;
/**
* A {@link BaseImageTranslator} that post-process the {@link NDArray} into {@link DetectedObjects}
* with boundaries.
*/
public abstract class ObjectDetectionTranslator extends BaseImageTranslator {
protected float threshold;
private SynsetLoader synsetLoader;
protected List classes;
protected double imageWidth;
protected double imageHeight;
/**
* Creates the {@link ObjectDetectionTranslator} from the given builder.
*
* @param builder the builder for the translator
*/
protected ObjectDetectionTranslator(ObjectDetectionBuilder> builder) {
super(builder);
this.threshold = builder.threshold;
this.synsetLoader = builder.synsetLoader;
this.imageWidth = builder.imageWidth;
this.imageHeight = builder.imageHeight;
}
/** {@inheritDoc} */
@Override
public void prepare(NDManager manager, Model model) throws IOException {
if (classes == null) {
classes = synsetLoader.load(model);
}
}
/** The base builder for the object detection translator. */
@SuppressWarnings("rawtypes")
public abstract static class ObjectDetectionBuilder
extends ClassificationBuilder {
protected float threshold = 0.2f;
protected double imageWidth;
protected double imageHeight;
/**
* Sets the threshold for prediction accuracy.
*
* Predictions below the threshold will be dropped.
*
* @param threshold the threshold for the prediction accuracy
* @return this builder
*/
public T optThreshold(float threshold) {
this.threshold = threshold;
return self();
}
/**
* Sets the optional rescale size.
*
* @param imageWidth the width to rescale images to
* @param imageHeight the height to rescale images to
* @return this builder
*/
public T optRescaleSize(double imageWidth, double imageHeight) {
this.imageWidth = imageWidth;
this.imageHeight = imageHeight;
return self();
}
/**
* Get resized image width.
*
* @return image width
*/
public double getImageWidth() {
return imageWidth;
}
/**
* Get resized image height.
*
* @return image height
*/
public double getImageHeight() {
return imageHeight;
}
/** {@inheritDoc} */
@Override
protected void configPostProcess(Map arguments) {
super.configPostProcess(arguments);
if (getBooleanValue(arguments, "rescale", false)) {
optRescaleSize(width, height);
}
threshold = getFloatValue(arguments, "threshold", 0.2f);
}
}
}