All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.zoo.pipeline.api.net.TFNet.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2018 Analytics Zoo Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.zoo.pipeline.api.net

import java.io.{File, FileInputStream, InputStream}
import java.nio._

import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.{MultiShape, Shape, T}
import org.tensorflow.framework.GraphDef
import org.tensorflow.types.UInt8
import org.tensorflow.{DataType, Graph, Session, Tensor => TTensor}

import scala.collection.JavaConverters._
import scala.io.Source
import scala.reflect.io.Path

import org.json4s._
import org.json4s.jackson.JsonMethods._

/**
 * [[TFNet]] wraps a tensorflow subgraph as a layer, and use tensorflow to
 * calculate the layer's output.
 *
 * This subgraph should not contain any tensorflow Variable and the input/output
 * must be numeric types
 *
 * When used with other layers for training, there should be no trainable layer
 * before this one, as the gradInput of this layer is always zero.
 *
 * @param graphDef serialized representation of a graph
 * @param inputNames the input tensor names of this subgraph
 * @param outputNames the output tensor names of this subgraph
 */
class TFNet private(graphDef: Array[Byte],
            val inputNames: Seq[String],
            val outputNames: Seq[String],
                    config: Array[Byte])
  extends AbstractModule[Activity, Activity, Float] {

  // this is a workaround for a bug in scala 2.10
  // transient lazy vals will null constructor fields
  // https://issues.scala-lang.org/browse/SI-8453
  private def size = graphDef.length


  output = {
    if (outputNames.length == 1) {
      Tensor[Float]()
    } else {
      val t = T()
      var i = 0
      while (i < outputNames.length) {
        t.insert(Tensor[Float]())
        i = i + 1
      }
      t
    }
  }

  gradInput = {
    if (inputNames.length == 1) {
      Tensor[Float]()
    } else {
      val t = T()
      var i = 0
      while (i < inputNames.length) {
        t.insert(Tensor[Float]())
        i = i + 1
      }
      t
    }
  }

  private def getOutput(idx: Int): Tensor[Float] = {
    if (output.isTable) {
      output.toTable[Tensor[Float]](idx)
    } else {
      output.toTensor[Float]
    }
  }

  @transient
  private lazy val graph = {
    val graph = new Graph()
    graph.importGraphDef(graphDef)
    graph
  }

  @transient
  private lazy val sess = {
    val sess = new Session(graph, config)
    sess
  }

  @transient
  private lazy val inputTypes = inputNames.map { name =>
    val Array(op, idx) = name.split(":")
    val operation = graph.operation(op)
    val output = operation.output(idx.toInt)
    output.dataType()
  }

  // add Cast Operation if the output tensor is not of type Float
  @transient
  private lazy val floatOutputNames = outputNames.map { name =>
    val Array(op, idx) = name.split(":")
    val operation = graph.operation(op)
    val output = operation.output(idx.toInt)
    if (output.dataType() != DataType.FLOAT) {
      val name = graph.opBuilder("Cast", s"${op}_to_float")
        .addInput(output)
        .setAttr("DstT", DataType.FLOAT)
        .setAttr("SrcT", output.dataType())
        .build()
        .name()
      s"$name:0"
    } else {
      name
    }
  }

  private def getShape(names: Seq[String]) = {
    val shapes = names.map { name =>
      val Array(op, idx) = name.split(":")
      val shape = graph.operation(op).output(idx.toInt).shape()
      Shape((0 until shape.numDimensions()).map(shape.size(_).toInt).toArray)
    }

    if (shapes.length == 1) {
      shapes.head
    } else {
      MultiShape(shapes.toList)
    }
  }

  override def parameters(): (Array[Tensor[Float]], Array[Tensor[Float]]) = {
    (Array(), Array())
  }

  private def bigdl2Tf(t: Tensor[Float], dataType: DataType): TTensor[_] = {
    val shape = t.size().map(_.toLong)
    val arr = t.storage().array()
    if (dataType == DataType.FLOAT) {
      val buffer = FloatBuffer.wrap(arr)
      TTensor.create(shape, buffer)
    } else if (dataType == DataType.UINT8) {
      val buffer = ByteBuffer.wrap(TFNet.floatToUint8(arr))
      TTensor.create(classOf[UInt8], shape, buffer)
    } else if (dataType == DataType.INT32) {
      val buffer = IntBuffer.wrap(TFNet.floatToInt(arr))
      TTensor.create(shape, buffer)
    } else if (dataType == DataType.INT64) {
      val buffer = LongBuffer.wrap(TFNet.floatToLong(arr))
      TTensor.create(shape, buffer)
    } else if (dataType == DataType.DOUBLE) {
      val buffer = DoubleBuffer.wrap(TFNet.floatToDouble(arr))
      TTensor.create(shape, buffer)
    } else {
      throw new Exception(s"data type ${dataType} are not supported")
    }

  }

  private def tf2bigdl(t: TTensor[Float], output: Tensor[Float]) = {
    val shape = t.shape().map(_.toInt)
    output.resize(shape)
    val buffer = FloatBuffer.wrap(
      output.storage().array(),
      output.storageOffset() - 1,
      shape.product)
    t.writeTo(buffer)
  }

  override def updateOutput(input: Activity): Activity = {
    val data = if (input.isTensor) {
      val tfTensor = bigdl2Tf(input.toTensor[Float], inputTypes.head)
      Seq(tfTensor)
    } else {
      val t = input.toTable
      for (i <- 1 to t.length()) yield {
        bigdl2Tf(t[Tensor[Float]](i), inputTypes(i-1))
      }
    }

    val runner = sess.runner()
    floatOutputNames.foreach(runner.fetch)
    inputNames.zipWithIndex.foreach { case (name, idx) =>
      runner.feed(name, data(idx))
    }

    val outputs = runner.run()
    outputs.asScala.zipWithIndex.foreach { case (t, idx) =>
      tf2bigdl(t.asInstanceOf[TTensor[Float]], getOutput(idx + 1))
    }
    // clean up resources
    data.foreach(_.close())
    outputs.asScala.foreach(_.close())
    output
  }

  override def updateGradInput(input: Activity, gradOutput: Activity): Activity = {
    if (gradInput.isTable) {
      var i = 0
      while (i < gradInput.toTable.length()) {
        gradInput.toTable[Tensor[Float]](i + 1)
          .resizeAs(input.toTable[Tensor[Float]](i + 1))
        i = i + 1
      }
    } else {
      gradInput.toTensor[Float]
        .resizeAs(input.toTensor[Float])
    }

    gradInput
  }
}

object TFNet {

  implicit val formats = DefaultFormats

  val defaultSessionConfig = Seq(16, 1, 40, 1, 72, 1).map(_.toByte).toArray
  // Ideally we should use the following code, however, importing tensorflow proto
  // will conflict with bigdl.

//  val defaultSessionConfig = ConfigProto.newBuilder()
//    .setInterOpParallelismThreads(1)
//    .setIntraOpParallelismThreads(1)
//    .setUsePerSessionThreads(true)
//    .build().toByteArray

  private def floatToInt(array: Array[Float]): Array[Int] = {
    val result = new Array[Int](array.length)
    var i = 0
    while (i < array.length) {
      result(i) = array(i).toInt
      i = i + 1
    }
    result
  }

  private def floatToLong(array: Array[Float]): Array[Long] = {
    val result = new Array[Long](array.length)
    var i = 0
    while (i < array.length) {
      result(i) = array(i).toLong
      i = i + 1
    }
    result
  }

  private def floatToDouble(array: Array[Float]): Array[Double] = {
    val result = new Array[Double](array.length)
    var i = 0
    while (i < array.length) {
      result(i) = array(i).toDouble
      i = i + 1
    }
    result
  }

  private def floatToUint8(array: Array[Float]): Array[Byte] = {
    val result = new Array[Byte](array.length)
    var i = 0
    while (i < array.length) {
      result(i) = array(i).toByte
      i = i + 1
    }
    result
  }

  /**
   * Create a TFNet
   * @param graphDef the tensorflow GraphDef object
   * @param inputNames the input tensor names of this subgraph
   * @param outputNames the output tensor names of this subgraph
   * @return
   */
  def apply(graphDef: GraphDef, inputNames: Seq[String],
            outputNames: Seq[String], config: Array[Byte] = defaultSessionConfig): TFNet = {
    new TFNet(graphDef.toByteArray, inputNames, outputNames, config)
  }

  /**
   * Create a TFNet
   * @param path the file path of a graphDef
   * @param inputNames the input tensor names of this subgraph
   * @param outputNames the output tensor names of this subgraph
   * @return
   */
  def apply(path: String,
            inputNames: Seq[String],
            outputNames: Seq[String], config: Array[Byte]): TFNet = {
    val graphDef = parseGraph(path)
    TFNet(graphDef, inputNames, outputNames, config)
  }

  /**
   * Create a TFNet
   * @param path the file path of a graphDef
   * @param inputNames the input tensor names of this subgraph
   * @param outputNames the output tensor names of this subgraph
   * @return
   */
  def apply(path: String,
            inputNames: Seq[String],
            outputNames: Seq[String]): TFNet = {
    val graphDef = parseGraph(path)
    TFNet(graphDef, inputNames, outputNames, defaultSessionConfig)
  }


  def apply(folder: String): TFNet = {
    val folderPath = Path(folder)
    if (!folderPath.exists) {
      throw new IllegalArgumentException(s"$folder does not exists")
    }

    val modelPath = folderPath / Path("frozen_inference_graph.pb")
    val metaPath = folderPath / Path("graph_meta.json")

    val jsonStr = Source.fromFile(metaPath.jfile).getLines().mkString

    val meta = parse(jsonStr).camelizeKeys.extract[Meta]

    TFNet(modelPath.toString(), meta.inputNames, meta.outputNames, defaultSessionConfig)
  }

  private case class Meta(inputNames: Array[String], outputNames: Array[String])

  private def parseGraph(graphProtoTxt: String) : GraphDef = {
    var fr: File = null
    var in: InputStream = null
    try {
      fr = new File(graphProtoTxt)
      in = new FileInputStream(fr)

      GraphDef.parseFrom(in)
    } finally {
      if (in != null) in.close()
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy