All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.training.loss.YOLOv3Loss Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;

/**
 * {@code YOLOv3Loss} is an implementation of {@link Loss}. It is used to compute the loss while
 * training a YOLOv3 model for object detection. It involves computing the targets given the
 * generated anchors, labels and predictions, and then computing the sum of class predictions and
 * bounding box predictions.
 */
public final class YOLOv3Loss extends Loss {
    // TODO: currently not finished, still have some bugs inside and it can only be trained with
    // PyTorch Engine
    /*
       PRESETANCHORS shapes come from the K-means clustering of COCO dataset, which image size is 416*416
       it can be reshaped into any shape like 256*256, just multiply each value with 256/416
    */
    private static final float[] PRESETANCHORS = {
        116, 90, 156, 198, 373, 326,
        30, 61, 62, 45, 59, 119,
        10, 13, 16, 30, 33, 23
    };

    private float[] anchors;
    private int numClasses;
    private int boxAttr;
    private Shape inputShape;
    private float ignoreThreshold;
    private NDManager manager;
    private static final float EPSILON = 1e-7f;

    private YOLOv3Loss(Builder builder) {
        super(builder.name);
        this.anchors = builder.anchorsArray;
        this.numClasses = builder.numClasses;
        this.boxAttr = builder.numClasses + 5; // 5 for x,y,h,w,c
        this.inputShape = builder.inputShape;
        this.ignoreThreshold = builder.ignoreThreshold;
    }

    /**
     * Gets the preset anchors of YoloV3.
     *
     * @return the preset anchors of YoloV3
     */
    public static float[] getPresetAnchors() {
        return PRESETANCHORS.clone();
    }

    /**
     * Make the value of given NDArray between tMin and tMax.
     *
     * @param tList the given NDArray
     * @param tMin the min value
     * @param tMax the max value
     * @return a NDArray where values are set between tMin and tMax
     */
    public NDArray clipByTensor(NDArray tList, float tMin, float tMax) {
        NDArray result = tList.gte(tMin).mul(tList).add(tList.lt(tMin).mul(tMin));
        result = result.lte(tMax).mul(result).add(result.gt(tMax).mul(tMax));
        return result;
    }

    /**
     * Calculates the MSELoss between prediction and target.
     *
     * @param prediction the prediction array
     * @param target the target array
     * @return the MSELoss between prediction and target
     */
    public NDArray mseLoss(NDArray prediction, NDArray target) {
        return prediction.sub(target).pow(2);
    }

    /**
     * Calculates the BCELoss between prediction and target.
     *
     * @param prediction the prediction array
     * @param target the target array
     * @return the BCELoss between prediction and target
     */
    public NDArray bceLoss(NDArray prediction, NDArray target) {
        prediction = clipByTensor(prediction, EPSILON, (float) (1.0 - EPSILON));
        return prediction
                .log()
                .mul(target)
                .add(prediction.mul(-1).add(1).log().mul(target.mul(-1).add(1)))
                .mul(-1);
    }

    /** {@inheritDoc} */
    @Override
    public NDArray evaluate(NDList labels, NDList predictions) {
        manager = predictions.getManager();

        /*
        three outputs in total
        NDArray out0 = predictions.get(0),  //Shape = (batchSize * 75 * 13 * 13) 75 = 3*(20+5)
                out1 = predictions.get(1),  //Shape = (batchSize * 75 * 26 * 26)
                out2 = predictions.get(2);  //Shape = (batchSize * 75 * 52 * 52)
        */

        NDArray[] lossComponents = new NDArray[3];
        for (int i = 0; i < 3; i++) {
            lossComponents[i] = evaluateOneOutput(i, predictions.get(i), labels.singletonOrThrow());
        }

        // calculate the final loss
        return NDArrays.add(lossComponents);
    }

    /**
     * Computes the Loss for one outputLayer.
     *
     * @param componentIndex which outputLayer does current input represent. the shape should be
     *     (13*13,26*26,52*52)
     * @param input one prediction layer of YOLOv3
     * @param labels target labels. Must contain (offsetLabels, masks, classlabels)
     * @return the total loss of a outputLayer
     */
    public NDArray evaluateOneOutput(int componentIndex, NDArray input, NDArray labels) {
        int batchSize = (int) input.getShape().get(0);
        int inW = (int) input.getShape().get(2);
        int inH = (int) input.getShape().get(3);

        NDArray prediction =
                input.reshape(batchSize, 3, boxAttr, inW, inH)
                        .transpose(1, 0, 3, 4, 2); // reshape into (3,batchSize,inW,inH,attrs)

        // the prediction value of x,y,w,h which shape should be (3,batchSize,inW,inH)
        NDArray x = Activation.sigmoid(prediction.get("...,0"));
        NDArray y = Activation.sigmoid(prediction.get("...,1"));
        NDArray w = prediction.get("...,2");
        NDArray h = prediction.get("...,3");

        // Confidence of whether there is an object and conditional probability of each class
        // it should be reshaped into (batchSize,3)
        NDArray conf = Activation.sigmoid(prediction.get("...,4")).transpose(1, 0, 2, 3);
        NDArray predClass = Activation.sigmoid(prediction.get("...,5:")).transpose(1, 0, 2, 3, 4);

        // get an NDList of groundTruth which contains boxLossScale and groundTruth
        NDList truthList = getTarget(labels, inH, inW);

        /*
           boxLossScale shape should be: (batchSize,3,inW,inH)
           groundTruth shape should be: (3,batchSize,inW,inH,boxAttr)
        */
        NDArray boxLossScale = truthList.get(0).transpose(1, 0, 2, 3);
        NDArray groundTruth = truthList.get(1);

        // iou shape should be: (batchSize,3 ,inW,inH)
        NDArray iou =
                calculateIOU(x, y, groundTruth.get("...,0:4"), componentIndex)
                        .transpose(1, 0, 2, 3);

        // get noObjMask and objMask
        NDArray noObjMask =
                NDArrays.where(
                        iou.lte(ignoreThreshold), manager.ones(iou.getShape()), manager.create(0f));

        NDArray objMask = iou.argMax(1).oneHot(3).transpose(0, 3, 1, 2);

        objMask =
                NDArrays.where(
                        iou.gte(ignoreThreshold / 2),
                        objMask,
                        manager.zeros(objMask.getShape())); // to get rid of wrong ones
        noObjMask = NDArrays.where(objMask.eq(1f), manager.zeros(noObjMask.getShape()), noObjMask);

        NDArray xTrue = groundTruth.get("...,0");
        NDArray yTrue = groundTruth.get("...,1");
        NDArray wTrue = groundTruth.get("...,2");
        NDArray hTrue = groundTruth.get("...,3");
        NDArray classTrue = groundTruth.get("...,4:").transpose(1, 0, 2, 3, 4);

        NDArray widths =
                manager.create(
                                new float[] {
                                    anchors[componentIndex * 6],
                                    anchors[componentIndex * 6 + 2],
                                    anchors[componentIndex * 6 + 4]
                                })
                        .div(inputShape.get(0));

        NDArray heights =
                manager.create(
                                new float[] {
                                    anchors[componentIndex * 6 + 1],
                                    anchors[componentIndex * 6 + 3],
                                    anchors[componentIndex * 6 + 5]
                                })
                        .div(inputShape.get(1));

        // three loss parts: box Loss, confidence Loss, and class Loss
        NDArray boxLoss =
                objMask.mul(boxLossScale)
                        .mul(
                                NDArrays.add(
                                                xTrue.sub(x).pow(2),
                                                yTrue.sub(y).pow(2),
                                                wTrue.sub(
                                                                w.exp()
                                                                        .mul(
                                                                                widths.broadcast(
                                                                                                inH,
                                                                                                inW,
                                                                                                batchSize,
                                                                                                3)
                                                                                        .transpose(
                                                                                                3,
                                                                                                2,
                                                                                                1,
                                                                                                0)))
                                                        .pow(2),
                                                hTrue.sub(
                                                                h.exp()
                                                                        .mul(
                                                                                heights.broadcast(
                                                                                                inH,
                                                                                                inW,
                                                                                                batchSize,
                                                                                                3)
                                                                                        .transpose(
                                                                                                3,
                                                                                                2,
                                                                                                1,
                                                                                                0)))
                                                        .pow(2))
                                        .transpose(1, 0, 2, 3))
                        .sum();

        NDArray confLoss =
                objMask.mul(
                                conf.add(EPSILON)
                                        .log()
                                        .mul(-1)
                                        .add(bceLoss(predClass, classTrue).sum(new int[] {4})))
                        .sum();

        NDArray noObjLoss = noObjMask.mul(conf.mul(-1).add(1 + EPSILON).log().mul(-1)).sum();

        return boxLoss.add(confLoss).add(noObjLoss).div(batchSize);
    }

    /**
     * Gets target NDArray for a given evaluator.
     *
     * @param labels the true labels
     * @param inH the height of current layer
     * @param inW the width of current layer
     * @return an NDList of {boxLossScale and groundTruth}
     */
    public NDList getTarget(NDArray labels, int inH, int inW) {
        int batchSize = (int) labels.size(0);

        // the loss Scale of a box, used to punctuate small boxes
        NDList boxLossComponents = new NDList();

        // the groundTruth of a true object in pictures
        NDList groundTruthComponents = new NDList();

        // Shape of labels:(batchSize,objectNum,5)

        for (int batch = 0; batch < batchSize; batch++) {
            if (labels.get(batch).size(0) == 0) {
                continue; // no object in current picture
            }

            NDArray boxLoss = manager.zeros(new Shape(inW, inH), DataType.FLOAT32);
            NDArray groundTruth = manager.zeros(new Shape(inW, inH, boxAttr - 1), DataType.FLOAT32);

            NDArray picture = labels.get(batch);
            // the shape should be (objectNums,5)
            NDArray xgt = picture.get("...,1").add(picture.get("...,3").div(2)).mul(inW);
            // Center of x should be X value in labels and add half of the width and multiplies the
            // input width to get which grid cell it's in
            NDArray ygt = picture.get("...,2").add(picture.get("...,4").div(2)).mul(inH);
            // Center of y is the same as well
            NDArray wgt = picture.get("...,3");
            // the width of the ground truth box
            NDArray hgt = picture.get("...,4");
            // the height of the ground truth box
            // we should transform the presentation of true class, like
            // [[0],[1],[2]]->[[1,0,0,...0],[0,1,0,...,0],[0,0,1,...,0]]
            NDArray objectClass = picture.get("...,0");
            objectClass = objectClass.oneHot(numClasses);

            NDArray curLabel = labels.get(batch); // curLabel shape:(objectNum,5)
            int objectNum = (int) curLabel.size(0);

            for (int i = 0; i < objectNum; i++) {
                // for each object, the middle of the object(x and y) should be in one grid cell of
                // 13*13
                // the tx and ty should indicate the grid cell and bx and by should indicate the
                // movement from top-left of the grid cell
                int tx = (int) xgt.get(i).getFloat();
                int ty = (int) ygt.get(i).getFloat();
                float bx = xgt.get(i).getFloat() - tx;
                float by = ygt.get(i).getFloat() - ty;

                String index = tx + "," + ty;
                // set groundTruth
                groundTruth.set(new NDIndex(index + ",0"), bx);
                groundTruth.set(new NDIndex(index + ",1"), by);
                groundTruth.set(new NDIndex(index + ",2"), wgt.getFloat(i));
                groundTruth.set(new NDIndex(index + ",3"), hgt.getFloat(i));
                groundTruth.set(new NDIndex(index + ",4:"), objectClass.get(i));

                // set boxLoss
                boxLoss.set(new NDIndex(index), 2 - wgt.getFloat(i) * hgt.getFloat(i));
            }
            boxLossComponents.add(boxLoss);
            groundTruthComponents.add(groundTruth);
        }

        NDArray boxLossScale = NDArrays.stack(boxLossComponents).broadcast(3, batchSize, inW, inH);
        NDArray groundTruth =
                NDArrays.stack(groundTruthComponents)
                        .broadcast(3, batchSize, inW, inH, boxAttr - 1);

        return new NDList(boxLossScale, groundTruth);
    }

    /**
     * Calculates the IOU between priori Anchors and groundTruth.
     *
     * @param predx the tx value of prediction
     * @param predy the ty value of prediction
     * @param groundTruth the groundTruth value of labels
     * @param componentIndex the current component Index
     * @return an NDArray of IOU
     */
    public NDArray calculateIOU(
            NDArray predx, NDArray predy, NDArray groundTruth, int componentIndex) {
        int inW = (int) predx.getShape().get(2);
        int inH = (int) predx.getShape().get(3);
        int strideW = (int) inputShape.get(0) / inW;
        int strideH = (int) inputShape.get(1) / inH;

        NDList iouComponent = new NDList();
        // shape of predx, predy should all be (3,batchSize,inW,inH)
        // shape of groundTruth should be (3,batchSize,inW,inH,4)
        for (int i = 0; i < 3; i++) {
            NDArray curPredx = predx.get(i);
            NDArray curPredy = predy.get(i);
            float width = anchors[componentIndex * 6 + 2 * i] / strideW;
            float height = anchors[componentIndex * 6 + 2 * i + 1] / strideH;
            NDArray predLeft = curPredx.sub(width / 2);
            NDArray predRight = curPredx.add(width / 2);
            NDArray predTop = curPredy.sub(height / 2);
            NDArray predBottom = curPredy.add(height / 2);

            NDArray truth = groundTruth.get(i);

            NDArray trueLeft = truth.get("...,0").sub(truth.get("...,2").mul(inW).div(2));
            NDArray trueRight = truth.get("...,0").add(truth.get("...,2").mul(inW).div(2));
            NDArray trueTop = truth.get("...,1").sub(truth.get("...,3").mul(inH).div(2));
            NDArray trueBottom = truth.get("...,1").add(truth.get("...,3").mul(inH).div(2));

            NDArray left = NDArrays.maximum(predLeft, trueLeft);
            NDArray right = NDArrays.minimum(predRight, trueRight);
            NDArray top = NDArrays.maximum(predTop, trueTop);
            NDArray bottom = NDArrays.minimum(predBottom, trueBottom);

            NDArray inter = right.sub(left).mul(bottom.sub(top));
            NDArray union =
                    truth.get("...,2")
                            .mul(inW)
                            .mul(truth.get("...,3").mul(inH))
                            .add(width * height)
                            .sub(inter)
                            .add(EPSILON); // should not be divided by zero

            iouComponent.add(inter.div(union));
        }

        return NDArrays.stack(iouComponent);
    }

    /**
     * Creates a new builder to build a {@link YOLOv3Loss}.
     *
     * @return a new builder;
     */
    public static Builder builder() {
        return new Builder();
    }

    /** The Builder to construct a {@link YOLOv3Loss} object. */
    public static class Builder {
        private String name = "YOLOv3Loss";
        private float[] anchorsArray = PRESETANCHORS;
        private int numClasses = 20;
        private Shape inputShape = new Shape(419, 419);
        private float ignoreThreshold = 0.5f;

        /**
         * Sets the loss name of YoloV3Loss.
         *
         * @param name the name of loss function
         * @return this {@code Builder}
         */
        public Builder setName(String name) {
            this.name = name;
            return this;
        }

        /**
         * Sets the preset anchors for YoloV3.
         *
         * @param anchorsArray the anchors in float array
         * @return this {@code Builder}
         */
        public Builder setAnchorsArray(float[] anchorsArray) {
            if (anchorsArray.length != PRESETANCHORS.length) {
                throw new IllegalArgumentException(
                        String.format(
                                "setAnchorsArray requires anchors of length %d, but was given"
                                        + " filters of length %d instead",
                                PRESETANCHORS.length, anchorsArray.length));
            }
            this.anchorsArray = anchorsArray;
            return this;
        }

        /**
         * Sets the number of total classes.
         *
         * @param numClasses the number of total classes
         * @return this {@code Builder}
         */
        public Builder setNumClasses(int numClasses) {
            this.numClasses = numClasses;
            return this;
        }

        /**
         * Sets the shape of the input picture.
         *
         * @param inputShape the shape of input picture.
         * @return this {@code Builder}
         */
        public Builder setInputShape(Shape inputShape) {
            this.inputShape = inputShape;
            return this;
        }

        /**
         * Sets the ignoreThreshold for iou to check if we think it detects a picture.
         *
         * @param ignoreThreshold the ignore threshold
         * @return this {@code Builder}
         */
        public Builder optIgnoreThreshold(float ignoreThreshold) {
            this.ignoreThreshold = ignoreThreshold;
            return this;
        }

        /**
         * Builds a {@link YOLOv3Loss} instance.
         *
         * @return a {@link YOLOv3Loss} instance.
         */
        public YOLOv3Loss build() {
            return new YOLOv3Loss(this);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy