All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.zoo.tfpark.TFTrainingHelper.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2018 Analytics Zoo Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.zoo.tfpark

import com.intel.analytics.bigdl.nn.Module
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.T
import com.intel.analytics.zoo.common.Utils
import com.intel.analytics.zoo.core.TFNetNative
import org.slf4j.LoggerFactory
import org.tensorflow.DataType
import org.tensorflow.framework.GraphDef

import scala.reflect.io.Path

// variables and gradVariables need to be sorted by name if you want to use multiple
// optimization methods for a TensorFlow model according to variable names.
private[zoo] class TFTrainingHelper protected(val graphRunner: GraphRunner,
                                              val checkpointPath: String,
                                              val inputs: Array[String],
                                              val inputTypes: Array[Int],
                                              val additionalInputs: Array[String],
                                              val additionalInputTypes: Array[Int],
                                              val labels: Array[String],
                                              val labelTypes: Array[Int],
                                              val predictionOutputs: Array[String],
                                              val metrics: Array[String],
                                              val variables: Array[String],
                                              val variableTypes: Array[Int],
                                              val variableAssignPlaceholders: Array[String],
                                              val assignVariableOp: String,
                                              val extraVariables: Array[String],
                                              val extraVariableTypes: Array[Int],
                                              val extraVariableAssignPlaceholders: Array[String],
                                              val assignExtraVariableOP: String,
                                              val gradVariables: Array[String],
                                              val updateOp: String,
                                              val initOp: Option[String],
                                              val defaultTensorValue: Array[Array[Float]])
  extends AbstractModule[Activity, Activity, Float] {

  this.setName("TFParkTraining")

  System.setProperty("bigdl.ModelBroadcastFactory",
    "com.intel.analytics.zoo.tfpark.TFModelBroadcastFactory")

  override def parameters(): (Array[Tensor[Float]], Array[Tensor[Float]]) = {
    (weights, gradWeights)
  }

  override def getExtraParameter(): Array[Tensor[Float]] = {
    extraParameters
  }

  protected val extraParameters: Array[Tensor[Float]] = initVariables(extraVariables)

  protected val weights = initVariables(variables)

  private val weightsMap = {
    val map = collection.mutable.Map[String, Tensor[Float]]()
    var i = 0
    while (i < variables.length) {
      map(variables(i)) = weights(i)
      i += 1
    }
    map
  }

  private def initVariables(variableNames: Array[String]): Array[Tensor[Float]] = {
    val ws = new Array[Tensor[Float]](variableNames.length)
    var i = 0
    while (i < ws.length) {
      ws(i) = Tensor[Float]()
      i += 1
    }
    ws
  }

  val outputNames = predictionOutputs ++ metrics

  private val gradWeights = variables.map(_ => Tensor[Float]())

  private val graphOutputs = {
    val graphOuts = Vector.newBuilder[Tensor[Float]]

    var i = 0
    while (i < outputNames.length) {
      graphOuts += Tensor[Float]()
      i += 1
    }

    i = 0
    while (i < gradVariables.length) {
      graphOuts += Tensor[Float]()
      i += 1
    }

    graphOuts.result()
  }

  private val gradWeightsBuffer =
    graphOutputs.slice(outputNames.length, graphOutputs.length)

  output = {
    if (outputNames.length == 1) {
      graphOutputs(0)
    } else {
      val out = T()
      var i = 0
      while (i < outputNames.length) {
        out.insert(graphOutputs(i))
        i += 1
      }
      out
    }
  }

  def getGraphOut(): Vector[Tensor[Float]] = graphOutputs

  protected def evaluateInternal(): Unit = {
    setVariableIntoTF(weights, variableAssignPlaceholders,
      variableTypes.map(TFUtils.tfenum2datatype), assignVariableOp)
  }

  override def evaluate(): TFTrainingHelper.this.type = {
    super.evaluate()
    evaluateInternal()
    this
  }


  protected def getVariableFromTF(weights: Array[Tensor[Float]],
                                  variableNames: Array[String]) = {
    val outputTypes = Vector.fill(variableNames.length)(DataType.FLOAT)
    graphRunner.run(input = Vector.empty, inputNames = Vector.empty, inputTypes = Vector.empty,
      output = weights.toVector, outputNames = variableNames.toVector, outputTypes = outputTypes,
      targets = Vector.empty)
  }

  protected def setVariableIntoTF(weights: Array[Tensor[Float]],
                                  inputNames: Array[String],
                                  variableTypes: Array[DataType],
                                  assignOp: String) = {
    graphRunner.run(input = weights.toVector, inputNames = inputNames.toVector,
      inputTypes = variableTypes.toVector, output = Vector.empty,
      outputNames = Vector.empty, outputTypes = Vector.empty, targets = Vector(assignOp))
  }

  def saveCheckpoint(): Unit = {
    setVariableIntoTF(weights, variableAssignPlaceholders,
      variableTypes.map(TFUtils.tfenum2datatype), assignVariableOp)
    setVariableIntoTF(extraParameters, extraVariableAssignPlaceholders,
      extraVariableTypes.map(TFUtils.tfenum2datatype), assignExtraVariableOP)
    graphRunner.saveToFile(checkpointPath)
  }

  @transient
  protected var extraParameterRestored: Boolean = false

  @transient
  protected var weightsRestored: Boolean = false

  @transient
  private lazy val tableInited = {
    if (initOp.isDefined) {
      runInitOp()
    }
    true
  }

  private def runInitOp(): Unit = {
    graphRunner.runTargets(targets = Vector(initOp.get))
  }

  def restoreFromCheckpoint(): Unit = {
    graphRunner.restoreFromFile(checkpointPath)
    if (weights.length > 0) {
      getVariableFromTF(weights, variableNames = variables)
    }

    if (extraParameters.length > 0) {
      getVariableFromTF(extraParameters, variableNames = extraVariables)
    }

    weightsRestored = true

    extraParameterRestored = true

  }

  override def apply(name: String): Option[AbstractModule[Activity, Activity, Float]] = {
    val targetVariables = if (name == getName()) variables else variables.filter(_.startsWith(name))
    if (targetVariables == null) {
      None
    }
    else {
      val targetWeights = targetVariables.map(weightsMap)
      Some(new TFSubGraph(targetWeights))
    }
  }

  protected def beforeRunGradient() = {
    if (this.isTraining() || !weightsRestored) {
      Utils.timeIt("setTrainingVariableIntoTF") {
        setVariableIntoTF(weights, variableAssignPlaceholders,
          variableTypes.map(TFUtils.tfenum2datatype), assignVariableOp)
      }
      weightsRestored = true
    }

    if (!extraParameterRestored) {
      setVariableIntoTF(extraParameters, extraVariableAssignPlaceholders,
        extraVariableTypes.map(TFUtils.tfenum2datatype), assignExtraVariableOP)
      extraParameterRestored = true
    }
  }

  protected def afterRunGradient() = {
    if (extraParameters.length > 0) {
      Utils.timeIt("getExtraVariableFromTF") {
        getVariableFromTF(extraParameters, variableNames = extraVariables)
      }
    }

    if (isTraining()) {
      gradWeights.zipWithIndex.foreach { case (grad, idx) =>
        grad.resizeAs(weights(idx)).add(gradWeightsBuffer(idx))
      }
    }
  }

  override def updateOutput(input: Activity): Activity = {
    Utils.timeIt("updateOutput") {

      assert(tableInited)

      this.beforeRunGradient()

      val feeds = Utils.activity2VectorBuilder(input)

      if (this.isTraining()) {
        var i = 0
        while (i < defaultTensorValue.length) {
          feeds += Tensor.scalar[Float](defaultTensorValue(i)(0))
          i += 1
        }
      } else {
        var i = 0
        while (i < defaultTensorValue.length) {
          feeds += Tensor.scalar[Float](defaultTensorValue(i)(1))
          i += 1
        }
      }

      val (names, tensors) = if (isTraining()) {
        (outputNames.toVector ++ gradVariables.toVector, graphOutputs)
      } else {
        (outputNames.toVector, graphOutputs.slice(0, outputNames.length))
      }

      val inputTensorNames = inputs ++ labels ++ additionalInputs
      val inputTensorTypes = (inputTypes ++ labelTypes ++ additionalInputTypes)
        .toVector.map(TFUtils.tfenum2datatype)

      val outputTypes = Vector.fill(names.length)(DataType.FLOAT)
      graphRunner.run(input = feeds.result(), inputNames = inputTensorNames.toVector,
        inputTypes = inputTensorTypes,
        output = tensors, outputNames = names, outputTypes = outputTypes,
        targets = Vector(updateOp))

      this.afterRunGradient()
    }

    output
  }

  override def updateGradInput(input: Activity, gradOutput: Activity): Activity = {
    gradInput
  }

  def loadZooCheckpoint(path: String): Unit = {
    val module = Module.load(path).asInstanceOf[TFTrainingHelper]
    assert(module.variables sameElements this.variables,
      "variables in graphdef does not equal to the current variables." +
        s"\nvariable names in checkpoint: ${module.variables}" +
        s"\nvariable names in current graph, ${this.variables}")
    this.parameters()._1.zip(module.parameters()._1).foreach { case (target, source) =>
      target.copy(source)
    }
    this.extraParameters.zip(module.extraParameters).foreach { case (target, source) =>
      target.copy(source)
    }
  }
}

object TFTrainingHelper {

  assert(TFNetNative.isLoaded)

  val logger = LoggerFactory.getLogger(getClass)

  def apply(graphDef: GraphDef,
            checkpointPath: String,
            trainMeta: TrainMeta,
            config: Array[Byte]): TFTrainingHelper = {

    val graphRunner = new GraphRunner(
      graphDef.toByteArray,
      trainMeta.restoreOp,
      trainMeta.restorePathPlaceholder,
      trainMeta.saveOp,
      trainMeta.savePathPlaceholder,
      config)

    val helper = if (trainMeta.trainOp.isEmpty) {
      new TFTrainingHelper(graphRunner,
        checkpointPath,
        trainMeta.inputs,
        trainMeta.inputTypes,
        trainMeta.additionalInputs,
        trainMeta.additionalInputTypes,
        trainMeta.labels,
        trainMeta.labelTypes,
        trainMeta.predictionOutputs,
        trainMeta.metricTensors ++ Array(trainMeta.batchSizeTensor, trainMeta.lossTensor),
        trainMeta.variables,
        trainMeta.variableTypes,
        trainMeta.variableAssignPlaceholders,
        trainMeta.assignVariableOp,
        trainMeta.extraVariables,
        trainMeta.extraVariableTypes,
        trainMeta.extraVariableAssignPlaceholders,
        trainMeta.assignExtraVariableOp,
        trainMeta.gradVariables,
        trainMeta.updateOp,
        trainMeta.initOp,
        trainMeta.defaultTensorValue
      )
    } else {
      new TFTrainingHelperV2(graphRunner,
        checkpointPath,
        trainMeta.inputs,
        trainMeta.inputTypes,
        trainMeta.additionalInputs,
        trainMeta.additionalInputTypes,
        trainMeta.labels,
        trainMeta.labelTypes,
        trainMeta.predictionOutputs,
        trainMeta.metricTensors ++ Array(trainMeta.batchSizeTensor, trainMeta.lossTensor),
        trainMeta.variables,
        trainMeta.variableTypes,
        trainMeta.variableAssignPlaceholders,
        trainMeta.assignVariableOp,
        trainMeta.extraVariables,
        trainMeta.extraVariableTypes,
        trainMeta.extraVariableAssignPlaceholders,
        trainMeta.assignExtraVariableOp,
        trainMeta.gradVariables,
        trainMeta.updateOp,
        trainMeta.trainOp.get,
        trainMeta.initOp,
        trainMeta.defaultTensorValue
      )
    }
    helper.restoreFromCheckpoint()
    helper
  }

  def apply(modelPath: String, sessionConfig: Array[Byte] = null): TFTrainingHelper = {

    val folderPath = Path(modelPath)
    val trainingMetaPath = folderPath / Path("training_meta.json")
    val graphDefPath = folderPath / Path("model.meta")
    val checkpointPath = folderPath / Path("model")

    val trainingMeta = TFUtils.getTrainMeta(trainingMetaPath)
    val graphDef = TFUtils.parseGraph(graphDefPath.toString())
    val config = Option(sessionConfig).getOrElse(TFUtils.defaultSessionConfig.toByteArray())

    val helper = TFTrainingHelper(graphDef, checkpointPath.toString(),
      trainingMeta, config)
    helper
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy