All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.modality.cv.translator.YoloPoseTranslator Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.output.Joints.Joint;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.TranslatorContext;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/** A translator for Yolov8 pose estimation models. */
public class YoloPoseTranslator extends BaseImageTranslator {

    private static final int MAX_DETECTION = 300;

    private float threshold;
    private float nmsThreshold;

    /**
     * Creates the Pose Estimation translator from the given builder.
     *
     * @param builder the builder for the translator
     */
    public YoloPoseTranslator(Builder builder) {
        super(builder);
        this.threshold = builder.threshold;
        this.nmsThreshold = builder.nmsThreshold;
    }

    /** {@inheritDoc} */
    @Override
    public Joints[] processOutput(TranslatorContext ctx, NDList list) {
        NDArray pred = list.singletonOrThrow();
        NDArray candidates = pred.get(4).gt(threshold);
        pred = pred.transpose();
        NDArray sub = pred.get("..., :4");
        sub = xywh2xyxy(sub);
        pred = sub.concat(pred.get("..., 4:"), -1);
        pred = pred.get(candidates);

        NDList split = pred.split(new long[] {4, 5}, 1);
        NDArray box = split.get(0);

        int numBox = Math.toIntExact(box.getShape().get(0));

        float[] buf = box.toFloatArray();
        float[] confidences = split.get(1).toFloatArray();
        float[] mask = split.get(2).toFloatArray();

        List boxes = new ArrayList<>(numBox);
        List scores = new ArrayList<>(numBox);

        for (int i = 0; i < numBox; ++i) {
            float xPos = buf[i * 4];
            float yPos = buf[i * 4 + 1];
            float w = buf[i * 4 + 2] - xPos;
            float h = buf[i * 4 + 3] - yPos;
            Rectangle rect = new Rectangle(xPos, yPos, w, h);
            boxes.add(rect);
            scores.add((double) confidences[i]);
        }
        List nms = Rectangle.nms(boxes, scores, nmsThreshold);
        if (nms.size() > MAX_DETECTION) {
            nms = nms.subList(0, MAX_DETECTION);
        }
        Joints[] ret = new Joints[nms.size()];
        for (int i = 0; i < ret.length; ++i) {
            List joints = new ArrayList<>();
            ret[i] = new Joints(joints);

            int index = nms.get(i);
            int pos = index * 51;
            for (int j = 0; j < 17; ++j) {
                joints.add(
                        new Joints.Joint(
                                mask[pos + j * 3] / width,
                                mask[pos + j * 3 + 1] / height,
                                mask[pos + j * 3 + 2]));
            }
        }
        return ret;
    }

    private NDArray xywh2xyxy(NDArray array) {
        NDArray xy = array.get("..., :2");
        NDArray wh = array.get("..., 2:").div(2);
        return xy.sub(wh).concat(xy.add(wh), -1);
    }

    /**
     * Creates a builder to build a {@code YoloPoseTranslator}.
     *
     * @return a new builder
     */
    public static Builder builder() {
        return new Builder();
    }

    /**
     * Creates a builder to build a {@code YoloPoseTranslator} with specified arguments.
     *
     * @param arguments arguments to specify builder options
     * @return a new builder
     */
    public static Builder builder(Map arguments) {
        Builder builder = new Builder();
        builder.configPreProcess(arguments);
        builder.configPostProcess(arguments);

        return builder;
    }

    /** The builder for Pose Estimation translator. */
    public static class Builder extends BaseBuilder {

        float threshold = 0.25f;
        float nmsThreshold = 0.7f;

        Builder() {}

        /**
         * Sets the threshold for prediction accuracy.
         *
         * 

Predictions below the threshold will be dropped. * * @param threshold the threshold for prediction accuracy * @return the builder */ public Builder optThreshold(float threshold) { this.threshold = threshold; return self(); } /** * Sets the NMS threshold. * * @param nmsThreshold the NMS threshold * @return this builder */ public Builder optNmsThreshold(float nmsThreshold) { this.nmsThreshold = nmsThreshold; return this; } /** {@inheritDoc} */ @Override protected Builder self() { return this; } /** {@inheritDoc} */ @Override protected void configPostProcess(Map arguments) { optThreshold(ArgumentsUtil.floatValue(arguments, "threshold", threshold)); optNmsThreshold(ArgumentsUtil.floatValue(arguments, "nmsThreshold", nmsThreshold)); } /** * Builds the translator. * * @return the new translator */ public YoloPoseTranslator build() { validate(); return new YoloPoseTranslator(this); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy