All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.SReLU.scala Maven / Gradle / Ivy

/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.nn.abstractnn._
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.serializer._
import com.intel.analytics.bigdl.utils.serializer.converters.DataConverter
import com.intel.analytics.bigdl.utils.{T, Table}
import com.intel.analytics.bigdl.serialization.Bigdl.{AttrValue, BigDLModule}

import scala.reflect.ClassTag

/**
 * S-shaped Rectified Linear Unit.
 *  It follows:
 *  `f(x) = t^r + a^r(x - t^r) for x >= t^r`,
 *  `f(x) = x for t^r > x > t^l`,
 *  `f(x) = t^l + a^l(x - t^l) for x <= t^l`.
 *
 * [Deep Learning with S-shaped Rectified Linear Activation Units](http://arxiv.org/abs/1512.07030)
 *
 * @param shape shape for tleft, aleft, tright, aright.
 *              E.g. for a 4-D input, the shape is the last 3-D
 * @param sharedAxes the axes along which to share learnable parameters
 *                    for the activation function.
 *                    For example, if the incoming feature maps are from a 2D convolution
 *                    with output shape `(batch, height, width, channels)`,
 *                    and you wish to share parameters across space
 *                    so that each filter only has one set of parameters,
 *                    set `shared_axes=[1, 2]`.
 */

@SerialVersionUID(7173457290010080259L)
class SReLU[T: ClassTag](val shape: Array[Int], val sharedAxes: Array[Int] = null)(
  implicit ev: TensorNumeric[T]) extends TensorModule[T]
    with Initializable {
  import SReLU._
  val weightsLen = 4
  val weights: Array[Tensor[T]] = Array.fill[Tensor[T]](4)(Tensor[T]())
  val gradWeights: Array[Tensor[T]] = Array.fill[Tensor[T]](4)(Tensor[T]())

  val weightsInit: Array[InitializationMethod] = Array(Zeros, Xavier, Xavier, Ones)

  // this attribute for computing the offset in weight because of sharedAxes
  private var indexes: Array[Int] = null

  init(shape).reset()

  private def init(shape: Array[Int]): this.type = {
    if (sharedAxes != null) {
      var i = 0
      while (i < sharedAxes.length) {
        shape(sharedAxes(i) - 1) = 1
        i += 1
      }
    }

    val variableFormat = shape.length match {
      case 2 => VariableFormat.IN_OUT
      case 4 => VariableFormat.OUT_IN_KW_KH
      case 5 => VariableFormat.OUT_IN_KT_KH_KW
      case _ => VariableFormat.Default
    }

    var i = 0
    while (i < weightsLen) {
      weights(i).resize(shape)
      weightsInit(i).init(weights(i), variableFormat)

      gradWeights(i).resize(shape)
      gradWeights(i).resizeAs(weights(i)).zero()

      i += 1
    }

    // ensure the the right part is always to the right of the left
    weights(tRight).abs().add(weights(tLeft))
    this
  }

  override def reset(): Unit = {
    for ((initMethod, weight) <- weightsInit.zip(weights)) {
      initMethod.init(weight)
    }
    zeroGradParameters()
  }

  private def getIndex(indexes: Array[Int], stride: Array[Int], ndim: Int, offset: Int): Unit = {
    var i = 0
    var tmp = offset
    while (i < ndim) {
      indexes(i) = tmp / stride(i) + 1 // 1 based
      tmp = tmp % stride(i)
      i += 1
    }

    // set back the shared axes
    if (sharedAxes != null) {
      i = 0
      while (i < sharedAxes.length) {
        indexes(sharedAxes(i) - 1) = 1
        i += 1
      }
    }
  }

  private def setValue(w: Array[Tensor[T]], i: Int, t: Int, v: T): Unit = {
    w(t).storage().array()(w(t).storageOffset() - 1 + i) = ev.plus(
      w(t).storage().array()(w(t).storageOffset() - 1 + i),
      v)
  }

  override def updateOutput(input: Tensor[T]): Tensor[T] = {
    require(input.isContiguous(), s"the input of SReLU must be contiguous")
    // ensure the the right part is always to the right of the left
    weights(tRight).abs().add(weights(tLeft))
    output.resizeAs(input)

    // temp buf for indexes
    if (indexes == null) {
      indexes = new Array[Int](weights(tRight).nDimension())
    }

    var batch = 0
    while (batch < input.size(1)) {
      val sliceInput = input.select(1, batch + 1)
      val sliceOutput = output.select(1, batch + 1)

      val xArr = sliceInput.storage().array()
      val yArr = sliceOutput.storage().array()
      val yOffset = sliceOutput.storageOffset() - 1
      val xOffset = sliceInput.storageOffset() - 1

      var i = 0
      while (i < sliceInput.nElement()) {
        getIndex(indexes, sliceInput.stride(), sliceInput.nDimension(), i)

        val tr = weights(tRight).apply(indexes)
        val ar = weights(aRight).apply(indexes)
        val tl = weights(tLeft).apply(indexes)
        val al = weights(aLeft).apply(indexes)

        val x = xArr(xOffset + i)

        yArr(yOffset + i) = if (ev.isGreaterEq(x, tr)) {
          // right: x_i >= t_i^r
          ev.plus(tr, ev.times(ar, ev.minus(x, tr)))
        } else if (ev.isGreaterEq(tl, x)) {
          // left: x_i <= t_i^l
          ev.plus(tl, ev.times(al, ev.minus(x, tl)))
        } else {
          // else x_i = x_i
          x
        }

        i += 1
      }

      batch += 1
    }
    output
  }

  override def updateGradInput(input: Tensor[T], gradOutput: Tensor[T]): Tensor[T] = {
    require(input.isContiguous(), s"the input of SReLU must be contiguous")
    require(gradOutput.isContiguous(), s"the gradOutput of SReLU must be contiguous")
    gradInput.resizeAs(input)

    var batch = 0
    while (batch < gradInput.size(1)) {
      val sliceInput = input.select(1, batch + 1)
      val sliceGradInput = gradInput.select(1, batch + 1)
      val sliceGradOutput = gradOutput.select(1, batch + 1)

      val xArr = sliceInput.storage().array()
      var xOffset = sliceInput.storageOffset() - 1

      val yArr = sliceGradInput.storage().array()
      var yOffset = sliceGradInput.storageOffset() - 1

      val zArr = sliceGradOutput.storage().array()
      var zOffset = sliceGradOutput.storageOffset() - 1

      var i = 0

      while (i < sliceGradInput.nElement()) {
        getIndex(indexes, sliceInput.stride(), sliceInput.nDimension(), i)

        val tr = weights(tRight).apply(indexes)
        val ar = weights(aRight).apply(indexes)
        val tl = weights(tLeft).apply(indexes)
        val al = weights(aLeft).apply(indexes)
        val x = xArr(xOffset + i)

        val t = if (ev.isGreaterEq(x, tr)) {
          ev.times(ar, zArr(zOffset + i))
        } else if (ev.isGreaterEq(tl, x)) {
          ev.times(al, zArr(zOffset + i))
        } else {
          zArr(zOffset + i)
        }
        yArr(yOffset + i) = ev.plus(yArr(yOffset + i), t)
        i += 1
      }

      batch += 1
    }

    gradInput
  }

  override def accGradParameters(input: Tensor[T], gradOutput: Tensor[T]): Unit = {
    var batch = 0
    while (batch < gradInput.size(1)) {
      val sliceInput = input.select(1, batch + 1)
      val sliceGradOutput = gradOutput.select(1, batch + 1)

      val xArr = sliceInput.storage().array()
      val xOffset = sliceInput.storageOffset() - 1

      val zArr = sliceGradOutput.storage().array()
      val zOffset = sliceGradOutput.storageOffset() - 1

      var i = 0
      while (i < sliceInput.nElement()) {
        getIndex(indexes, sliceInput.stride(), sliceInput.nDimension(), i)

        // weight offset
        var wOffset = 0
        var j = 0
        while (j < indexes.length) {
          // because indexes is 1 based, so we should minus 1 here
          wOffset += (indexes(j) - 1) * gradWeights(tLeft).stride(j + 1)
          j += 1
        }

        val tr = weights(tRight).apply(indexes)
        val ar = weights(aRight).apply(indexes)
        val tl = weights(tLeft).apply(indexes)
        val al = weights(aLeft).apply(indexes)
        val x = xArr(xOffset + i)

        if (ev.isGreaterEq(x, tr)) {
          setValue(gradWeights, wOffset, tRight, ev.times(ev.minus(ev.fromType(1), ar),
            zArr(zOffset + i)))
          setValue(gradWeights, wOffset, aRight, ev.times(ev.minus(x, tr),
            zArr(zOffset + i)))
        } else {
          setValue(gradWeights, wOffset, tRight, ev.fromType(0))
          setValue(gradWeights, wOffset, aRight, ev.fromType(0))
        }

        if (ev.isGreaterEq(tl, x)) {
          setValue(gradWeights, wOffset, tLeft, ev.times(ev.minus(ev.fromType(1), al),
            zArr(zOffset + i)))
          setValue(gradWeights, wOffset, aLeft, ev.times(ev.minus(xArr(xOffset + i), tl),
            zArr(zOffset + i)))
        } else {
          setValue(gradWeights, wOffset, tLeft, ev.fromType(0))
          setValue(gradWeights, wOffset, aLeft, ev.fromType(0))
        }

        i += 1
      }

      batch += 1
    }
  }

  override def getParametersTable(): Table = {
    T(getName() -> T(
      "tLeft" -> weights(tLeft),
      "aLeft" -> weights(aLeft),
      "tRight" -> weights(tRight),
      "aRight" -> weights(aRight)))
  }

  override def parameters(): (Array[Tensor[T]], Array[Tensor[T]]) = {
    (weights, gradWeights)
  }

  override def setInitMethod(initMethods: Array[InitializationMethod]): this.type = {
    for (i <- Array(tLeft, aLeft, tRight, aRight)) {
      if (initMethods(i) != null) {
        weightsInit(i) = initMethods(i)
      }
    }
    reset()
    this
  }

  override def setInitMethod(weightInitMethod: InitializationMethod = null,
    biasInitMethod: InitializationMethod = null): this.type = {
    throw new UnsupportedOperationException(
      s"SReLU should call setInitMethod(initMethods: Array[InitializationMethod])")
  }
}


object SReLU extends ModuleSerializable {
  def apply[T: ClassTag](shape: Array[Int], shareAxes: Array[Int] = null)(
    implicit ev: TensorNumeric[T]): SReLU[T] = {
    new SReLU[T](shape, shareAxes)
  }

  val (tLeft, aLeft, tRight, aRight) = (0, 1, 2, 3)

  override def doLoadModule[T: ClassTag](context: DeserializeContext)
    (implicit ev: TensorNumeric[T]) : AbstractModule[Activity, Activity, T] = {
    val attrMap = context.bigdlModule.getAttrMap
    val srelu = super.doLoadModule(context).asInstanceOf[SReLU[T]]

    srelu.weights(tLeft) = DataConverter.
      getAttributeValue(context, attrMap.get("tLeft")).
      asInstanceOf[Tensor[T]]

    srelu.weights(aLeft) = DataConverter.
      getAttributeValue(context, attrMap.get("aLeft")).
      asInstanceOf[Tensor[T]]

    srelu.weights(tRight) = DataConverter.
      getAttributeValue(context, attrMap.get("tRight")).
      asInstanceOf[Tensor[T]]

    srelu.weights(aRight) = DataConverter.
      getAttributeValue(context, attrMap.get("aRight")).
      asInstanceOf[Tensor[T]]


    srelu
  }
  override def doSerializeModule[T: ClassTag](context: SerializeContext[T],
    sreluBuilder : BigDLModule.Builder)
    (implicit ev: TensorNumeric[T]) : Unit = {

    super.doSerializeModule(context, sreluBuilder)

    val srelu = context.moduleData.module.asInstanceOf[SReLU[T]]

    val runningMeanBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, runningMeanBuilder,
      srelu.weights(tLeft), ModuleSerializer.tensorType)
    sreluBuilder.putAttr("tLeft", runningMeanBuilder.build)

    val runningVarBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, runningVarBuilder,
      srelu.weights(aLeft), ModuleSerializer.tensorType)
    sreluBuilder.putAttr("aLeft", runningVarBuilder.build)

    val saveMeanBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, saveMeanBuilder,
      srelu.weights(tRight), ModuleSerializer.tensorType)
    sreluBuilder.putAttr("tRight", saveMeanBuilder.build)

    val saveStdBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, saveStdBuilder,
      srelu.weights(aRight), ModuleSerializer.tensorType)
    sreluBuilder.putAttr("aRight", saveStdBuilder.build)
  }
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy