All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.zoo.tfpark.GraphRunner.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2018 Analytics Zoo Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.zoo.tfpark

import java.nio._

import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.zoo.common.Utils
import com.intel.analytics.zoo.core.TFNetNative
import com.intel.analytics.zoo.pipeline.api.net.TFNet
import org.tensorflow.{DataType, Graph, Session, Tensor => TTensor}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

/**
 * [[TFNet]] wraps a tensorflow subgraph as a layer, and use tensorflow to
 * calculate the layer's output.
 *
 * This subgraph should not contain any tensorflow Variable and the input/output
 * must be numeric types
 *
 * When used with other layers for training, there should be no trainable layer
 * before this one, as the gradInput of this layer is always zero.
 *
 * @param graphDef serialized representation of a graph
 */
class GraphRunner(
            private[zoo] val graphDef: Array[Byte],
            private val restoreOp: String,
            private val restorePathPlaceholder: String,
            private val saveOp: String,
            private val savePathPlaceholder: String,
            private val config: Array[Byte]) extends java.io.Serializable {

  def makeCopy(): GraphRunner = {
    new GraphRunner(graphDef, restoreOp, restorePathPlaceholder,
      saveOp, savePathPlaceholder, config)
  }

  @transient
  private lazy val tensorManager = new TFResourceManager()

  val output = ArrayBuffer[Tensor[Float]]()

  @transient
  private[zoo] lazy val sess = {
    assert(TFNetNative.isLoaded)
    val graph = new Graph()
    graph.importGraphDef(graphDef)
    val sess = new Session(graph, config)
    sess
  }

  def restoreFromFile(checkpointPath: String): Unit = {
    val runner = sess.runner()
    runner.addTarget(restoreOp)
    val pathTensor = org.tensorflow.Tensor.create(checkpointPath.getBytes())
    runner.feed(restorePathPlaceholder, pathTensor)
    runner.run()
    pathTensor.close()
  }

  def saveToFile(checkpointPath: String): Unit = {
    val runner = sess.runner()
    runner.addTarget(saveOp)
    val pathTensor = org.tensorflow.Tensor.create(checkpointPath.getBytes())
    runner.feed(savePathPlaceholder, pathTensor)
    runner.run()
    pathTensor.close()
  }

  def runTargets(targets: Vector[String]): Unit = {
    run(Vector.empty, Vector.empty, Vector.empty,
      Vector.empty, Vector.empty, Vector.empty, targets)
  }

  def runTargets(targets: Vector[String],
                 inputs: Vector[Tensor[_]],
                 inputTypes: Vector[DataType],
                 inputNames: Vector[String]): Unit = {
    run(inputs, inputNames, inputTypes, Vector.empty, Vector.empty, Vector.empty, targets)
  }

  def runOutputs(outputs: Vector[Tensor[_]],
                 outputNames: Vector[String], outputTypes: Vector[DataType]): Unit = {
    run(Vector.empty, Vector.empty, Vector.empty,
      outputs, outputNames, outputTypes, Vector.empty)
  }

  def run(input: Vector[Tensor[_]],
          inputNames: Vector[String],
          inputTypes: Vector[DataType],
          output: Vector[Tensor[_]],
          outputNames: Vector[String],
          outputTypes: Vector[DataType],
          targets: Vector[String]): Unit = {
    Utils.timeIt("Graph Runner Run") {
      try {
        val runner = sess.runner()

        val inputTFTensors = new Array[TTensor[_]](inputNames.length)

        tensorManager.tensor2TFTensors(input, inputTypes, inputTFTensors)

        // feed inputs
        inputNames.zipWithIndex.foreach { case (name, idx) =>
          runner.feed(name, inputTFTensors(idx))
        }

        // fetch outputs
        outputNames.foreach(runner.fetch)

        // add targets
        targets.foreach(runner.addTarget)


        val outputs = Utils.timeIt("Session Run") {
          runner.run()
        }

        outputs.asScala.zipWithIndex.foreach { case (t, idx) =>
          TFUtils.tf2bigdl(t, output(idx))
        }

        // outputs is returned by tensorflow and cannot be freed using tensorManager
        emptyTFTensorArray(outputs.asScala)

      } finally {
        tensorManager.destructTFTensors()
      }
    }
  }

  private def emptyTFTensorArray(arr: mutable.Buffer[TTensor[_]]): Unit = {
    var i = 0
    while (i < arr.length) {
      tensorManager.releaseTensor(arr(i))
      arr(i) = null
      i += 1
    }
  }

  override def finalize(): Unit = {
    super.finalize()
    this.sess.close()
  }

  def release(): Unit = {
    this.sess.close()
  }
}

object GraphRunner {

  assert(TFNetNative.isLoaded)

  def apply(graphDef: Array[Byte],
            restoreOp: String,
            restorePathPlaceholder: String,
            saveOp: String,
            savePathPlaceholder: String,
            config: Array[Byte]): GraphRunner = {
    new GraphRunner(graphDef, restoreOp, restorePathPlaceholder,
      saveOp, savePathPlaceholder, config)
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy