ai.djl.modality.cv.output.Rectangle Maven / Gradle / Ivy
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.output;
import ai.djl.util.JsonSerializable;
import ai.djl.util.JsonUtils;
import com.google.gson.JsonObject;
import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;
/**
* A {@code Rectangle} specifies an area in a coordinate space that is enclosed by the {@code
* Rectangle} object's upper-left point {@link Point} in the coordinate space, its width, and its
* height.
*
* The rectangle coordinates are usually from 0-1 and are ratios of the image size. For example,
* if you have an image width of 400 pixels and the rectangle starts at 100 pixels, you would use
* .25.
*/
public class Rectangle implements BoundingBox, JsonSerializable {
private static final long serialVersionUID = 1L;
@SuppressWarnings("serial")
private List corners;
private double width;
private double height;
/**
* Constructs a new {@code Rectangle} whose upper-left corner is specified as {@code (x,y)} and
* whose width and height are specified by the arguments of the same name.
*
* @param x the specified X coordinate (0-1)
* @param y the specified Y coordinate (0-1)
* @param width the width of the {@code Rectangle} (0-1)
* @param height the height of the {@code Rectangle} (0-1)
*/
public Rectangle(double x, double y, double width, double height) {
this(new Point(x, y), width, height);
}
/**
* Constructs a new {@code Rectangle} whose upper-left corner is specified as coordinate {@code
* point} and whose width and height are specified by the arguments of the same name.
*
* @param point the upper-left corner of the coordinate (0-1)
* @param width the width of the {@code Rectangle} (0-1)
* @param height the height of the {@code Rectangle} (0-1)
*/
public Rectangle(Point point, double width, double height) {
this.width = width;
this.height = height;
corners = new ArrayList<>(4);
corners.add(point);
corners.add(new Point(point.getX() + width, point.getY()));
corners.add(new Point(point.getX() + width, point.getY() + height));
corners.add(new Point(point.getX(), point.getY() + height));
}
/** {@inheritDoc} */
@Override
public Rectangle getBounds() {
return this;
}
/** {@inheritDoc} */
@Override
public Iterable getPath() {
return corners;
}
/** {@inheritDoc} */
@Override
public Point getPoint() {
return corners.get(0);
}
/** {@inheritDoc} */
@Override
public double getIoU(BoundingBox box) {
Rectangle rect = box.getBounds();
// computing area of each rectangles
double s1 = (width + 1) * (height + 1);
double s2 = (rect.getWidth() + 1) * (rect.getHeight() + 1);
double sumArea = s1 + s2;
// find each edge of intersect rectangle
double left = Math.max(getX(), rect.getX());
double top = Math.max(getY(), rect.getY());
double right = Math.min(getX() + getWidth(), rect.getX() + rect.getWidth());
double bottom = Math.min(getY() + getHeight(), rect.getY() + rect.getHeight());
// judge if there is a intersect
if (left > right || top > bottom) {
return 0.0;
}
double intersect = (right - left + 1) * (bottom - top + 1);
return intersect / (sumArea - intersect);
}
/**
* Returns the left x-coordinate of the Rectangle.
*
* @return the left x-coordinate of the Rectangle (0-1)
*/
public double getX() {
return getPoint().getX();
}
/**
* Returns the top y-coordinate of the Rectangle.
*
* @return the top y-coordinate of the Rectangle (0-1)
*/
public double getY() {
return getPoint().getY();
}
/**
* Returns the width of the Rectangle.
*
* @return the width of the Rectangle (0-1)
*/
public double getWidth() {
return width;
}
/**
* Returns the height of the Rectangle.
*
* @return the height of the Rectangle (0-1)
*/
public double getHeight() {
return height;
}
/**
* Returns the upper left and bottom right coordinates.
*
* @return the upper left and bottom right coordinates
*/
public double[] getCoordinates() {
Point upLeft = corners.get(0);
Point bottomRight = corners.get(2);
return new double[] {upLeft.getX(), upLeft.getY(), bottomRight.getX(), bottomRight.getY()};
}
/** {@inheritDoc} */
@Override
public JsonObject serialize() {
JsonObject ret = new JsonObject();
ret.add("rect", JsonUtils.GSON.toJsonTree(getCoordinates()));
return ret;
}
/** {@inheritDoc} */
@Override
public String toString() {
return toJson();
}
/**
* Applies nms (non-maximum suppression) to the list of rectangles.
*
* @param boxes an list of {@code Rectangle}
* @param scores a list of scores
* @param nmsThreshold the nms threshold
* @return the filtered list with the index of the original list
*/
public static List nms(
List boxes, List scores, float nmsThreshold) {
List ret = new ArrayList<>();
PriorityQueue pq =
new PriorityQueue<>(
50,
(lhs, rhs) -> {
// Intentionally reversed to put high confidence at the head of the
// queue.
return Double.compare(scores.get(rhs), scores.get(lhs));
});
for (int i = 0; i < boxes.size(); ++i) {
pq.add(i);
}
// do non maximum suppression
while (!pq.isEmpty()) {
// insert detection with max confidence
int[] detections = pq.stream().mapToInt(Integer::intValue).toArray();
ret.add(detections[0]);
Rectangle box = boxes.get(detections[0]);
pq.clear();
for (int i = 1; i < detections.length; i++) {
int detection = detections[i];
Rectangle location = boxes.get(detection);
if (box.boxIou(location) < nmsThreshold) {
pq.add(detection);
}
}
}
return ret;
}
private double boxIou(Rectangle other) {
double intersection = intersection(other);
double union =
getWidth() * getHeight() + other.getWidth() * other.getHeight() - intersection;
return intersection / union;
}
private double intersection(Rectangle b) {
double w =
overlap(
(getX() * 2 + getWidth()) / 2,
getWidth(),
(b.getX() * 2 + b.getWidth()) / 2,
b.getWidth());
double h =
overlap(
(getY() * 2 + getHeight()) / 2,
getHeight(),
(b.getY() * 2 + b.getHeight()) / 2,
b.getHeight());
if (w < 0 || h < 0) {
return 0;
}
return w * h;
}
private double overlap(double x1, double w1, double x2, double w2) {
double l1 = x1 - w1 / 2;
double l2 = x2 - w2 / 2;
double left = Math.max(l1, l2);
double r1 = x1 + w1 / 2;
double r2 = x2 + w2 / 2;
double right = Math.min(r1, r2);
return right - left;
}
}