com.intel.analytics.bigdl.nn.mkldnn.SpatialBatchNormalization.scala Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2016 The BigDL Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.bigdl.nn.mkldnn
import com.intel.analytics.bigdl.mkl._
import com.intel.analytics.bigdl.nn.abstractnn.{Activity, DataFormat, Initializable}
import com.intel.analytics.bigdl.nn.mkldnn.Phase.{InferencePhase, TrainingPhase}
import com.intel.analytics.bigdl.nn.{MklInt8Convertible, Ones, VariableFormat, Zeros}
import com.intel.analytics.bigdl.tensor._
import scala.collection.mutable.ArrayBuffer
class SpatialBatchNormalization(
val nOutput: Int,
val eps: Double = 1e-5,
val momentum: Double = 0.1,
private val initWeight: Tensor[Float] = null,
private val initBias: Tensor[Float] = null,
private val initGradWeight: Tensor[Float] = null,
private val initGradBias: Tensor[Float] = null,
val format: DataFormat = DataFormat.NCHW
) extends MklDnnLayer with Initializable with MklInt8Convertible {
@transient private var forwardDesc: Long = 0L
private var _relu: Boolean = false
def setReLU(value: Boolean): this.type = {
_relu = value
this
}
def relu: Boolean = _relu
// reminder: runningMean/runningVariance in blas batch_norm is
// same to scaled runningMean/runningVariance in dnn.
private[bigdl] var needScale = false
class SwitchablePrimitives() {
private var _forwardDesc: Long = 0
private var _updateOutputMemoryPrimitives : Array[Long] = _
private var _updateOutputPrimitives: Array[Long] = _
private var _fwdPrimDesc: Long = 0
private var _inputFormat: NativeData = _
private var _outputFormat: NativeData = _
def switchInOutFormats(): Unit = {
if (_inputFormat == null) {
_inputFormat = MemoryData.operationWant(fwdPrimDesc, Query.SrcPd)
}
if (_outputFormat == null) {
_outputFormat = MemoryData.operationWant(fwdPrimDesc, Query.DstPd)
}
_inputFormats(0) = _inputFormat
_outputFormats(0) = _outputFormat
}
def fwdPrimDesc: Long = {
if (_fwdPrimDesc == 0) {
_fwdPrimDesc = if (relu) {
val postOps = MklDnnMemory.CreatePostOps()
MklDnn.PostOpsAppendEltwise(postOps, 1.0f, AlgKind.EltwiseRelu, 0.0f, 0.0f)
val attr = MklDnnMemory.CreateAttr()
MklDnn.AttrSetPostOps(attr, postOps)
MklDnnMemory.PrimitiveDescCreateV2(_forwardDesc, attr, runtime.engine, 0)
} else {
MklDnnMemory.PrimitiveDescCreate(_forwardDesc, runtime.engine, 0)
}
}
_fwdPrimDesc
}
def forwardDesc(gen: () => Long): Long = {
if (_forwardDesc == 0) {
_forwardDesc = gen()
}
_forwardDesc
}
def switchUpdateOutputMemoryPrimitives(gen: () => (Array[Long], Array[Long])): Unit = {
if (_updateOutputMemoryPrimitives == null) {
val generated = gen()
_updateOutputMemoryPrimitives = generated._1
_updateOutputPrimitives = generated._2
}
updateOutputMemoryPrimitives = _updateOutputMemoryPrimitives
updateOutputPrimitives = _updateOutputPrimitives
}
}
@transient private lazy val trainingPrimitives = new SwitchablePrimitives
@transient private lazy val inferencePrimitives = new SwitchablePrimitives
@transient private var updateOutputTensors: Array[Tensor[Float]] = _
@transient private var updateOutputMemoryPrimitives: Array[Long] = _
@transient private var updateGradInputTensors: Array[Tensor[Float]] = _
@transient private var updateGradInputMemoryPrimitives: Array[Long] = _
@transient private var modelPhase: Phase = null
private val mean: DnnTensor[Float] = DnnTensor[Float](nOutput)
private val variance: DnnTensor[Float] = DnnTensor[Float](nOutput)
private[mkldnn] val runningMean = new TensorMMap(Array(nOutput))
private[mkldnn] val runningVariance = new TensorMMap(Array(nOutput))
// TODO we should make it private. Currently, ResNet50 will use it out of this scope.
val weightAndBias = new TensorMMap(Array(nOutput * 2))
val gradWeightAndBias = new TensorMMap(Array(nOutput * 2))
// TODO the two should be learnable parameters
var scaleFactor: Float = 1.0f
var biasFactor: Float = 1.0f
private val runningMeanScaled = Tensor[Float].resizeAs(runningMean.dense)
private val runningVarianceScaled = Tensor[Float].resizeAs(runningVariance.dense)
// the blank shoud be here, otherwise the runningVarianceScaled will be a method
{
val wInit = Ones // RandomUniform(0, 1)
val bInit = Zeros
setInitMethod(wInit, bInit)
}
override def reset(): Unit = {
val init = Tensor[Float]().resize(Array(2, nOutput))
val weight = init.select(1, 1)
val bias = init.select(1, 2)
if (initWeight != null) {
require(initWeight.size(1) == nOutput)
weight.copy(initWeight)
} else {
weightInitMethod.init(weight, VariableFormat.ONE_D)
}
if (initBias != null) {
require(initBias.size(1) == nOutput)
bias.copy(initBias)
} else {
biasInitMethod.init(bias, VariableFormat.ONE_D)
}
weightAndBias.dense.copy(init.view(2 * nOutput))
val zeros = Tensor[Float](Array(nOutput)).fill(0)
mean.copy(zeros)
variance.copy(zeros)
runningMean.copy(zeros)
runningVariance.copy(zeros)
}
private object Index extends Serializable {
val input = 0
val weight = 1
val output = 2
val mean = 3
val variance = 4
}
private def initPhase(phase: Phase): Unit = {
if (phase != null) modelPhase = phase
(isTraining(), modelPhase) match {
case (true, InferencePhase) =>
train = false
case (false, TrainingPhase) =>
train = true
case (true, null) =>
modelPhase = TrainingPhase
case (false, null) =>
modelPhase = InferencePhase
case _ =>
}
}
override private[mkldnn] def initFwdPrimitives(inputs: Array[MemoryData], phase: Phase) = {
val m = inputs(0).shape.product / this.nOutput
biasFactor = if (m > 1) { m.toFloat / (m - 1) } else { 1 }
val List(mean, variance, runningMean, runningVariance): List[NativeData] =
(0 until 4).map { _ =>
NativeData(Array(nOutput), Memory.Format.x)
}.toList
// weight and bias should be combined
val weightAndBias: NativeData = NativeData(Array(nOutput * 2), Memory.Format.x)
// the bn only accept F32 as input, like lrn
val src = NativeData(inputs.head.shape, inputs.head.layout, DataType.F32)
// init once
if (_inputFormats == null) {
_inputFormats = new Array[MemoryData](1)
require(_outputFormats == null)
_outputFormats = new Array[MemoryData](1)
}
// init phase status
initPhase(phase)
modelPhase match {
case TrainingPhase =>
forwardDesc = trainingPrimitives.forwardDesc(() => MklDnnMemory.BatchNormForwardDescInit(
PropKind.Forward,
src.getMemoryDescription(), eps.toFloat, MklDnn.BatchNormFlag.mkldnn_use_scaleshift))
val fwdPrimDesc = trainingPrimitives.fwdPrimDesc
trainingPrimitives.switchInOutFormats()
trainingPrimitives.switchUpdateOutputMemoryPrimitives(() => {
val srcs = Array(inputFormats()(0), weightAndBias).map(_.getPrimitive(runtime))
val dsts = Array(outputFormats()(0), mean, variance).map(_.getPrimitive(runtime))
val indexes = Array.fill(srcs.length)(0)
val primitive = MklDnnMemory.PrimitiveCreate2(fwdPrimDesc, srcs, indexes,
srcs.length, dsts, dsts.length)
val _updateOutputMemoryPrimitives = srcs ++ dsts
val _updateOutputPrimitives = Array(primitive)
(_updateOutputMemoryPrimitives, _updateOutputPrimitives)
})
case InferencePhase =>
// we always use the weight and bias / scale and offset. So the flags should be combined
// with use_scaleshift and use_global_stats.
forwardDesc = inferencePrimitives.forwardDesc(() =>
MklDnnMemory.BatchNormForwardDescInit(PropKind.ForwardInference,
src.getMemoryDescription(), eps.toFloat, MklDnn.BatchNormFlag.mkldnn_use_global_stats
| MklDnn.BatchNormFlag.mkldnn_use_scaleshift))
val fwdPrimDesc = inferencePrimitives.fwdPrimDesc
inferencePrimitives.switchInOutFormats()
inferencePrimitives.switchUpdateOutputMemoryPrimitives(() => {
val srcs = Array(inputFormats()(0), mean, variance, weightAndBias).map(_.getPrimitive
(runtime))
val dsts = Array(outputFormats()(0).getPrimitive(runtime))
val indexes = Array.fill(srcs.length)(0)
val primitive = MklDnnMemory.PrimitiveCreate2(fwdPrimDesc, srcs, indexes,
srcs.length, dsts, dsts.length)
val _updateOutputMemoryPrimitives = srcs ++ dsts
val _updateOutputPrimitives = Array(primitive)
(_updateOutputMemoryPrimitives, _updateOutputPrimitives)
})
case _ => throw new UnsupportedOperationException
}
// init once
// if the output is not null, it means we have initialized the primitives before.
// so we do not need create weightAndBias native space again.
if (output == null || output.isInstanceOf[DnnTensor[_]] &&
output.toTensor[Float].size().deep != outputFormats()(0).shape.deep) {
output = initTensor(outputFormats()(0))
}
if (updateOutputTensors != null) {
updateOutputTensors = null
}
// init once
if (this.weightAndBias.native == null) {
if (modelPhase == InferencePhase) {
this.runningMean.setMemoryData(
HeapData(this.runningMean.size(), Memory.Format.x), runningMean, runtime)
this.runningVariance.setMemoryData(
HeapData(this.runningVariance.size(), Memory.Format.x), runningVariance, runtime)
// for inference, we must copy the heap memory to native first.
this.runningMean.sync()
this.runningVariance.sync()
} else {
this.runningMean.setMemoryData(runningMean,
HeapData(this.runningMean.size(), Memory.Format.x), runtime)
this.runningVariance.setMemoryData(runningVariance,
HeapData(this.runningVariance.size(), Memory.Format.x), runtime)
}
// for runningMean and runningVariance, we should copy them to native at first
this.weightAndBias.setMemoryData(HeapData(this.weightAndBias.size(), Memory.Format.x),
weightAndBias, runtime)
}
this.weightAndBias.sync()
(inputFormats(), outputFormats())
}
override def updateOutput(input: Activity): Activity = {
if (updateOutputTensors == null) {
if (this.isTraining()) {
val buffer = new ArrayBuffer[Tensor[Float]]()
buffer.append(input.asInstanceOf[Tensor[Float]])
buffer.append(weightAndBias.native)
buffer.append(output.asInstanceOf[Tensor[Float]])
buffer.append(mean)
buffer.append(variance)
updateOutputTensors = buffer.toArray
} else {
val buffer = new ArrayBuffer[Tensor[Float]]()
buffer.append(input.asInstanceOf[Tensor[Float]])
buffer.append(mean)
buffer.append(variance)
buffer.append(weightAndBias.native)
buffer.append(output.asInstanceOf[Tensor[Float]])
updateOutputTensors = buffer.toArray
}
}
if (this.isTraining()) {
weightAndBias.sync()
} else {
// we should re-computing the running mean and running variance.
// FIXME should do it at `initFwdPrimitives`
mean.scale(runningMean.native, 1 / scaleFactor)
variance.scale(runningVariance.native, 1 / scaleFactor)
}
updateWithNewTensor(updateOutputTensors, 0, input)
MklDnnOps.streamSubmit(runtime.stream, 1, updateOutputPrimitives, updateOutputPrimitives.length,
updateOutputMemoryPrimitives, updateOutputTensors)
if (this.isTraining()) {
// update running(Mean, Var) and scaleFactor
scaleFactor = scaleFactor * momentum.toFloat + 1
mean.axpby(1, momentum.toFloat, runningMean.native)
variance.axpby(biasFactor, momentum.toFloat, runningVariance.native)
runningMean.sync()
runningVariance.sync()
}
output
}
override private[mkldnn] def initBwdPrimitives(grad: Array[MemoryData], phase: Phase) = {
_gradOutputFormats = Array(NativeData(outputFormats()(0).shape, outputFormats()(0).layout))
// init phase status
initPhase(phase)
// [PERF] the format of gradInput should be the same as input
val backwardDesc = modelPhase match {
case TrainingPhase =>
MklDnnMemory.BatchNormBackwardDescInit(PropKind.Backward,
inputFormats()(0).getMemoryDescription(),
inputFormats()(0).getMemoryDescription(), eps.toFloat,
MklDnn.BatchNormFlag.mkldnn_use_scaleshift)
case _ => throw new UnsupportedOperationException
}
val gradWeightAndBias: NativeData = NativeData(Array(nOutput * 2), Memory.Format.x)
val gradWeightPrimitive = gradWeightAndBias.getPrimitive(runtime)
val primDesc = MklDnnMemory.PrimitiveDescCreate(backwardDesc, runtime.engine, 0)
_gradInputFormats = Array(MemoryData.operationWant(primDesc, Query.DiffSrcPd))
// maybe will throw null exception
val srcs = Array(updateOutputMemoryPrimitives(Index.input),
updateOutputMemoryPrimitives(Index.mean),
updateOutputMemoryPrimitives(Index.variance),
grad(0).getPrimitive(runtime),
updateOutputMemoryPrimitives(Index.weight))
val indexes = Array.fill(srcs.length)(0)
val dsts = Array(gradInputFormats()(0), gradWeightAndBias).map(_.getPrimitive(runtime))
val primitive = MklDnnMemory.PrimitiveCreate2(primDesc, srcs, indexes, srcs.length,
dsts, dsts.length)
updateGradInputMemoryPrimitives = srcs ++ dsts
updateGradInputPrimitives = Array(primitive)
gradInput = initTensor(gradInputFormats()(0))
this.gradWeightAndBias.setMemoryData(gradWeightAndBias,
HeapData(this.gradWeightAndBias.size(), Memory.Format.x), runtime)
this.gradWeightAndBias.zero()
(_gradOutputFormats, gradInputFormats())
}
override def updateGradInput(input: Activity, gradOutput: Activity): Activity = {
if (updateGradInputTensors == null) {
val buffer = new ArrayBuffer[Tensor[Float]]()
buffer.append(input.asInstanceOf[Tensor[Float]])
buffer.append(mean)
buffer.append(variance)
buffer.append(gradOutput.asInstanceOf[Tensor[Float]])
buffer.append(weightAndBias.native)
buffer.append(gradInput.asInstanceOf[Tensor[Float]])
buffer.append(gradWeightAndBias.native)
updateGradInputTensors = buffer.toArray
}
updateWithNewTensor(updateGradInputTensors, 0, input)
updateWithNewTensor(updateGradInputTensors, 3, gradOutput)
MklDnnOps.streamSubmit(runtime.stream, 1, updateGradInputPrimitives,
updateGradInputPrimitives.length, updateGradInputMemoryPrimitives, updateGradInputTensors)
gradWeightAndBias.sync()
gradInput
}
override def accGradParameters(input: Activity, gradOutput: Activity): Unit = {
// do nothing
}
override def zeroGradParameters(): Unit = {
}
override def parameters(): (Array[Tensor[Float]], Array[Tensor[Float]]) = {
(Array(weightAndBias.dense), Array(gradWeightAndBias.dense))
}
override def paramsMMap(): (Array[TensorMMap], Array[TensorMMap]) = {
(Array(weightAndBias), Array(gradWeightAndBias))
}
override def getExtraParameter(): Array[Tensor[Float]] = {
if (needScale) {
runningMeanScaled.copy(runningMean.dense).div(scaleFactor)
runningVarianceScaled.copy(runningVariance.dense).div(scaleFactor)
Array(runningMeanScaled, runningVarianceScaled)
} else {
Array(runningMean.dense, runningVariance.dense)
}
}
override def toString(): String = {
s"nn.mkl.SpatialBatchNormalization($nOutput, $eps, $momentum)"
}
override def evaluate(): this.type = {
if (modelPhase == TrainingPhase) {
initFwdPrimitives(inputFormats(), InferencePhase)
}
this
}
override def training(): this.type = {
if (modelPhase == InferencePhase) {
initFwdPrimitives(inputFormats(), TrainingPhase)
}
this
}
}
object SpatialBatchNormalization {
def apply(
nOutput: Int,
eps: Double = 1e-5,
momentum: Double = 0.1,
affine: Boolean = true,
initWeight: Tensor[Float] = null,
initBias: Tensor[Float] = null,
initGradWeight: Tensor[Float] = null,
initGradBias: Tensor[Float] = null,
format: DataFormat = DataFormat.NCHW): SpatialBatchNormalization = {
new SpatialBatchNormalization(nOutput, eps, momentum, initWeight, initBias, initGradWeight,
initGradBias, format = format)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy