Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.intel.analytics.bigdl.nn.ConvLSTMPeephole.scala Maven / Gradle / Ivy
/*
* Copyright 2016 The BigDL Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.bigdl.nn
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.optim.Regularizer
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.{T, Table}
import scala.reflect.ClassTag
/**
* Convolution Long Short Term Memory architecture with peephole.
* Ref. A.: https://arxiv.org/abs/1506.04214 (blueprint for this module)
* B. https://github.com/viorik/ConvLSTM
*
* @param inputSize number of input planes in the image given into forward()
* @param outputSize number of output planes the convolution layer will produce
* @param kernelI Convolutional filter size to convolve input
* @param kernelC Convolutional filter size to convolve cell
* @param stride The step of the convolution
* @param wRegularizer: instance of [[Regularizer]]
(eg. L1 or L2 regularization), applied to the input weights matrices.
* @param uRegularizer: instance [[Regularizer]]
(eg. L1 or L2 regularization), applied to the recurrent weights matrices.
* @param bRegularizer: instance of [[Regularizer]]
applied to the bias.
* @param withPeephole: whether use last cell status control a gate.
*/
class ConvLSTMPeephole[T : ClassTag] (
val inputSize: Int,
val outputSize: Int,
val kernelI: Int,
val kernelC: Int,
val stride: Int,
var wRegularizer: Regularizer[T] = null,
var uRegularizer: Regularizer[T] = null,
var bRegularizer: Regularizer[T] = null,
val withPeephole: Boolean = true
)(implicit ev: TensorNumeric[T])
extends Cell[T](
hiddensShape = Array(outputSize, outputSize),
regularizers = Array(wRegularizer, uRegularizer, bRegularizer)
) {
var inputGate: Sequential[T] = _
var forgetGate: Sequential[T] = _
var outputGate: Sequential[T] = _
var hiddenLayer: Sequential[T] = _
var cellLayer: Sequential[T] = _
// val joinDim = 2
override var cell: AbstractModule[Activity, Activity, T] = buildConvLSTM()
// override def preTopology: AbstractModule[Activity, Activity, T] =
// Sequential()
// .add(TimeDistributed(SpatialConvolution(inputSize, outputSize*4, kernelI, kernelI,
// stride, stride, kernelI/2, kernelI/2, wRegularizer = wRegularizer,
// bRegularizer = bRegularizer)))
// def buildGate(offset: Int, length: Int): Sequential[T] = {
// val i2g = Narrow(joinDim, offset, length)
def buildGate(): Sequential[T] = {
val i2g = Sequential()
.add(Contiguous())
.add(SpatialConvolution(inputSize, outputSize, kernelI, kernelI,
stride, stride, kernelI/2, kernelI/2, wRegularizer = wRegularizer,
bRegularizer = bRegularizer))
val h2g = SpatialConvolution(outputSize, outputSize, kernelC, kernelC,
stride, stride, kernelC/2, kernelC/2, withBias = false,
wRegularizer = uRegularizer)
val gate = Sequential()
if (withPeephole) {
gate
.add(ParallelTable()
.add(i2g)
.add(h2g)
.add(CMul(Array(1, outputSize, 1, 1))))
} else {
gate.add(NarrowTable(1, 2))
gate
.add(ParallelTable()
.add(i2g)
.add(h2g))
}
gate.add(CAddTable())
.add(Sigmoid())
}
def buildInputGate(): Sequential[T] = {
// inputGate = buildGate(1 + outputSize, outputSize)
inputGate = buildGate()
inputGate
}
def buildForgetGate(): Sequential[T] = {
// forgetGate = buildGate(1, outputSize)
forgetGate = buildGate()
forgetGate
}
def buildOutputGate(): Sequential[T] = {
// outputGate = buildGate(1 + 3 * outputSize, outputSize)
outputGate = buildGate()
outputGate
}
def buildHidden(): Sequential[T] = {
val hidden = Sequential()
.add(NarrowTable(1, 2))
// val i2h = Narrow(joinDim, 1 + 2 * outputSize, outputSize)
val i2h = Sequential()
.add(Contiguous())
.add(SpatialConvolution(inputSize, outputSize, kernelI, kernelI,
stride, stride, kernelI/2, kernelI/2, wRegularizer = wRegularizer,
bRegularizer = bRegularizer))
val h2h = SpatialConvolution(outputSize, outputSize, kernelC, kernelC,
stride, stride, kernelC/2, kernelC/2, withBias = false,
wRegularizer = uRegularizer)
hidden
.add(ParallelTable()
.add(i2h)
.add(h2h))
.add(CAddTable())
.add(Tanh())
this.hiddenLayer = hidden
hidden
}
def buildCell(): Sequential[T] = {
buildInputGate()
buildForgetGate()
buildHidden()
val forgetLayer = Sequential()
.add(ConcatTable()
.add(forgetGate)
.add(SelectTable(3)))
.add(CMulTable())
val inputLayer = Sequential()
.add(ConcatTable()
.add(inputGate)
.add(hiddenLayer))
.add(CMulTable())
val cellLayer = Sequential()
.add(ConcatTable()
.add(forgetLayer)
.add(inputLayer))
.add(CAddTable())
this.cellLayer = cellLayer
cellLayer
}
def buildConvLSTM(): Sequential[T] = {
buildCell()
buildOutputGate()
val convlstm = Sequential()
.add(FlattenTable())
.add(ConcatTable()
.add(NarrowTable(1, 2))
.add(cellLayer))
.add(FlattenTable())
.add(ConcatTable()
.add(Sequential()
.add(ConcatTable()
.add(outputGate)
.add(Sequential()
.add(SelectTable(3))
.add(Tanh())))
.add(CMulTable())
.add(Contiguous()))
.add(SelectTable(3)))
.add(ConcatTable()
.add(SelectTable(1))
.add(Identity()))
cell = convlstm
convlstm
}
override def canEqual(other: Any): Boolean = other.isInstanceOf[ConvLSTMPeephole[T]]
override def equals(other: Any): Boolean = other match {
case that: ConvLSTMPeephole[T] =>
super.equals(that) &&
(that canEqual this) &&
inputSize == that.inputSize &&
outputSize == that.outputSize &&
kernelI == that.kernelI &&
kernelC == that.kernelC &&
stride == that.stride
case _ => false
}
override def hashCode(): Int = {
val state = Seq(super.hashCode(), inputSize, outputSize, kernelI, kernelC, stride)
state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
}
override def reset(): Unit = {
super.reset()
cell.reset()
}
override def toString: String = s"ConvLSTMPeephole($inputSize, $outputSize," +
s"$kernelI, $kernelC, $stride)"
}
object ConvLSTMPeephole {
def apply[@specialized(Float, Double) T: ClassTag](
inputSize: Int,
outputSize: Int,
kernelI: Int,
kernelC: Int,
stride: Int = 1,
wRegularizer: Regularizer[T] = null,
uRegularizer: Regularizer[T] = null,
bRegularizer: Regularizer[T] = null,
withPeephole: Boolean = true
)(implicit ev: TensorNumeric[T]): ConvLSTMPeephole[T] = {
new ConvLSTMPeephole[T](inputSize, outputSize, kernelI, kernelC, stride,
wRegularizer, uRegularizer, bRegularizer, withPeephole)
}
}