All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.Threshold.scala Maven / Gradle / Ivy

/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.nn.abstractnn.{IdentityOutputShape, TensorModule}
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.tensor._
import com.intel.analytics.bigdl.utils.{Engine, Shape}

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Future
import scala.reflect.ClassTag

/**
 * Threshold input Tensor.
 * If values in the Tensor smaller than th, then replace it with v
 *
 * @param th the threshold to compare with
 * @param v the value to replace with
 * @param ip inplace mode
 */

@SerialVersionUID(3953292249027271493L)
class Threshold[T: ClassTag](
  private val th: Double = 1e-6, private val v: Double = 0.0, private val ip: Boolean = false)(
  implicit ev: TensorNumeric[T]) extends TensorModule[T] {
  var threshold = th
  var value = v
  var inPlace = ip
  validateParameters()

  override def updateOutput(input: Tensor[T]): Tensor[T] = {
    require(input.isContiguous())
    validateParameters()

    val taskSize = input.nElement() / Engine.model.getPoolSize
    var extraTaskSize = input.nElement() % Engine.model.getPoolSize
    var allocated = 0
    val tasks = new ArrayBuffer[(Int, Int)]()
    while (allocated < input.nElement()) {
      val end = math.min(input.nElement(), if (extraTaskSize > 0) {
        extraTaskSize -= 1
        allocated + taskSize + 1
      } else {
        allocated + taskSize
      })
      tasks += ((allocated, end))
      allocated = end
    }

    val taskArray = tasks.toArray
    val results = new Array[Future[Unit]](taskArray.length)

    if (inPlace) {
      output = input
      ev.getType() match {
        case DoubleType =>
          val inputDouble = input.asInstanceOf[Tensor[Double]]
          val inputData = inputDouble.storage().array()
          val inputOffset = inputDouble.storageOffset() - 1

          var t = 0
          while (t < taskArray.length) {
            val _t = t
            results(_t) = Engine.model.invoke(() => {
              var i = taskArray(_t)._1
              while (i < taskArray(_t)._2) {
                inputData(inputOffset + i) =
                  if (inputData(inputOffset + i) <= threshold) {
                    value
                  } else {
                    inputData(inputOffset + i)
                  }
                i += 1
              }
            })
            t += 1
          }
        case FloatType =>
          val inputDouble = input.asInstanceOf[Tensor[Float]]
          val inputData = inputDouble.storage().array()
          val inputOffset = inputDouble.storageOffset() - 1

          val valueFloat = value.toFloat
          var t = 0
          while (t < taskArray.length) {
            val _t = t
            results(_t) = Engine.model.invoke(() => {
              var i = taskArray(_t)._1
              while (i < taskArray(_t)._2) {
                inputData(inputOffset + i) =
                  if (inputData(inputOffset + i) <= threshold) {
                    valueFloat
                  } else {
                    inputData(inputOffset + i)
                  }
                i += 1
              }
            })
            t += 1
          }
        case _ => throw new UnsupportedOperationException(s"Only Float/Double supported")
      }
      input
    }
    else {
      ev.getType() match {
        case DoubleType =>
          output.asInstanceOf[Tensor[Double]].resizeAs(input.asInstanceOf[Tensor[Double]])

          val inputDouble = input.asInstanceOf[Tensor[Double]]
          val inputData = inputDouble.storage().array()
          val inputOffset = inputDouble.storageOffset() - 1
          val outputDouble = output.asInstanceOf[Tensor[Double]]
          val outputData = outputDouble.storage().array()
          val outputOffset = outputDouble.storageOffset() - 1

          var t = 0
          while (t < taskArray.length) {
            val _t = t
            results(_t) = Engine.model.invoke(() => {
              var i = taskArray(_t)._1
              while (i < taskArray(_t)._2) {
                outputData(outputOffset + i) =
                  if (inputData(inputOffset + i) > threshold) {
                    inputData(inputOffset + i)
                  } else {
                    value
                  }
                i += 1
              }
            })
            t += 1
          }
        case FloatType =>
          output.asInstanceOf[Tensor[Float]].resizeAs(input.asInstanceOf[Tensor[Float]])

          val inputFloat = input.asInstanceOf[Tensor[Float]]
          val inputData = inputFloat.storage().array()
          val inputOffset = inputFloat.storageOffset() - 1
          val outputFloat = output.asInstanceOf[Tensor[Float]]
          val outputData = outputFloat.storage().array()
          val outputOffset = outputFloat.storageOffset() - 1

          val valueFloat = value.toFloat
          var t = 0
          while (t < taskArray.length) {
            val _t = t
            results(_t) = Engine.model.invoke(() => {
              var i = taskArray(_t)._1
              while (i < taskArray(_t)._2) {
                outputData(outputOffset + i) =
                  if (inputData(inputOffset + i) > threshold) {
                    inputData(inputOffset + i)
                  } else {
                    valueFloat
                  }
                i += 1
              }
            })
            t += 1
          }
        case _ => throw new UnsupportedOperationException(s"Only Float/Double supported")
      }
    }
    Engine.model.sync(results)
    output
  }

  private def updateGradInputNoContinuous(input: Tensor[T], gradOutput: Tensor[T]): Tensor[T] = {
    validateParameters()
    if (inPlace) {
      gradInput = gradOutput
      ev.getType() match {
        case DoubleType =>
          gradInput.asInstanceOf[Tensor[Double]].map(input.asInstanceOf[Tensor[Double]], (g, i) =>
            if (i <= threshold) 0 else g)
        case FloatType =>
          gradInput.asInstanceOf[Tensor[Float]].map(input.asInstanceOf[Tensor[Float]], (g, i) =>
            if (i <= threshold) 0 else g)
        case _ =>
          throw new UnsupportedOperationException(s"Only Float/Double supported")
      }
    }
    else {
      gradInput.resizeAs(gradOutput)
      gradInput.copy(gradOutput)
      ev.getType() match {
        case DoubleType =>
          gradInput.asInstanceOf[Tensor[Double]].map(input.asInstanceOf[Tensor[Double]], (g, i) =>
            if (i > threshold) g else 0)
        case FloatType =>
          gradInput.asInstanceOf[Tensor[Float]].map(input.asInstanceOf[Tensor[Float]], (g, i) =>
            if (i > threshold) g else 0)
        case _ => throw new UnsupportedOperationException(s"Only Float/Double supported")
      }
    }
    gradInput
  }

  override def updateGradInput(input: Tensor[T], gradOutput: Tensor[T]): Tensor[T] = {
    validateParameters()

    var i = 1
    while (i <= input.nDimension()) {
      if (input.stride(i) != gradOutput.stride(i)) {
        return updateGradInputNoContinuous(input, gradOutput)
      }
      i += 1
    }

    val taskSize = gradOutput.nElement() / Engine.model.getPoolSize
    var extraTaskSize = gradOutput.nElement() % Engine.model.getPoolSize
    var allocated = 0
    val tasks = new ArrayBuffer[(Int, Int)]()
    while (allocated < gradOutput.nElement()) {
      val end = math.min(gradOutput.nElement(), if (extraTaskSize > 0) {
        extraTaskSize -= 1
        allocated + taskSize + 1
      } else {
        allocated + taskSize
      })
      tasks += ((allocated, end))
      allocated = end
    }

    val taskArray = tasks.toArray
    val results = new Array[Future[Unit]](taskArray.length)

    if (inPlace) {
      gradInput = gradOutput
      ev.getType() match {
        case DoubleType =>
          val gradInputDouble = gradInput.asInstanceOf[Tensor[Double]]
          val inputDouble = input.asInstanceOf[Tensor[Double]]
          val gradInputData = gradInputDouble.storage().array()
          val gradInputOffset = gradInputDouble.storageOffset() - 1
          val inputData = inputDouble.storage().array()
          val inputOffset = inputDouble.storageOffset() - 1

          var t = 0
          while (t < taskArray.length) {
            val _t = t
            results(_t) = Engine.model.invoke(() => {
              var i = taskArray(_t)._1
              while (i < taskArray(_t)._2) {
                gradInputData(gradInputOffset + i) =
                  if (inputData(inputOffset + i) <= threshold) {
                    0.0
                  } else {
                    gradInputData(gradInputOffset + i)
                  }
                i += 1
              }
            })
            t += 1
          }

        case FloatType =>
          val gradInputFloat = gradInput.asInstanceOf[Tensor[Float]]
          val inputFloat = input.asInstanceOf[Tensor[Float]]
          val gradInputData = gradInputFloat.storage().array()
          val gradInputOffset = gradInputFloat.storageOffset() - 1
          val inputData = inputFloat.storage().array()
          val inputOffset = inputFloat.storageOffset() - 1

          var t = 0
          while (t < taskArray.length) {
            val _t = t
            results(_t) = Engine.model.invoke(() => {
              var i = taskArray(_t)._1
              while (i < taskArray(_t)._2) {
                gradInputData(gradInputOffset + i) =
                  if (inputData(inputOffset + i) <= threshold) {
                    0.0f
                  } else {
                    gradInputData(gradInputOffset + i)
                  }
                i += 1
              }
            })
            t += 1
          }
        case _ => throw new UnsupportedOperationException(s"Only Float/Double supported")
      }
    }
    else {
      ev.getType() match {
        case DoubleType =>
          gradInput.asInstanceOf[Tensor[Double]].resizeAs(gradOutput.asInstanceOf[Tensor[Double]])
          gradInput.asInstanceOf[Tensor[Double]].copy(gradOutput.asInstanceOf[Tensor[Double]])
          val gradInputDouble = gradInput.asInstanceOf[Tensor[Double]]
          val inputDouble = input.asInstanceOf[Tensor[Double]]
          val gradInputData = gradInputDouble.storage().array()
          val gradInputOffset = gradInputDouble.storageOffset() - 1
          val inputData = inputDouble.storage().array()
          val inputOffset = inputDouble.storageOffset() - 1

          var t = 0
          while (t < taskArray.length) {
            val _t = t
            results(_t) = Engine.model.invoke(() => {
              var i = taskArray(_t)._1
              while (i < taskArray(_t)._2) {
                gradInputData(gradInputOffset + i) =
                  if (inputData(inputOffset + i) <= threshold) {
                    0.0
                  } else {
                    gradInputData(gradInputOffset + i)
                  }
                i += 1
              }
            })
            t += 1
          }
        case FloatType =>
          gradInput.asInstanceOf[Tensor[Float]].resizeAs(gradOutput.asInstanceOf[Tensor[Float]])
          gradInput.asInstanceOf[Tensor[Float]].copy(gradOutput.asInstanceOf[Tensor[Float]])
          val gradInputFloat = gradInput.asInstanceOf[Tensor[Float]]
          val inputFloat = input.asInstanceOf[Tensor[Float]]
          val gradInputData = gradInputFloat.storage().array()
          val gradInputOffset = gradInputFloat.storageOffset() - 1
          val inputData = inputFloat.storage().array()
          val inputOffset = inputFloat.storageOffset() - 1

          var t = 0
          while (t < taskArray.length) {
            val _t = t
            results(_t) = Engine.model.invoke(() => {
              var i = taskArray(_t)._1
              while (i < taskArray(_t)._2) {
                gradInputData(gradInputOffset + i) =
                  if (inputData(inputOffset + i) <= threshold) {
                    0.0f
                  } else {
                    gradInputData(gradInputOffset + i)
                  }
                i += 1
              }
            })
            t += 1
          }
        case _ => throw new UnsupportedOperationException(s"Only Float/Double supported")
      }
    }

    Engine.model.sync(results)
    gradInput
  }

  def validateParameters(): Unit = {
    if (inPlace) {
      require(value <= threshold, "in-place processing requires value (" +
        value + "') not exceed threshold (" + threshold + ")")
    }
  }

  override def computeOutputShape(inputShape: Shape): Shape = {
    inputShape
  }

  override def equals(obj: Any): Boolean = {

    if (!super.equals(obj)) {
      return false
    }

    if (!obj.isInstanceOf[Threshold[T]]) {
      return false
    }
    val other = obj.asInstanceOf[Threshold[T]]
    if (this.eq(other)) {
      return true
    }
    threshold == other.threshold && value == other.value && inPlace == other.inPlace
  }

  override def hashCode() : Int = {
    val seed = 37
    var hash = super.hashCode()
    hash = hash * seed + threshold.hashCode()
    hash = hash * seed + value.hashCode()
    hash = hash * seed + inPlace.hashCode()

    hash
  }

  override def toString(): String = {
    s"${getPrintName}($th, $v)"
  }

  override def clearState(): this.type = {
    if (!inPlace) {
      super.clearState()
    }
    this
  }
}

object Threshold {
  def apply[@specialized(Float, Double) T: ClassTag](
      th: Double = 1e-6,
      v: Double = 0.0,
      ip: Boolean = false)(implicit ev: TensorNumeric[T]) : Threshold[T] = {
    new Threshold[T](th, v, ip)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy