Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*
* Copyright 2016 The BigDL Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.bigdl.nn.mkldnn
import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.nn.MklInt8Convertible
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.Node
/**
* Add fusion operation for dnn graph node, there are three cases about fusion:
* case 1: fuse relu with conv(SpatialConvolution) or bn(SpatialBatchNormalization)
* case 2: fuse conv with bn
* case 3: sum conv output with another layer output
* If you want to use fusion for inference, please set property "bigdl.mkldnn.fusion" to true
*/
private[mkldnn] object Fusion {
private def fuse = System.getProperty("bigdl.mkldnn.fusion", "true").toBoolean
def fuseModule(node: Node[AbstractModule[Activity, Activity, Float]]): Unit = {
if (!fuse) return;
node.element match {
case relu: ReLU => fusionRelu(node)
case bn: SpatialBatchNormalization => fusionBN(node)
case _ =>
}
}
def fuseCAdd(node: Node[AbstractModule[Activity, Activity, Float]]): Unit = {
if (!fuse) return;
node.element match {
case cadd: CAddTable => fusionCAddTable(node)
case _ =>
}
}
/**
* fuse conv(without relu or bn fusion) with bn
* if bn has fused with relu, then fuse relu and bn with conv
* @param node
*/
private def fusionBN(node: Node[AbstractModule[Activity, Activity, Float]]): Unit = {
val bn = node.element.asInstanceOf[SpatialBatchNormalization]
node.prevNodes.foreach(n => {
n.element match {
case conv : SpatialConvolution =>
// reminder: may be conv can fuse with two bn
if (!conv.relu && !conv.batchNorm) {
if (bn.relu) conv.setReLU(true)
fusionConvBn(conv, bn)
node.element = Identity[Float]().asInstanceOf[AbstractModule[Activity, Activity, Float]]
}
case _ => null
}})
}
/**
* fuse relu with conv or bn
* @param node
*/
private def fusionRelu(node: Node[AbstractModule[Activity, Activity, Float]]): Unit = {
node.prevNodes.foreach(n => {
n.element match {
case conv: SpatialConvolution =>
if (!conv.relu) {
conv.setReLU(true)
conv.setOutputScales(node.element.asInstanceOf[ReLU].getOutputScales())
node.element = Identity[Float]().asInstanceOf[AbstractModule[Activity, Activity, Float]]
}
case bn: SpatialBatchNormalization =>
if (!bn.relu) {
bn.setReLU(true)
bn.setOutputScales(node.element.asInstanceOf[ReLU].getOutputScales())
node.element = Identity[Float]().asInstanceOf[AbstractModule[Activity, Activity, Float]]
}
case _ => null
}})
}
private def findPrevious(node: Node[AbstractModule[Activity, Activity, Float]])
: Node[AbstractModule[Activity, Activity, Float]] = {
if (node.element.isInstanceOf[Identity] && node.prevNodes.length == 1) {
findPrevious(node.prevNodes(0))
} else node
}
private def findNext(node: Node[AbstractModule[Activity, Activity, Float]])
: Seq[Node[AbstractModule[Activity, Activity, Float]]] = {
if (node.element.isInstanceOf[Identity]) {
node.nextNodes.flatMap(n => findNext(n))
} else {
Seq(node)
}
}
/**
* If previous layers number of CAddTable is two, and one of it is conv layer.
* then fuse output of the other layer in conv layer.
* @param node
*/
private def fusionCAddTable(node: Node[AbstractModule[Activity, Activity, Float]]): Unit = {
if (node.element.isInstanceOf[CAddTable] && node.prevNodes.length == 2) {
val previousNodes = node.prevNodes.toArray
val node1 = findPrevious(previousNodes(0))
val node2 = findPrevious(previousNodes(1))
var conv : Node[Module[Float]] = null
var otherNumber: Int = 0
if (node1.element.isInstanceOf[SpatialConvolution]) {
if (requirements(node1)) conv = node1
otherNumber = 1
} else if (node2.element.isInstanceOf[SpatialConvolution]) {
if (requirements(node2)) conv = node2
otherNumber = 0
}
// meet fuse requirements
if (conv != null) {
node.element = conv.element
val element = node.element.asInstanceOf[SpatialConvolution]
element.setSumOp(previousNodes(otherNumber).element, otherNumber + 1)
conv.element = Identity[Float]().asInstanceOf[AbstractModule[Activity, Activity, Float]]
val nexts = node.nextNodes(0)
if (nexts.element.isInstanceOf[ReLU] && !element.relu) {
node.element.asInstanceOf[SpatialConvolution].setReLU(true)
node.element.asInstanceOf[SpatialConvolution].setOutputScales(
nexts.element.asInstanceOf[ReLU].getOutputScales())
nexts.element = new Identity()
}
val prevIsNotIdentity = findPrevious(previousNodes(otherNumber))
prevIsNotIdentity.element match {
case conv: SpatialConvolution =>
conv.setOutputScales(node.element.asInstanceOf[SpatialConvolution].getOutputScales())
case relu: ReLU =>
relu.setOutputScales(node.element.asInstanceOf[SpatialConvolution].getOutputScales())
prevIsNotIdentity.nextNodes.flatMap(x => findNext(x))
.filter(x => x != node && x.element.isInstanceOf[MklInt8Convertible])
.foreach(_.element.asInstanceOf[MklInt8Convertible].setInputScales(
node.element.asInstanceOf[SpatialConvolution].getOutputScales()))
case _ =>
}
}
}
}
private def requirements(node: Node[AbstractModule[Activity, Activity, Float]]): Boolean = {
val conv = node.element.asInstanceOf[SpatialConvolution]
if (conv.sum) false else true
}
private def fusionConvBn(conv: SpatialConvolution,
bn: SpatialBatchNormalization): Unit = {
conv.setBatchNorm(true)
val originVar = Tensor[Float].resize(bn.runningVariance.size()).copy(bn.runningVariance.dense)
val originMean = Tensor[Float].resize(bn.runningMean.size()).copy(bn.runningMean.dense)
val convWeight = Tensor[Float].resize(conv.weight.size()).copy(conv.weight.dense)
val convBias = Tensor[Float].resize(conv.bias.size()).copy(conv.bias.dense)
val bnWeight = Tensor[Float].resizeAs(bn.weightAndBias.dense).copy(bn.weightAndBias.dense)
(0 until bn.nOutput).foreach { j =>
val variance = originVar.storage().array()(j + originVar.storageOffset() - 1)
val base = Math.sqrt(variance.asInstanceOf[Float] + bn.eps).toFloat
require(base != 0.0, s"the eps of ${bn.getName()} should be more than 0")
val alpha = bnWeight.storage().array()(bnWeight.storageOffset() - 1 + j)
val beta = bnWeight.storage().array()(bnWeight.storageOffset() - 1 + bn.nOutput + j)
val weight = if (conv.nGroup == 1) {
convWeight.select(1, j + 1)
} else {
convWeight.select(2, j + 1)
}
weight.div(base)
weight.mul(alpha)
val bias = convBias.storage().array()(j)
val mean = originMean.storage().array()(j)
convBias.storage().array()(j) = alpha / base * bias + beta - (alpha * mean) / base
}
conv.weight.copy(convWeight)
conv.bias.copy(convBias)
// regenerate the weight scales and output scales
conv.flushWeightScales(conv.weight.dense)
conv.setOutputScales(bn.getOutputScales())
}
def setNegativeInputOfConv(node: Node[AbstractModule[Activity, Activity, Float]]): Unit = {
def findAllNonIdentityPrevs(node: Node[AbstractModule[Activity, Activity, Float]])
: Seq[Node[AbstractModule[Activity, Activity, Float]]] = {
// TODO currently, it will only skip the Identity, MaxPooling, AvgPooling
// becase if the output of layer/op previous of the three, they will output
// nonnegative too. it's not an elegant impl.
if (node.element.isInstanceOf[Identity] ||
node.element.isInstanceOf[MaxPooling] ||
node.element.isInstanceOf[AvgPooling]) {
node.prevNodes.flatMap(findAllNonIdentityPrevs)
} else {
Seq(node)
}
}
if (!fuse || !node.element.isInstanceOf[SpatialConvolution]) return
val successFromReLU = node.prevNodes.flatMap(x => findAllNonIdentityPrevs(x))
.map { x =>
x.element match {
case _: SpatialConvolution =>
x.element.asInstanceOf[SpatialConvolution].relu
case _: ReLU =>
true
case _ =>
false
}
}.forall(_ == true)
if (successFromReLU) {
node.element.asInstanceOf[SpatialConvolution].negativeInput = false
}
}
}