All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.Graph.scala Maven / Gradle / Ivy

There is a newer version: 0.11.1
Show newest version
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.bigdl.nn

import java.util

import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.nn.Graph.ModuleNode
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.nn.tf._
import com.intel.analytics.bigdl.serialization.Bigdl.{AttrValue, BigDLModule}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils._
import com.intel.analytics.bigdl.utils.serializer._
import com.intel.analytics.bigdl.utils.serializer.converters.DataConverter
import com.intel.analytics.bigdl.utils.tf.Tensorflow
import com.intel.analytics.bigdl.visualization.tensorboard.{FileWriter => TFFileWriter}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import scala.reflect.runtime.universe
import scala.language.existentials
import scala.collection.JavaConverters._
import org.tensorflow.framework.GraphDef

/**
 * A graph container. The modules in the container are connected as a directed Graph. Each module
 * can output one tensor or multiple tensors(as table). The edges between modules in the graph
 * define how these tensors are passed. For example, if a module outputs two tensors, you can
 * pass these two tensors together to its following module, or pass only one of them
 * to its following module. If a tensor in the module output is connected to multiple modules, in
 * the back propagation, the gradients from multiple connection will be accumulated. If multiple
 * edges point to one module, the tensors from these edges will be stack as a table, then pass to
 * that module. In the back propagation, the gradients will be splited based on how the input
 * tensors stack.
 *
 * The graph container has multiple inputs and multiple outputs. The order of the input tensors
 * should be same with the order of the input nodes when you construct the graph container. In the
 * back propagation, the order of the gradients tensors should be the same with the order of the
 * output nodes.
 *
 * If there's one output, the module output is a tensor. If there're multiple outputs, the module
 * output is a table, which is actually an sequence of tensor. The order of the output tensors is
 * same with the order of the output modules.
 *
 * All inputs should be able to connect to outputs through some paths in the graph. It is
 * allowed that some successors of the inputs node are not connect to outputs. If so, these nodes
 * will be excluded in the computation.
 *
 * @param inputs input nodes
 * @param outputs output nodes
 * @param variables an Array of tensor containing all the weights and biases of this graph,
 *                used when different nodes of this graph may share the same weight or bias.
 * @tparam T Numeric type. Only support float/double now
 */
@SerialVersionUID(- 2896121321564992779L)
abstract class Graph[T: ClassTag](
  val inputs : Seq[ModuleNode[T]],
  private[bigdl] val outputs : Seq[ModuleNode[T]],
  private[bigdl] val variables: Option[(Array[Tensor[T]], Array[Tensor[T]])] = None
)(implicit ev: TensorNumeric[T])
  extends Container[Activity, Activity, T] with MklInt8Convertible {

  /**
   * For a multi-tensor output module, some output tensors may not contributed to the final forward
   * result. So in the back propagation, the gradient on these positions are missing. And we use
   * zero tensor to populate.
   *
   * @param output
   * @param gradOutput
   */
  protected def addZeroTensorToMissingGradOutput(output: Table, gradOutput: Table): Unit = {
    var i = 0
    while (i < output.length()) {
      if (!gradOutput.contains(i + 1)) {
        val tensor = output[Tensor[T]](i + 1)
        val zero = Tensor(tensor.size())
        gradOutput(i + 1) = zero
      }
      i = i + 1
    }
  }

  private def calcSumTimesOfAllNodes(
    timesOfAllNodes: Array[(AbstractModule[_ <: Activity, _ <: Activity, T], Long, Long)])
  : (Long, Long) = {
    var sumForward = 0L
    var sumBackward = 0L
    timesOfAllNodes.foreach(x => {
      sumForward += x._2
      sumBackward += x._3
    })
    (sumForward, sumBackward)
  }

  override def getTimes():
  Array[(AbstractModule[_ <: Activity, _ <: Activity, T], Long, Long)] = {
    val timesOfAllNodes = this.modules.flatMap(_.getTimes()).toArray
    val (sumForward, sumBackward) = calcSumTimesOfAllNodes(timesOfAllNodes)
    timesOfAllNodes ++ Array((this, this.forwardTime - sumForward, this.backwardTime - sumBackward))
  }

  override def parameters(): (Array[Tensor[T]], Array[Tensor[T]]) = {
    variables match {
      case None => super.parameters()
      case Some((weights, gradients)) => (weights, gradients)
    }
  }

  // todo: expand the graph
  override def toGraph(startNodes: ModuleNode[T]*): Graph[T] = this

  /**
   * Return the corresponding node has the given name. If the given name doesn't match any node,
   * NoSuchElementException will be thrown
   * @param name
   * @return
   */
  def node(name: String): ModuleNode[T] = {
    val matchNodes = forwardNodes.filter(_.element.getName() == name).toArray
    if (matchNodes.length == 0) {
      throw new NoSuchElementException(s"Can not find node with name $name")
    } else {
      return matchNodes.head
    }
  }

  // Add a dummy output node, to get an one end forward graph. So the nodes that are not dependent
  // by the outputs will be excluded
  protected val dummyOutput = new ModuleNode[T](new Identity[T]())
  outputs.foreach(_ -> dummyOutput)
  protected val forwardGraph = dummyOutput.graph(reverse = true)
  protected val forwardNodes = forwardGraph.DFS.toArray

  populateModules()

  // Check all inputs of the graph should be passed in
  checkRoots

  protected def populateModules(): Unit

  // Check if the graph is correct
  private def checkRoots: Unit = {
    def duplicatedNames(names: Seq[String]): mutable.Set[String] = {
      names.sortWith(_ < _)
      val buffer = new mutable.HashSet[String]()
      var i = 1
      while(i < names.length) {
        if (names(i) == names(i - 1)) buffer.add(names(i))
        i += 1
      }
      buffer
    }

    require(forwardNodes.map(_.element.getName()).distinct.length == forwardNodes.length,
      s"the name of node in the graph should be unique, but find duplicated name " +
        s"${duplicatedNames(forwardNodes.map(_.element.getName())).mkString(", ")}")

    val roots = forwardNodes.filter(_.prevNodes.size == 0)
      .filterNot(_.element.isInstanceOf[WithoutInput])
      .filterNot(_.element.isInstanceOf[ControlDependency[_]])

    val realInputs = inputs.filterNot(_.element.isInstanceOf[WithoutInput])
    require(roots.size == realInputs.length, s"There're ${realInputs.length} inputs, " +
      s"but graph has ${roots.size} roots")

    realInputs.foreach(n =>
      require(roots.contains(n), "inputs and graph roots are not match")
    )
  }

  protected var dummyOutputGrad: ModuleNode[T] = _
  protected var backwardGraph: DirectedGraph[AbstractModule[Activity, Activity, T]] = _
  protected var backwardNodes: Array[Node[AbstractModule[Activity, Activity, T]]] = _
  // If the graph will generate gradInput for the input

  private var isGradInputAvailable: Array[Boolean] = _

  /**
   * Generate backward graph and apply the stopGrad
   */
  private[bigdl] def buildBackwardGraph(): this.type = {
    // Clone the forward graph and reverse the edge
    val gradGraph = forwardGraph.cloneGraph(reverseEdge = true)
    dummyOutputGrad = gradGraph.source
    gradGraph.DFS.filter(x => isStopGradient(x.element)).foreach(removeStopNodes(_))
    backwardNodes = gradGraph.DFS
      .filterNot(_.eq(dummyOutputGrad))
      .filterNot(_.element.isInstanceOf[ControlDependency[_]]).toArray

    val inputNames = inputs.map(_.element.getName()).toSet
    val dummyBackwardEnd = Identity().inputs()
    val backwardTargets = backwardNodes
      .filter(n => (n.element.parameters() != null && n.element.parameters()._1.length != 0)
        || inputNames.contains(n.element.getName()))
    backwardTargets.foreach(_ -> dummyBackwardEnd)
    backwardGraph = dummyBackwardEnd.graph(true)

    // Check if gradInput is empty for each input
    isGradInputAvailable = inputs.map(_ => false).toArray
    backwardGraph.DFS.foreach(curNode => {
      inputs.zipWithIndex.map { case (n, i) =>
        if (curNode.element.getName() == n.element.getName() && !isStopGradient(n.element)) {
          isGradInputAvailable(i) = true
        }
      }
    })

    clearState()
    this
  }

  private var stopGradientLayers: util.HashSet[String] = _

  def getStopGradientLayers(): util.HashSet[String] = stopGradientLayers

  /**
   * whether stop propagating gradInput back
   * @return
   */
  protected def isStopGradient(module: AbstractModule[_ <: Activity, _ <: Activity, T]): Boolean = {
    null != stopGradientLayers && stopGradientLayers.contains(module.getName())
  }

  /**
   * stop the input gradient of layers that match the given ```names```
   * their input gradient are not computed.
   * And they will not contributed to the input gradient computation of
   * layers that depend on them.
   * @param names an array of layer names
   * @return current graph model
   */
  def stopGradient(names: Array[String]): this.type = {
    if (stopGradientLayers == null) stopGradientLayers = new util.HashSet[String]()

    names.foreach(name => {
      val layer = this (name)
      require(layer.isDefined, s"cannot find layer match ${name}")
      stopGradientLayers.add(layer.get.getName())
    })
    buildBackwardGraph()
    this
  }

  /**
   * set an array of layers that match the given ```names``` to be "freezed",
   * i.e. their parameters(weight/bias, if exists) are not changed in training process
   * @param names an array of layer names
   * @return current graph model
   */
  def freeze(names: Array[String]): this.type = {
    names.foreach(name => {
      val layer = this (name)
      require(layer.isDefined, s"cannot find layer match ${name}")
      layer.get.setScaleW(0)
      layer.get.setScaleB(0)
    })
    this
  }

  private[bigdl] def removeStopNodes(n: Node[_]): Unit = {
    val nodes = n.nextNodes
    n.removeNextEdges()
    nodes.filter(_.prevNodes.length == 0).foreach(removeStopNodes(_))
  }


  protected def getInput(
    node: Node[AbstractModule[Activity, Activity, T]],
    input: Activity
  ): Activity = {
    if (inputs.length == 1) {
      require(inputs(0).eq(node), "input node is not in the input list")
      input
    } else {
      val i = inputs.indexOf(node)
      require(i != -1, "input node is not in the input list")
      input.toTable[Tensor[T]](i + 1)
    }
  }

  def findInput(node: ModuleNode[T], input: Activity): Activity = {
    if (node.element.isInstanceOf[WithoutInput]) return null

    val nodeInput = if (node.prevNodes.isEmpty) {
      getInput(node, input)
    } else {
      val prevActivities = node.prevNodesAndEdges
        .filterNot(n => n._1.element.isInstanceOf[ControlDependency[T]])
        .map(n => {
          n._2.fromIndex match {
            case Some(i) =>
              if (n._1.element.output == null || (i == 1 && n._1.element.output.isTensor)) {
                n._1.element.output
              } else {
                n._1.element.output.toTable.apply[Activity](i)
              }
            case None => n._1.element.output
          }
        })
      if (prevActivities.length == 1) {
        prevActivities.head
      } else {
        T.seq(prevActivities)
      }
    }
    nodeInput
  }

  protected def findGradOutput(curNode: ModuleNode[T], gradOutput: Activity): Activity = {
    var curGradOutput : Activity = if (curNode.eq(dummyOutputGrad)) gradOutput else null

    curNode.prevNodesAndEdges.filterNot(n => n._1.element.isInstanceOf[ControlDependency[T]])
      .foreach(n => {
        val otherActivity = if (n._1.element.gradInput.isTensor || n._1.nextEdges.length == 1) {
          n._1.element.gradInput
        } else {
          val index = n._1.nextEdges.indexOf(n._2) + 1
          n._1.element.gradInput.toTable.apply[Activity](index)
        }

        n._2.fromIndex match {
          case Some(i) =>
            if (i == 1 && curNode.element.output.isTensor) {
              curGradOutput = accActivity(curGradOutput, otherActivity)
            } else {
              if (curNode.element.output.isTable && curGradOutput == null) {
                curGradOutput = T()
              }
              val curActivity = curGradOutput.toTable.getOrElse[Activity](i, null)
              curGradOutput.toTable(i) = accActivity(curActivity, otherActivity)
            }
          case None =>
            curGradOutput = accActivity(curGradOutput, otherActivity)
        }
      })

    if (curNode.element.output.isTable) {
      addZeroTensorToMissingGradOutput(curNode.element.output.toTable, curGradOutput.toTable)
    }

    curGradOutput
  }

  protected def fetchModelGradInput(): Activity = {
    if (inputs.length == 1) {
      if (isGradInputAvailable.head) {
        inputs.head.element.gradInput
      } else {
        Activity.emptyGradInput(this.getName())
      }
    } else {
      var i = 0
      T.seq(inputs.zipWithIndex.map{ case(n, i) =>
        if (isGradInputAvailable(i)) {
          n.element.gradInput
        } else {
          Activity.emptyGradInput(this.getName())
        }
      })
    }
  }

  override def reset(): Unit = {
    if (null != stopGradientLayers) stopGradientLayers.clear()
    unFreeze()
    buildBackwardGraph()
  }

  /**
   * Get forward executions, the dummy node will be filtered.
   *
   * This method will output an unsorted executions.
   * @return
   */
  def getForwardExecutions(): Array[Node[AbstractModule[Activity, Activity, T]]] = {
    forwardNodes.filterNot(_.eq(dummyOutput))
  }

  /**
   * Get forward executions, the dummy nodes and control dependency nodes will be filtered.
   *
   * This method will output a sorted executions. If the graph contains loop, it will throw an
   * exception
   * @return
   */
  def getSortedForwardExecutions(): Array[ModuleNode[T]] = {
    forwardGraph.topologySort
      // todo: convert control dep node to edge
      .filterNot(_.element.isInstanceOf[ControlDependency[T]]).reverse
      .filter(n => !n.eq(dummyOutput))
  }

  @inline
  protected def accActivity(activity: Activity, other: Activity): Activity = {
    if (activity == null) {
      other
    } else {
      if (other.isTensor) {
        require(activity.isTensor, "Cannot add a table to a tensor")
        activity.toTensor[T].add(other.toTensor[T])
      } else {
        // if 'activity' and 'other' are both table, we need to merge 'other' to 'activity'
        // if 'other' and 'activity' both contains the index, update 'activity' by sum
        // if 'other' contains the index while 'activity' does not,
        // just insert the corresponding tensor of 'other' to 'activity'
        val actTable = activity.toTable
        val otherTable = other.toTable
        otherTable.keySet.foreach(index => {
          if (actTable.contains(index)) {
            accActivity(actTable[Activity](index), otherTable[Activity](index))
          } else {
            actTable.insert(index.asInstanceOf[Int], otherTable(index))
          }
        })
        actTable
      }
    }
  }

  /**
   * Save current model graph to a folder, which can be display in tensorboard by running
   *   tensorboard --logdir logPath
   * @param logPath
   * @param backward Draw backward graph instead of forward
   * @return
   */
  def saveGraphTopology(logPath: String, backward: Boolean = false): this.type = {
    val writer = new TFFileWriter(logPath)
    val graphBuilder = GraphDef.newBuilder()
    val nodes = if (backward) {
      backwardNodes.filter(n => !n.eq(dummyOutputGrad))
    } else {
      forwardNodes.filter(n => !n.eq(dummyOutput))
    }
    nodes.map(m => {
      val nodeDef = Tensorflow.bigdlModule(m.element, m.prevNodes.map(_.element.getName()).asJava)
      graphBuilder.addNode(nodeDef)
    })

    writer.addGraphDef(graphBuilder.build())
    writer.close()
    this
  }

  /**
   * Clear the original module and reset with module in the graph
   */
  def resetModules(): Unit = {
    modules.clear()
    modules.appendAll(forwardGraph.DFS.toArray
      .filterNot(_.element.isInstanceOf[ControlDependency[T]])
      .filter(n => !n.eq(dummyOutput)).map(_.element)
      // Some tests compare the paramerters between sequential and graph,add a reverse makes
      // it's eaiser to compare
      .reverse
    )
  }
}

object Graph extends GraphSerializable {
  /**
   * Node for graph container. The module should have a tensor/table input while a tensor output
   * @tparam T
   */
  type ModuleNode[T] = Node[AbstractModule[Activity, Activity, T]]

  /**
   * Build multiple inputs, multiple outputs graph container.
   * @param input input node
   * @param output output node
   * @return a graph container
   */
  def apply[T: ClassTag](
    input: Array[ModuleNode[T]],
    output: Array[ModuleNode[T]],
    variables: Option[(Array[Tensor[T]], Array[Tensor[T]])] = None
  )(implicit ev: TensorNumeric[T]): Graph[T] = {
    new StaticGraph[T](input, output, variables)
  }

  def apply[T: ClassTag](preprocessor: Module[T], trainable: Module[T])
    (implicit ev: TensorNumeric[T]): Graph[T] = {
    val preprocessorNode = preprocessor.inputs()
    val stopGradients = Identity[T]().inputs(preprocessorNode)
    val trainableNode = trainable.inputs(stopGradients)
    val graph = apply[T](preprocessorNode, trainableNode)
    graph.stopGradient(Array(stopGradients.element.getName()))
    graph
  }

  private[bigdl] def dynamic[T: ClassTag](
    input : Array[ModuleNode[T]],
    output : Array[ModuleNode[T]],
    variables: Option[(Array[Tensor[T]], Array[Tensor[T]])] = None,
    generateBackward: Boolean = true)(implicit ev: TensorNumeric[T]): Graph[T] = {
    new DynamicGraph[T](input, output, variables, generateBackward)
  }

  /**
   * Build a single input, multiple outputs graph container
   * @param input input node
   * @param output output nodes
   * @return a graph container
   */
  def apply[T: ClassTag](input: ModuleNode[T], output: Array[ModuleNode[T]])
    (implicit ev: TensorNumeric[T]): Graph[T] = {
    new StaticGraph[T](Seq(input), output)
  }

  private[bigdl] def dynamic[T: ClassTag](input : ModuleNode[T], output : Array[ModuleNode[T]])
    (implicit ev: TensorNumeric[T]) : Graph[T] = {
    new DynamicGraph[T](Array(input), output, None, true)
  }

  /**
   * Build a multiple inputs, single output graph container
   * @param input input nodes
   * @param output output node
   * @return a graph container
   */
  def apply[T: ClassTag](input: Array[ModuleNode[T]], output: ModuleNode[T])
    (implicit ev: TensorNumeric[T]): Graph[T] = {
    new StaticGraph[T](input, Seq(output))
  }

  private[bigdl] def dynamic[T: ClassTag](input : Array[ModuleNode[T]], output : ModuleNode[T])
    (implicit ev: TensorNumeric[T]) : Graph[T] = {
    new DynamicGraph[T](input, Array(output), None, true)
  }

  /**
   * Build a single input, single output graph container
   * @param input input nodes
   * @param output output nodes
   * @return a graph container
   */
  def apply[T: ClassTag](input: ModuleNode[T], output: ModuleNode[T])
    (implicit ev: TensorNumeric[T]): Graph[T] = {
    new StaticGraph[T](Seq(input), Seq(output))
  }

  private[bigdl] def dynamic[T: ClassTag](input : ModuleNode[T], output : ModuleNode[T])
    (implicit ev: TensorNumeric[T]) : Graph[T] = {
    new DynamicGraph[T](Array(input), Array(output), None, true)
  }
}

trait GraphSerializable extends ContainerSerializable {

  private[bigdl] def prepareLoadModule[T: ClassTag](context: DeserializeContext)
                                                   (implicit ev: TensorNumeric[T]) = {

    val module = context.bigdlModule
    val subModules = module.getSubModulesList.asScala

    val attributes = module.getAttrMap
    val inputNames = new ArrayBuffer[String]
    val outputNames = new ArrayBuffer[String]
    DataConverter.getAttributeValue(context, attributes.get("inputNames"))
      .asInstanceOf[Array[String]].map(name => inputNames.append(name))
    DataConverter.getAttributeValue(context, attributes.get("outputNames"))
      .asInstanceOf[Array[String]].map(name => outputNames.append(name))

    val inputs = new ArrayBuffer[ModuleNode[T]]
    val outputs = new ArrayBuffer[ModuleNode[T]]

    // layer name to layer node mapping
    val layerMap = new mutable.HashMap[String, (ModuleNode[T], Seq[String])]()
    subModules.foreach(subModule => {
      val bigDLModule = ModuleSerializer.load(DeserializeContext(subModule,
        context.storages, context.storageType))
      val moduleNode = bigDLModule.module match {
        case controlOps: ControlOps[T] => createControlNode(controlOps)
        case _ => new ModuleNode[T](bigDLModule.module)
      }
      val preNodes = bigDLModule.pre
      layerMap(bigDLModule.module.getName) = (moduleNode, preNodes)
    })

    layerMap.values.foreach(moduleNode => {
      val edges = DataConverter.getAttributeValue(context,
        attributes.get(s"${moduleNode._1.element.getName}_edges")).
        asInstanceOf[mutable.HashMap[String, mutable.HashMap[String, Int]]]
      val edgeMap = edges.get(moduleNode._1.element.getName).get
      moduleNode._2.foreach(pre => {
        if (layerMap.contains(pre)) {
          val edge: Edge = edgeMap.get(pre).get match {
            case -1 => Edge()
            case index: Int => Edge(index)
          }
          layerMap(pre)._1.add(moduleNode._1, edge)
        }
      })
    })

    inputNames.foreach(inputName => inputs.append(layerMap(inputName)._1))
    outputNames.foreach(outputName => outputs.append(layerMap(outputName)._1))

    var sharedVariables: Option[(Array[Tensor[T]], Array[Tensor[T]])] = None
    if (attributes.containsKey("sharedWeight") && attributes.containsKey("sharedBias")) {
      val weights = attributes.get("sharedWeight")
      val biases = attributes.get("sharedBias")
      val weightArray = DataConverter.getAttributeValue(context, weights)
        .asInstanceOf[Array[Tensor[T]]]
      val biasArray = DataConverter.getAttributeValue(context, biases)
        .asInstanceOf[Array[Tensor[T]]]
      sharedVariables = Some(weightArray, biasArray)
    }

    val generateBackwardValue = attributes.get("generateBackward")
    (module, inputs, outputs, generateBackwardValue, sharedVariables)
  }

  override def doLoadModule[T: ClassTag](context: DeserializeContext)
    (implicit ev: TensorNumeric[T]): AbstractModule[Activity, Activity, T] = {
    val (module, inputs, outputs, generateBackwardValue, sharedVariables) =
      prepareLoadModule(context)
    val attributes = module.getAttrMap
    val graph = if (generateBackwardValue != null) {
      val generateBackward = DataConverter.getAttributeValue(context, generateBackwardValue)
        .asInstanceOf[Boolean]
      Graph.dynamic[T](inputs.toArray, outputs.toArray, sharedVariables, generateBackward)
    } else {
      new StaticGraph[T](inputs, outputs, sharedVariables, false)
    }
    var serializedStopGradientLayers : Array[String] = null
    // this is to keep backward compatible
    if (attributes.containsKey("stopGradientLayers")) {
      val stopGradientLayers = attributes.get("stopGradientLayers")
      serializedStopGradientLayers = DataConverter.
        getAttributeValue(context, stopGradientLayers).asInstanceOf[Array[String]]
    }
    if (serializedStopGradientLayers != null) {
      graph.stopGradient(serializedStopGradientLayers)
    }
    graph
  }

  private def createControlNode[T: ClassTag](controlOps: ControlOps[T]): ModuleNode[T] = {
    controlOps match {
      case switchOps: SwitchOps[T] => new SwitchControlNode[Module[T]](switchOps)
      case mergeOps: MergeOps[T] => new MergeControlNode[Module[T]](mergeOps)
      case _ => new Node[Module[T]](controlOps)
    }
  }

  override def doSerializeModule[T: ClassTag](context: SerializeContext[T],
      graphBuilder: BigDLModule.Builder)
    (implicit ev: TensorNumeric[T]): Unit = {
    val module = context.moduleData
    module.next.foreach(_ => graphBuilder.addAllPreModules(_))
    module.pre.foreach(_ => graphBuilder.addAllNextModules(_))
    val graph = module.module.asInstanceOf[Graph[T]]
    val inputsNames = graph.inputs.map(_.element.getName).toArray
    val outputsNames = graph.outputs.map(_.element.getName).toArray
    graph.getForwardExecutions.foreach(execution => {

      val edgeMap = new mutable.HashMap[String, mutable.Map[String, Int]]

      val preNodesAndEdges = execution.prevNodesAndEdges
      val preNodes = preNodesAndEdges.map(_._1.element.getName)
      val nextNodes = preNodesAndEdges.map(_._1.element.getName)
      val currNode = execution.element
        .asInstanceOf[AbstractModule[Activity, Activity, T]]
      val subModel = ModuleSerializer.serialize(SerializeContext(
        ModuleData(currNode, preNodes, nextNodes), context.storages, context.storageType))
      // add edges
      val preNodeEdges = new mutable.HashMap[String, Int]()

      preNodesAndEdges.foreach(pre => {
        val preNodeName = pre._1.element.getName
        val preEdgeIndex = pre._2.fromIndex match {
          case Some(i) => i
          case None => -1
        }
        preNodeEdges(preNodeName) = preEdgeIndex
      })
      edgeMap(execution.element.getName) = preNodeEdges
      val attriBulder = AttrValue.newBuilder
      DataConverter.setAttributeValue(context, attriBulder, edgeMap)

      graphBuilder.putAttr(s"${execution.element.getName}_edges", attriBulder.build)
      graphBuilder.addSubModules(subModel.bigDLModule)
    })


    if (graph.variables.isDefined) {
      val (weights, bias) = graph.variables.get
      val weightAttrBuilder = AttrValue.newBuilder
      DataConverter.setAttributeValue(context, weightAttrBuilder, weights,
        universe.typeOf[Array[Tensor[_ <: scala.Any]]])
      graphBuilder.putAttr("sharedWeight", weightAttrBuilder.build)

      val biasAttrBuilder = AttrValue.newBuilder
      DataConverter.setAttributeValue(context, biasAttrBuilder, bias,
        universe.typeOf[Array[Tensor[_ <: scala.Any]]])
      graphBuilder.putAttr("sharedBias", biasAttrBuilder.build)
    }

    val inputNamesAttrBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, inputNamesAttrBuilder,
      inputsNames, universe.typeOf[Array[String]])
    graphBuilder.putAttr("inputNames", inputNamesAttrBuilder.build)

    val outputNamesBuilder = AttrValue.newBuilder
    DataConverter.setAttributeValue(context, outputNamesBuilder,
      outputsNames, universe.typeOf[Array[String]])
    graphBuilder.putAttr("outputNames", outputNamesBuilder.build)

    if (graph.isInstanceOf[DynamicGraph[_]]) {
      val generateBackwardBuilder = AttrValue.newBuilder
      DataConverter.setAttributeValue(context, generateBackwardBuilder,
        graph.asInstanceOf[DynamicGraph[_]].generateBackward, universe.typeOf[Boolean])
      graphBuilder.putAttr("generateBackward", generateBackwardBuilder.build)
    }

    val stopGradientLayers = graph.getStopGradientLayers

    if (stopGradientLayers != null && stopGradientLayers.size > 0) {
      val stopGradientLayersBuilder = AttrValue.newBuilder
      DataConverter.setAttributeValue(context, stopGradientLayersBuilder,
        stopGradientLayers.toArray(new Array[String](stopGradientLayers.size)),
        universe.typeOf[Array[String]])
      graphBuilder.putAttr("stopGradientLayers", stopGradientLayersBuilder.build)
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy