Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.deeplearning4j.nn.layers.objdetect.YoloUtils Maven / Gradle / Ivy
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.layers.objdetect;
import lombok.NonNull;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import static org.nd4j.linalg.indexing.NDArrayIndex.*;
public class YoloUtils {
/** Essentially: just apply activation functions... For NCHW format. For NCHW format, use one of the other activate methods */
public static INDArray activate(INDArray boundingBoxPriors, INDArray input) {
return activate(boundingBoxPriors, input, true);
}
public static INDArray activate(INDArray boundingBoxPriors, INDArray input, boolean nchw) {
return activate(boundingBoxPriors, input, nchw, LayerWorkspaceMgr.noWorkspaces());
}
public static INDArray activate(@NonNull INDArray boundingBoxPriors, @NonNull INDArray input, LayerWorkspaceMgr layerWorkspaceMgr) {
return activate(boundingBoxPriors, input, true, layerWorkspaceMgr);
}
public static INDArray activate(@NonNull INDArray boundingBoxPriors, @NonNull INDArray input, boolean nchw, LayerWorkspaceMgr layerWorkspaceMgr){
if(!nchw)
input = input.permute(0,3,1,2); //NHWC to NCHW
long mb = input.size(0);
long h = input.size(2);
long w = input.size(3);
long b = boundingBoxPriors.size(0);
long c = input.size(1)/b-5; //input.size(1) == b * (5 + C) -> C = (input.size(1)/b) - 5
INDArray output = layerWorkspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), input.shape(), 'c');
INDArray output5 = output.reshape('c', mb, b, 5+c, h, w);
INDArray output4 = output; //output.get(all(), interval(0,5*b), all(), all());
INDArray input4 = input.dup('c'); //input.get(all(), interval(0,5*b), all(), all()).dup('c');
INDArray input5 = input4.reshape('c', mb, b, 5+c, h, w);
//X/Y center in grid: sigmoid
INDArray predictedXYCenterGrid = input5.get(all(), all(), interval(0,2), all(), all());
Transforms.sigmoid(predictedXYCenterGrid, false);
//width/height: prior * exp(input)
INDArray predictedWHPreExp = input5.get(all(), all(), interval(2,4), all(), all());
INDArray predictedWH = Transforms.exp(predictedWHPreExp, false);
Broadcast.mul(predictedWH, boundingBoxPriors.castTo(input.dataType()), predictedWH, 1, 2); //Box priors: [b, 2]; predictedWH: [mb, b, 2, h, w]
//Confidence - sigmoid
INDArray predictedConf = input5.get(all(), all(), point(4), all(), all()); //Shape: [mb, B, H, W]
Transforms.sigmoid(predictedConf, false);
output4.assign(input4);
//Softmax
//TODO OPTIMIZE?
INDArray inputClassesPreSoftmax = input5.get(all(), all(), interval(5, 5+c), all(), all()); //Shape: [minibatch, C, H, W]
INDArray classPredictionsPreSoftmax2d = inputClassesPreSoftmax.permute(0,1,3,4,2) //[minibatch, b, c, h, w] To [mb, b, h, w, c]
.dup('c').reshape('c', new long[]{mb*b*h*w, c});
Transforms.softmax(classPredictionsPreSoftmax2d, false);
INDArray postSoftmax5d = classPredictionsPreSoftmax2d.reshape('c', mb, b, h, w, c ).permute(0, 1, 4, 2, 3);
INDArray outputClasses = output5.get(all(), all(), interval(5, 5+c), all(), all()); //Shape: [minibatch, C, H, W]
outputClasses.assign(postSoftmax5d);
if(!nchw)
output = output.permute(0,2,3,1); //NCHW to NHWC
return output;
}
/** Returns overlap between lines [x1, x2] and [x3. x4]. */
public static double overlap(double x1, double x2, double x3, double x4) {
if (x3 < x1) {
if (x4 < x1) {
return 0;
} else {
return Math.min(x2, x4) - x1;
}
} else {
if (x2 < x3) {
return 0;
} else {
return Math.min(x2, x4) - x3;
}
}
}
/** Returns intersection over union (IOU) between o1 and o2. */
public static double iou(DetectedObject o1, DetectedObject o2) {
double x1min = o1.getCenterX() - o1.getWidth() / 2;
double x1max = o1.getCenterX() + o1.getWidth() / 2;
double y1min = o1.getCenterY() - o1.getHeight() / 2;
double y1max = o1.getCenterY() + o1.getHeight() / 2;
double x2min = o2.getCenterX() - o2.getWidth() / 2;
double x2max = o2.getCenterX() + o2.getWidth() / 2;
double y2min = o2.getCenterY() - o2.getHeight() / 2;
double y2max = o2.getCenterY() + o2.getHeight() / 2;
double ow = overlap(x1min, x1max, x2min, x2max);
double oh = overlap(y1min, y1max, y2min, y2max);
double intersection = ow * oh;
double union = o1.getWidth() * o1.getHeight() + o2.getWidth() * o2.getHeight() - intersection;
return intersection / union;
}
/** Performs non-maximum suppression (NMS) on objects, using their IOU with threshold to match pairs. */
public static void nms(List objects, double iouThreshold) {
for (int i = 0; i < objects.size(); i++) {
for (int j = 0; j < objects.size(); j++) {
DetectedObject o1 = objects.get(i);
DetectedObject o2 = objects.get(j);
if (o1 != null && o2 != null
&& o1.getPredictedClass() == o2.getPredictedClass()
&& o1.getConfidence() < o2.getConfidence()
&& iou(o1, o2) > iouThreshold) {
objects.set(i, null);
}
}
}
Iterator it = objects.iterator();
while (it.hasNext()) {
if (it.next() == null) {
it.remove();
}
}
}
/**
* Given the network output and a detection threshold (in range 0 to 1) determine the objects detected by
* the network.
* Supports minibatches - the returned {@link DetectedObject} instances have an example number index.
*
* Note that the dimensions are grid cell units - for example, with 416x416 input, 32x downsampling by the network
* (before getting to the Yolo2OutputLayer) we have 13x13 grid cells (each corresponding to 32 pixels in the input
* image). Thus, a centerX of 5.5 would be xPixels=5.5x32 = 176 pixels from left. Widths and heights are similar:
* in this example, a with of 13 would be the entire image (416 pixels), and a height of 6.5 would be 6.5/13 = 0.5
* of the image (208 pixels).
*
* @param boundingBoxPriors as given to Yolo2OutputLayer
* @param networkOutput 4d activations out of the network
* @param confThreshold Detection threshold, in range 0.0 (least strict) to 1.0 (most strict). Objects are returned
* where predicted confidence is >= confThreshold
* @param nmsThreshold passed to {@link #nms(List, double)} (0 == disabled) as the threshold for intersection over union (IOU)
* @return List of detected objects
*/
public static List getPredictedObjects(INDArray boundingBoxPriors, INDArray networkOutput, double confThreshold, double nmsThreshold){
if(networkOutput.rank() != 4){
throw new IllegalStateException("Invalid network output activations array: should be rank 4. Got array "
+ "with shape " + Arrays.toString(networkOutput.shape()));
}
if(confThreshold < 0.0 || confThreshold > 1.0){
throw new IllegalStateException("Invalid confidence threshold: must be in range [0,1]. Got: " + confThreshold);
}
//Activations format: [mb, 5b+c, h, w]
long mb = networkOutput.size(0);
long h = networkOutput.size(2);
long w = networkOutput.size(3);
long b = boundingBoxPriors.size(0);
long c = (networkOutput.size(1)/b)-5; //input.size(1) == b * (5 + C) -> C = (input.size(1)/b) - 5
//Reshape from [minibatch, B*(5+C), H, W] to [minibatch, B, 5+C, H, W] to [minibatch, B, 5, H, W]
INDArray output5 = networkOutput.dup('c').reshape(mb, b, 5+c, h, w);
INDArray predictedConfidence = output5.get(all(), all(), point(4), all(), all()); //Shape: [mb, B, H, W]
INDArray softmax = output5.get(all(), all(), interval(5, 5+c), all(), all());
List out = new ArrayList<>();
for( int i=0; i 0) {
nms(out, nmsThreshold);
}
return out;
}
}