com.intel.analytics.bigdl.nn.abstractnn.AbstractModule.scala Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2016 The BigDL Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.bigdl.nn.abstractnn
import java.nio.ByteOrder
import com.intel.analytics.bigdl._
import com.intel.analytics.bigdl.dataset._
import com.intel.analytics.bigdl.nn.Graph.ModuleNode
import com.intel.analytics.bigdl.nn.quantized.Quantization
import com.intel.analytics.bigdl.nn.{Module, _}
import com.intel.analytics.bigdl.optim._
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.tensor.{QuantizedTensor, Tensor, TensorDataType}
import com.intel.analytics.bigdl.transform.vision.image.{DistributedImageFrame, ImageFeature, ImageFrame, LocalImageFrame}
import com.intel.analytics.bigdl.utils.TorchObject.TYPE_MODULE
import com.intel.analytics.bigdl.utils._
import com.intel.analytics.bigdl.utils.caffe.CaffePersister
import com.intel.analytics.bigdl.utils.intermediate.ConversionUtils
import com.intel.analytics.bigdl.utils.serializer._
import com.intel.analytics.bigdl.utils.tf.{TensorflowDataFormat, TensorflowSaver}
import org.apache.commons.lang3.SerializationUtils
import org.apache.spark.rdd.RDD
import scala.collection.mutable
import scala.reflect.ClassTag
/**
* [[TensorModule]] is an abstract sub-class of [[AbstractModule]], whose
* input and output type both are [[Tensor]].
*
* @tparam T The numeric type in this module parameters
*/
abstract class TensorModule[T: ClassTag]
(implicit ev: TensorNumeric[T]) extends AbstractModule[Tensor[T], Tensor[T], T]
/**
* Module is the basic component of a neural network. It forward activities and backward gradients.
* Modules can connect to others to construct a complex neural network.
*
* @tparam A Input data type
* @tparam B Output data type
* @tparam T The numeric type in this module parameters.
*/
abstract class AbstractModule[A <: Activity: ClassTag, B <: Activity: ClassTag, T: ClassTag](
implicit ev: TensorNumeric[T]) extends Serializable with InferShape{
// ================================= Public APIs =============================================
/**
* The cached output. So we don't compute it again when need it
*/
var output: B = Activity.allocate[B, T]()
/**
* The cached gradient of activities. So we don't compute it again when need it
*/
var gradInput: A = Activity.allocate[A, T]()
protected var inputsFormats: Seq[Int] = null
protected var outputsFormats: Seq[Int] = null
/**
* set input formats for graph
* @param formats
* @return
*/
def setInputFormats(formats: Seq[Int]): this.type = {
inputsFormats = formats
this
}
/**
* set output formats for graph
* @param formats
* @return
*/
def setOutputFormats(formats: Seq[Int]): this.type = {
outputsFormats = formats
this
}
/**
* Get the scale of gradientWeight
*/
final def getScaleW(): Double = {
scaleW
}
/**
* Get the scale of gradientBias
*/
final def getScaleB(): Double = {
scaleB
}
/**
* Set the scale of gradientWeight
*
* @param w the value of the scale of gradientWeight
* @return this
*/
def setScaleW(w: Double): this.type = {
scaleW = w
this
}
/**
* Set the scale of gradientBias
*
* @param b the value of the scale of gradientBias
* @return this
*/
def setScaleB(b: Double): this.type = {
scaleB = b
this
}
/**
* Clear cached activities to save storage space or network bandwidth. Note that we use
* Tensor.set to keep some information like tensor share
*
* The subclass should override this method if it allocate some extra resource, and call the
* super.clearState in the override method
*
* @return
*/
def clearState() : this.type = {
if (output.isInstanceOf[Tensor[_]]) {
output.asInstanceOf[Tensor[_]].set()
}
if (gradInput.isInstanceOf[Tensor[_]]) {
gradInput.asInstanceOf[Tensor[_]].set()
}
this
}
/**
* Whether user set a name to the module before
* @return
*/
final def hasName: Boolean = name != null
/**
* Set the module name
*
* @param name
* @return
*/
final def setName(name : String) : this.type = {
this.name = name
this
}
/**
* Get the module name, default name is className@namePostfix
*
* @return
*/
final def getName() : String = {
if (this.name == null) {
s"${this.getClass.getSimpleName}${namePostfix}"
} else {
this.name
}
}
override def toString(): String = getPrintName
/**
* Get the forward/backward cost time for the module or its submodules
* @return
*/
def getTimes(): Array[(AbstractModule[_ <: Activity, _ <: Activity, T], Long, Long)] = {
Array((this, forwardTime, backwardTime))
}
/**
* Get the forward/backward cost time for the module or its submodules
* and group by module type.
* @return (module type name, forward time, backward time)
*/
final def getTimesGroupByModuleType():
Array[(String, Long, Long)] = {
this.getTimes().map(v => (v._1.getClass().getName(), v._2, v._3)).groupBy(_._1)
.map(v => (v._1, v._2.reduce((a, b) => (v._1, a._2 + b._2, a._3 + b._3))))
.map(v => (v._1, v._2._2, v._2._3))
.toArray
.sortWith((a, b) => (a._2 + a._3) > (b._2 + b._3))
}
/**
* Reset the forward/backward record time for the module or its submodules
* @return
*/
def resetTimes(): Unit = {
forwardTime = 0
backwardTime = 0
}
/**
* freeze the module,
* i.e. their parameters(weight/bias, if exists) are not changed in training process
* if names is not empty,
* set an array of layers that match the given ```names``` to be "freezed",
*
* @param names an array of layer names
* @return current graph model
*/
def freeze(names: String*): this.type = {
if (names.isEmpty) {
// in case when freeze is called many times
if (scaleW != 0) {
scaleWCache = scaleW
scaleW = 0
}
if (scaleB != 0) {
scaleBCache = scaleB
scaleB = 0
}
} else {
names.foreach(name => {
this (name) match {
case Some(x) => x.freeze()
case _ => throw new Exception(s"cannot match module named $name")
}
})
}
this
}
/**
* "unfreeze" module, i.e. make the module parameters(weight/bias, if exists)
* to be trained(updated) in training process
* if names is not empty, unfreeze layers that match given names
*
* @param names array of module names to unFreeze
*/
def unFreeze(names: String*): this.type = {
if (names.isEmpty) {
scaleW = scaleWCache
scaleB = scaleBCache
} else {
names.foreach(name => {
this (name) match {
case Some(x) => x.unFreeze()
case _ => throw new Exception(s"cannot match module named $name")
}
})
}
this
}
/**
* Takes an input object, and computes the corresponding output of the module. After a forward,
* the output state variable should have been updated to the new value.
*
* @param input input data
* @return output data
*/
final def forward(input: A): B = {
val before = System.nanoTime()
try {
updateParameter
updateOutput(input)
} catch {
case l: LayerException =>
l.layerMsg = this.toString() + "/" + l.layerMsg
throw l
case e: Throwable =>
throw new LayerException(this.toString(), e)
}
forwardTime += System.nanoTime() - before
output
}
/**
* Performs a back-propagation step through the module, with respect to the given input. In
* general this method makes the assumption forward(input) has been called before, with the same
* input. This is necessary for optimization reasons. If you do not respect this rule, backward()
* will compute incorrect gradients.
*
* @param input input data
* @param gradOutput gradient of next layer
* @return gradient corresponding to input data
*/
def backward(input: A, gradOutput: B): A = {
val before = System.nanoTime()
updateGradInput(input, gradOutput)
accGradParameters(input, gradOutput)
backwardTime += System.nanoTime() - before
asyncGradient
gradInput
}
private[bigdl] def asyncGradient(): Unit = {
if (this.getParameterSynchronizer() != null) {
if (this.parameters() != null) {
this.getParameterSynchronizer.put(this.getName)
}
}
}
/**
* Computes the output using the current parameter set of the class and input. This function
* returns the result which is stored in the output field.
*
* @param input
* @return
*/
def updateOutput(input: A): B
/**
* Computing the gradient of the module with respect to its own input. This is returned in
* gradInput. Also, the gradInput state variable is updated accordingly.
*
* @param input
* @param gradOutput
* @return
*/
def updateGradInput(input: A, gradOutput: B): A
/**
* Computing the gradient of the module with respect to its own parameters. Many modules do not
* perform this step as they do not have any parameters. The state variable name for the
* parameters is module dependent. The module is expected to accumulate the gradients with
* respect to the parameters in some variable.
*
* @param input
* @param gradOutput
*/
def accGradParameters(input: A, gradOutput: B): Unit = {}
/**
* If the module has parameters, this will zero the accumulation of the gradients with respect
* to these parameters. Otherwise, it does nothing.
*/
def zeroGradParameters(): Unit = {
if (parameters() != null) {
parameters()._1.zip(parameters()._2)foreach{ case (weight, grad) =>
grad.resizeAs(weight).zero()
}
}
}
/**
* This function returns two arrays. One for the weights and the other the gradients
* Custom modules should override this function if they have parameters
*
* @return (Array of weights, Array of grad)
*/
def parameters(): (Array[Tensor[T]], Array[Tensor[T]]) = null
/**
* Get extra parameter in this module.
* Extra parameter means the trainable parameters beside weight and bias. Such as runningMean
* and runningVar in BatchNormalization.
*
* The subclass should override this method if it has some parameters besides weight and bias.
*
* @return an array of tensor
*/
def getExtraParameter(): Array[Tensor[T]] = null
/**
* Set extra parameter to this module.
* Extra parameter means the trainable parameters beside weight and bias. Such as runningMean
* and runningVar in BatchNormalization.
*
* @return this
*/
def setExtraParameter(extraParam: Array[Tensor[T]]): this.type = {
val currentExtraParam = this.getExtraParameter()
if (extraParam != null && currentExtraParam != null) {
require(extraParam.length == currentExtraParam.length,
"state's length doesn't match, excepted:" +
s"${currentExtraParam.length}, but got ${extraParam.length}")
var i = 0
while (i < extraParam.length) {
currentExtraParam(i).copy(extraParam(i))
i += 1
}
this
} else if (extraParam == null && currentExtraParam == null) {
this
} else {
throw new IllegalArgumentException(s"module's extraParameter is $currentExtraParam" +
s", while setting param is ${extraParam}")
}
}
/**
* This function returns a table contains ModuleName, the parameter names and parameter value
* in this module.
*
* The result table is a structure of Table(ModuleName -> Table(ParameterName -> ParameterValue)),
* and the type is Table[String, Table[String, Tensor[T]]].
*
* For example, get the weight of a module named conv1:
* table[Table]("conv1")[Tensor[T]]("weight").
*
* The names of the parameters follow such convention:
*
* 1. If there's one parameter, the parameter is named as "weight", the gradient is named as
* "gradWeight"
*
* 2. If there're two parameters, the first parameter is named as "weight", the first gradient is
* named as "gradWeight"; the second parameter is named as "bias", the seconcd gradient is
* named as "gradBias"
*
* 3. If there're more parameters, the weight is named as "weight" with a seq number as suffix,
* the gradient is named as "gradient" with a seq number as suffix
*
* Custom modules should override this function the default impl if the convention doesn't meet
* the requirement.
*
* @return Table
*/
def getParametersTable(): Table = {
val params = parameters()
if (params == null) return null
val (weights, gradients) = params
require(gradients.length == weights.length, "weight number is not equal to grad number")
if (weights.length == 1) {
T(getName() -> T("weight" -> weights(0), "gradWeight" -> gradients(0)))
} else if (weights.length == 2) {
T(getName() -> T("weight" -> weights(0), "bias" -> weights(1),
"gradWeight" -> gradients(0), "gradBias" -> gradients(1)))
} else {
val result = T()
weights.zip(gradients).zipWithIndex.map { case ((w, g), i) =>
result(s"weight$i") = w
result(s"gradient$i") = g
}
T(getName() -> result)
}
}
/**
* Set the module to training mode
* @return
*/
def training(): this.type = {
train = true
this
}
/**
* Set the module to evaluate mode
* @return
*/
def evaluate(): this.type = {
train = false
this
}
/**
* Check if the model is in training mode
* @return
*/
final def isTraining(): Boolean = {
this.train
}
/**
* Reset module parameters, which is re-initialize the parameter with given initMethod
*/
def reset(): Unit = {}
/**
* Set the line separator when print the module
* @param line
* @return
*/
final def setLine(line: String): this.type = {
this.line = line
this
}
/**
* Clone the model
* @return
*/
final def cloneModule(): this.type = {
SerializationUtils.clone(this)
}
/**
* Clone the module, deep or shallow copy
* @param deepCopy
* @return
*/
final def clone(deepCopy : Boolean): AbstractModule[A, B, T] = {
val moduleData = ModuleData[T](this.
asInstanceOf[AbstractModule[Activity, Activity, T]], Seq[String](), Seq[String]())
val storages = new mutable.HashMap[Int, Any]()
val context = SerializeContext(moduleData, storages, ProtoStorageType, false)
val serializedModule = ModuleSerializer.serialize[T](context).bigDLModule
ModulePersister.setTensorStorage(serializedModule, storages)
storages.clear()
val deserializeContext = DeserializeContext(serializedModule.build,
storages, ProtoStorageType, false)
ModuleLoader.initTensorStorage[T](deserializeContext)
val copy = ModuleSerializer.load[T](deserializeContext).module
.asInstanceOf[AbstractModule[A, B, T]]
setWeightAndBias(copy, deepCopy)
copy
}
override def equals(other: Any): Boolean = other match {
case that: AbstractModule[A, B, T] =>
(that canEqual this) &&
(that.getClass equals this.getClass) &&
output == that.output &&
gradInput == that.gradInput &&
name == that.name
case _ => false
}
override def hashCode(): Int = {
def getHashCode(a: Object): Int = if (a == null) 0 else a.hashCode()
val state = Seq(output, gradInput, this.getClass, this.name)
state.map(getHashCode).foldLeft(0)((a, b) => 31 * a + b)
}
/**
* Save this module to path.
* @param path path to save module, local file system, HDFS and Amazon S3 is supported.
* HDFS path should be like "hdfs://[host]:[port]/xxx"
* Amazon S3 path should be like "s3a://bucket/xxx"
* @param overWrite if overwrite
* @return self
*/
@deprecated("please use recommended saveModule(path, overWrite)", "0.3.0")
def save(path : String, overWrite: Boolean = false) : this.type = {
this.clearState()
File.save(this, path, overWrite)
this
}
/**
* Save this module to path with protobuf format
* @param path path to save module, local file system, HDFS and Amazon S3 is supported.
* HDFS path should be like "hdfs://[host]:[port]/xxx"
* Amazon S3 path should be like "s3a://bucket/xxx"
* @param weightPath where to store weight
* @param overWrite if overwrite
* @return self
*/
final def saveModule(path : String, weightPath : String = null,
overWrite: Boolean = false) : this.type = {
this.clearState()
ModulePersister.saveToFile(path, weightPath, this, overWrite)
this
}
/**
* Save this module definition to path.
* @param path path to save module, local file system, HDFS and Amazon S3 is supported.
* HDFS path should be like "hdfs://[host]:[port]/xxx"
* Amazon S3 path should be like "s3a://bucket/xxx"
* @param overWrite if overwrite
* @return self
*/
final def saveDefinition(path : String, overWrite: Boolean = false) : this.type = {
this.clearState()
ModulePersister.saveModelDefinitionToFile(path, this, overWrite)
this
}
/**
* Save this module to path in torch7 readable format
* @param path
* @param overWrite
* @return
*/
final def saveTorch(path : String, overWrite: Boolean = false) : this.type = {
this.clearState()
File.saveTorch(this, path, TYPE_MODULE, overWrite)
this
}
/**
* Save this module to path in caffe readable format
* @param prototxtPath
* @param modelPath
* @param useV2
* @param overwrite
* @return
*/
final def saveCaffe(prototxtPath: String, modelPath: String,
useV2 : Boolean = true, overwrite : Boolean = false) : this.type = {
this.clearState()
CaffePersister.persist[T](prototxtPath, modelPath, this, useV2, overwrite)
this
}
/**
* Save this module to path in tensorflow readable format
* @param inputs
* @param path
* @param byteOrder
* @param dataFormat
* @return
*/
final def saveTF(
inputs : Seq[(String, Seq[Int])],
path: String,
byteOrder: ByteOrder = ByteOrder.LITTLE_ENDIAN,
dataFormat: TensorflowDataFormat = TensorflowDataFormat.NHWC): this.type = {
require(this.isInstanceOf[Graph[T]], "only Graph container can be saved as Tensorflow model")
this.clearState()
val inTrainMode = train
if (inTrainMode) {
this.evaluate()
}
TensorflowSaver.saveGraph(this.asInstanceOf[Graph[T]], inputs, path, byteOrder, dataFormat)
if (inTrainMode) {
this.training()
}
this
}
/**
* Get numeric type of module parameters
* @return
*/
final def getNumericType(): TensorDataType = {
ev.getType()
}
/**
* module predict, return the probability distribution
* @param dataset dataset for prediction
* @param batchSize total batchSize for all partitions.
* if -1, default is 4 * partitionNumber of datatset
* @param shareBuffer whether to share same memory for each batch predict results
*/
final def predict(dataset: RDD[Sample[T]],
batchSize: Int = -1,
shareBuffer: Boolean = false): RDD[Activity] = {
Predictor(this).predict(dataset, batchSize, shareBuffer)
}
/**
* module predict, return the predict label
* @param dataset dataset for prediction
* @param batchSize total batchSize for all partitions.
* if -1, default is 4 * partitionNumber of dataset
*/
final def predictClass(dataset: RDD[Sample[T]], batchSize: Int = -1): RDD[Int] = {
Predictor(this).predictClass(dataset, batchSize)
}
/**
* model predict images, return imageFrame with predicted tensor,
* if you want to call predictImage multiple times,
* it is recommended to use Predictor for DistributedImageFrame
* or LocalPredictor for LocalImageFrame
* @param imageFrame imageFrame that contains images
* @param outputLayer if outputLayer is not null, the output of layer that matches
* outputLayer will be used as predicted output
* @param shareBuffer whether to share same memory for each batch predict results
* @param batchPerPartition batch size per partition, default is 4
* @param predictKey key to store predicted result
* @param featurePaddingParam featurePaddingParam if the inputs have variant size
* @return
*/
final def predictImage(imageFrame: ImageFrame,
outputLayer: String = null,
shareBuffer: Boolean = false,
batchPerPartition: Int = 4,
predictKey: String = ImageFeature.predict,
featurePaddingParam: Option[PaddingParam[T]] = None): ImageFrame = {
imageFrame match {
case distributedImageFrame: DistributedImageFrame =>
Predictor(this, featurePaddingParam, batchPerPartition)
.predictImage(distributedImageFrame, outputLayer, shareBuffer, predictKey)
case localImageFrame: LocalImageFrame =>
val predictor = LocalPredictor(this, featurePaddingParam, batchPerPartition)
val imageFrame = predictor.predictImage(localImageFrame, outputLayer, shareBuffer,
predictKey)
predictor.shutdown()
imageFrame
}
}
/**
* Set weight and bias for the module
* @param newWeights array of weights and bias
* @return
*/
final def setWeightsBias(newWeights: Array[Tensor[T]]): this.type = {
require(parameters() != null, "this layer does not have weight/bias")
require(parameters()._1.length == newWeights.length,
"the number of input weight/bias is not consistant with " +
"number of weight/bias of this layer, " +
s"number of input ${parameters()._1.length}," +
s" number of output ${newWeights.length}")
val weights = parameters()._1
for(i <- newWeights.indices) {
// TODO: enable this checking as we don't respect shape right now.
// require(weights(i).size().deep == newWeights(i).size().deep,
// s"Mismatch shape, ${weights(i).size().mkString(",")}" +
// s" vs ${newWeights(i).size().mkString(",")} ")
weights(i).copy(newWeights(i))
}
this
}
/**
* Get weight and bias for the module
* @return array of weights and bias
*
*/
final def getWeightsBias(): Array[Tensor[T]] = {
if (parameters() != null) {
parameters()._1
} else {
null
}
}
/**
* save weights and bias to file
* @param path file to save
* @param overWrite whether to overwrite or not
*/
final def saveWeights(path: String, overWrite: Boolean): Unit = {
val parameterTable = getParametersTable()
val weightsBiasTable = T()
parameterTable.foreach {
case (name: String, params: Table) =>
val wb = T()
if (params.contains("weight")) {
wb("weight") = params("weight")
}
if (params.contains("bias")) {
wb("bias") = params("bias")
}
weightsBiasTable(name) = wb
case _ => throw new UnsupportedOperationException("invalid parameter table")
}
weightsBiasTable.save(path, overWrite)
}
/**
* load pretrained weights and bias to current module
* @param weightPath file to store weights and bias
* @param matchAll whether to match all layers' weights and bias,
* if not, only load existing pretrained weights and bias
* @return current module
*/
final def loadWeights(weightPath: String, matchAll: Boolean = true): this.type = {
val srcParameter = File.load[Table](weightPath)
val targetParameter = getParametersTable()
copyWeights(targetParameter, srcParameter, matchAll)
this
}
/**
* copy weights from another model, mapping by layer name
* @param srcModel model to copy from
* @param matchAll whether to match all layers' weights and bias,
* @return current module
*/
final def loadModelWeights(srcModel: Module[Float], matchAll: Boolean = true): this.type = {
val srcParameter = srcModel.getParametersTable()
val targetParameter = getParametersTable()
copyWeights(targetParameter, srcParameter, matchAll)
this
}
protected def processInputs(nodes: Seq[ModuleNode[T]]): ModuleNode[T] = {
val curNode = new ModuleNode[T](this)
nodes.foreach(node => {
node.add(curNode, Edge())
})
curNode
}
protected def processInputs(first: (ModuleNode[T], Int),
nodesWithIndex : (ModuleNode[T], Int)*): ModuleNode[T] = {
val curNode = new ModuleNode[T](this)
first._1.add(curNode, Edge(first._2))
nodesWithIndex.foreach(nodeWithIndex => {
nodeWithIndex._1.add(curNode, Edge(nodeWithIndex._2))
})
curNode
}
/**
* Build graph: some other modules point to current module
* @param nodes upstream module nodes
* @return node containing current module
*/
def inputs(nodes : ModuleNode[T]*): ModuleNode[T] = {
validateInput(nodes.map(_.element))
processInputs(nodes)
}
/**
* Build graph: some other modules point to current module
* @param nodes upstream module nodes in an array
* @return node containing current module
*/
def inputs(nodes : Array[ModuleNode[T]]): ModuleNode[T] = {
validateInput(nodes.map(_.element))
processInputs(nodes)
}
/**
* Build graph: some other modules point to current module
* @param first distinguish from another inputs when input parameter list is empty
* @param nodesWithIndex upstream module nodes and the output tensor index. The start index is 1.
* @return node containing current module
*/
def inputs(first: (ModuleNode[T], Int), nodesWithIndex : (ModuleNode[T], Int)*): ModuleNode[T] = {
validateInput(List(first._1.element))
validateInput(nodesWithIndex.map(_._1.element))
processInputs(first, nodesWithIndex: _*)
}
/**
* Generate graph module with start nodes
* @param startNodes
* @return
*/
def toGraph(startNodes: ModuleNode[T]*): Graph[T] = {
val starts = if (startNodes.isEmpty) Array(Input[T]()) else startNodes.toArray
val endNodes = this.getEndNodes(starts)
var graph = Graph(starts, endNodes)
if (graph.isInstanceOf[StaticGraph[T]]) {
// Merge nested graphs inside to make the whole graph non-nested
graph = graph.asInstanceOf[StaticGraph[T]].toSingleGraph()
}
if (inputsFormats != null) {
graph.setInputFormats(inputsFormats)
}
if (outputsFormats != null) {
graph.setOutputFormats(outputsFormats)
}
graph
}
/**
* Find a module with given name. If there is no module with given name, it will return None. If
* there are multiple modules with the given name, an exception will be thrown.
* @param name
* @return
*/
def apply(name : String): Option[AbstractModule[Activity, Activity, T]] = {
if (this.getName() == name) {
Some(this)
} else {
None
}
}
/**
* use ValidationMethod to evaluate module on the given rdd dataset
* @param dataset dataset for test
* @param vMethods validation methods
* @param batchSize total batchsize of all partitions,
* optional param and default 4 * partitionNum of dataset
* @return
*/
final def evaluate(
dataset: RDD[Sample[T]],
vMethods: Array[_ <:ValidationMethod[T]],
batchSize: Option[Int] = None
): Array[(ValidationResult, ValidationMethod[T])] = {
Evaluator(this).test(dataset, vMethods.map(v => v), batchSize)
}
/**
* use ValidationMethod to evaluate module on the given rdd dataset
* @param dataset
* @param vMethods
* @return
*/
final def evaluate(
dataset: RDD[MiniBatch[T]],
vMethods: Array[_ <:ValidationMethod[T]]
): Array[(ValidationResult, ValidationMethod[T])] = {
Evaluator(this).testMiniBatch(dataset, vMethods.map(v => v))
}
/**
* use ValidationMethod to evaluate module on the given ImageFrame
* @param imageFrame ImageFrame for valudation
* @param vMethods validation methods
* @param batchSize total batch size of all partitions
* @return
*/
final def evaluateImage(imageFrame: ImageFrame,
vMethods: Array[_ <:ValidationMethod[T]],
batchSize: Option[Int] = None
): Array[(ValidationResult, ValidationMethod[T])] = {
require(imageFrame.isDistributed(), "ImageFrame must be distributed")
val rdd = imageFrame.toDistributed().rdd.map(imageFeature => {
if (imageFeature.isValid) {
require(imageFeature.contains(ImageFeature.sample), "ImageFeature must have sample")
imageFeature[Sample[T]](ImageFeature.sample)
} else {
null
}
}).filter(_ != null)
evaluate(rdd, vMethods, batchSize)
}
/**
* use ValidationMethod to evaluate module on the given local dataset
* @param dataSet
* @param vMethods
* @return
*/
final def evaluate(
dataSet: LocalDataSet[MiniBatch[T]],
vMethods: Array[_ <:ValidationMethod[T]]
): Array[(ValidationResult, ValidationMethod[T])] = {
Validator(this, dataSet).test(vMethods.map(v => v))
}
/**
* Quantize this module, which reduces the precision of the parameter. Get a higher speed with a
* little accuracy cost.
* @return
*/
final def quantize(): Module[T] = {
ConversionUtils.convert[T](this, true)
}
// ================================= Internal APIs ===========================================
private var namePostfix = Integer.toHexString(java.util.UUID.randomUUID().hashCode())
final private[bigdl] def getNamePostfix : String = namePostfix
final private[bigdl] def setNamePostfix(namePostfix : String) : Unit =
this.namePostfix = namePostfix
/**
* The scale of gradient weight and gradient bias
* before gradParameters being accumulated.
*/
protected var scaleW: Double = 1.0
protected var scaleB: Double = 1.0
private[nn] final def allocateAs(dest: Activity): Activity = dest match {
case tensor: Tensor[T] => Tensor[T]()
case table: Table => T()
case _ => throw new IllegalArgumentException("Activity only support tensor and table now")
}
/**
* The name of the module
*/
private var name : String = null
private var id: Int = 0
private[bigdl] def setId(id: Int): Unit = {
this.id = id
}
private[bigdl] def getId(): Int = this.id
protected final def getPrintName(): String = {
val postfix = if (name == null) {
namePostfix
} else {
name
}
s"${this.getClass.getSimpleName}[${postfix}]"
}
protected var forwardTime = 0L
protected var backwardTime = 0L
private var scaleWCache: Double = scaleW
private var scaleBCache: Double = scaleB
/**
* This function returns two tensors. One for the flattened trainable parameters flatParameters
* and another for the gradients of the energy wrt to the trainable parameters flatGradParameters.
*
* Custom modules should not override this function. They should instead override parameters(...)
* which is, in turn, called by the present function.
*
* This function will go over all the weights and gradWeights and make them view into a single
* tensor (one for weights and one for gradWeights).
*
* @return
*/
final private[bigdl] def getParameters(): (Tensor[T], Tensor[T]) = {
val (weightParameters, gradParameters) = this.parameters()
// maybe null if not weights in this module.
require(weightParameters != null && weightParameters.length > 0,
s"model ${this.getName()} doesn't have any trainable parameters.")
// If some gradParameters are not allocated storage, allocate it
require(weightParameters.size == gradParameters.size,
"weights and gradient number are not match")
weightParameters.zip(gradParameters).foreach{ case(w, g) => g.resizeAs(w)}
(Module.flatten[T](weightParameters), Module.flatten[T](gradParameters))
}
/**
* Module status. It is useful for modules like dropout/batch normalization
*/
protected var train: Boolean = true
protected var line = "\n"
private val engineType: EngineType = Engine.getEngineType()
/**
* get execution engine type
*/
private[bigdl] def checkEngineType(): this.type = {
if (engineType != Engine.getEngineType()) {
throw new Error("Module's EngineType doesn't march global EngineType")
}
this
}
final private def setWeightAndBias(copy : AbstractModule[A, B, T], deepCopy : Boolean): Unit = {
val parameterTable = this.getParametersTable
val copiedModuleParamTable = copy.getParametersTable
if (parameterTable != null) {
require(copiedModuleParamTable != null, "cloned module should have params")
parameterTable.foreach {
case (name: String, params: Table) =>
require(copiedModuleParamTable.get(name) != None, s"cloned module should have for $name")
setLayerWeightAndBias(params,
copiedModuleParamTable.get(name).get.asInstanceOf[Table], deepCopy)
case _ =>
throw new UnsupportedOperationException("unsupported $name and $params")
}
}
}
final private def setLayerWeightAndBias(params : Table,
copyParams : Table, deepCopy : Boolean): Unit = {
params.foreach(param => {
copyParam(params, copyParams, deepCopy, param._1.toString)
})
}
final private def copyParam(params : Table, copyParams : Table,
deepCopy : Boolean, paraName : String) : Unit = {
if (params.contains(paraName)) {
// this is for quantization tensors where the weight might be an array
if (params.get(paraName).get
.isInstanceOf[Array[Tensor[T]]]) {
val copies = copyParams.get(paraName).get
.asInstanceOf[Array[Tensor[T]]]
val origins = params.get(paraName).get
.asInstanceOf[Array[Tensor[T]]]
var i = 0
while (i < copies.length) {
copyTensor(origins(i), copies(i), deepCopy)
i += 1
}
} else {
// For normal layers, their params are just tensors
copyTensor(params.get(paraName).get.asInstanceOf[Tensor[T]],
copyParams.get(paraName).get.asInstanceOf[Tensor[T]], deepCopy)
}
}
}
final private def copyTensor(t1 : Tensor[T], t2 : Tensor[T], deepCopy : Boolean) = {
if (t2.isInstanceOf[QuantizedTensor[_]]) {
t2.asInstanceOf[QuantizedTensor[_]].release()
}
if (deepCopy) {
t2.copy(t1)
} else {
t2.set(t1)
}
}
final private def copyWeights(target: Table, src: Table, matchAll: Boolean): Unit = {
target.foreach {
case (name: String, targetParams: Table) =>
if (src.contains(name)) {
val srcParams = src[Table](name)
if (srcParams.contains("weight")) {
val w = srcParams[Tensor[T]]("weight")
targetParams[Tensor[T]]("weight").resizeAs(w).copy(w)
}
if (srcParams.contains("bias")) {
val b = srcParams[Tensor[T]]("bias")
targetParams[Tensor[T]]("bias").resizeAs(b).copy(b)
}
} else {
if (matchAll) new Exception(s"module $name cannot find corresponding weight bias")
}
case _ =>
throw new UnsupportedOperationException("unsupported $name and $targetParams")
}
}
private[bigdl] def canEqual(other: Any): Boolean = other.isInstanceOf[AbstractModule[A, B, T]]
/**
* Generate end nodes of current module with start nodes
* @param startNodes: current start nodes
* @return current end nodes
*/
private[bigdl] def getEndNodes(startNodes: Array[ModuleNode[T]]): Array[ModuleNode[T]] = {
val endNodes = Array(this.processInputs(startNodes))
endNodes
}
/**
* Return classTag numerics for module serialization. If your module contains multiple classtag
* in the constructor, you should override this method
* @return
*/
private[bigdl] def getClassTagNumerics() : (Array[ClassTag[_]], Array[TensorNumeric[_]]) = {
(Array(scala.reflect.classTag[T]), Array(ev))
}
/**
* Check if some module is duplicated in the model. For a layer it cannot be duplicated.
* Container should override this method
*/
private[bigdl] def checkDuplicate(
record: mutable.HashSet[Int] = mutable.HashSet()
): Unit = {
val errMsg = "Some module is duplicate in the current model: "
val curId = System.identityHashCode(this)
require(this.skipDuplicateCheck() || !record.contains(curId), errMsg + this.getName())
record.add(curId)
}
/**
* Sometimes, some layer need skip the duplicate check process, e.g. Keras-like input layer
* @return
*/
private[nn] def skipDuplicateCheck(): Boolean = false
/**
* if the model contains native resources such as aligned memory, we should release it by manual.
* JVM GC can't release them reliably.
*/
def release(): Unit = {}
/**
* parameter synchronizer for gradient synchronization
*/
private var _parameterSynchronizer: DistriParameterSynchronizer[T] = null
/**
* set parameter synchronizer
* @param parameterSynchronizer parameter synchronizer
*/
private[bigdl] def setParameterSynchronizer(parameterSynchronizer:
DistriParameterSynchronizer[T]): Unit = {
_parameterSynchronizer = parameterSynchronizer
}
/**
* get parameter synchronizer
* @return parameter synchronizer
*/
private[bigdl] def getParameterSynchronizer():
DistriParameterSynchronizer[T] = _parameterSynchronizer
private var _optimMethod: OptimMethod[T] = null
/**
* set optim method
*/
private[bigdl] def setOptimMethod(optimMethod: OptimMethod[T]): Unit = {
_optimMethod = optimMethod
}
/**
* get optim method for layer
*/
private[bigdl] def getOptimMethod(): OptimMethod[T] = _optimMethod
private[bigdl] def updateParameter(): Unit = {
if (this.getParameterSynchronizer() != null && this.isTraining) {
if (this.parameters() != null) {
val before = System.nanoTime()
val (weights, grads) = this.getParameterSynchronizer.get(this.getName)
val syncEndTime = System.nanoTime()
if (grads != null) {
val optimMethod = this.getOptimMethod
require(optimMethod != null, s"optim method for ${this.getName} cannot be null")
optimMethod.optimize(_ => (ev.fromType(0.0f), grads),
weights)
this.zeroGradParameters
}
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy