All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.samediff.internal.TrainingSession Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.autodiff.samediff.internal;

import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.primitives.AtomicDouble;

import java.util.*;

/**
 * TrainingSession extends InferenceSession, to add training-specific functionality:
* - Application of regularization (L1, L2, weight decay etc)
* - Inline updating of variables, using updater/optimizer (Adam, Nesterov, SGD, etc)
* - Calculation of regularization scores (Score for L1, L2, etc) * * @author Alex Black */ @Slf4j public class TrainingSession extends InferenceSession { protected TrainingConfig config; protected Map gradVarToVarMap; protected Map updaters; protected Map lossVarsToLossIdx; protected double[] currIterLoss; protected Map, AtomicDouble> currIterRegLoss; protected List listeners; public TrainingSession(SameDiff sameDiff) { super(sameDiff); } /** * Perform one iteration of training - i.e., do forward and backward passes, and update the parameters * * @param config Training configuration * @param placeholders Current placeholders * @param paramsToTrain Set of parameters that will be trained * @param updaters Current updater state * @param batch Current data/batch (mainly for listeners, should have already been converted to placeholders map) * @param lossVariables Loss variables (names) * @param listeners Listeners (if any) * @param at Current epoch, iteration, etc * @return The Loss at the current iteration */ public Loss trainingIteration(TrainingConfig config, Map placeholders, Set paramsToTrain, Map updaters, MultiDataSet batch, List lossVariables, List listeners, At at) { this.config = config; this.updaters = updaters; //Preprocess listeners, get the relevant ones if (listeners == null) { this.listeners = null; } else { List filtered = new ArrayList<>(); for (Listener l : listeners) { if (l.isActive(at.operation())) { filtered.add(l); } } this.listeners = filtered.isEmpty() ? null : filtered; } List requiredActivations = new ArrayList<>(); gradVarToVarMap = new HashMap<>(); //Key: gradient variable. Value: variable that the key is gradient for for (String s : paramsToTrain) { Preconditions.checkState(sameDiff.hasVariable(s), "SameDiff instance does not have a variable with name \"%s\"", s); SDVariable v = sameDiff.getVariable(s); Preconditions.checkState(v.getVariableType() == VariableType.VARIABLE, "Can only train VARIABLE type variable - \"%s\" has type %s", s, v.getVariableType()); SDVariable grad = sameDiff.getVariable(s).getGradient(); if (grad == null) { //In some cases, a variable won't actually impact the loss value, and hence won't have a gradient associated with it //For example: floatVar -> cast to integer -> cast to float -> sum -> loss //In this case, the gradient of floatVar isn't defined (due to no floating point connection to the loss) continue; } requiredActivations.add(grad.name()); gradVarToVarMap.put(grad.name(), s); } //Set up losses lossVarsToLossIdx = new LinkedHashMap<>(); List lossVars; currIterLoss = new double[lossVariables.size()]; currIterRegLoss = new HashMap<>(); for (int i = 0; i < lossVariables.size(); i++) { lossVarsToLossIdx.put(lossVariables.get(i), i); } //Do training iteration List outputVars = new ArrayList<>(gradVarToVarMap.keySet()); //TODO this should be empty, and grads calculated in requiredActivations Map m = output(outputVars, placeholders, batch, requiredActivations, listeners, at); double[] finalLoss = new double[currIterLoss.length + currIterRegLoss.size()]; System.arraycopy(currIterLoss, 0, finalLoss, 0, currIterLoss.length); if (currIterRegLoss.size() > 0) { lossVars = new ArrayList<>(lossVariables.size() + currIterRegLoss.size()); lossVars.addAll(lossVariables); int s = currIterRegLoss.size(); //Collect regularization losses for (Map.Entry, AtomicDouble> entry : currIterRegLoss.entrySet()) { lossVars.add(entry.getKey().getSimpleName()); finalLoss[s] = entry.getValue().get(); } } else { lossVars = lossVariables; } Loss loss = new Loss(lossVars, finalLoss); if (listeners != null) { for (Listener l : listeners) { if (l.isActive(Operation.TRAINING)) { l.iterationDone(sameDiff, at, batch, loss); } } } return loss; } @Override public INDArray[] getOutputs(SameDiffOp op, FrameIter outputFrameIter, Set opInputs, Set allIterInputs, Set constAndPhInputs, List listeners, At at, MultiDataSet batch, Set allReqVariables) { //Get outputs from InferenceSession INDArray[] out = super.getOutputs(op, outputFrameIter, opInputs, allIterInputs, constAndPhInputs, listeners, at, batch, allReqVariables); List outputs = op.getOutputsOfOp(); int outIdx = 0; for (String s : outputs) { //If this is a loss variable - record it if (lossVarsToLossIdx.containsKey(s)) { int lossIdx = lossVarsToLossIdx.get(s); INDArray arr = out[outIdx]; double l = arr.isScalar() ? arr.getDouble(0) : arr.sumNumber().doubleValue(); currIterLoss[lossIdx] += l; } //If this is a gradient variable - apply the updater and update the parameter array in-line if (gradVarToVarMap.containsKey(s)) { String varName = gradVarToVarMap.get(s); //log.info("Calculated gradient for variable \"{}\": (grad var name: \"{}\")", varName, s); Variable gradVar = sameDiff.getVariables().get(s); if (gradVar.getInputsForOp() != null && gradVar.getInputsForOp().isEmpty()) { //Should be rare, and we should handle this by tracking dependencies, and only update when safe // (i.e., dependency tracking) throw new IllegalStateException("Op depends on gradient variable: " + s + " for variable " + varName); } GradientUpdater u = updaters.get(varName); Preconditions.checkState(u != null, "No updater found for variable \"%s\"", varName); Variable var = sameDiff.getVariables().get(varName); INDArray gradArr = out[outIdx]; INDArray paramArr = var.getVariable().getArr(); //Pre-updater regularization (L1, L2) List r = config.getRegularization(); if (r != null && r.size() > 0) { double lr = config.getUpdater().hasLearningRate() ? config.getUpdater().getLearningRate(at.iteration(), at.epoch()) : 1.0; for (Regularization reg : r) { if (reg.applyStep() == Regularization.ApplyStep.BEFORE_UPDATER) { if (this.listeners != null) { double score = reg.score(paramArr, at.iteration(), at.epoch()); if (!currIterRegLoss.containsKey(reg.getClass())) { currIterRegLoss.put(reg.getClass(), new AtomicDouble()); } currIterRegLoss.get(reg.getClass()).addAndGet(score); } reg.apply(paramArr, gradArr, lr, at.iteration(), at.epoch()); } } } u.applyUpdater(gradArr, at.iteration(), at.epoch()); //Post-apply regularization (weight decay) if (r != null && r.size() > 0) { double lr = config.getUpdater().hasLearningRate() ? config.getUpdater().getLearningRate(at.iteration(), at.epoch()) : 1.0; for (Regularization reg : r) { if (reg.applyStep() == Regularization.ApplyStep.POST_UPDATER) { if (this.listeners != null) { double score = reg.score(paramArr, at.iteration(), at.epoch()); if (!currIterRegLoss.containsKey(reg.getClass())) { currIterRegLoss.put(reg.getClass(), new AtomicDouble()); } currIterRegLoss.get(reg.getClass()).addAndGet(score); } reg.apply(paramArr, gradArr, lr, at.iteration(), at.epoch()); } } } if (listeners != null) { for (Listener l : listeners) { if (l.isActive(at.operation())) l.preUpdate(sameDiff, at, var, gradArr); } } //Update: if (config.isMinimize()) { paramArr.subi(gradArr); } else { paramArr.addi(gradArr); } log.trace("Applied updater to gradient and updated variable: {}", varName); } outIdx++; } return out; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy