org.deeplearning4j.nn.updater.BaseMultiLayerUpdater Maven / Gradle / Ivy
package org.deeplearning4j.nn.updater;
import lombok.Getter;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.accum.Norm2;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* BaseMultiLayerUpdater - core functionality for applying updaters to MultiLayerNetwork and ComputationGraph.
*
* This implements updater combining: that is, for any layers (and variables) that:
* (a) have contiguous parameters/gradients in the view arrays, and
* (b) have identical updater configuration (including updater, LR, LR/momentum schedules, etc - different L1/L2 are OK,
* however)
* are combined into a single {@link org.nd4j.linalg.learning.GradientUpdater} operation, instead of having a set of
* smaller operations. A smaller number of larger operations improves performance, especially for GPUs.
*
* @author Alex Black
*/
@Getter
public abstract class BaseMultiLayerUpdater implements Updater {
protected final T network;
protected Map layersByName;
protected final List updaterBlocks;
protected INDArray updaterStateViewArray;
public BaseMultiLayerUpdater(T network) {
this(network, null);
}
/**
*
* @param network Network to create the updater for
* @param updaterState The updater state to use. Note: This array is used *directly* and isn't copied/cloned
*/
public BaseMultiLayerUpdater(T network, INDArray updaterState) {
this.network = network;
Layer[] layers = getOrderedLayers();
int updaterStateSize = 0;
//Iterate through layers, and variables for each layer.
//While the updater configuration is the same: combine into one op, rather than doing a lot of smaller
// (yet identical) ops.
Layer lastLayer = null;
String lastVariable = null;
UpdaterBlock currentBlock = null;
updaterBlocks = new ArrayList<>();
INDArray paramsView = network.params();
INDArray gradientView = getFlattenedGradientsView();
int paramsViewSoFar = 0;
int currentUpdaterOffset = 0;
for (int i = 0; i < layers.length; i++) {
Map layerParamTable = layers[i].paramTable();
if (layerParamTable != null) {
List variables = new ArrayList<>(layerParamTable.keySet()); //Is from a set, but iteration order should be fixed per layer as it's a from a LinkedHashSet
for (int j = 0; j < variables.size(); j++) {
String var = variables.get(j);
int paramSizeThisVariable = layerParamTable.get(var).length();
int updaterStateSizeThisVariable = (int) layers[i].conf().getLayer().getIUpdaterByParam(var)
.stateSize(paramSizeThisVariable);
INDArray gradientViewSubset = null;
INDArray paramsViewSubset = null;
if (paramSizeThisVariable > 0) {
paramsViewSubset = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(paramsViewSoFar,
paramsViewSoFar + paramSizeThisVariable));
gradientViewSubset = gradientView.get(NDArrayIndex.point(0), NDArrayIndex
.interval(paramsViewSoFar, paramsViewSoFar + paramSizeThisVariable));
}
//First: decide whether to add to the existing updater block, or create a new one
if (currentBlock == null || !UpdaterUtils.updaterConfigurationsEquals(lastLayer, lastVariable,
layers[i], var)) {
//Create a new block
List list = new ArrayList<>();
list.add(new UpdaterBlock.ParamState(layers[i], var, paramsViewSoFar,
paramsViewSoFar + paramSizeThisVariable, paramsViewSubset, gradientViewSubset));
currentBlock = new UpdaterBlock(paramsViewSoFar, paramsViewSoFar + paramSizeThisVariable,
currentUpdaterOffset, currentUpdaterOffset + updaterStateSizeThisVariable,
list);
updaterBlocks.add(currentBlock);
} else {
//Add to existing updater block
currentBlock.setParamOffsetEnd(currentBlock.getParamOffsetEnd() + paramSizeThisVariable);
currentBlock.setUpdaterViewOffsetEnd(
currentBlock.getUpdaterViewOffsetEnd() + updaterStateSizeThisVariable);
currentBlock.getLayersAndVariablesInBlock()
.add(new UpdaterBlock.ParamState(layers[i], var, paramsViewSoFar,
paramsViewSoFar + paramSizeThisVariable, paramsViewSubset,
gradientViewSubset));
}
lastLayer = layers[i];
lastVariable = variables.get(j);
updaterStateSize += updaterStateSizeThisVariable;
paramsViewSoFar += paramSizeThisVariable;
currentUpdaterOffset += updaterStateSizeThisVariable;
}
}
}
//Initialize the updater state, if required
boolean updaterRequiresInit = false;
if (updaterState != null) {
updaterStateViewArray = updaterState;
updaterRequiresInit = false;
} else if (updaterStateSize > 0) {
//May be 0 if all SGD or NONE updaters, for example
updaterStateViewArray = Nd4j.createUninitialized(new int[] {1, updaterStateSize}, Nd4j.order());
updaterRequiresInit = true;
}
//Create and set up the updaters, for the updater blocks:
int updaterViewSoFar = 0;
paramsViewSoFar = 0;
for (int i = 0; i < updaterBlocks.size(); i++) {
UpdaterBlock ub = updaterBlocks.get(i);
int viewStateSize = ub.getUpdaterViewOffsetEnd() - ub.getUpdaterViewOffsetStart();
int gradSize = ub.getParamOffsetEnd() - ub.getParamOffsetStart();
if (viewStateSize > 0) {
INDArray updaterViewSubset = updaterStateViewArray.get(NDArrayIndex.point(0),
NDArrayIndex.interval(updaterViewSoFar, updaterViewSoFar + viewStateSize));
ub.setUpdaterView(updaterViewSubset);
ub.setUpdaterViewRequiresInitialization(updaterRequiresInit);
}
if (gradSize > 0) {
INDArray gradientViewSubset = gradientView.get(NDArrayIndex.point(0),
NDArrayIndex.interval(paramsViewSoFar, paramsViewSoFar + gradSize));
ub.setGradientView(gradientViewSubset);
}
ub.init();
updaterViewSoFar += viewStateSize;
paramsViewSoFar += gradSize;
}
}
/**
*
* @return Array of layers, in the correct order (i.e., same order as the parameter/gradient/updater flattening
* order - input to output for MultiLayerNetwork, or topological order for ComputationGraph)
*/
protected abstract Layer[] getOrderedLayers();
/**
* @return The flattened gradient view array for the model
*/
protected abstract INDArray getFlattenedGradientsView();
/**
* @return The flattened parameter array for the model
*/
protected abstract INDArray getParams();
/**
* @return True if the configuration for the model is set to minibatch (divide by minibatch size), false otherwise
*/
protected abstract boolean isMiniBatch();
/**
* Set the view array. Note that this does an assign operation - the provided array is not stored internally.
*
* @param viewArray The new updater state
*/
public void setStateViewArray(INDArray viewArray) {
if (this.updaterStateViewArray.length() != viewArray.length())
throw new IllegalStateException("Invalid input: view arrays differ in length. " + "Expected length "
+ this.updaterStateViewArray.length() + ", got length " + viewArray.length());
this.updaterStateViewArray.assign(viewArray);
}
@Override
public void setStateViewArray(Layer layer, INDArray viewArray, boolean initialize) {
this.setStateViewArray(viewArray);
}
@Override
public INDArray getStateViewArray() {
return updaterStateViewArray;
}
@Override
public void update(Layer layer, Gradient gradient, int iteration, int batchSize) {
update(gradient, iteration, batchSize);
}
/**
* Update the gradient for the model.
* This operates in 3 steps:
* 1. Pre-apply: gradient clipping, etc on a per-layer basis
* 2. Execute the updater (Adam, Nesterov momentum, etc) - in blocks of layers at a time
* 3. Divide by minibatch size
*
* @param gradient Gradient to updater
* @param iteration The current iteration (i.e., number of parameter updates so far)
* @param batchSize The current minibatch size (number of examples)
*/
public void update(Gradient gradient, int iteration, int batchSize) {
//First: check if gradient is standard or external...
//In a MultiLayerNetwork, the INDArray returned by .gradient() is always the standard full view array
// hence should be the same object under normal circumstances
boolean isExternal = gradient.gradient() != getFlattenedGradientsView();
//Split up the gradients on a per-layer basis, for pre-apply
Map layerGradients = new HashMap<>();
Layer[] layers = getOrderedLayers();
if (layers.length == 1 && isSingleLayerUpdater()) {
layerGradients.put(layers[0].conf().getLayer().getLayerName(), gradient);
} else {
for (Map.Entry gradientPair : gradient.gradientForVariable().entrySet()) {
String key = gradientPair.getKey();
int idx = key.lastIndexOf('_');
if (idx == -1)
throw new IllegalStateException(
"Invalid key: Gradient key does not have layer separator: \"" + key + "\"");
String layerName = key.substring(0, idx);
Gradient g = layerGradients.get(layerName);
if (g == null) {
g = new DefaultGradient();
layerGradients.put(layerName, g);
}
String newKey = key.substring(idx + 1);
g.setGradientFor(newKey, gradientPair.getValue());
}
}
//PRE apply (gradient clipping, etc): done on a per-layer basis
for (Map.Entry entry : layerGradients.entrySet()) {
String layerName = entry.getKey();
Layer layer = layersByName.get(layerName);
preApply(layer, layerGradients.get(layerName), iteration);
}
//Apply the updaters in blocks. This also applies LR and momentum schedules, L1 and L2
//
for (UpdaterBlock ub : updaterBlocks) {
if (ub.skipDueToPretrainConfig()) {
//Should skip some updater blocks sometimes
//For example, VAE decoder params while doing supervised backprop
continue;
}
if (Nd4j.getWorkspaceManager().checkIfWorkspaceExists(ComputationGraph.workspaceFeedForward)) {
try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager()
.getAndActivateWorkspace(ComputationGraph.workspaceFeedForward)) {
if (isExternal) {
//RL4J etc type case: calculate gradients in 1 net, update them in another
ub.updateExternalGradient(iteration, gradient.gradient(), getParams());
} else {
//Standard case
ub.update(iteration);
}
}
} else {
if (isExternal) {
//RL4J etc type case: calculate gradients in 1 net, update them in another
ub.updateExternalGradient(iteration, gradient.gradient(), getParams());
} else {
//Standard case
ub.update(iteration);
}
}
}
//Divide by minibatch size if necessary
if (isMiniBatch()) {
//OK even with pretrain layers: their gradients will get modified during next backprop iteration
if (isExternal) {
gradient.gradient().divi(batchSize);
} else {
//Standard case
getFlattenedGradientsView().divi(batchSize);
}
}
}
protected boolean isSingleLayerUpdater() {
return false;
}
/**
* Pre-apply: Apply gradient normalization/clipping
*
* @param layer Layer to apply gradient normalization/clipping for
* @param gradient Gradient to update
* @param iteration The current iteration (i.e., number of parameter updates so far)
*/
public void preApply(Layer layer, Gradient gradient, int iteration) {
if (!(layer.conf().getLayer() instanceof BaseLayer)) {
//Layer does not have parameters -> no gradient
return;
}
BaseLayer bLayer = (BaseLayer) layer.conf().getLayer();
GradientNormalization normalization = bLayer.getGradientNormalization();
if (normalization == null || normalization == GradientNormalization.None || layer.conf().isPretrain())
return; //no op
final double threshold = bLayer.getGradientNormalizationThreshold();
INDArray layerGradientView = layer.getGradientsViewArray();
switch (normalization) {
case RenormalizeL2PerLayer:
if (layerGradientView != null) {
double l2 = layerGradientView.norm2Number().doubleValue();
layerGradientView.divi(l2);
}
break;
case RenormalizeL2PerParamType:
for (INDArray g : gradient.gradientForVariable().values()) {
double l2 = Nd4j.getExecutioner().execAndReturn(new Norm2(g)).getFinalResult().doubleValue();
g.divi(l2);
}
break;
case ClipElementWiseAbsoluteValue:
if (layerGradientView != null) {
BooleanIndexing.replaceWhere(layerGradientView, threshold, Conditions.greaterThan(threshold));
BooleanIndexing.replaceWhere(layerGradientView, -threshold, Conditions.lessThan(-threshold));
}
break;
case ClipL2PerLayer:
if (layerGradientView != null) {
double layerL2 = layerGradientView.norm2Number().doubleValue();
if (layerL2 > threshold) {
double scalingFactor = threshold / layerL2; // g = g / l2 * threshold ->
layerGradientView.muli(scalingFactor);
}
}
break;
case ClipL2PerParamType:
for (INDArray g : gradient.gradientForVariable().values()) {
double l2 = g.norm2Number().doubleValue();
if (l2 > threshold) {
double scalingFactor = l2 / threshold;
g.divi(scalingFactor);
}
}
break;
default:
throw new RuntimeException(
"Unknown (or not implemented) gradient normalization strategy: " + normalization);
}
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
BaseMultiLayerUpdater> that = (BaseMultiLayerUpdater>) o;
return updaterStateViewArray != null ? updaterStateViewArray.equals(that.updaterStateViewArray)
: that.updaterStateViewArray == null;
}
@Override
public int hashCode() {
int result = layersByName != null ? layersByName.hashCode() : 0;
result = 31 * result + (updaterBlocks != null ? updaterBlocks.hashCode() : 0);
result = 31 * result + (updaterStateViewArray != null ? updaterStateViewArray.hashCode() : 0);
return result;
}
}