All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.StaticGraph.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.mkl.Memory
import com.intel.analytics.bigdl.nn.Graph.ModuleNode
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.nn.tf.ControlDependency
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.intermediate.{BlasToIR, IRGraph}
import com.intel.analytics.bigdl.utils.{Node, Util}
import com.intel.analytics.bigdl.optim.DistriOptimizer._

import scala.reflect.ClassTag

/**
 * A graph container. The modules in the container are connected as a DAG graph.
 *
 * @param _inputs inputs modules, user can feed data into these modules in the forward method
 * @param _outputs output modules
 * @param _variables
 * @tparam T Numeric type. Only support float/double now
 */
class StaticGraph[T: ClassTag](
  private val _inputs : Seq[ModuleNode[T]],
  private val _outputs : Seq[ModuleNode[T]],
  private val _variables: Option[(Array[Tensor[T]], Array[Tensor[T]])] = None,
  private val enableExcludeChecking: Boolean = true
)(implicit ev: TensorNumeric[T]) extends Graph[T](_inputs, _outputs, _variables) {
  private val forwardExecution = forwardGraph.topologySort.reverse
  private var backwardExecution: Array[Node[AbstractModule[Activity, Activity, T]]] = _
  private val inputCache = new Array[Activity](forwardExecution.length)
  private var backId2ForwardId: Array[Int] = _
  private var gradOutputCache: Array[Activity] = _

  if (enableExcludeChecking) {
    excludeInvalidLayers(forwardExecution.map {_.element})
  }

  buildBackwardGraph()

  override def updateOutput(input: Activity): Activity = {
    var i = 0
    while(i < forwardExecution.length) {
      val node = forwardExecution(i)
      val nodeInput = findInput(node, input)
      inputCache(i) = nodeInput
      node.element.forward(nodeInput)
      i += 1
    }

    output = dummyOutput.element.output
    output
  }

  override def backward(input: Activity, gradOutput: Activity): Activity = {
    val before = System.nanoTime()
    val gradients = backwardExecution(input, gradOutput, true)
    backwardTime += System.nanoTime() - before
    gradients
  }

  override def updateGradInput(input: Activity, gradOutput: Activity): Activity = {
    backwardExecution(input, gradOutput, false)
  }


  override def buildBackwardGraph(): this.type = {
    super.buildBackwardGraph()
    backwardExecution = backwardGraph.topologySort.reverse
    backId2ForwardId = new Array[Int](backwardExecution.length)
    gradOutputCache = new Array[Activity](backwardExecution.length)

    var i = 0
    while(i < backwardExecution.length - 1) {
      var j = 0
      var find = false
      while(j < forwardExecution.length) {
        if (forwardExecution(j).element.getName() == backwardExecution(i).element.getName()) {
          backId2ForwardId(i) = j
          find = true
        }
        j += 1
      }
      require(find, "Cannot find backward layer in forward executions")
      i += 1
    }

    this
  }

  override def accGradParameters(input: Activity, gradOutput: Activity): Unit = {
    var i = 0
    while (i < backwardExecution.length - 1) {
      val curNode = backwardExecution(i)
      val curInput = inputCache(backId2ForwardId(i))
      curNode.element.accGradParameters(curInput, gradOutputCache(i))
      i += 1
    }
  }

  override def populateModules(): Unit = {
    modules.appendAll(
      forwardGraph.topologySort
        // todo: convert control dep node to edge
        .filterNot(_.element.isInstanceOf[ControlDependency[T]])
        .filter(n => !n.eq(dummyOutput)).map(_.element)
        .reverse
    )
    checkDuplicate()
  }


  private def backwardExecution(input: Activity, gradOutput: Activity,
    executeBackward: Boolean): Activity = {
    dummyOutputGrad.element.gradInput = gradOutput

    var i = 0
    while (i < backwardExecution.length - 1) {  // do not execute the dummy backward end
      val curNode = backwardExecution(i)
      val curGradOutput = findGradOutput(curNode, gradOutput)
      gradOutputCache(i) = curGradOutput
      val curInput = inputCache(backId2ForwardId(i))
      if (!isStopGradient(curNode.element)) {
        if (executeBackward) {
          curNode.element.backward(curInput, curGradOutput)
        } else {
          curNode.element.updateGradInput(curInput, curGradOutput)
        }
      } else if (executeBackward) {
        curNode.element.accGradParameters(curInput, curGradOutput)
      }
      i += 1
    }

    gradInput = fetchModelGradInput()
    gradInput
  }

  /**
   * convert static graph to ir graph and build according to engine type
   * @return return ir graph if converted successfully, otherwise null
   */
  def toIRgraph() : IRGraph[T] = {
    val inFormats = if (inputsFormats == null) {
      logger.warn("Input formats NCHW by default, Please set explicitly if needed")
      Seq(Memory.Format.nchw)
    } else inputsFormats

    val outFormats = if (outputsFormats == null) {
      logger.warn("Output formats NC by default, Please set explicitly if needed")
      Seq(Memory.Format.nc)
    } else outputsFormats

    val allNodes = forwardExecution
    if (!BlasToIR[T].convertingCheck(allNodes)) return null

    val nodeMap = BlasToIR[T].convert(allNodes)
    val inputNodes = inputs.toArray.map(n => nodeMap.get(n).get)
    val outputNodes = outputs.toArray.map(n => nodeMap.get(n).get)

    val inputsIR = inputs.toArray.map(n => nodeMap.get(n).get)
    val outputsIR = outputs.toArray.map(n => nodeMap.get(n).get)

    val model = IRGraph(inputsIR, outputsIR, variables, true, inFormats, outFormats)
    model.build()
  }

  // Merge a nested StaticGraph into a non-nested one
  private[bigdl] def toSingleGraph(): StaticGraph[T] = {
    if (this.isNestedGraph()) {
      val graph = this.cloneModule()
      val fwdExecution = graph.getSortedForwardExecutions()
      val dmOutput = fwdExecution(fwdExecution.length - 1).nextNodes(0)

      var i = 0
      while (i < fwdExecution.length) {
        if (fwdExecution(i).element.isInstanceOf[StaticGraph[T]]) {
          var g = fwdExecution(i).element.asInstanceOf[StaticGraph[T]].toSingleGraph()
          fwdExecution(i).element = g

          for (inputIndex <- 0 until fwdExecution(i).prevNodes.length) {
            val inputNode = g.inputs(inputIndex)
            inputNode.element = Identity()

            while (fwdExecution(i).prevNodes.length != 0) {
              val preNode = fwdExecution(i).prevNodes(0)
              preNode.delete(fwdExecution(i))
              preNode.add(inputNode)
            }
          }

          for (outputIndex <- 0 until g.outputs.length) {
            val outputNode = g.outputs(outputIndex)
            outputNode.removeNextEdges()
            while (fwdExecution(i).nextNodes.length != 0) {
              val nextNode = fwdExecution(i).nextNodes(0)
              fwdExecution(i).delete(nextNode)
              outputNode.add(nextNode)
            }
          }
        }
        i += 1
      }

      val resultOutputNodes = dmOutput.prevNodes
      resultOutputNodes.foreach(_.delete(dmOutput))
      new StaticGraph[T](Array(graph.inputs(0)), resultOutputNodes,
        enableExcludeChecking = this.enableExcludeChecking)
    } else {
      this
    }
  }

  private def isNestedGraph(): Boolean = {
    for (i <- 0 until forwardExecution.length) {
      if (forwardExecution(i).element.isInstanceOf[StaticGraph[T]]) {
        return true
      }
    }

    false
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy