All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.MklInt8Convertible.scala Maven / Gradle / Ivy

/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.Table

import scala.collection.mutable.ArrayBuffer


/**
 * Trait which provides MKL-DNN functionality to convert from FP32 to INT8
 */
trait MklInt8Convertible {
  // input dimension mask
  protected var inputDimMask: Int = 0
  // output dimension mask
  protected var outputDimMask: Int = 0
  // weight dimension mask
  protected var weightDimMask: Int = 0
  // input activation scales
  private[nn] var inputScalesBuffer: ArrayBuffer[Array[Float]] = ArrayBuffer.empty[Array[Float]]
  // output scales
  private[nn] var outputScalesBuffer: ArrayBuffer[Array[Float]] = ArrayBuffer.empty[Array[Float]]
  // weight scales
  private[nn] var weightScalesBuffer: ArrayBuffer[Array[Float]] = ArrayBuffer.empty[Array[Float]]

  /**
   * Calculate the required scales for converting int8 modules
   * Currently there are four type of modules should be supported:
   * 1) Graph: calculate scales for input and output
   * 2) Linear: calculate scales for input, output and weight
   * 3) Spatial Convolution: calculate scales for input, output and weight
   * 4) Sequential: calculate scales for input, output as well as the scales of submodules
   * 5) ConcatTable: calculate scales for input, output as well as the scales of submodules
   * @param inputActvt input activity
   */
  def calcScales(inputActvt: Activity): Unit = {

    if (inputActvt != null) {
      val module = this.asInstanceOf[AbstractModule[_, _, Float]]
      // do not forward here, because the input maybe not the real input
      // such as th ReLU(true) will do the computing inplace
      val outputActvt = module.output.asInstanceOf[Activity]

      module match {
        case graph: Graph[Float] => calcGraphScales(inputActvt, outputActvt)
        // handlers for BLAS modules
        case linear: Linear[Float@unchecked] =>
          calcModuleScales(inputActvt, outputActvt, getWeight(linear))
        case spatialConv: SpatialConvolution[Float@unchecked] =>
          calcModuleScales(inputActvt, outputActvt, getWeight(spatialConv))
        case relu: ReLU[Float@unchecked] =>
          calcModuleScales(inputActvt, outputActvt)
        case caddTable: CAddTable[Float@unchecked, Float@unchecked] =>
          calcModuleScales(inputActvt, outputActvt)
        case bn: SpatialBatchNormalization[Float@unchecked] =>
          calcModuleScales(inputActvt, outputActvt)
        case sequential: Sequential[Float@unchecked] =>
          calcSequentialScales(inputActvt, outputActvt)
        case concatTable: ConcatTable[Float@unchecked] =>
          calcConcatTableScales(inputActvt, outputActvt)
        // handlers for DNN modules
        case dnnLinear: mkldnn.Linear =>
          calcModuleScales(inputActvt, outputActvt, getWeight(dnnLinear))
        case dnnSpatialConv: mkldnn.SpatialConvolution =>
          calcModuleScales(inputActvt, outputActvt, getWeight(dnnSpatialConv))
        case dnnSequential: mkldnn.Sequential =>
          calcSequentialScales(inputActvt, outputActvt)
        case dnnConcatTable: mkldnn.ConcatTable =>
          calcConcatTableScales(inputActvt, outputActvt)
        case relu: mkldnn.ReLU =>
          calcModuleScales(inputActvt, outputActvt)
        case bn: mkldnn.SpatialBatchNormalization =>
          calcModuleScales(inputActvt, outputActvt)
        case caddTable: mkldnn.CAddTable =>
          calcModuleScales(inputActvt, outputActvt)
        case _ => throw new UnsupportedOperationException(
          "Int8 conversion is not supported for module: " + module.getName()
        )
      }
    }
  }

  private[bigdl] def flushWeightScales(weight: Tensor[Float]): Unit = {
    weightScalesBuffer.clear()
    appendWeightScales(calcTensorScale(weight, weightDimMask))
  }

  /**
   * Calculate module's scales given its input and output
   * Store calculated scales in array buffers
   * @param inputActvt input activity
   * @param outputActvt output activity
   */
  private def calcModuleScales(inputActvt: Activity, outputActvt: Activity): Unit = {
    if (inputActvt != null) {
      val denseIn = mkldnn.Utils.getDenseIn(this, inputActvt)
      calcActivityScales(denseIn, inputDimMask).foreach(appendInputScales)
    }

    if (outputActvt != null) {
      val denseOut = mkldnn.Utils.getDenseOut(this, outputActvt)
      calcActivityScales(denseOut, outputDimMask).foreach(appendOutputScales)
    }
  }

  /**
   * Calculate module's scales given its input, output and weight
   * @param inActivity input activity
   * @param outActivity output activity
   * @param weightTensor weight
   */
  private def calcModuleScales(inActivity: Activity, outActivity: Activity,
                               weightTensor: Tensor[Float]): Unit = {
    // calculate scales for input and output
    calcModuleScales(inActivity, outActivity)
    // calculate scales for weight
    appendWeightScales(calcTensorScale(weightTensor, weightDimMask))
  }

  /**
   * Calculate scales given activity and mask
   * @param activity target activity to get scales
   * @param mask dimension mask associated with target activity
   */
  private def calcActivityScales(activity: Activity, mask: Int): Array[Array[Float]] = {
    activity match {
      case tensor: Tensor[Float@unchecked] => Array(calcTensorScale(activity.toTensor[Float], mask))
      case table: Table => activity.toTable.map[Array[Float]](elem => {
          val index: Any = elem._1
          val tensor: Tensor[Float] = elem._2.asInstanceOf[Tensor[Float]]
          calcTensorScale(tensor, mask)
        }).toArray
      case _ => throw new IllegalArgumentException("Invalid activity " + activity)
    }
  }

  /** Given a tensor and a dimension mask, calculate the scales of this tensor
   * @param tensor tensor of float, stores high dimension data
   * @param mask dimension mask
   * @return scalesBuffer Array, an array stores scales
   */
  private def calcTensorScale(tensor: Tensor[Float], mask: Int): Array[Float] = {
    // we must clone the tensor, the abs will change the original tensor's value
    if (mask == 0) { // no mask performed, return max of tensor storage
      Array(tensor.clone().abs().max())
    } else if (scala.math.pow(2, tensor.dim()) - 1 == mask) {
      // mask bits are ON for all dimensions
      // return the abs value of tensor as an array
      tensor.clone().abs().storage().toArray[Float]
    } else {
      Utils.calcScales(tensor, mask)
    }
  }

  /**
   * Scales calculator for Sequential Module
   * @param inputActvt input of the Sequential Module
   * @param outputActvt output of the Sequential Module
   */
  private def calcSequentialScales(inputActvt: Activity, outputActvt: Activity): Unit = {
    require(this.isInstanceOf[Sequential[Float@unchecked]] || this.isInstanceOf[mkldnn.Sequential],
      this.getClass.getName + " is not an instance of Sequential.")

    val module: DynamicContainer[_, _, Float] = this.asInstanceOf[DynamicContainer[_, _, Float]]

    // output of previous module is the input of current module
    var prevOutputActivity: Activity = inputActvt

    // calc scales for main module
    this.calcModuleScales(inputActvt, outputActvt)

    // Iterator of Sequential modules
    val moduleIter = module.modules.iterator
    // calc scales for sub-module
    while (moduleIter.hasNext) {
      val currModule = moduleIter.next()
      if (currModule.isInstanceOf[MklInt8Convertible]) {
        val cvtbModule = currModule.asInstanceOf[MklInt8Convertible]
        cvtbModule.calcScales(prevOutputActivity)
      }
      // update previous output
      prevOutputActivity = currModule.output
    }
  }

  /**
   * Scales calculator for ConcatTable module
   * Submodules inside ConcatTable share the same input
   * @param inputActvt input of the ConcatTable Module
   * @param outputActvt output of the ConcatTable Module
   */
  private def calcConcatTableScales(inputActvt: Activity, outputActvt: Activity): Unit = {
    require(this.isInstanceOf[ConcatTable[Float@unchecked]] || this.isInstanceOf[mkldnn.ConcatTable]
      , this.getClass.getName + " is not an instance of ConcatTable.")

    val module: DynamicContainer[_, _, Float] = this.asInstanceOf[DynamicContainer[_, _, Float]]

    // calc scales for main module
    this.calcModuleScales(inputActvt, outputActvt)

    // calc scales for sub-module
    val moduleIter = module.modules.iterator
    while (moduleIter.hasNext) {
      val currModule = moduleIter.next()
      if (currModule.isInstanceOf[MklInt8Convertible]) {
        val cvtbModule = currModule.asInstanceOf[MklInt8Convertible]
        cvtbModule.calcScales(inputActvt)
      }
    }
  }

  /**
   * Scales calculator for Graph module
   * Submodules inside Graph are traversed based on its topological sort
   * The order can obtain from by calling getForwardExecutions
   * @param inputActvt input activity of the graph module
   * @param outputActvt output activity of the graph module
   */
  private def calcGraphScales(inputActvt: Activity, outputActvt: Activity): Unit = {
    require(this.isInstanceOf[Graph[Float@unchecked]], this.getClass.getName +
    " is not an instance of Graph[Float]")

    // calc scales for main module
    calcModuleScales(inputActvt, outputActvt)

    // calc scales for sub-module
    val module: Graph[Float] = this.asInstanceOf[Graph[Float]]
    val outputNodes = module.getForwardExecutions()
    var i = 0
    // traverse through all the sub-modules
    while(i < outputNodes.length) {
      // get current sub-module
      val currNode = outputNodes(i)
      // get the input activity of current sub-module
      val currInputActvt = module.findInput(currNode, inputActvt)
      // calculate scales if current sub-module is int8 convertible
      if (currNode.element.isInstanceOf[MklInt8Convertible]) {
        currNode.element.asInstanceOf[MklInt8Convertible].calcScales(currInputActvt)
      }
      i += 1
    }
  }

  /**
   * Helper function to get weight from module parameter
   * @param module the module to get weight from
   * @return a tensor contains weight
   */
  private def getWeight(module: AbstractModule[_, _, Float]): Tensor[Float] = {
    if (module != null) {
      // the getParameters will flatten the weight and bias, it's wrong
      val weight = module.parameters()._1(0)
      // If the weight is came from nn.SpatialConvolution and the nGroup is 1,
      // we need to skip the first dimension. Because if the group is 1, mkldnn thinks
      // it's 4-D tensor weight. But for original nn.SpatialConvolution, for convenience,
      // it always use 5-D tensor weight although the nGroup is 1.
      if (module.isInstanceOf[SpatialConvolution[Float]] && weight.size(1) == 1) {
        weight.select(1, 1)
      } else {
        weight
      }
    } else {
      null
    }
  }

  /**
   * Get dimension mask of input
   * @return inputDimMask field which stores value of input dimension mask
   */
  def getInputDimMask(): Int = {
    inputDimMask
  }

  /**
   * Set dimension mask of input
   * @param mask value of input dimension mask to be set
   * @param overrideSubmodules when set it to true,
   *             update mask including itself and submodules,
   *             otherwise only update mask to module itself.
   * @return Unit
   */
  def setInputDimMask(mask: Int, overrideSubmodules: Boolean = false) : Unit = {
    inputDimMask = mask
    if (this.isInstanceOf[Container[_, _, Float@unchecked]] && overrideSubmodules == true) {
      val container = this.asInstanceOf[Container[_, _, Float@unchecked]]
      val modules = container.modules
      modules.foreach(module => {
        if (module.isInstanceOf[MklInt8Convertible]) {
          module.asInstanceOf[MklInt8Convertible].setInputDimMask(mask, overrideSubmodules)
        }
      })
    }
  }

  /**
   * Get dimension mask of output
   * @return outputDimMask field which stores value of output dimension mask
   */
  private[bigdl] def getOutputDimMask(): Int = {
    outputDimMask
  }

  /**
   * Set dimension mask of output
   * @param mask value of output dimension mask to be set
   * @param overrideSubmodules when set it to true,
   *             update mask in full scope including itself and submodules,
   *             otherwise only update mask to module itself.
   * @return Unit
   */
  def setOutputDimMask(mask: Int, overrideSubmodules: Boolean = false): Unit = {
    outputDimMask = mask
    if (this.isInstanceOf[Container[_, _, Float@unchecked]] && overrideSubmodules == true) {
      val container = this.asInstanceOf[Container[_, _, Float@unchecked]]
      val modules = container.modules
      modules.foreach(module => {
        if (module.isInstanceOf[MklInt8Convertible]) {
          module.asInstanceOf[MklInt8Convertible].setOutputDimMask(mask, overrideSubmodules)
        }
      })
    }
  }

  /**
   * Get dimension mask of weight
   * @return weightDimMask which stores value of weight mask
   */
  def getWeightDimMask(): Int = {
    weightDimMask
  }

  /**
   * Set dimension mask for weight
   * @param mask value of weight mask to be set
   * @param overrideSubmodules when set it to true,
   *             update mask in full scope including itself and submodules,
   *             otherwise only update mask to module itself.
   * @return Unit
   */
  def setWeightDimMask(mask: Int, overrideSubmodules: Boolean = false): Unit = {
    weightDimMask = mask
    if (this.isInstanceOf[Container[_, _, Float@unchecked]] && overrideSubmodules == true) {
      val container = this.asInstanceOf[Container[_, _, Float@unchecked]]
      val modules = container.modules
      modules.foreach(module => {
        if (module.isInstanceOf[MklInt8Convertible]) {
          module.asInstanceOf[MklInt8Convertible].setWeightDimMask(mask, overrideSubmodules)
        }
      })
    }
  }

  /**
   * Get input scales
   * @return field which stores value of input scales
   */
  def getInputScales(): Array[Array[Float]] = {
    inputScalesBuffer.toArray
  }

  /**
   * Set input scales
   * Clear existing buffer of input scales, and place updated scales into the cleared buffer
   * @param inScales value of input scales to be set
   * @return Unit
   */
  def setInputScales(inScales: Array[Array[Float]]): Unit = {
    inputScalesBuffer.clear()
    inScales.foreach(appendInputScales)
  }

  /**
   * Get output scales
   * @return field which stores value of output scales
   */
  def getOutputScales(): Array[Array[Float]] = {
    outputScalesBuffer.toArray
  }

  /**
   * Set output scales
   * Clear existing buffer of output scales, and place updated scales into the cleared buffer
   * @param outScales value of output scales to be set
   * @return Unit
   */
  def setOutputScales(outScales: Array[Array[Float]]): Unit = {
    outputScalesBuffer.clear()
    outScales.foreach(appendOutputScales)
  }

  /**
   * Get weight scales
   * @return field which stores value of weight scales
   */
  def getWeightScales(): Array[Array[Float]] = {
    weightScalesBuffer.toArray
  }

  /**
   * Set weight scales
   * Clear existing buffer of weight scales, and place updated scales into the cleared buffer
   * @param weightScales value of weight scales to be set
   * @return Unit
   */
  def setWeightScales(weightScales: Array[Array[Float]]): Unit = {
    weightScalesBuffer.clear()
    weightScales.foreach(appendWeightScales)
  }

  /**
   * Append a scale, an array of float, into input scales buffer
   * @param scale value of an input scale to be appended
   * @return Unit
   */
  private def appendInputScales(scale: Array[Float]): Unit = {
    inputScalesBuffer.append(scale)
  }

  /**
   * Append a scale, an array of float, into output scales buffer
   * @param scale value of an output scale to be appended
   * @return Unit
   */
  private def appendOutputScales(scale: Array[Float]): Unit = {
    outputScalesBuffer.append(scale)
  }

  /**
   * Append a scale, an array of float, into weight scales buffer
   * @param scale value of an weight scale to be appended
   * @return Unit
   */
  private def appendWeightScales(scale: Array[Float]): Unit = {
    weightScalesBuffer.append(scale)
  }

  /**
   * Update input scales at specific index with provided new scale
   * @param scale the new scale
   * @param index the index of which the scale need to be updated
   * @return Unit
   */
  private def updateInputScales(scale: Array[Float], index: Int): Unit = {
    updateScalesHelper(inputScalesBuffer, scale, index)
  }

  /**
   * Update output scales at specific index with provided new scale
   * @param scale the new scale
   * @param index the index of which the scale need to be updated
   * @return Unit
   */
  private def updateOutputScales(scale: Array[Float], index: Int): Unit = {
    updateScalesHelper(outputScalesBuffer, scale, index)
  }

  /**
   * Update weight scales at specific index with provided new scale
   * @param scale the new scale
   * @param index the index of which the scale need to be updated
   * @return Unit
   */
  private def updateWeightScales(scale: Array[Float], index: Int): Unit = {
    updateScalesHelper(weightScalesBuffer, scale, index)
  }

  /**
   * Scales update helper. Replace scale at specific index with provided new scale
   * @param scales the scales arrayBuffer to be updated
   * @param scale the new scale
   * @param index the index of which the scale need to be updated
   * @return Unit
   */
  private def updateScalesHelper(scales: ArrayBuffer[Array[Float]],
                                 scale: Array[Float], index: Int): Unit = {
    if (scales.length - 1 < index) {
      scales.append(scale)
    }

    scales(index).indices.foreach(i =>
      if (scale(i) > scales(index)(i)) {
        scales(index)(i) = scale(i)
      })
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy