All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.modality.cv.translator.YoloSegmentationTranslator Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Mask;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.TranslatorContext;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/** A translator for Yolov8 instance segmentation models. */
public class YoloSegmentationTranslator extends YoloV5Translator {

    private static final int[] AXIS_0 = {0};
    private static final int[] AXIS_1 = {1};

    private float threshold;
    private float nmsThreshold;

    /**
     * Creates the instance segmentation translator from the given builder.
     *
     * @param builder the builder for the translator
     */
    public YoloSegmentationTranslator(Builder builder) {
        super(builder);
        this.threshold = builder.threshold;
        this.nmsThreshold = builder.nmsThreshold;
    }

    /** {@inheritDoc} */
    @Override
    public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
        NDArray pred = list.get(0);
        NDArray protos = list.get(1);
        int maskIndex = classes.size() + 4;
        NDArray candidates = pred.get("4:" + maskIndex).max(AXIS_0).gt(threshold);
        pred = pred.transpose();
        NDArray sub = pred.get("..., :4");
        sub = xywh2xyxy(sub);
        pred = sub.concat(pred.get("..., 4:"), -1);
        pred = pred.get(candidates);

        NDList split = pred.split(new long[] {4, maskIndex}, 1);
        NDArray box = split.get(0);

        int numBox = Math.toIntExact(box.getShape().get(0));

        float[] buf = box.toFloatArray();
        float[] confidences = split.get(1).max(AXIS_1).toFloatArray();
        long[] ids = split.get(1).argMax(1).toLongArray();

        List boxes = new ArrayList<>(numBox);
        List scores = new ArrayList<>(numBox);
        for (int i = 0; i < numBox; ++i) {
            float xPos = buf[i * 4];
            float yPos = buf[i * 4 + 1];
            float w = buf[i * 4 + 2] - xPos;
            float h = buf[i * 4 + 3] - yPos;
            Rectangle rect = new Rectangle(xPos, yPos, w, h);
            boxes.add(rect);
            scores.add((double) confidences[i]);
        }
        List nms = Rectangle.nms(boxes, scores, nmsThreshold);
        long[] idx = nms.stream().mapToLong(Integer::longValue).toArray();
        NDArray selected = box.getManager().create(idx);
        NDArray masks = split.get(2).get(selected);

        int maskW = Math.toIntExact(protos.getShape().get(2));
        int maskH = Math.toIntExact(protos.getShape().get(1));

        protos = protos.reshape(32, (long) maskH * maskW);
        masks =
                masks.matMul(protos)
                        .reshape(nms.size(), maskH, maskW)
                        .gt(0f)
                        .toType(DataType.FLOAT32, true);

        float[] maskArray = masks.toFloatArray();
        box = box.get(selected);
        buf = box.toFloatArray();

        List retClasses = new ArrayList<>();
        List retProbs = new ArrayList<>();
        List retBB = new ArrayList<>();
        for (int i = 0; i < idx.length; ++i) {
            float x = buf[i * 4] / width;
            float y = buf[i * 4 + 1] / height;
            float w = buf[i * 4 + 2] / width - x;
            float h = buf[i * 4 + 3] / width - y;
            int id = nms.get(i);
            retClasses.add(classes.get((int) ids[id]));
            retProbs.add((double) confidences[id]);

            float[][] maskFloat = new float[maskH][maskW];
            for (int j = 0; j < maskH; j++) {
                System.arraycopy(maskArray, j * maskW, maskFloat[j], 0, maskW);
            }
            Mask bb = new Mask(x, y, w, h, maskFloat, true);
            retBB.add(bb);
        }
        return new DetectedObjects(retClasses, retProbs, retBB);
    }

    private NDArray xywh2xyxy(NDArray array) {
        NDArray xy = array.get("..., :2");
        NDArray wh = array.get("..., 2:").div(2);
        return xy.sub(wh).concat(xy.add(wh), -1);
    }

    /**
     * Creates a builder to build a {@code YoloSegmentationTranslator}.
     *
     * @return a new builder
     */
    public static Builder builder() {
        return new Builder();
    }

    /**
     * Creates a builder to build a {@code YoloSegmentationTranslator} with specified arguments.
     *
     * @param arguments arguments to specify builder options
     * @return a new builder
     */
    public static Builder builder(Map arguments) {
        Builder builder = new Builder();
        builder.configPreProcess(arguments);
        builder.configPostProcess(arguments);

        return builder;
    }

    /** The builder for instance segmentation translator. */
    public static class Builder extends YoloV5Translator.Builder {

        Builder() {}

        /** {@inheritDoc} */
        @Override
        protected Builder self() {
            return this;
        }

        /** {@inheritDoc} */
        @Override
        public YoloSegmentationTranslator build() {
            validate();
            return new YoloSegmentationTranslator(this);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy