All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.Nms.scala Maven / Gradle / Ivy

/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.tensor.Tensor

/**
 * Non-Maximum Suppression (nms) for Object Detection
 * The goal of nms is to solve the problem that groups of several detections near the real location,
 * ideally obtaining only one detection per object
 */
class Nms extends Serializable {

  @transient private var areas: Array[Float] = _
  @transient private var sortedScores: Tensor[Float] = _
  @transient private var sortedInds: Tensor[Float] = _
  @transient private var sortIndBuffer: Array[Int] = _
  @transient private var suppressed: Array[Int] = _

  private def init(size: Int): Unit = {
    if (suppressed == null || suppressed.length < size) {
      suppressed = new Array[Int](size)
      sortIndBuffer = new Array[Int](size)
      areas = new Array[Float](size)
    } else {
      var i = 0
      while (i < size) {
        suppressed(i) = 0
        i += 1
      }
    }
    if (sortedScores == null) {
      sortedScores = Tensor[Float]
      sortedInds = Tensor[Float]
    }
  }

  /**
   * 1. first sort the scores from highest to lowest and get indices
   * 2. for the bbox of first index,
   * get the overlap between this box and the remaining bboxes
   * put the first index to result buffer
   * 3. update the indices by keeping those bboxes with overlap less than thresh
   * 4. repeat 2 and 3 until the indices are empty
   * @param scores  score tensor
   * @param boxes   box tensor, with the size N*4
   * @param thresh  overlap thresh
   * @param indices buffer to store indices after nms
   * @param sorted whether the scores are sorted
   * @param orderWithBBox whether return indices in the order with original bbox index
   * @return the length of indices after nms
   */
  def nms(scores: Tensor[Float], boxes: Tensor[Float], thresh: Float,
    indices: Array[Int], sorted: Boolean = false,
    orderWithBBox: Boolean = false): Int = {
    if (scores.nElement() == 0) return 0
    require(indices.length >= scores.nElement() && boxes.size(2) == 4)

    init(scores.nElement())
    val boxArray = boxes.storage().array()
    val offset = boxes.storageOffset() - 1
    val rowLength = boxes.stride(1)
    getAreas(boxArray, offset, rowLength, boxes.size(1), areas)
    // indices start from 0
    val orderLength = if (!sorted) {
      getSortedScoreInds(scores, sortIndBuffer)
    } else {
      var i = 0
      while ( i < scores.nElement()) {
        sortIndBuffer(i) = i
        i += 1
      }
      scores.nElement()
    }
    var indexLenth = 0
    var i = 0
    var curInd = 0

    while (i < orderLength) {
      curInd = sortIndBuffer(i)
      if (suppressed(curInd) != 1) {
        indices(indexLenth) = curInd + 1
        indexLenth += 1
        var k = i + 1
        while (k < orderLength) {
          if (suppressed(sortIndBuffer(k)) != 1 &&
            isOverlapRatioGtThresh(boxArray, offset, rowLength,
              areas, curInd, sortIndBuffer(k), thresh)) {
            suppressed(sortIndBuffer(k)) = 1
          }
          k += 1
        }
      }

      i += 1
    }

    // use suppressed
    if (orderWithBBox) {
      var j = 0
      for (i <- 0 to (orderLength - 1)) {
        if (suppressed(i) == 0) {
          indices(j) = i + 1
          j += 1
        }
      }
    }
    indexLenth
  }

  def isKeepCurIndex(boxArray: Array[Float], offset: Int, rowLength: Int, areas: Array[Float],
    curInd: Int, adaptiveThresh: Float, indices: Array[Int], indexLength: Int,
    normalized: Boolean): Boolean = {
    var keep = true
    var k = 0
    while (k < indexLength) {
      if (keep) {
        val keptInd = indices(k) - 1
        keep = !isOverlapRatioGtThresh(boxArray, offset, rowLength, areas, curInd,
          keptInd, adaptiveThresh, normalized)
        k += 1
      } else {
        return false
      }
    }
    keep
  }

  def nmsFast(scores: Tensor[Float], boxes: Tensor[Float], nmsThresh: Float, scoreThresh: Float,
    indices: Array[Int], topk: Int = -1, eta: Float = 1, normalized: Boolean = true): Int = {
    init(scores.nElement())
    val boxArray = boxes.storage().array()
    val offset = boxes.storageOffset() - 1
    val rowLength = boxes.stride(1)
    getAreas(boxArray, offset, rowLength, boxes.size(1), areas, normalized)
    var adaptiveThresh = nmsThresh
    val orderLength = getSortedScoreInds(scores, sortIndBuffer, scoreThresh, topk)
    var i = 0
    var curInd = 0
    var indexLength = 0
    while (i < orderLength) {
      curInd = sortIndBuffer(i)

      val keep = isKeepCurIndex(boxArray, offset, rowLength, areas, curInd,
        adaptiveThresh, indices, indexLength, normalized)
      if (keep) {
        indices(indexLength) = curInd + 1
        indexLength += 1
      }
      if (keep && eta < 1 && adaptiveThresh > 0.5) {
        adaptiveThresh *= eta
      }
      i += 1
    }
    indexLength
  }

  private def getSortedScoreInds(scores: Tensor[Float], resultBuffer: Array[Int],
    scoreThresh: Float = 0, topK: Int = -1): Int = {
    var num = 0
    if (scoreThresh > 0) {
      scores.apply1(x => if (x < scoreThresh) {
        0f
      } else {
        num += 1
        x
      })
    } else {
      num = scores.nElement()
    }
    if (topK > 0) num = Math.min(topK, num)
    if (num == 0) return num
    // note that when the score is the same,
    // the order of the indices are different in python and here
    scores.topk(num, dim = 1, increase = false, result = sortedScores,
      indices = sortedInds
    )

    var i = 0
    while (i < num) {
      resultBuffer(i) = sortedInds.valueAt(i + 1).toInt - 1
      i += 1
    }
    num
  }

  private def getAreas(boxesArr: Array[Float], offset: Int, rowLength: Int, total: Int,
    areas: Array[Float], normalized: Boolean = false): Array[Float] = {
    var i = 0
    while (i < total) {
      val x1 = boxesArr(offset + rowLength * i)
      val y1 = boxesArr(offset + 1 + rowLength * i)
      val x2 = boxesArr(offset + 2 + rowLength * i)
      val y2 = boxesArr(offset + 3 + rowLength * i)
      areas(i) = if (!normalized) {
        (x2 - x1 + 1) * (y2 - y1 + 1)
      } else {
        // If bbox is within range [0, 1].
        (x2 - x1) * (y2 - y1)
      }
      i += 1
    }
    areas
  }

  private def isOverlapRatioGtThresh(boxArr: Array[Float], offset: Int, rowLength: Int,
    areas: Array[Float], ind: Int, ind2: Int, thresh: Float,
    normalized: Boolean = false): Boolean = {
    val b1x1 = boxArr(offset + 2 + rowLength * ind2)
    val b1x2 = boxArr(offset + rowLength * ind2)
    val b2x1 = boxArr(offset + 2 + rowLength * ind)
    val b2x2 = boxArr(offset + rowLength * ind)
    val w = if (normalized) math.min(b1x1, b2x1) - math.max(b1x2, b2x2)
    else math.min(b1x1, b2x1) - math.max(b1x2, b2x2) + 1
    if (w < 0) return false

    val b1y1 = boxArr(offset + 3 + rowLength * ind2)
    val b1y2 = boxArr(offset + 1 + rowLength * ind2)
    val b2y1 = boxArr(offset + 3 + rowLength * ind)
    val b2y2 = boxArr(offset + 1 + rowLength * ind)
    val h = if (normalized) math.min(b1y1, b2y1) - math.max(b1y2, b2y2)
    else math.min(b1y1, b2y1) - math.max(b1y2, b2y2) + 1
    if (h < 0) return false

    val overlap = w * h
    overlap / ((areas(ind2) + areas(ind)) - overlap) > thresh
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy