All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.LSTM.scala Maven / Gradle / Ivy

/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl._
import com.intel.analytics.bigdl.nn.Graph.ModuleNode
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity, TensorModule}
import com.intel.analytics.bigdl.optim.Regularizer
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.{T, Table}

import scala.reflect.ClassTag

/**
 * Long Short Term Memory architecture.
 * Ref. A.: http://arxiv.org/pdf/1303.5778v1 (blueprint for this module)
 * B. http://web.eecs.utk.edu/~itamar/courses/ECE-692/Bobby_paper1.pdf
 * C. http://arxiv.org/pdf/1503.04069v1.pdf
 * D. https://github.com/wojzaremba/lstm
 *
 * @param inputSize the size of each input vector
 * @param hiddenSize Hidden unit size in the LSTM
 * @param p is used for [[Dropout]] probability. For more details about
 *           RNN dropouts, please refer to
 *           [RnnDrop: A Novel Dropout for RNNs in ASR]
 *           (http://www.stat.berkeley.edu/~tsmoon/files/Conference/asru2015.pdf)
 *           [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks]
 *           (https://arxiv.org/pdf/1512.05287.pdf)
 * @param activation: activation function, by default to be Tanh if not specified.
 * @param innerActivation: activation function for inner cells,
 *                       by default to be Sigmoid if not specified.
 * @param wRegularizer: instance of [[Regularizer]]
 *                    (eg. L1 or L2 regularization), applied to the input weights matrices.
 * @param uRegularizer: instance [[Regularizer]]
            (eg. L1 or L2 regularization), applied to the recurrent weights matrices.
 * @param bRegularizer: instance of [[Regularizer]]
            applied to the bias.
 */
@SerialVersionUID(- 8176191554025511686L)
class LSTM[T : ClassTag] (
  val inputSize: Int,
  val hiddenSize: Int,
  val p: Double = 0,
  var activation: TensorModule[T] = null,
  var innerActivation: TensorModule[T] = null,
  var wRegularizer: Regularizer[T] = null,
  var uRegularizer: Regularizer[T] = null,
  var bRegularizer: Regularizer[T] = null
)
  (implicit ev: TensorNumeric[T])
  extends Cell[T](
    hiddensShape = Array(hiddenSize, hiddenSize),
    regularizers = Array(wRegularizer, uRegularizer, bRegularizer)
  ) {

  if (activation == null) activation = Tanh[T]()
  if (innerActivation == null) innerActivation = Sigmoid[T]()
  var gates: Sequential[T] = _
  var cellLayer: Sequential[T] = _

  override var cell: AbstractModule[Activity, Activity, T] = buildModel()

  override var preTopology: TensorModule[T] = if (p != 0) {
    null
  } else {
    Linear(inputSize, 4 * hiddenSize,
      wRegularizer = wRegularizer, bRegularizer = bRegularizer)
  }

  override def hiddenSizeOfPreTopo: Int = 4 * hiddenSize

  def buildGates()(input1: ModuleNode[T], input2: ModuleNode[T])
  : (ModuleNode[T], ModuleNode[T], ModuleNode[T], ModuleNode[T]) = {

    var i2g: ModuleNode[T] = null
    var h2g: ModuleNode[T] = null

    if (p != 0) {
      val dropi2g1 = Dropout(p).inputs(input1)
      val dropi2g2 = Dropout(p).inputs(input1)
      val dropi2g3 = Dropout(p).inputs(input1)
      val dropi2g4 = Dropout(p).inputs(input1)

      val lineari2g1 = Linear(inputSize, hiddenSize,
        wRegularizer = wRegularizer, bRegularizer = bRegularizer).inputs(dropi2g1)
      val lineari2g2 = Linear(inputSize, hiddenSize,
        wRegularizer = wRegularizer, bRegularizer = bRegularizer).inputs(dropi2g2)
      val lineari2g3 = Linear(inputSize, hiddenSize,
        wRegularizer = wRegularizer, bRegularizer = bRegularizer).inputs(dropi2g3)
      val lineari2g4 = Linear(inputSize, hiddenSize,
        wRegularizer = wRegularizer, bRegularizer = bRegularizer).inputs(dropi2g4)

      i2g = JoinTable(1, 1).inputs(lineari2g1, lineari2g2, lineari2g3, lineari2g4)

      val droph2g1 = Dropout(p).inputs(input2)
      val droph2g2 = Dropout(p).inputs(input2)
      val droph2g3 = Dropout(p).inputs(input2)
      val droph2g4 = Dropout(p).inputs(input2)

      val linearh2g1 = Linear(hiddenSize, hiddenSize,
        wRegularizer = wRegularizer, bRegularizer = bRegularizer).inputs(droph2g1)
      val linearh2g2 = Linear(hiddenSize, hiddenSize,
        wRegularizer = wRegularizer, bRegularizer = bRegularizer).inputs(droph2g2)
      val linearh2g3 = Linear(hiddenSize, hiddenSize,
        wRegularizer = wRegularizer, bRegularizer = bRegularizer).inputs(droph2g3)
      val linearh2g4 = Linear(hiddenSize, hiddenSize,
        wRegularizer = wRegularizer, bRegularizer = bRegularizer).inputs(droph2g4)

      h2g = JoinTable(1, 1).inputs(linearh2g1, linearh2g2, linearh2g3, linearh2g4)
    } else {
      i2g = input1
      h2g = Linear(hiddenSize, 4 * hiddenSize,
        withBias = false, wRegularizer = uRegularizer).inputs(input2)
    }

    val caddTable = CAddTable(false).inputs(i2g, h2g)
    val reshape = Reshape(Array(4, hiddenSize)).inputs(caddTable)
    val split1 = Select(2, 1).inputs(reshape)
    val split2 = Select(2, 2).inputs(reshape)
    val split3 = Select(2, 3).inputs(reshape)
    val split4 = Select(2, 4).inputs(reshape)

    // make different instances of inner activation
    val innerActivation2 = innerActivation.cloneModule()
    innerActivation2.setName(Integer.toHexString(java.util.UUID.randomUUID().hashCode()))
    val innerActivation3 = innerActivation.cloneModule()
    innerActivation3.setName(Integer.toHexString(java.util.UUID.randomUUID().hashCode()))

    (innerActivation.inputs(split1),
      activation.inputs(split2),
      innerActivation2.inputs(split3),
      innerActivation3.inputs(split4))
  }

  def buildModel(): Sequential[T] = {
    Sequential()
      .add(FlattenTable())
      .add(buildLSTM())
      .add(ConcatTable()
        .add(SelectTable(1))
        .add(NarrowTable(2, 2)))
  }

  def buildLSTM(): Graph[T] = {
    val input1 = Input()
    val input2 = Input()
    val input3 = Input()
    val (in, hid, forg, out) = buildGates()(input1, input2)

    /**
     * g: activation
     * cMult1 = in * hid
     * cMult2 = forg * input3
     * cMult3 = out * g(cMult1 + cMult2)
     */
    val cMult1 = CMulTable().inputs(in, hid)
    val cMult2 = CMulTable().inputs(forg, input3)
    val cadd = CAddTable(true).inputs(cMult1, cMult2)
    val activation2 = activation.cloneModule()
    activation2.setName(Integer.toHexString(java.util.UUID.randomUUID().hashCode()))
    val activate = activation2.inputs(cadd)
    val cMult3 = CMulTable().inputs(activate, out)

    val out1 = Identity().inputs(cMult3)
    val out2 = Identity().inputs(cMult3)
    val out3 = cadd

    /**
     * out1 = cMult3
     * out2 = out1
     * out3 = cMult1 + cMult2
     */
    Graph(Array(input1, input2, input3), Array(out1, out2, out3))
  }

  override def canEqual(other: Any): Boolean = other.isInstanceOf[LSTM[T]]

  override def equals(other: Any): Boolean = other match {
    case that: LSTM[T] =>
      super.equals(that) &&
        (that canEqual this) &&
        inputSize == that.inputSize &&
        hiddenSize == that.hiddenSize &&
        p == that.p
    case _ => false
  }

  override def hashCode(): Int = {
    val state = Seq(super.hashCode(), inputSize, hiddenSize, p)
    state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
  }

  override def reset(): Unit = {
    super.reset()
    cell.reset()
  }


  override def toString: String = s"LSTM($inputSize, $hiddenSize, $p)"
}

object LSTM {
  def apply[@specialized(Float, Double) T: ClassTag](
    inputSize: Int,
    hiddenSize: Int,
    p: Double = 0,
    activation: TensorModule[T] = null,
    innerActivation: TensorModule[T] = null,
    wRegularizer: Regularizer[T] = null,
    uRegularizer: Regularizer[T] = null,
    bRegularizer: Regularizer[T] = null
  )
    (implicit ev: TensorNumeric[T]): LSTM[T] = {
    new LSTM[T](inputSize, hiddenSize, p, activation, innerActivation,
      wRegularizer, uRegularizer, bRegularizer)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy