All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.zoo.tfpark.TFTrainingHelperV2.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2018 Analytics Zoo Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.zoo.tfpark

import com.intel.analytics.zoo.common.Utils
import org.tensorflow.DataType

class TFTrainingHelperV2(graphRunner: GraphRunner,
                         checkpointPath: String,
                         inputs: Array[String],
                         inputTypes: Array[Int],
                         additionalInputs: Array[String],
                         additionalInputTypes: Array[Int],
                         labels: Array[String],
                         labelTypes: Array[Int],
                         outputs: Array[String],
                         metrics: Array[String],
                         variables: Array[String],
                         variableTypes: Array[Int],
                         variableAssignPlaceholders: Array[String],
                         assignVariableOp: String,
                         extraVariables: Array[String],
                         extraVariableTypes: Array[Int],
                         extraVariableAssignPlaceholders: Array[String],
                         assignExtraVariableOP: String,
                         gradVariables: Array[String],
                         updateOp: String,
                         private val trainOp: String,
                         initOp: Option[String],
                         defaultTensorValue: Array[Array[Float]])
  extends TFTrainingHelper(graphRunner, checkpointPath, inputs, inputTypes,
    additionalInputs, additionalInputTypes, labels, labelTypes, outputs,
    metrics, variables, variableTypes, variableAssignPlaceholders, assignVariableOp,
    extraVariables, extraVariableTypes, extraVariableAssignPlaceholders, assignExtraVariableOP,
    gradVariables, updateOp, initOp, defaultTensorValue) {

  @transient
  private var shouldUpdateParameter = false

  override protected def evaluateInternal(): Unit = {
    // do nothing
  }


  override def beforeRunGradient(): Unit = {

    if (!weightsRestored) {
      Utils.timeIt("setTrainingVariableIntoTF") {
        setVariableIntoTF(weights, variableAssignPlaceholders,
          variableTypes.map(TFUtils.tfenum2datatype), assignVariableOp)
      }
      weightsRestored = true
    }

    if (shouldUpdateParameter) {
      graphRunner.runTargets(targets = Vector(trainOp),
        inputs = weights.toVector, inputNames = gradVariables.toVector,
        inputTypes = Vector.fill(gradVariables.length)(DataType.FLOAT))
      shouldUpdateParameter = false
    }


    if (!extraParameterRestored) {
      setVariableIntoTF(extraParameters, extraVariableAssignPlaceholders,
        extraVariableTypes.map(TFUtils.tfenum2datatype), assignExtraVariableOP)
      extraParameterRestored = true
    }
  }

  override def afterRunGradient(): Unit = {
    super.afterRunGradient()
    if (this.isTraining()) shouldUpdateParameter = true
  }

  def moveWeightsOutOfTF(): Unit = {
    if (!weightsRestored) {
      return
    }
    if (shouldUpdateParameter) {
      graphRunner.runTargets(targets = Vector(trainOp),
        inputs = weights.toVector, inputNames = gradVariables.toVector,
        inputTypes = Vector.fill(gradVariables.length)(DataType.FLOAT))
      shouldUpdateParameter = false
    }
    getVariableFromTF(weights, variableNames = variables)
    if (extraParameters.length > 0) {
      Utils.timeIt("getExtraVariableFromTF") {
        getVariableFromTF(extraParameters, variableNames = extraVariables)
      }
    }
  }

}






© 2015 - 2025 Weber Informatics LLC | Privacy Policy