com.intel.analytics.zoo.pipeline.api.net.NetUtils.scala Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2018 Analytics Zoo Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.zoo.pipeline.api.net
import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.nn.Graph.ModuleNode
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.nn.keras.KerasLayer
import com.intel.analytics.bigdl.nn.{Container, DynamicGraph, Graph, StaticGraph}
import com.intel.analytics.bigdl.serialization.Bigdl.BigDLModule
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.serializer._
import com.intel.analytics.zoo.pipeline.api.keras.layers.KerasLayerWrapper
import com.intel.analytics.zoo.pipeline.api.keras.layers.utils.{GraphRef, KerasUtils}
import com.intel.analytics.zoo.pipeline.api.keras.models.Model
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
class GraphNet[T: ClassTag](graph: Graph[T])(implicit ev: TensorNumeric[T])
extends Container[Activity, Activity, T] with NetUtils[T, GraphNet[T]] {
// need to refer this object to make the register effective
GraphNet
private val labor = graph
modules.append(labor)
def getSubModules(): List[AbstractModule[Activity, Activity, T]] = {
this.labor.modules.toList
}
val outputNodes = NetUtils.getGraphOutputs(graph)
val inputNodes = graph.inputs
override def updateOutput(input: Activity): Activity = {
output = labor.updateOutput(input)
output
}
override def updateGradInput(input: Activity, gradOutput: Activity): Activity = {
gradInput = labor.updateGradInput(input, gradOutput)
gradInput
}
override def accGradParameters(input: Activity, gradOutput: Activity): Unit = {
labor.accGradParameters(input, gradOutput)
}
override def node(name: String): ModuleNode[T] = this.graph.node(name)
override def newGraph(output: String): GraphNet[T] = {
newGraph(Seq(output))
}
override def newGraph(outputs: Seq[String]): GraphNet[T] = {
val inputs = graph.inputs
val variables = NetUtils.getGraphVariables(graph)
.asInstanceOf[Option[(Array[Tensor[T]], Array[Tensor[T]])]]
graph match {
case g: StaticGraph[T] =>
val newGraph = Graph(inputs.toArray, nodes(outputs)
.map(_.removeNextEdges()).toArray, variables)
new GraphNet[T](newGraph)
case g =>
val newGraph = NetUtils.dynamic[T](inputs.toArray, nodes(outputs)
.map(_.removeNextEdges()).toArray,
variables, NetUtils.getGenerateBackward(g))
new GraphNet[T](newGraph)
}
}
override def toKeras(): KerasLayer[Activity, Activity, T] = {
new KerasLayerWrapper[T](this)
}
}
object GraphNet extends ContainerSerializable {
ModuleSerializer.registerModule(
"com.intel.analytics.zoo.pipeline.api.net.GraphNet",
GraphNet)
override def doSerializeModule[T: ClassTag](context: SerializeContext[T],
builder: BigDLModule.Builder)
(implicit ev: TensorNumeric[T]): Unit = {
val labor = context.moduleData.module.
asInstanceOf[GraphNet[T]].labor
val subModule = ModuleSerializer.serialize(SerializeContext(ModuleData(labor,
new ArrayBuffer[String](), new ArrayBuffer[String]()), context.storages,
context.storageType, _copyWeightAndBias))
builder.addSubModules(subModule.bigDLModule)
}
override def doLoadModule[T: ClassTag](context: DeserializeContext)
(implicit ev: TensorNumeric[T]): AbstractModule[Activity, Activity, T] = {
val subProtoModules = context.bigdlModule.getSubModulesList.asScala
val subModules = subProtoModules.map(module => {
val subModuleData = ModuleSerializer.load(DeserializeContext(module,
context.storages, context.storageType, _copyWeightAndBias))
subModuleData.module
})
val tGraph = subModules.head.asInstanceOf[StaticGraph[T]]
tGraph
}
}
object NetUtils {
private[zoo] def getGraphOutputs[T](graph: Graph[T]): Seq[ModuleNode[T]] = {
KerasUtils.invokeMethod(graph, "outputs").asInstanceOf[Seq[ModuleNode[T]]]
}
private[zoo] def getGraphVariables[T](graph: Graph[T]) = {
KerasUtils.invokeMethod(graph, "variables")
.asInstanceOf[Option[(Array[Tensor[T]], Array[Tensor[T]])]]
}
private[zoo] def getGenerateBackward[T](graph: Graph[T]): Boolean = {
KerasUtils.invokeMethod(graph, "generateBackward").asInstanceOf[Boolean]
}
private[zoo] def dynamic[T](
input : Array[ModuleNode[T]],
output : Array[ModuleNode[T]],
variables: Option[(Array[Tensor[T]], Array[Tensor[T]])] = None,
generateBackward: Boolean = true)
(implicit ev: TensorNumeric[T], ev2: ClassTag[T]): Graph[T] = {
import scala.reflect.runtime.{universe => ru}
val m = ru.runtimeMirror(Graph.getClass.getClassLoader)
val mirror = m.reflect(Graph)
val dynamic = mirror.symbol.typeSignature
.member(ru.newTermName("dynamic"))
.filter(_.asMethod.paramss.flatten.length == 6)
val result = mirror.reflectMethod(dynamic.asMethod)(input, output,
variables, generateBackward, ev2, ev)
result.asInstanceOf[Graph[T]]
}
}
trait NetUtils[T, D <: Module[T] with NetUtils[T, D]] {
/**
* Return the nodes in the graph as specified by the names
*/
def nodes(names: Seq[String]): Seq[ModuleNode[T]] = {
names.map(node)
}
/**
* Return the node in the graph as specified by the name
*/
def node(name: String): ModuleNode[T]
/**
* Freeze the model from the bottom up to the layers
* specified by names (inclusive).
*
* This is useful for finetune a model
*/
def freezeUpTo(names: String*): this.type = {
dfs(nodes(names)).foreach(_.element.freeze())
this
}
/**
* Specify a node as output and return a new graph using
* the existing nodes
*/
def newGraph(output: String): D
/**
* Specify a seq of nodes as output and return a new graph using
* the existing nodes
*/
def newGraph(outputs: Seq[String]): D
def toKeras(): KerasLayer[Activity, Activity, T]
private def dfs(endPoints: Seq[ModuleNode[T]]): Iterator[ModuleNode[T]] = {
new Iterator[ModuleNode[T]] {
private val stack = new mutable.Stack[ModuleNode[T]]()
endPoints.map(stack.push)
private val visited = new mutable.HashSet[ModuleNode[T]]()
override def hasNext: Boolean = stack.nonEmpty
override def next(): ModuleNode[T] = {
require(hasNext, "No more elements in the graph")
val node = stack.pop()
visited.add(node)
val nextNodes = node.prevNodes
// to preserve order
val nodesSet = mutable.LinkedHashSet[ModuleNode[T]]()
nextNodes.foreach(nodesSet.add)
nodesSet.filter(!visited.contains(_))
.filter(!stack.contains(_)).foreach(stack.push)
node
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy