All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.zoo.tfpark.TFUtils.scala Maven / Gradle / Ivy

/*
 * Copyright 2018 Analytics Zoo Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.zoo.tfpark

import java.io.{File, FileInputStream, InputStream}
import java.nio._

import com.intel.analytics.bigdl.dataset.Sample
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractCriterion, AbstractModule, Activity}
import com.intel.analytics.bigdl.optim._
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.transform.vision.image.ImageFeature
import com.intel.analytics.bigdl.utils.{T, Table}
import com.intel.analytics.zoo.feature.common.Preprocessing
import com.intel.analytics.zoo.feature.image.ImageProcessing
import com.intel.analytics.zoo.pipeline.api.keras.{metrics => kmetrics}
import com.intel.analytics.zoo.pipeline.api.keras.metrics.{Accuracy, BinaryAccuracy, CategoricalAccuracy, SparseCategoricalAccuracy}
import org.tensorflow.framework.GraphDef
import org.tensorflow.{DataType, Tensor => TTensor}

import scala.io.Source
import scala.reflect.io.Path

object TFUtils {

  val defaultSessionConfig = SessionConfig()

  private[zoo] def getTrainMeta(trainMetaPath: Path) = {
    val jsonStr = Source.fromFile(trainMetaPath.jfile).getLines().mkString
    import org.json4s._
    import org.json4s.jackson.JsonMethods._
    implicit val formats = DefaultFormats

    parse(jsonStr).camelizeKeys.extract[TrainMeta]
  }

  private[zoo] def parseGraph(graphProtoTxt: String): GraphDef = {
    var fr: File = null
    var in: InputStream = null
    try {
      fr = new File(graphProtoTxt)
      in = new FileInputStream(fr)

      GraphDef.parseFrom(in)
    } finally {
      if (in != null) in.close()
    }
  }

  private def decodeUVarInt64(bytes: Array[Byte], offset: Int): (Long, Int) = {
    var shift = 0
    var p = offset
    var result: Long = 0
    while (shift  <= 63 && p < bytes.length) {
      val b = bytes(p)
      p += 1
      if ((b & 128) != 0) {
        result |= ((b & 127) << shift)
      } else {
        result |= (b << shift)
        return (result, p)
      }

      shift += 7
    }
    (result, p)
  }

  private def getOffsets(buffer: ByteBuffer, numElem: Int): Array[Int] = {
    val offsetsBuffer = ByteBuffer.wrap(buffer.array()
      .slice(buffer.arrayOffset(), numElem * 8))
      .order(ByteOrder.nativeOrder())
      .asLongBuffer()
    val offsets = new Array[Long](numElem)
    offsetsBuffer.get(offsets)
    offsets.map(_.toInt)
  }

  private[zoo] def tf2bigdl(t: TTensor[_], output: Tensor[_]) = {
    val shape = t.shape().map(_.toInt)
    output.resize(shape)
    val dataType = t.dataType()

    val numericDataTypes = Set(DataType.FLOAT,
      DataType.UINT8, DataType.INT32, DataType.INT64, DataType.DOUBLE, DataType.BOOL)

    if (dataType == DataType.STRING) {
      val outputTensor = output.asInstanceOf[Tensor[Array[Byte]]]
      require(t.numDimensions() <= 1, "only scalar or Vector string are supported")
      val elements = t.numElements()
      val buffer = ByteBuffer.allocate(t.numBytes())
      t.writeTo(buffer)
      val offsets = getOffsets(buffer, elements)
      val storage = outputTensor.storage().array()
      val strDataOffset = elements * 8
      var i = 0
      while (i < elements) {
        val offset = buffer.arrayOffset() + offsets(i) + strDataOffset
        val (strLen, strStart) = decodeUVarInt64(buffer.array(), offset)
        storage(outputTensor.storageOffset() - 1 + i) = buffer.array()
          .slice(strStart, strStart + strLen.toInt)
        i += 1
      }
    } else if (numericDataTypes(dataType)) {
      dataType match {
        case DataType.FLOAT =>
          val outputTensor = output.asInstanceOf[Tensor[Float]]
          val buffer = FloatBuffer.wrap(
            outputTensor.storage().array(),
            outputTensor.storageOffset() - 1,
            shape.product)
          t.writeTo(buffer)
        case DataType.UINT8 =>
          val outputTensor = output.asInstanceOf[Tensor[Float]]
          val arr = new Array[Byte](shape.product)
          val buffer = ByteBuffer.wrap(arr)
          t.writeTo(buffer)
          byte2float(arr, outputTensor.storage().array(), outputTensor.storageOffset() - 1)
        case DataType.INT32 =>
          val outputTensor = output.asInstanceOf[Tensor[Float]]
          val arr = new Array[Int](shape.product)
          val buffer = IntBuffer.wrap(arr)
          t.writeTo(buffer)
          int2float(arr, outputTensor.storage().array(), outputTensor.storageOffset() - 1)
        case DataType.INT64 =>
          val outputTensor = output.asInstanceOf[Tensor[Float]]
          val arr = new Array[Long](shape.product)
          val buffer = LongBuffer.wrap(arr)
          t.writeTo(buffer)
          long2float(arr, outputTensor.storage().array(), outputTensor.storageOffset() - 1)
        case DataType.DOUBLE =>
          val outputTensor = output.asInstanceOf[Tensor[Float]]
          val arr = new Array[Double](shape.product)
          val buffer = DoubleBuffer.wrap(arr)
          t.writeTo(buffer)
          double2float(arr, outputTensor.storage().array(), outputTensor.storageOffset() - 1)
        case DataType.BOOL =>
          val outputTensor = output.asInstanceOf[Tensor[Float]]
          val arr = new Array[Byte](t.numBytes())
          assert(t.numBytes() == shape.product, "sanity check")
          val buffer = ByteBuffer.wrap(arr)
          t.writeTo(buffer)
          byte2float(arr, outputTensor.storage().array(), outputTensor.storageOffset() - 1)
      }

    } else {
      throw new Exception(s"data type ${dataType} are not supported")
    }
  }

  private[zoo] def byte2float(src: Array[Byte], dest: Array[Float], offset: Int): Unit = {
    val length = src.length
    var i = 0
    while (i < length) {
      dest(offset + i) = src(i).toFloat
      i += 1
    }
  }

  private[zoo] def int2float(src: Array[Int], dest: Array[Float], offset: Int): Unit = {
    val length = src.length
    var i = 0
    while (i < length) {
      dest(offset + i) = src(i).toFloat
      i += 1
    }
  }

  private[zoo] def long2float(src: Array[Long], dest: Array[Float], offset: Int): Unit = {
    val length = src.length
    var i = 0
    while (i < length) {
      dest(offset + i) = src(i).toFloat
      i += 1
    }
  }

  private[zoo] def double2float(src: Array[Double], dest: Array[Float], offset: Int): Unit = {
    val length = src.length
    var i = 0
    while (i < length) {
      dest(offset + i) = src(i).toFloat
      i += 1
    }
  }

  def tfenum2datatype(enum: Int): DataType = {
    enum match {
      case 1 => DataType.FLOAT
      case 2 => DataType.DOUBLE
      case 3 => DataType.INT32
      case 4 => DataType.UINT8
      case 7 => DataType.STRING
      case 9 => DataType.INT64
      case 10 => DataType.BOOL
      case _ => throw new IllegalArgumentException(s"unsupported tensorflow datatype $enum")

    }
  }

  def tfdatatype2enum(dataType: DataType): Int = {
    dataType match {
      case DataType.FLOAT => 1
      case DataType.DOUBLE => 2
      case DataType.INT32 => 3
      case DataType.UINT8 => 4
      case DataType.STRING => 7
      case DataType.INT64 => 9
      case DataType.BOOL => 10
      case _ => throw new IllegalArgumentException(s"unsupported tensorflow datatype $dataType")

    }
  }
}

class IdentityCriterion extends AbstractCriterion[Activity, Activity, Float]() {

  override def updateOutput(input: Activity, target: Activity): Float = {
    if (input.isTensor) {
      input.toTensor[Float].value()
    } else {
      val table = input.toTable
      table[Tensor[Float]](table.length()).value()
    }
  }
  override def updateGradInput(input: Activity, target: Activity): Activity = {
    gradInput
  }
}

class TFValidationMethod(val valMethod: ValidationMethod[Float],
                         name: String,
                         outputIndices: java.util.List[Int],
                         labelIndices: java.util.List[Int]) extends ValidationMethod[Float] {

  private def toActivity(indices: java.util.List[Int], table: Table) = {
    if (indices.size() == 1) {
      table[Tensor[Float]](indices.get(0) + 1)
    } else {
      var i = 0
      val outputs = T()
      while (i < indices.size()) {
        outputs.insert(table(indices.get(i) + 1))
        i += 1
      }
      outputs
    }
  }

  private def oneBasedLabel(activity: Activity) = {
    if (activity.isTensor) {
      activity.toTensor[Float].add(1.0f)
    } else {
      val t = activity.toTable
      var i = 0
      while (i < t.length()) {
        t[Tensor[Float]](i + 1).add(1.0f)
        i += 1
      }
    }
  }

  override def apply(output: Activity, target: Activity): ValidationResult = {
    // the output layout [grads..., outputs..., labels..., loss]
    val outputT = output.toTable

    if (valMethod.isInstanceOf[Loss[Float]]) {
      val loss = outputT[Tensor[Float]](outputT.length()).value()
      return new LossResult(loss, 1)
    }

    val outputActivity: Activity = toActivity(outputIndices, outputT)
    val targetActivity: Activity = toActivity(labelIndices, outputT)

    val to1basedLabel = valMethod match {
      case _: SparseCategoricalAccuracy[Float] => false
      case _: CategoricalAccuracy[Float] => false
      case _: BinaryAccuracy[Float] => false
      case v: kmetrics.Top5Accuracy[Float] => !v.zeroBasedLabel
      case v: Accuracy[Float] => !v.zeroBasedLabel
      case _: Top1Accuracy[Float] => true
      case _: Top5Accuracy[Float] => true
      case _: TreeNNAccuracy[Float] => true
      case _ => false
    }

    if (to1basedLabel) {
      oneBasedLabel(targetActivity)
    }

    valMethod.apply(outputActivity, targetActivity)
  }

  override protected def format(): String = {
    (name + " " + valMethod.toString()).trim
  }
}

class StatelessMetric(name: String, idx: Int, countIdx: Int) extends ValidationMethod[Float] {
  override def apply(output: Activity, target: Activity): ValidationResult = {
    // the output layout [grads..., metrics]
    val outputT = output.toTable

    val value = outputT[Tensor[Float]](idx + 1).value()
    val count = outputT[Tensor[Float]](countIdx + 1).value().toInt

    new ContiguousResult(value * count, count, name)
  }

  override protected def format(): String = {
    name
  }
}

class MergeFeatureLabel() extends ImageProcessing {

  def createNewMergedSample(sample: Sample[Float]): Sample[Float] = {
    val newSize = sample.getFeatureSize() ++ sample.getLabelSize()
    Sample(sample.getData(), newSize, null)
  }

  override def transform(feature: ImageFeature): ImageFeature = {
    val oldSample = feature[Sample[Float]](ImageFeature.sample)
    val newSample = createNewMergedSample(oldSample)
    val newFeature = new ImageFeature()
    newFeature(ImageFeature.sample) = newSample
    newFeature
  }
}

class MergeFeatureLabelFeatureTransformer() extends Preprocessing[Any, Any] {

  private val mergeFun = new MergeFeatureLabel()
  override def apply(prev: Iterator[Any]): Iterator[Any] = {
    prev.map(transform)
  }

  private def transform(element: Any): Any = {
    element match {
      case feature: ImageFeature =>
        mergeFun.transform(feature)
      case sample: Sample[Float] =>
        mergeFun.createNewMergedSample(sample)
      case _ => throw new IllegalArgumentException(
        s"Element type ImageFeaute and Sample[Float] is supported. " +
          s"Element type ${element.getClass} is not supported.")
    }
  }
}


case class TrainMeta(inputs: Array[String],
                     inputTypes: Array[Int],
                     additionalInputs: Array[String],
                     additionalInputTypes: Array[Int],
                     labels: Array[String],
                     labelTypes: Array[Int],
                     predictionOutputs: Array[String],
                     metricTensors: Array[String],
                     batchSizeTensor: String,
                     lossTensor: String,
                     variables: Array[String],
                     variableTypes: Array[Int],
                     variableAssignPlaceholders: Array[String],
                     assignVariableOp: String,
                     extraVariables: Array[String],
                     extraVariableTypes: Array[Int],
                     extraVariableAssignPlaceholders: Array[String],
                     assignExtraVariableOp: String,
                     gradVariables: Array[String],
                     restoreOp: String,
                     restorePathPlaceholder: String,
                     saveOp: String,
                     savePathPlaceholder: String,
                     updateOp: String,
                     trainOp: Option[String],
                     initOp: Option[String],
                     defaultTensorValue: Array[Array[Float]],
                     metricsNames: Array[String])

/**
 * TFSubGraph will only be used in DistriOptimizer for the purpose of training a TensorFlow
 * model using multiple optimization methods based on variable names.
 * Applying a TFTrainingHelper2 layer by name will get a corresponding instance of TFSubGraph.
 *
 * In DistriOptimizer.optimize(), TFSubGraph will only be used to get the sizes and offsets of
 * each weight portion, slice on the original weights and gradients and apply the optimization
 * method accordingly.
 * The gradients of TFSubGraph will never be used and thus a dummy Tensor is put as a placeholder.
 */
private[zoo] class TFSubGraph(
        weights: Array[Tensor[Float]]) extends AbstractModule[Activity, Activity, Float] {
  override def updateOutput(input: Activity): Activity = {
    input
  }

  override def updateGradInput(input: Activity, gradOutput: Activity): Activity = {
    gradInput
  }

  override def parameters(): (Array[Tensor[Float]], Array[Tensor[Float]]) = {
    (weights, weights.map(_ => Tensor[Float]()))
  }
}

case class SessionConfig(intraOpParallelismThreads: Int = 1,
                         interOpParallelismThreads: Int = 1,
                         usePerSessionThreads: Boolean = true) {

  // Ideally we should use the following code, however, importing tensorflow proto
  // will conflict with bigdl.

  //  val defaultSessionConfig = ConfigProto.newBuilder()
  //    .setInterOpParallelismThreads(1)
  //    .setIntraOpParallelismThreads(1)
  //    .setUsePerSessionThreads(true)
  //    .build().toByteArray

  def toByteArray(): Array[Byte] = {
    val intraSeq = if (intraOpParallelismThreads > 0) {
      Seq(16, intraOpParallelismThreads)
    } else {
      Seq[Int]()
    }
    val interSeq = if (interOpParallelismThreads > 0) {
      Seq(40, interOpParallelismThreads)
    } else {
      Seq[Int]()
    }
    val perSessSeq = if (usePerSessionThreads) {
      Seq(72, 1)
    } else {
      Seq[Int]()
    }

    (intraSeq ++ interSeq ++ perSessSeq).map(_.toByte).toArray
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy