All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.nn.layers.objdetect;

import lombok.*;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.impl.LossL2;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.common.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.nd4j.common.util.ArrayUtil;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;

import static org.nd4j.linalg.indexing.NDArrayIndex.*;

public class Yolo2OutputLayer extends AbstractLayer implements Serializable, IOutputLayer {
    private static final Gradient EMPTY_GRADIENT = new DefaultGradient();

    //current input and label matrices
    @Setter @Getter
    protected INDArray labels;

    private double fullNetRegTerm;
    private double score;

    public Yolo2OutputLayer(NeuralNetConfiguration conf, DataType dataType) {
        super(conf, dataType);
    }

    @Override
    public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        INDArray epsOut = computeBackpropGradientAndScore(workspaceMgr, false, false);

        return new Pair<>(EMPTY_GRADIENT, epsOut);
    }

    private INDArray computeBackpropGradientAndScore(LayerWorkspaceMgr workspaceMgr, boolean scoreOnly, boolean computeScoreForExamples){
        assertInputSet(true);
        Preconditions.checkState(labels != null, "Cannot calculate gradients/score: labels are null");
        Preconditions.checkState(labels.rank() == 4, "Expected rank 4 labels array with shape [minibatch, 4+numClasses, h, w]" +
                " but got rank %s labels array with shape %s", labels.rank(), labels.shape());

        boolean nchw = layerConf().getFormat() == CNN2DFormat.NCHW;
        INDArray input = nchw ? this.input : this.input.permute(0,3,1,2);   //NHWC to NCHW
        INDArray labels = this.labels.castTo(input.dataType());     //Ensure correct dtype (same as params); no-op if already correct dtype
        if(!nchw)
            labels = labels.permute(0,3,1,2);   //NHWC to NCHW

        double lambdaCoord = layerConf().getLambdaCoord();
        double lambdaNoObj = layerConf().getLambdaNoObj();

        long mb = input.size(0);
        long h = input.size(2);
        long w = input.size(3);
        int b = (int) layerConf().getBoundingBoxes().size(0);
        int c = (int) labels.size(1)-4;



        //Various shape arrays, to reuse
        long[] nhw = new long[]{mb, h, w};

        //Labels shape: [mb, 4+C, H, W]
        //Infer mask array from labels. Mask array is 1_i^B in YOLO paper - i.e., whether an object is present in that
        // grid location or not. Here: we are using the fact that class labels are one-hot, and assume that values are
        // all 0s if no class label is present
        Preconditions.checkState(labels.rank() == 4, "Expected labels array to be rank 4 with shape [minibatch, 4+numClasses, H, W]. Got labels array with shape %ndShape", labels);
        Preconditions.checkState(labels.size(1) > 0, "Invalid labels array: labels.size(1) must be > 4. labels array should be rank 4 with shape [minibatch, 4+numClasses, H, W]. Got labels array with shape %ndShape", labels);

        val size1 = labels.size(1);
        INDArray classLabels = labels.get(all(), interval(4,size1), all(), all());   //Shape: [minibatch, nClasses, H, W]
        INDArray maskObjectPresent = classLabels.sum(Nd4j.createUninitialized(input.dataType(), nhw, 'c'), 1);//.castTo(DataType.BOOL); //Shape: [minibatch, H, W]
        INDArray maskObjectPresentBool = maskObjectPresent.castTo(DataType.BOOL);

        // ----- Step 1: Labels format conversion -----
        //First: Convert labels/ground truth (x1,y1,x2,y2) from "coordinates (grid box units)" format to "center position in grid box" format
        //0.5 * ([x1,y1]+[x2,y2])   ->      shape: [mb, B, 2, H, W]
        INDArray labelTLXY = labels.get(all(), interval(0,2), all(), all());
        INDArray labelBRXY = labels.get(all(), interval(2,4), all(), all());

        INDArray labelCenterXY = labelTLXY.add(labelBRXY).muli(0.5);  //In terms of grid units
        INDArray labelsCenterXYInGridBox = labelCenterXY.dup(labelCenterXY.ordering());         //[mb, 2, H, W]
        labelsCenterXYInGridBox.subi(Transforms.floor(labelsCenterXYInGridBox,true));

        //Also infer size/scale (label w/h) from (x1,y1,x2,y2) format to (w,h) format
        INDArray labelWHSqrt = labelBRXY.sub(labelTLXY);
        labelWHSqrt = Transforms.sqrt(labelWHSqrt, false);



        // ----- Step 2: apply activation functions to network output activations -----
        //Reshape from [minibatch, B*(5+C), H, W] to [minibatch, B, 5+C, H, W]
        long[] expInputShape = new long[]{mb, b*(5+c), h, w};
        long[] newShape = new long[]{mb, b, 5+c, h, w};
        long newLength = ArrayUtil.prodLong(newShape);
        Preconditions.checkState(Arrays.equals(expInputShape, input.shape()), "Unable to reshape input - input array shape does not match" +
                " expected shape. Expected input shape [minibatch, B*(5+C), H, W]=%s but got input of shape %ndShape. This may be due to an incorrect nOut (layer size/channels)" +
                " for the last convolutional layer in the network. nOut of the last layer must be B*(5+C) where B is the number of" +
                " bounding boxes, and C is the number of object classes. Expected B=%s from network configuration and C=%s from labels", expInputShape, input, b, c);
        INDArray input5 = input.dup('c').reshape('c', mb, b, 5+c, h, w);
        INDArray inputClassesPreSoftmax = input5.get(all(), all(), interval(5, 5+c), all(), all());

        // Sigmoid for x/y centers
        INDArray preSigmoidPredictedXYCenterGrid = input5.get(all(), all(), interval(0,2), all(), all());
        INDArray predictedXYCenterGrid = Transforms.sigmoid(preSigmoidPredictedXYCenterGrid, true); //Not in-place, need pre-sigmoid later

        //Exponential for w/h (for: boxPrior * exp(input))      ->      Predicted WH in grid units (0 to 13 usually)
        INDArray predictedWHPreExp = input5.get(all(), all(), interval(2,4), all(), all());
        INDArray predictedWH = Transforms.exp(predictedWHPreExp, true);
        Broadcast.mul(predictedWH, layerConf().getBoundingBoxes().castTo(predictedWH.dataType()), predictedWH, 1, 2);  //Box priors: [b, 2]; predictedWH: [mb, b, 2, h, w]

        //Apply sqrt to W/H in preparation for loss function
        INDArray predictedWHSqrt = Transforms.sqrt(predictedWH, true);



        // ----- Step 3: Calculate IOU(predicted, labels) to infer 1_ij^obj mask array (for loss function) -----
        //Calculate IOU (intersection over union - aka Jaccard index) - for the labels and predicted values
        IOURet iouRet = calculateIOULabelPredicted(labelTLXY, labelBRXY, predictedWH, predictedXYCenterGrid, maskObjectPresent, maskObjectPresentBool);  //IOU shape: [minibatch, B, H, W]
        INDArray iou = iouRet.getIou();

        //Mask 1_ij^obj: isMax (dimension 1) + apply object present mask. Result: [minibatch, B, H, W]
        //In this mask: 1 if (a) object is present in cell [for each mb/H/W], AND (b) it is the box with the highest
        // IOU of any in the grid cell
        //We also need 1_ij^noobj, which is (a) no object, or (b) object present in grid cell, but this box doesn't
        // have the highest IOU
        INDArray mask1_ij_obj = Nd4j.create(DataType.BOOL, iou.shape(), 'c');
        Nd4j.exec(new IsMax(iou, mask1_ij_obj, 1));
        Nd4j.exec(new BroadcastMulOp(mask1_ij_obj, maskObjectPresentBool, mask1_ij_obj, 0,2,3));
        INDArray mask1_ij_noobj = Transforms.not(mask1_ij_obj);
        mask1_ij_obj = mask1_ij_obj.castTo(input.dataType());



        // ----- Step 4: Calculate confidence, and confidence label -----
        //Predicted confidence: sigmoid (0 to 1)
        //Label confidence: 0 if no object, IOU(predicted,actual) if an object is present
        INDArray labelConfidence = iou.mul(mask1_ij_obj);  //Need to reuse IOU array later. IOU Shape: [mb, B, H, W]
        INDArray predictedConfidencePreSigmoid = input5.get(all(), all(), point(4), all(), all());    //Shape: [mb, B, H, W]
        INDArray predictedConfidence = Transforms.sigmoid(predictedConfidencePreSigmoid, true);



        // ----- Step 5: Loss Function -----
        //One design goal here is to make the loss function configurable. To do this, we want to reshape the activations
        //(and masks) to a 2d representation, suitable for use in DL4J's loss functions

        INDArray mask1_ij_obj_2d = mask1_ij_obj.reshape(mb*b*h*w, 1);  //Must be C order before reshaping
        INDArray mask1_ij_noobj_2d = mask1_ij_obj_2d.rsub(1.0);


        INDArray predictedXYCenter2d = predictedXYCenterGrid.permute(0,1,3,4,2)  //From: [mb, B, 2, H, W] to [mb, B, H, W, 2]
                .dup('c').reshape('c', mb*b*h*w, 2);
        //Don't use INDArray.broadcast(int...) until ND4J issue is fixed: https://github.com/deeplearning4j/nd4j/issues/2066
        //INDArray labelsCenterXYInGridBroadcast = labelsCenterXYInGrid.broadcast(mb, b, 2, h, w);
        //Broadcast labelsCenterXYInGrid from [mb, 2, h, w} to [mb, b, 2, h, w]
        INDArray labelsCenterXYInGridBroadcast = Nd4j.createUninitialized(input.dataType(), new long[]{mb, b, 2, h, w}, 'c');
        for(int i=0; i 0.0) {
                scoreForExamples.addi(fullNetRegTerm);
            }

            return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, scoreForExamples);
        }

        double positionLoss = layerConf().getLossPositionScale().computeScore(labelXYCenter2d, predictedXYCenter2d, identity, mask1_ij_obj_2d, false );
        double sizeScaleLoss = layerConf().getLossPositionScale().computeScore(labelWHSqrt2d, predictedWHSqrt2d, identity, mask1_ij_obj_2d, false);
        double confidenceLoss = lossConfidence.computeScore(labelConfidence2d, predictedConfidence2d, identity, mask1_ij_obj_2d, false)
                + lambdaNoObj * lossConfidence.computeScore(labelConfidence2d, predictedConfidence2d, identity, mask1_ij_noobj_2d, false);    //TODO: possible to optimize this?
        double classPredictionLoss = layerConf().getLossClassPredictions().computeScore(classLabels2d, classPredictionsPreSoftmax2d, new ActivationSoftmax(), mask1_ij_obj_2d, false);

        this.score = lambdaCoord * (positionLoss + sizeScaleLoss) +
                confidenceLoss  +
                classPredictionLoss;

        this.score /= getInputMiniBatchSize();

        this.score += fullNetRegTerm;

        if(scoreOnly)
            return null;


        //==============================================================
        // ----- Gradient Calculation (specifically: return dL/dIn -----

        INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape(), 'c');
        INDArray epsOut5 = Shape.newShapeNoCopy(epsOut, new long[]{mb, b, 5+c, h, w}, false);
        INDArray epsClassPredictions = epsOut5.get(all(), all(), interval(5, 5+c), all(), all());    //Shape: [mb, b, 5+c, h, w]
        INDArray epsXY = epsOut5.get(all(), all(), interval(0,2), all(), all());
        INDArray epsWH = epsOut5.get(all(), all(), interval(2,4), all(), all());
        INDArray epsC = epsOut5.get(all(), all(), point(4), all(), all());


        //Calculate gradient component from class probabilities (softmax)
        //Shape: [minibatch*h*w, c]
        INDArray gradPredictionLoss2d = layerConf().getLossClassPredictions().computeGradient(classLabels2d, classPredictionsPreSoftmax2d, new ActivationSoftmax(), mask1_ij_obj_2d);
        INDArray gradPredictionLoss5d = gradPredictionLoss2d.dup('c').reshape(mb, b, h, w, c).permute(0,1,4,2,3).dup('c');
        epsClassPredictions.assign(gradPredictionLoss5d);


        //Calculate gradient component from position (x,y) loss - dL_position/dx and dL_position/dy
        INDArray gradXYCenter2d = layerConf().getLossPositionScale().computeGradient(labelXYCenter2d, predictedXYCenter2d, identity, mask1_ij_obj_2d);
        gradXYCenter2d.muli(lambdaCoord);
        INDArray gradXYCenter5d = gradXYCenter2d.dup('c')
                .reshape('c', mb, b, h, w, 2)
                .permute(0,1,4,2,3);   //From: [mb, B, H, W, 2] to [mb, B, 2, H, W]
        gradXYCenter5d = new ActivationSigmoid().backprop(preSigmoidPredictedXYCenterGrid.dup(), gradXYCenter5d).getFirst();
        epsXY.assign(gradXYCenter5d);

        //Calculate gradient component from width/height (w,h) loss - dL_size/dW and dL_size/dW
        //Note that loss function gets sqrt(w) and sqrt(h)
        //gradWHSqrt2d = dL/dsqrt(w) and dL/dsqrt(h)
        INDArray gradWHSqrt2d = layerConf().getLossPositionScale().computeGradient(labelWHSqrt2d, predictedWHSqrt2d, identity, mask1_ij_obj_2d);   //Shape: [mb*b*h*w, 2]
            //dL/dW = dL/dsqrtw * dsqrtw / dW = dL/dsqrtw * 0.5 / sqrt(w)
        INDArray gradWH2d = gradWHSqrt2d.muli(0.5).divi(predictedWHSqrt2d);  //dL/dW and dL/dH, w = pw * exp(tw)
            //dL/dinWH = dL/dW * dW/dInWH = dL/dW * pw * exp(tw)
        INDArray gradWH5d = gradWH2d.dup('c').reshape(mb, b, h, w, 2).permute(0,1,4,2,3);   //To: [mb, b, 2, h, w]
        gradWH5d.muli(predictedWH);
        gradWH5d.muli(lambdaCoord);
        epsWH.assign(gradWH5d);


        //Calculate gradient component from confidence loss... 2 parts (object present, no object present)
        INDArray gradConfidence2dA = lossConfidence.computeGradient(labelConfidence2d, predictedConfidence2d, identity, mask1_ij_obj_2d);
        INDArray gradConfidence2dB = lossConfidence.computeGradient(labelConfidence2d, predictedConfidence2d, identity, mask1_ij_noobj_2d);


        INDArray dLc_dC_2d = gradConfidence2dA.addi(gradConfidence2dB.muli(lambdaNoObj));  //dL/dC; C = sigmoid(tc)
        INDArray dLc_dzc_2d = new ActivationSigmoid().backprop( predictedConfidence2dPreSigmoid, dLc_dC_2d).getFirst();
        //Calculate dL/dtc
        INDArray epsConfidence4d = dLc_dzc_2d.dup('c').reshape('c', mb, b, h, w);   //[mb*b*h*w, 2] to [mb, b, h, w]
        epsC.assign(epsConfidence4d);





        //Note that we ALSO have components to x,y,w,h  from confidence loss (via IOU, which depends on all of these values)
        //that is: dLc/dx, dLc/dy, dLc/dW, dLc/dH
        //For any value v, d(I/U)/dv = (U * dI/dv + I * dU/dv) / U^2

        //Confidence loss: sum squared errors + masking.
        //C == IOU when label present

        //Lc = 1^(obj)*(iou - predicted)^2 + lambdaNoObj * 1^(noobj) * (iou - predicted)^2 -> dLc/diou = 2*1^(obj)*(iou-predicted) + 2 * lambdaNoObj * 1^(noobj) * (iou-predicted) = 2*(iou-predicted) * (1^(obj) + lambdaNoObj * 1^(noobj))
        INDArray twoIOUSubPredicted = iou.subi(predictedConfidence).muli(2.0);  //Shape: [mb, b, h, w]. Note that when an object is present, IOU and confidence are the same. In-place to avoid copy op (iou no longer needed)
        INDArray dLc_dIOU = twoIOUSubPredicted.muli(mask1_ij_noobj.castTo(input.dataType()).muli(lambdaNoObj).addi(mask1_ij_obj));


        INDArray dLc_dxy = Nd4j.createUninitialized(iouRet.dIOU_dxy.dataType(), iouRet.dIOU_dxy.shape(), iouRet.dIOU_dxy.ordering());
        Broadcast.mul(iouRet.dIOU_dxy, dLc_dIOU, dLc_dxy, 0, 1, 3, 4);    //[mb, b, h, w] x [mb, b, 2, h, w]

        INDArray dLc_dwh = Nd4j.createUninitialized(iouRet.dIOU_dwh.dataType(), iouRet.dIOU_dwh.shape(), iouRet.dIOU_dwh.ordering());
        Broadcast.mul(iouRet.dIOU_dwh, dLc_dIOU, dLc_dwh, 0, 1, 3, 4);    //[mb, b, h, w] x [mb, b, 2, h, w]


        //Backprop through the wh and xy activation functions...
        //dL/dW and dL/dH, w = pw * exp(tw), //dL/dinWH = dL/dW * dW/dInWH = dL/dW * pw * exp(in_w)
        //as w = pw * exp(in_w) and dW/din_w = w
        INDArray dLc_din_wh = dLc_dwh.muli(predictedWH);
        INDArray dLc_din_xy = new ActivationSigmoid().backprop(preSigmoidPredictedXYCenterGrid, dLc_dxy).getFirst();    //Shape: same as subset of input... [mb, b, 2, h, w]

        //Finally, apply masks: dLc_dwh and dLc_dxy should be 0 if no object is present in that box
        //Apply mask 1^obj_ij with shape [mb, b, h, w]
        Broadcast.mul(dLc_din_wh, mask1_ij_obj, dLc_din_wh, 0, 1, 3, 4);
        Broadcast.mul(dLc_din_xy, mask1_ij_obj, dLc_din_xy, 0, 1, 3, 4);


        epsWH.addi(dLc_din_wh);
        epsXY.addi(dLc_din_xy);

        if(!nchw)
            epsOut = epsOut.permute(0,2,3,1);   //NCHW to NHWC

        return epsOut;
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        assertInputSet(false);
        boolean nchw = layerConf().getFormat() == CNN2DFormat.NCHW;
        return YoloUtils.activate(layerConf().getBoundingBoxes(), input, nchw, workspaceMgr);
    }

    @Override
    public Layer clone() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public boolean needsLabels() {
        return true;
    }

    @Override
    public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) {
        this.fullNetRegTerm = fullNetRegTerm;

        computeBackpropGradientAndScore(workspaceMgr, true, false);
        return score();
    }

    @Override
    public double score(){
        return score;
    }

    /**
     * Calculate IOU(truth, predicted) and gradients. Returns 5d arrays [mb, b, 2, H, W]
     * ***NOTE: All labels - and predicted values - are in terms of grid units - 0 to 12 usually, with default config ***
     *
     * @param labelTL   4d [mb, 2, H, W], label top/left (x,y) in terms of grid boxes
     * @param labelBR   4d [mb, 2, H, W], label bottom/right (x,y) in terms of grid boxes
     * @param predictedWH 5d [mb, b, 2, H, W] - predicted H/W in terms of number of grid boxes.
     * @param predictedXYinGridBox 5d [mb, b, 2, H, W] - predicted X/Y in terms of number of grid boxes. Values 0 to 1, center box value being 0.5
     * @param objectPresentMask 3d [mb, H, W] - mask array, for objects present (1) or not (0) in grid cell
     * @return IOU and gradients
     */
    private static IOURet calculateIOULabelPredicted(INDArray labelTL, INDArray labelBR, INDArray predictedWH, INDArray predictedXYinGridBox, INDArray objectPresentMask, INDArray objectPresentMaskBool){

        long mb = labelTL.size(0);
        long h = labelTL.size(2);
        long w = labelTL.size(3);
        long b = predictedWH.size(1);

        INDArray labelWH = labelBR.sub(labelTL);                //4d [mb, 2, H, W], label W/H in terms of number of grid boxes

        long gridH = labelTL.size(2);
        long gridW = labelTL.size(3);
        //Add grid positions to the predicted XY values (to get predicted XY in terms of grid cell units in image,
        // from (0 to 1 in grid cell) format)
        INDArray linspaceX = Nd4j.linspace(0, gridW-1, gridW, predictedWH.dataType());
        INDArray linspaceY = Nd4j.linspace(0, gridH-1, gridH, predictedWH.dataType());
        INDArray grid = Nd4j.createUninitialized(predictedWH.dataType(), new long[]{2, gridH, gridW}, 'c');
        INDArray gridX = grid.get(point(0), all(), all());
        INDArray gridY = grid.get(point(1), all(), all());
        Broadcast.copy(gridX, linspaceX, gridX, 1);
        Broadcast.copy(gridY, linspaceY, gridY, 0);

        //Calculate X/Y position overall (in grid box units) from "position in current grid box" format
        INDArray predictedXY = predictedXYinGridBox.ulike();;
        Broadcast.add(predictedXYinGridBox, grid, predictedXY, 2,3,4); // [2, H, W] to [mb, b, 2, H, W]


        INDArray halfWH = predictedWH.mul(0.5);
        INDArray predictedTL_XY = halfWH.rsub(predictedXY);     //xy - 0.5 * wh
        INDArray predictedBR_XY = halfWH.add(predictedXY);      //xy + 0.5 * wh

        INDArray maxTL = predictedTL_XY.ulike();   //Shape: [mb, b, 2, H, W]
        Broadcast.max(predictedTL_XY, labelTL, maxTL, 0, 2, 3, 4);
        INDArray minBR = predictedBR_XY.ulike();
        Broadcast.min(predictedBR_XY, labelBR, minBR, 0, 2, 3, 4);

        INDArray diff = minBR.sub(maxTL);
        INDArray intersectionArea = diff.prod(2);   //[mb, b, 2, H, W] to [mb, b, H, W]
        Broadcast.mul(intersectionArea, objectPresentMask, intersectionArea, 0, 2, 3);

        //Need to mask the calculated intersection values, to avoid returning non-zero values when intersection is actually 0
        //No intersection if: xP + wP/2 < xL - wL/2 i.e., BR_xPred < TL_xLab   OR  TL_xPred > BR_xLab (similar for Y axis)
        //Here, 1 if intersection exists, 0 otherwise. This is doing x/w and y/h simultaneously
        INDArray noIntMask1 = Nd4j.createUninitialized(DataType.BOOL, maxTL.shape(), maxTL.ordering());
        INDArray noIntMask2 = Nd4j.createUninitialized(DataType.BOOL, maxTL.shape(), maxTL.ordering());
        //Does both x and y on different dims
        Broadcast.lte(predictedBR_XY, labelTL, noIntMask1, 0, 2, 3, 4);  //Predicted BR <= label TL
        Broadcast.gte(predictedTL_XY, labelBR, noIntMask2, 0, 2, 3, 4);  //predicted TL >= label BR

        noIntMask1 = Transforms.or(noIntMask1.get(all(), all(), point(0), all(), all()), noIntMask1.get(all(), all(), point(1), all(), all()) );    //Shape: [mb, b, H, W]. Values 1 if no intersection
        noIntMask2 = Transforms.or(noIntMask2.get(all(), all(), point(0), all(), all()), noIntMask2.get(all(), all(), point(1), all(), all()) );
        INDArray noIntMask = Transforms.or(noIntMask1, noIntMask2 );

        INDArray intMask = Transforms.not(noIntMask); //Values 0 if no intersection
        Broadcast.mul(intMask, objectPresentMaskBool, intMask, 0, 2, 3);

        //Mask the intersection area: should be 0 if no intersection
        intMask = intMask.castTo(predictedWH.dataType());
        intersectionArea.muli(intMask);


        //Next, union area is simple: U = A1 + A2 - intersection
        INDArray areaPredicted = predictedWH.prod(2);   //[mb, b, 2, H, W] to [mb, b, H, W]
        Broadcast.mul(areaPredicted, objectPresentMask, areaPredicted, 0,2,3);
        INDArray areaLabel = labelWH.prod(1);           //[mb, 2, H, W] to [mb, H, W]

        INDArray unionArea = Broadcast.add(areaPredicted, areaLabel, areaPredicted.dup(), 0, 2, 3);
        unionArea.subi(intersectionArea);
        unionArea.muli(intMask);

        INDArray iou = intersectionArea.div(unionArea);
        BooleanIndexing.replaceWhere(iou, 0.0, Conditions.isNan()); //0/0 -> NaN -> 0

        //Apply the "object present" mask (of shape [mb, h, w]) - this ensures IOU is 0 if no object is present
        Broadcast.mul(iou, objectPresentMask, iou, 0, 2, 3);

        //Finally, calculate derivatives:
        INDArray maskMaxTL = Nd4j.createUninitialized(DataType.BOOL, maxTL.shape(), maxTL.ordering());    //1 if predicted Top/Left is max, 0 otherwise
        Broadcast.gt(predictedTL_XY, labelTL, maskMaxTL, 0, 2, 3, 4);   // z = x > y
        maskMaxTL = maskMaxTL.castTo(predictedWH.dataType());

        INDArray maskMinBR = Nd4j.createUninitialized(DataType.BOOL, maxTL.shape(), maxTL.ordering());    //1 if predicted Top/Left is max, 0 otherwise
        Broadcast.lt(predictedBR_XY, labelBR, maskMinBR, 0, 2, 3, 4);   // z = x < y
        maskMinBR = maskMinBR.castTo(predictedWH.dataType());

        //dI/dx = lambda * (1^(min(x1+w1/2) - 1^(max(x1-w1/2))
        //dI/dy = omega * (1^(min(y1+h1/2) - 1^(max(y1-h1/2))
        //omega = min(x1+w1/2,x2+w2/2) - max(x1-w1/2,x2+w2/2)       i.e., from diff = minBR.sub(maxTL), which has shape [mb, b, 2, h, w]
        //lambda = min(y1+h1/2,y2+h2/2) - max(y1-h1/2,y2+h2/2)
        INDArray dI_dxy = maskMinBR.sub(maskMaxTL);              //Shape: [mb, b, 2, h, w]
        INDArray dI_dwh = maskMinBR.addi(maskMaxTL).muli(0.5);    //Shape: [mb, b, 2, h, w]

        dI_dxy.get(all(), all(), point(0), all(), all()).muli(diff.get(all(), all(), point(1), all(), all()));
        dI_dxy.get(all(), all(), point(1), all(), all()).muli(diff.get(all(), all(), point(0), all(), all()));

        dI_dwh.get(all(), all(), point(0), all(), all()).muli(diff.get(all(), all(), point(1), all(), all()));
        dI_dwh.get(all(), all(), point(1), all(), all()).muli(diff.get(all(), all(), point(0), all(), all()));

        //And derivatives WRT IOU:
        INDArray uPlusI = unionArea.add(intersectionArea);
        INDArray u2 = unionArea.mul(unionArea);
        INDArray uPlusIDivU2 = uPlusI.div(u2);   //Shape: [mb, b, h, w]
        BooleanIndexing.replaceWhere(uPlusIDivU2, 0.0, Conditions.isNan());     //Handle 0/0

        INDArray dIOU_dxy = Nd4j.createUninitialized(predictedWH.dataType(), new long[]{mb, b, 2, h, w}, 'c');
        Broadcast.mul(dI_dxy, uPlusIDivU2, dIOU_dxy, 0, 1, 3, 4);   //[mb, b, h, w] x [mb, b, 2, h, w]

        INDArray predictedHW = Nd4j.createUninitialized(predictedWH.dataType(), new long[]{mb, b, 2, h, w}, predictedWH.ordering());
        //Next 2 lines: permuting the order... WH to HW along dimension 2
        predictedHW.get(all(), all(), point(0), all(), all()).assign(predictedWH.get(all(), all(), point(1), all(), all()));
        predictedHW.get(all(), all(), point(1), all(), all()).assign(predictedWH.get(all(), all(), point(0), all(), all()));

        INDArray Ihw = predictedHW.ulike();;
        Broadcast.mul(predictedHW, intersectionArea, Ihw, 0, 1, 3, 4 );    //Predicted_wh: [mb, b, 2, h, w]; intersection: [mb, b, h, w]

        INDArray dIOU_dwh = Nd4j.createUninitialized(predictedHW.dataType(), new long[]{mb, b, 2, h, w}, 'c');
        Broadcast.mul(dI_dwh, uPlusI, dIOU_dwh, 0, 1, 3, 4);
        dIOU_dwh.subi(Ihw);
        Broadcast.div(dIOU_dwh, u2, dIOU_dwh, 0, 1, 3, 4);
        BooleanIndexing.replaceWhere(dIOU_dwh, 0.0, Conditions.isNan());     //Handle division by 0 (due to masking, etc)

        return new IOURet(iou, dIOU_dxy, dIOU_dwh);
    }


    @AllArgsConstructor
    @Data
    private static class IOURet {
        private INDArray iou;
        private INDArray dIOU_dxy;
        private INDArray dIOU_dwh;

    }

    @Override
    public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr){

        //TODO
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public Pair gradientAndScore() {
        return new Pair<>(gradient(), score());
    }

    @Override
    public INDArray computeScoreForExamples(double fullNetRegTerm, LayerWorkspaceMgr workspaceMgr) {
        this.fullNetRegTerm = fullNetRegTerm;
        return computeBackpropGradientAndScore(workspaceMgr, false, true);
    }

    @Override
    public double f1Score(DataSet data) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public double f1Score(INDArray examples, INDArray labels) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public int numLabels() {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void fit(DataSetIterator iter) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public int[] predict(INDArray examples) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public List predict(DataSet dataSet) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void fit(INDArray examples, INDArray labels) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void fit(DataSet data) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void fit(INDArray examples, int[] labels) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    public void clearNoiseWeightParams() {
        //No op
    }

    /** @see YoloUtils#getPredictedObjects(INDArray, INDArray, double, double) */
    public List getPredictedObjects(INDArray networkOutput, double threshold){
        return YoloUtils.getPredictedObjects(layerConf().getBoundingBoxes(), networkOutput, threshold, 0.0);
    }

    /**
     * Get the confidence matrix (confidence for all x/y positions) for the specified bounding box, from the network
     * output activations array
     *
     * @param networkOutput Network output activations
     * @param example       Example number, in minibatch
     * @param bbNumber      Bounding box number
     * @return Confidence matrix
     */
    public INDArray getConfidenceMatrix(INDArray networkOutput, int example, int bbNumber){

        //Input format: [minibatch, 5B+C, H, W], with order [x,y,w,h,c]
        //Therefore: confidences are at depths 4 + bbNumber * 5

        INDArray conf = networkOutput.get(point(example), point(4+bbNumber*5), all(), all());
        return conf;
    }

    /**
     * Get the probability matrix (probability of the specified class, assuming an object is present, for all x/y
     * positions), from the network output activations array
     *
     * @param networkOutput Network output activations
     * @param example       Example number, in minibatch
     * @param classNumber   Class number
     * @return Confidence matrix
     */
    public INDArray getProbabilityMatrix(INDArray networkOutput, int example, int classNumber){
        //Input format: [minibatch, 5B+C, H, W], with order [x,y,w,h,c]
        //Therefore: probabilities for class I is at depths 5B + classNumber

        val bbs = layerConf().getBoundingBoxes().size(0);
        INDArray conf = networkOutput.get(point(example), point(5*bbs + classNumber), all(), all());
        return conf;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy