All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.mkldnn.RNN.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.bigdl.nn.mkldnn

import com.intel.analytics.bigdl.mkl._
import com.intel.analytics.bigdl.nn.{InitializationMethod, RandomUniform, VariableFormat, Zeros}
import com.intel.analytics.bigdl.nn.abstractnn.{Activity, Initializable}
import com.intel.analytics.bigdl.tensor.{DnnTensor, Tensor}

import scala.collection.mutable.ArrayBuffer

/**
 * @param mode       : the type of RNN cell (LSTM / GRU)
 * @param inputSize  : the size of input vector
 * @param hiddenSize : the size of hidden state
 * @param f          : the type of output activation function
 *                   (AlgKind.EltwiseTanh or AlgKind.EltwiseRelu)
 * @param direction  : the direction to run RNN
 *                   (e.g. Direction.UnidirectionalLeft2Right or Direction.BidirectionalConcat)
 * @param layers     : the number of RNN layers
 */

class RNN(
  val mode: Int,
  val inputSize: Int,
  val hiddenSize: Int,
  val f: Int,
  val direction: Int,
  val layers: Int = 1,
  val flags: Int = RNNCellFlags.RNNCellWithRelu,
  val alpha: Float = 0F,
  val clipping: Float = 0F,
  private val initWeight: Tensor[Float] = null,
  private val initWeightIter: Tensor[Float] = null,
  private val initBias: Tensor[Float] = null
) extends MklDnnLayer with Initializable {
  private var updateOutputMemoryPrimitives: Array[Long] = _
  private var updateOutputTensors: Array[Tensor[Float]] = _
  private var updateGradInputMemoryPrimitives: Array[Long] = _
  private var updateGradInputTensors: Array[Tensor[Float]] = _
  private var fwdPD: Long = _
  private var rnnCellDesc : Long = 0L

  private[mkldnn] var weight: TensorMMap = _
  private[mkldnn] var weight_i: TensorMMap = _
  private[mkldnn] var bias: TensorMMap = _
  private[mkldnn] var gradWeight: TensorMMap = _
  private[mkldnn] var gradWeight_i: TensorMMap = _
  private[mkldnn] var gradBias: TensorMMap = _

  private var workSpaceFormat: MemoryData = _
  private var workSpace : Tensor[Float] = _

  @transient private lazy val reorderManager = new ReorderManager
  private var weightForBackward: DnnTensor[Float] = _
  private var weightForBackwardMemoryData: MemoryData = _
  private var weightIterForBackward: DnnTensor[Float] = _
  private var weightIterForBackwardMemoryData: MemoryData = _

  private var batchSize: Int = _
  private var stepSize: Int = _
  private var inputShape: Array[Int] = _
  private var outputShape: Array[Int] = _
  private var weightShape: Array[Int] = _
  private var weightIterShape: Array[Int] = _
  private var biasShape: Array[Int] = _
  private var commonIterShape: Array[Int] = _

  private var src_i: DnnTensor[Float] = _
  private var dst_i: DnnTensor[Float] = _
  private var gradsrc_i: DnnTensor[Float] = _
  private var graddst_i: DnnTensor[Float] = _

  if(layers > 1) {
    require(inputSize == hiddenSize,
      "If layer number of RNN is more than 1, the input size and the hidden size should equal.\n"
      + "inputSize: " + inputSize + '\n'
      + "hiddenSize: " + hiddenSize)
  }

  var (ngates, nstates) = mode match {
    case AlgKind.VanillaLstm => (4, 2)
    case AlgKind.VanillaGru => (3, 1)
    case _ =>
      throw new UnsupportedOperationException("Not support such RNN Cell. Cell type: " + mode)
  }

  /** TODO: Multi-layer Bidirectional Sum RNN is available in MKLDNN,
   *  TODO: but the current version of BigDL BLAS does not support it.
   */

  val (numOfDirections, outputSizeFactor) = direction match {
    case Direction.UnidirectionalLeft2Right
         | Direction.UnidirectionalRight2Left => (1, 1)
    case Direction.BidirectionalConcat =>
      require(layers == 1, "Bidirectional Concat RNN does not support multiple layers. " +
        "layers = " + layers)
      (2, 2)
    case Direction.BidirectionalSum => (2, 1)
    case _ => throw new UnsupportedOperationException("Not support such direction")
  }

  /**
   * Gate order matching between MKLDNN LSTM and nn/LSTM:
   * MKLDNN Gate 1 -> nn/LSTM Gate 1 (input gate)
   * MKLDNN Gate 2 -> nn/LSTM Gate 3 (forget gate)
   * MKLDNN Gate 3 -> nn/LSTM Gate 2 (hidden)
   * MKLDNN Gate 4 -> nn/LSTM Gate 4 (output gate)
   *
   * Gate order matching between MKLDNN GRU and nn/GRU:
   * MKLDNN Gate 1 -> nn/GRU Gate 2
   * MKLDNN Gate 2 -> nn/GRU Gate 1
   * MKLDNN Gate 3 -> nn/GRU Gate 3
   */

  weightShape = Array(layers, numOfDirections, inputSize, ngates, hiddenSize)
  weightIterShape = Array(layers, numOfDirections, hiddenSize, ngates, hiddenSize)
  biasShape = Array(layers, numOfDirections, ngates, hiddenSize)

  weight = new TensorMMap(weightShape)
  weight_i = new TensorMMap(weightIterShape)
  bias = new TensorMMap(biasShape)
  gradWeight = new TensorMMap(weightShape)
  gradWeight_i = new TensorMMap(weightIterShape)
  gradBias = new TensorMMap(biasShape)

  {
    val stdv = 1.0 / math.sqrt(hiddenSize)
    val wInit: InitializationMethod = RandomUniform(-stdv, stdv)
    val bInit: InitializationMethod = Zeros
    setInitMethod(wInit, bInit)
  }

  override def reset(): Unit = {
    if (initWeight == null) {
      weightInitMethod.init(weight.dense, VariableFormat.Default)
    } else {
      weight.dense.copy(initWeight)
    }

    if (initWeightIter == null) {
      weightInitMethod.init(weight_i.dense, VariableFormat.Default)
    } else {
      weight_i.dense.copy(initWeightIter)
    }

    if (initBias == null) {
      biasInitMethod.init(bias.dense, VariableFormat.Default)
    } else {
      bias.dense.copy(initBias)
    }
  }

  override private[mkldnn] def initFwdPrimitives(inputs: Array[MemoryData], phase: Phase) = {
    val kind = if (!isTraining()) {
      PropKind.ForwardInference
    } else {
      PropKind.ForwardTraining
    }

    /**
     * TODO: The default format of input is TNC
     * Batch size of input is needed by creating memory descriptors of src iter and dst iter.
     * Step size of input is needed by creating memory descriptor of dst layer.
     * By default, batch size of input is the second element of inputShape
     * and step size is the first element of inputShape.
     */

    inputs(0).layout match {
      case Memory.Format.tnc =>
        batchSize = inputs(0).shape(1)
        stepSize = inputs(0).shape(0)
      case Memory.Format.ntc =>
        batchSize = inputs(0).shape(0)
        stepSize = inputs(0).shape(1)
      case _ =>
        throw new UnsupportedOperationException("Unsupported input format: " +
          inputs(0).layout)
    }

    inputShape = Array(stepSize, batchSize, inputSize)
    outputShape = Array(stepSize, batchSize, outputSizeFactor * hiddenSize)
    commonIterShape = Array(layers, numOfDirections, nstates, batchSize, hiddenSize)

    val src_layer = NativeData(inputShape, Memory.Format.any)
    val src_iter = NativeData(commonIterShape, Memory.Format.any)
    val wei_layer = NativeData(weightShape, Memory.Format.any)
    val wei_iter = NativeData(weightIterShape, Memory.Format.any)
    val bis = NativeData(biasShape, Memory.Format.any)
    val dst_layer = NativeData(outputShape, Memory.Format.any)
    val dst_iter = NativeData(commonIterShape, Memory.Format.any)

    val src_layer_MD = src_layer.getMemoryDescription()
    val src_iter_MD = src_iter.getMemoryDescription()
    val weights_layer_MD = wei_layer.getMemoryDescription()
    val weights_iter_MD = wei_iter.getMemoryDescription()
    val bis_MD = bis.getMemoryDescription()
    val dist_layer_MD = dst_layer.getMemoryDescription()
    val dist_iter_MD = dst_iter.getMemoryDescription()

    rnnCellDesc = mode match {
      case AlgKind.VanillaLstm =>
        MklDnnMemory.RNNCellDescInit(AlgKind.VanillaLstm, f, flags, alpha, clipping)
      case AlgKind.VanillaGru =>
        MklDnnMemory.RNNCellDescInit(AlgKind.VanillaGru, f, flags, alpha, clipping)
      case _ => throw new UnsupportedOperationException("Not support such RNN cell. " +
        "Cell type: " + mode)
    }

    val description = MklDnnMemory.RNNForwardDescInit(kind, rnnCellDesc, direction, src_layer_MD,
      src_iter_MD, weights_layer_MD, weights_iter_MD, bis_MD, dist_layer_MD, dist_iter_MD)

    fwdPD = MklDnnMemory.PrimitiveDescCreate(description, runtime.engine, 0L)

    val realSrc = MemoryData.operationWant(fwdPD, Query.SrcPd, 0)
    val realSrc_iter = MemoryData.operationWant(fwdPD, Query.SrcPd, 1)
    val realWei = MemoryData.operationWant(fwdPD, Query.WeightsPd, 0)
    val realWei_iter = MemoryData.operationWant(fwdPD, Query.WeightsPd, 1)
    val realBias = MemoryData.operationWant(fwdPD, Query.WeightsPd, 2)
    val realDst = MemoryData.operationWant(fwdPD, Query.DstPd, 0)
    val realDst_iter = MemoryData.operationWant(fwdPD, Query.DstPd, 1)

    require(weight.size().product == realWei.shape.product,
      s"${getName} weight shape is not correct.")
    require(weight_i.size().product == realWei_iter.shape.product,
      s"${getName} weight iter shape is not correct.")
    require(bias.size().product == realBias.shape.product,
      s"${getName} bias shape is not correct.")

    weight.setMemoryData(HeapData(weightShape, Memory.Format.ldigo), realWei, runtime)
    weight_i.setMemoryData(HeapData(weightIterShape, Memory.Format.ldigo), realWei_iter, runtime)
    bias.setMemoryData(HeapData(biasShape, Memory.Format.ldgo), realBias, runtime)

    weight.sync()
    weight_i.sync()
    bias.sync()

    src_i = initTensor(realSrc_iter).asInstanceOf[DnnTensor[Float]]
    dst_i = initTensor(realDst_iter).asInstanceOf[DnnTensor[Float]]
    src_i.zero()
    dst_i.zero()

    val srcs = Array(realSrc.getPrimitive(runtime), realSrc_iter.getPrimitive(runtime),
      realWei.getPrimitive(runtime), realWei_iter.getPrimitive(runtime),
      realBias.getPrimitive(runtime))
    val indexes = Array.fill(srcs.length)(0)

    if (isTraining()) {
      workSpaceFormat = MemoryData.operationWant(fwdPD, Query.WorkspacePd, 0)
      workSpace = initTensor(workSpaceFormat).asInstanceOf[Tensor[Float]]
    }

    val dsts = if (isTraining()) {
      Array(realDst.getPrimitive(runtime),
        realDst_iter.getPrimitive(runtime),
        workSpaceFormat.getPrimitive(runtime))
    }
    else {
      Array(realDst.getPrimitive(runtime),
        realDst_iter.getPrimitive(runtime))
    }

    val primitive = MklDnnMemory.PrimitiveCreate2(fwdPD, srcs, indexes,
      srcs.length, dsts, dsts.length)

    updateOutputMemoryPrimitives = srcs ++ dsts
    updateOutputPrimitives = Array(primitive)
    output = initTensor(realDst)

    _inputFormats = Array(realSrc)
    _outputFormats = Array(realDst)

    (_inputFormats, _outputFormats)
  }

  override def updateOutput(input: Activity): Activity = {
    if (updateOutputTensors == null) {
      val buffer = new ArrayBuffer[Tensor[Float]]()
      buffer.append(input.asInstanceOf[Tensor[Float]])
      buffer.append(src_i)
      buffer.append(weight.native)
      buffer.append(weight_i.native)
      buffer.append(bias.native)
      buffer.append(output.asInstanceOf[Tensor[Float]])
      buffer.append(dst_i)
      if (isTraining()) {
        buffer.append(workSpace)
      }

      updateOutputTensors = buffer.toArray
    }

    if (isTraining()) {
      weight.sync()
      weight_i.sync()
      bias.sync()
    }

    updateWithNewTensor(updateOutputTensors, 0, input)

    MklDnnOps.streamSubmit(runtime.stream, 1, updateOutputPrimitives, updateOutputPrimitives.length,
      updateOutputMemoryPrimitives, updateOutputTensors)

    output
  }

  override private[mkldnn] def initBwdPrimitives(grad: Array[MemoryData], phase: Phase) = {
    reorderManager.setRuntime(runtime)

    val src_layer_bw = NativeData(inputShape, Memory.Format.any)
    val src_iter_bw = NativeData(commonIterShape, Memory.Format.any)
    val wei_layer_bw = NativeData(weightShape, Memory.Format.any)
    val wei_iter_bw = NativeData(weightIterShape, Memory.Format.any)
    val bis_bw = NativeData(biasShape, Memory.Format.any)
    val dst_layer_bw = NativeData(outputShape, Memory.Format.any)
    val dst_iter_bw = NativeData(commonIterShape, Memory.Format.any)
    val diff_src_layer = NativeData(inputShape, Memory.Format.any)
    val diff_src_iter = NativeData(commonIterShape, Memory.Format.any)
    val diff_weights_layer = NativeData(weightShape, Memory.Format.ldigo)
    // IMPORTANT : it has to be ldigo
    val diff_weights_iter = NativeData(weightIterShape, Memory.Format.ldigo)
    // IMPORTANT : it has to be ldigo
    val diff_bias = NativeData(biasShape, Memory.Format.any)
    val diff_dist_layer = NativeData(outputShape, Memory.Format.any)
    val diff_dist_iter = NativeData(commonIterShape, Memory.Format.any)

    val src_layer_bw_MD = src_layer_bw.getMemoryDescription()
    val src_iter_bw_MD = src_iter_bw.getMemoryDescription()
    val weights_layer_bw_MD = wei_layer_bw.getMemoryDescription()
    val weights_iter_bw_MD = wei_iter_bw.getMemoryDescription()
    val bis_bw_MD = bis_bw.getMemoryDescription()
    val dist_layer_bw_MD = dst_layer_bw.getMemoryDescription()
    val dist_iter_bw_MD = dst_iter_bw.getMemoryDescription()
    val diff_src_layer_MD = diff_src_layer.getMemoryDescription()
    val diff_src_iter_MD = diff_src_iter.getMemoryDescription()
    val diff_weights_layer_MD = diff_weights_layer.getMemoryDescription()
    val diff_weights_iter_MD = diff_weights_iter.getMemoryDescription()
    val diff_bis_MD = diff_bias.getMemoryDescription()
    val diff_dist_layer_MD = diff_dist_layer.getMemoryDescription()
    val diff_dist_iter_MD = diff_dist_iter.getMemoryDescription()

    val description = MklDnnMemory.RNNBackwardDescInit(PropKind.Backward, rnnCellDesc,
      direction, src_layer_bw_MD,
      src_iter_bw_MD, weights_layer_bw_MD,
      weights_iter_bw_MD, bis_bw_MD,
      dist_layer_bw_MD, dist_iter_bw_MD,

      diff_src_layer_MD, diff_src_iter_MD,
      diff_weights_layer_MD, diff_weights_iter_MD,
      diff_bis_MD, diff_dist_layer_MD,
      diff_dist_iter_MD
    )

    val bwdPD = MklDnnMemory.PrimitiveDescCreate(description, runtime.engine, fwdPD)

    val realSrc = MemoryData.operationWant(bwdPD, Query.SrcPd, 0)
    val realSrc_iter = MemoryData.operationWant(bwdPD, Query.SrcPd, 1)
    val realWei = MemoryData.operationWant(bwdPD, Query.WeightsPd, 0)
    val realWei_iter = MemoryData.operationWant(bwdPD, Query.WeightsPd, 1)
    val realBias = MemoryData.operationWant(bwdPD, Query.WeightsPd, 2)
    val realDst = MemoryData.operationWant(bwdPD, Query.DstPd, 0)
    val realDst_iter = MemoryData.operationWant(bwdPD, Query.DstPd, 1)
    val realDiffDst = MemoryData.operationWant(bwdPD, Query.DiffDstPd, 0)
    val realDiffDst_iter = MemoryData.operationWant(bwdPD, Query.DiffDstPd, 1)

    val realDiffSrc = MemoryData.operationWant(bwdPD, Query.DiffSrcPd, 0)
    val realDiffSrc_iter = MemoryData.operationWant(bwdPD, Query.DiffSrcPd, 1)
    val realDiffWei = MemoryData.operationWant(bwdPD, Query.DiffWeightsPd, 0)
    val realDiffWei_iter = MemoryData.operationWant(bwdPD, Query.DiffWeightsPd, 1)
    val realDiffBias = MemoryData.operationWant(bwdPD, Query.DiffWeightsPd, 2)

    weightForBackwardMemoryData = realWei
    reorderManager.register(weight.heapData, realWei)
    weightForBackward = reorderManager
      .infer(Array(weight.heapData), Array(weightForBackwardMemoryData), weight.dense)
      .asInstanceOf[DnnTensor[Float]]

    weightIterForBackwardMemoryData = realWei_iter
    reorderManager.register(weight_i.heapData, realWei_iter)
    weightIterForBackward = reorderManager
      .infer(Array(weight_i.heapData), Array(weightIterForBackwardMemoryData), weight_i.dense)
      .asInstanceOf[DnnTensor[Float]]

    gradWeight.setMemoryData(realDiffWei, HeapData(weightShape, Memory.Format.ldigo), runtime)
    gradWeight_i.setMemoryData(realDiffWei_iter, HeapData(weightIterShape, Memory.Format.ldigo),
      runtime)
    gradBias.setMemoryData(realDiffBias, HeapData(biasShape, Memory.Format.ldgo), runtime)

    gradWeight.zero()
    gradWeight_i.zero()
    gradBias.zero()

    gradsrc_i = initTensor(realDiffSrc_iter).asInstanceOf[DnnTensor[Float]]
    graddst_i = initTensor(realDiffDst_iter).asInstanceOf[DnnTensor[Float]]
    gradsrc_i.zero()
    graddst_i.zero()

    val srcs = Array(realSrc.getPrimitive(runtime), realSrc_iter.getPrimitive(runtime),
      realWei.getPrimitive(runtime), realWei_iter.getPrimitive(runtime),
      realBias.getPrimitive(runtime), realDst.getPrimitive(runtime),
      realDst_iter.getPrimitive(runtime), realDiffDst.getPrimitive(runtime),
      realDiffDst_iter.getPrimitive(runtime), workSpaceFormat.getPrimitive(runtime))
    val indexes = Array.fill(srcs.length)(0)

    val dsts = Array(realDiffSrc.getPrimitive(runtime), realDiffSrc_iter.getPrimitive(runtime),
      realDiffWei.getPrimitive(runtime), realDiffWei_iter.getPrimitive(runtime),
      realDiffBias.getPrimitive(runtime)
    )

    val primitive = MklDnnMemory.PrimitiveCreate2(bwdPD, srcs, indexes, srcs.length,
      dsts, dsts.length)

    updateGradInputMemoryPrimitives = srcs ++ dsts
    updateGradInputPrimitives = Array(primitive)
    gradInput = initTensor(realDiffSrc)

    _gradInputFormats = Array(realDiffSrc)
    _gradOutputFormats = Array(realDiffDst)
    (_gradOutputFormats, _gradInputFormats)
  }

  override def updateGradInput(input: Activity, gradOutput: Activity): Activity = {
    if (updateGradInputTensors == null) {
      val buffer = new ArrayBuffer[Tensor[Float]]()
      buffer.append(input.asInstanceOf[Tensor[Float]])
      buffer.append(src_i)
      buffer.append(weightForBackward)
      buffer.append(weightIterForBackward)
      buffer.append(bias.native)
      buffer.append(output.asInstanceOf[Tensor[Float]])
      buffer.append(dst_i)
      buffer.append(gradOutput.asInstanceOf[Tensor[Float]])
      buffer.append(graddst_i)
      buffer.append(workSpace)

      buffer.append(gradInput.asInstanceOf[Tensor[Float]])
      buffer.append(gradsrc_i)
      buffer.append(gradWeight.native)
      buffer.append(gradWeight_i.native)
      buffer.append(gradBias.native)

      updateGradInputTensors = buffer.toArray
    }

    updateWithNewTensor(updateGradInputTensors, 0, input)
    updateWithNewTensor(updateGradInputTensors, 7, gradOutput)

    MklDnnOps.streamSubmit(runtime.stream, 1, updateGradInputPrimitives,
      updateGradInputPrimitives.length, updateGradInputMemoryPrimitives, updateGradInputTensors)

    gradWeight.sync()
    gradWeight_i.sync()
    gradBias.sync()

    gradInput
  }

  override def accGradParameters(input: Activity, gradOutput: Activity): Unit = {
    // Do nothing
  }

  override def parameters(): (Array[Tensor[Float]], Array[Tensor[Float]]) = {
    (Array(weight.dense, bias.dense, weight_i.dense),
      Array(gradWeight.dense, gradBias.dense, gradWeight_i.dense))
  }

  override def zeroGradParameters(): Unit = {
  }

}

object RNN{
  def apply(
   mode: Int,
   inputSize: Int,
   hiddenSize: Int,
   f: Int,
   direction: Int,
   layers: Int = 1,
   flags: Int = RNNCellFlags.RNNCellWithRelu,
   alpha: Float = 0F,
   clipping: Float = 0F,
   initWeight: Tensor[Float] = null,
   initWeightIter: Tensor[Float] = null,
   initBias: Tensor[Float] = null
 ): RNN = new RNN(mode, inputSize, hiddenSize, f, direction, layers, flags, alpha,
    clipping, initWeight, initWeightIter, initBias)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy