Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.updater;
import lombok.Getter;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Trainable;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.learning.config.IUpdater;
import java.util.*;
@Getter
public abstract class BaseMultiLayerUpdater implements Updater {
protected final T network;
protected Map layersByName;
protected final List updaterBlocks;
protected INDArray updaterStateViewArray;
protected boolean initializedMinibatchDivision;
protected List gradientsForMinibatchDivision;
public BaseMultiLayerUpdater(T network) {
this(network, null);
}
/**
*
* @param network Network to create the updater for
* @param updaterState The updater state to use. Note: This array is used *directly* and isn't copied/cloned
*/
public BaseMultiLayerUpdater(T network, INDArray updaterState) {
this.network = network;
Trainable[] layers = getOrderedLayers(); //May also include vertices
int updaterStateSize = 0;
//Iterate through layers, and variables for each layer.
//While the updater configuration is the same: combine into one op, rather than doing a lot of smaller
// (yet identical) ops.
Trainable lastLayer = null;
String lastVariable = null;
UpdaterBlock currentBlock = null;
updaterBlocks = new ArrayList<>();
INDArray paramsView = network.params();
INDArray gradientView = getFlattenedGradientsView();
int paramsViewSoFar = 0;
int currentUpdaterOffset = 0;
for (int i = 0; i < layers.length; i++) {
Map layerParamTable = layers[i].paramTable(false);
if (layerParamTable != null) {
List variables = new ArrayList<>(layerParamTable.keySet()); //Is from a set, but iteration order should be fixed per layer as it's a from a LinkedHashSet
for (int j = 0; j < variables.size(); j++) {
String var = variables.get(j);
long paramSizeThisVariable = layerParamTable.get(var).length();
IUpdater u = layers[i].getConfig().getUpdaterByParam(var);
Preconditions.checkNotNull(u, "Updater for parameter %s, layer \"%s\" was null", var, layers[i].getConfig().getLayerName());
int updaterStateSizeThisVariable = (int) u.stateSize(paramSizeThisVariable);
INDArray gradientViewSubset = null;
INDArray paramsViewSubset = null;
INDArray paramsViewReshape = paramsView.reshape(paramsView.length());
INDArray gradientViewReshape = gradientView.reshape(gradientView.length());
if (paramSizeThisVariable > 0) {
paramsViewSubset = paramsViewReshape.get(NDArrayIndex.interval(paramsViewSoFar,
paramsViewSoFar + paramSizeThisVariable));
gradientViewSubset = gradientViewReshape.get( NDArrayIndex
.interval(paramsViewSoFar, paramsViewSoFar + paramSizeThisVariable));
}
//First: decide whether to add to the existing updater block, or create a new one
if (currentBlock == null || !UpdaterUtils.updaterConfigurationsEquals(lastLayer, lastVariable,
layers[i], var)) {
if (paramsViewSoFar + paramSizeThisVariable > Integer.MAX_VALUE || paramsViewSoFar + paramSizeThisVariable > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
//Create a new block
List list = new ArrayList<>();
list.add(new UpdaterBlock.ParamState(layers[i], var, paramsViewSoFar,
(int) (paramsViewSoFar + paramSizeThisVariable), paramsViewSubset, gradientViewSubset));
currentBlock = new UpdaterBlock(paramsViewSoFar, (int) (paramsViewSoFar + paramSizeThisVariable),
currentUpdaterOffset, currentUpdaterOffset + updaterStateSizeThisVariable,
list);
updaterBlocks.add(currentBlock);
} else {
long newOffset = currentBlock.getParamOffsetEnd() + paramSizeThisVariable;
if (newOffset > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
//Add to existing updater block
currentBlock.setParamOffsetEnd((int) newOffset);
currentBlock.setUpdaterViewOffsetEnd(
currentBlock.getUpdaterViewOffsetEnd() + updaterStateSizeThisVariable);
currentBlock.getLayersAndVariablesInBlock()
.add(new UpdaterBlock.ParamState(layers[i], var, paramsViewSoFar,
(int) (paramsViewSoFar + paramSizeThisVariable), paramsViewSubset,
gradientViewSubset));
}
lastLayer = layers[i];
lastVariable = variables.get(j);
updaterStateSize += updaterStateSizeThisVariable;
paramsViewSoFar += paramSizeThisVariable;
currentUpdaterOffset += updaterStateSizeThisVariable;
}
}
}
//Initialize the updater state, if required
boolean updaterRequiresInit = false;
if (updaterState != null) {
updaterStateViewArray = updaterState;
updaterRequiresInit = false;
} else if (updaterStateSize > 0) {
//May be 0 if all SGD or NONE updaters, for example
updaterStateViewArray = Nd4j.createUninitialized(network.params().dataType(), new long[] { updaterStateSize}, Nd4j.order());
updaterRequiresInit = true;
}
//Create and set up the updaters, for the updater blocks:
int updaterViewSoFar = 0;
paramsViewSoFar = 0;
for (int i = 0; i < updaterBlocks.size(); i++) {
UpdaterBlock ub = updaterBlocks.get(i);
int viewStateSize = ub.getUpdaterViewOffsetEnd() - ub.getUpdaterViewOffsetStart();
int gradSize = ub.getParamOffsetEnd() - ub.getParamOffsetStart();
if (viewStateSize > 0) {
INDArray updaterViewSubset = updaterStateViewArray.get(
NDArrayIndex.interval(updaterViewSoFar, updaterViewSoFar + viewStateSize));
ub.setUpdaterView(updaterViewSubset);
ub.setUpdaterViewRequiresInitialization(updaterRequiresInit);
}
if (gradSize > 0) {
INDArray gradientViewSubset = gradientView.reshape(gradientView.length()).get(
NDArrayIndex.interval(paramsViewSoFar, paramsViewSoFar + gradSize));
ub.setGradientView(gradientViewSubset);
}
ub.init();
updaterViewSoFar += viewStateSize;
paramsViewSoFar += gradSize;
}
}
/**
*
* @return Array of layers, in the correct order (i.e., same order as the parameter/gradient/updater flattening
* order - input to output for MultiLayerNetwork, or topological order for ComputationGraph)
*/
protected abstract Trainable[] getOrderedLayers();
/**
* @return The flattened gradient view array for the model
*/
public abstract INDArray getFlattenedGradientsView();
/**
* @return The flattened parameter array for the model
*/
protected abstract INDArray getParams();
/**
* @return True if the configuration for the model is set to minibatch (divide by minibatch size), false otherwise
*/
protected abstract boolean isMiniBatch();
/**
* Set the view array. Note that this does an assign operation - the provided array is not stored internally.
*
* @param viewArray The new updater state
*/
public void setStateViewArray(INDArray viewArray) {
if(this.updaterStateViewArray == null){
if(viewArray == null)
return; //No op - for example, SGD and NoOp updater - i.e., no stored state
else {
throw new IllegalStateException("Attempting to set updater state view array with null value");
}
}
if (this.updaterStateViewArray.length() != viewArray.length())
throw new IllegalStateException("Invalid input: view arrays differ in length. " + "Expected length "
+ this.updaterStateViewArray.length() + ", got length " + viewArray.length());
this.updaterStateViewArray.assign(viewArray);
}
@Override
public void setStateViewArray(Trainable layer, INDArray viewArray, boolean initialize) {
this.setStateViewArray(viewArray);
}
@Override
public INDArray getStateViewArray() {
return updaterStateViewArray;
}
/**
* A synchronized version of {@link #getStateViewArray()} that duplicates the view array internally.
* This should be used in preference to {@link #getStateViewArray()} when the updater state is accessed in one
* thread while another thread is using the updater for training.
* @return A copy (duplicate) of the updater state
*/
public synchronized INDArray getStateViewArrayCopy(){
Nd4j.getExecutioner().commit();
return updaterStateViewArray.dup();
}
@Override
public void update(Trainable layer, Gradient gradient, int iteration, int epoch, int batchSize, LayerWorkspaceMgr workspaceMgr) {
update(gradient, iteration, epoch, batchSize, workspaceMgr);
}
/**
* Update the gradient for the model.
* This operates in 3 steps:
* 1. Pre-apply: gradient clipping, etc on a per-layer basis
* 2. Execute the updater (Adam, Nesterov momentum, etc) - in blocks of layers at a time
* 3. Divide by minibatch size
*
* @param gradient Gradient to updater
* @param iteration The current iteration (i.e., number of parameter updates so far)
* @param batchSize The current minibatch size (number of examples)
*/
public synchronized void update(Gradient gradient, int iteration, int epoch, int batchSize, LayerWorkspaceMgr workspaceMgr) {
//First: check if gradient is standard or external...
//In a MultiLayerNetwork, the INDArray returned by .gradient() is always the standard full view array
// hence should be the same object under normal circumstances
boolean isExternal = gradient.gradient() != getFlattenedGradientsView();
//Split up the gradients on a per-layer basis, for pre-apply
Map layerGradients = new HashMap<>();
Trainable[] layers = getOrderedLayers();
if (layers.length == 1 && isSingleLayerUpdater()) {
layerGradients.put(layers[0].getConfig().getLayerName(), gradient);
} else {
for (Map.Entry gradientPair : gradient.gradientForVariable().entrySet()) {
String key = gradientPair.getKey();
int idx = key.lastIndexOf('_');
if (idx == -1)
throw new IllegalStateException(
"Invalid key: Gradient key does not have layer separator: \"" + key + "\"");
String layerName = key.substring(0, idx);
Gradient g = layerGradients.get(layerName);
if (g == null) {
g = new DefaultGradient();
layerGradients.put(layerName, g);
}
String newKey = key.substring(idx + 1);
g.setGradientFor(newKey, gradientPair.getValue());
}
}
if(isMiniBatch()){
divideByMinibatch(isExternal, gradient, batchSize);
}
//PRE apply (gradient clipping, etc): done on a per-layer basis
for (Map.Entry entry : layerGradients.entrySet()) {
String layerName = entry.getKey();
Trainable layer = layersByName.get(layerName);
preApply(layer, layerGradients.get(layerName), iteration);
}
//Apply the updaters in blocks. This also applies LR and momentum schedules, L1 and L2
if(getClass() != LayerUpdater.class){
//OK for LayerUpdater as this is part of layerwise pretraining
workspaceMgr.assertNotOpen(ArrayType.UPDATER_WORKING_MEM, "Updater working memory");
}
for (UpdaterBlock ub : updaterBlocks) {
if (ub.skipDueToPretrainConfig(this instanceof LayerUpdater)) {
//Should skip some updater blocks sometimes
//For example, VAE decoder params while doing supervised backprop
continue;
}
try(MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.UPDATER_WORKING_MEM)){
if (isExternal) {
//RL4J etc type case: calculate gradients in 1 net, update them in another
ub.updateExternalGradient(iteration, epoch, gradient.gradient(), getParams());
} else {
//Standard case
ub.update(iteration, epoch);
}
}
}
}
protected void divideByMinibatch(boolean isExternal, Gradient gradient, int batchSize){
//Challenge here: most gradients are actual gradients, and should be divided by the minibatch to get the average
//However, some 'gradients' are actually updates - an example being BatchNorm mean/variance estimates... these
// shouldn't be modified
if(!initializedMinibatchDivision){
gradientsForMinibatchDivision = getMinibatchDivisionSubsets(getFlattenedGradientsView());
initializedMinibatchDivision = true;
}
List toDivide;
if(isExternal){
toDivide = getMinibatchDivisionSubsets(gradient.gradient());
} else {
toDivide = gradientsForMinibatchDivision;
}
for(INDArray arr : toDivide){
arr.divi(batchSize);
}
}
protected List getMinibatchDivisionSubsets(INDArray from){
from = from.reshape(from.length());
List out = new ArrayList<>();
long paramsSoFar = 0;
long currentStart = 0;
long currentEnd = 0;
for(Trainable t : getOrderedLayers()){
Set layerParams = t.paramTable(false).keySet();
Map paramTable = t.paramTable(false);
for(String s : layerParams) {
if(t.updaterDivideByMinibatch(s)){
long l = paramTable.get(s).length();
currentEnd += l;
} else {
//This param/gradient subset should be excluded
if(currentEnd > currentStart) {
INDArray subset = from.get( NDArrayIndex.interval(currentStart, currentEnd));
out.add(subset);
}
currentStart = paramsSoFar + paramTable.get(s).length();
currentEnd = currentStart;
}
paramsSoFar += paramTable.get(s).length();
}
}
if(currentEnd > currentStart && currentStart < from.length()){
//Process last part of the gradient view array
INDArray subset = from.reshape(from.length()).get( NDArrayIndex.interval(currentStart, currentEnd));
out.add(subset);
}
return out;
}
protected boolean isSingleLayerUpdater() {
return false;
}
/**
* Pre-apply: Apply gradient normalization/clipping
*
* @param layer Layer to apply gradient normalization/clipping for
* @param gradient Gradient to update
* @param iteration The current iteration (i.e., number of parameter updates so far)
*/
public void preApply(Trainable layer, Gradient gradient, int iteration) {
if (layer.getConfig() == null || layer.numParams() == 0) {
//Layer does not have parameters -> no gradient
return;
}
GradientNormalization normalization = layer.getConfig().getGradientNormalization();
if (normalization == null || normalization == GradientNormalization.None)
return; //no op
final double threshold = layer.getConfig().getGradientNormalizationThreshold();
INDArray layerGradientView = layer.getGradientsViewArray();
switch (normalization) {
case RenormalizeL2PerLayer:
if (layerGradientView != null) {
double l2 = layerGradientView.norm2Number().doubleValue();
if (l2 == 0.0)
l2 = 1e-5; //Avoid 0/0 -> NaN
layerGradientView.divi(l2);
}
break;
case RenormalizeL2PerParamType:
for (INDArray g : gradient.gradientForVariable().values()) {
double l2 = Nd4j.getExecutioner().execAndReturn(new Norm2(g)).getFinalResult().doubleValue();
if (l2 == 0.0)
l2 = 1e-5; //Avoid 0/0 -> NaN
g.divi(l2);
}
break;
case ClipElementWiseAbsoluteValue:
if (layerGradientView != null) {
CustomOp op = DynamicCustomOp.builder("clipbyvalue")
.addInputs(layerGradientView)
.callInplace(true)
.addFloatingPointArguments(-threshold, threshold)
.build();
Nd4j.getExecutioner().exec(op);
}
break;
case ClipL2PerLayer:
if (layerGradientView != null) {
double layerL2 = layerGradientView.norm2Number().doubleValue();
if (layerL2 > threshold) {
double scalingFactor = threshold / layerL2; // g = g / l2 * threshold ->
layerGradientView.muli(scalingFactor);
}
}
break;
case ClipL2PerParamType:
for (INDArray g : gradient.gradientForVariable().values()) {
double l2 = g.norm2Number().doubleValue();
if (l2 > threshold) {
double scalingFactor = l2 / threshold;
g.divi(scalingFactor);
}
}
break;
default:
throw new RuntimeException(
"Unknown (or not implemented) gradient normalization strategy: " + normalization);
}
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
BaseMultiLayerUpdater> that = (BaseMultiLayerUpdater>) o;
return updaterStateViewArray != null ? updaterStateViewArray.equals(that.updaterStateViewArray)
: that.updaterStateViewArray == null;
}
@Override
public int hashCode() {
int result = layersByName != null ? layersByName.hashCode() : 0;
result = 31 * result + (updaterBlocks != null ? updaterBlocks.hashCode() : 0);
result = 31 * result + (updaterStateViewArray != null ? updaterStateViewArray.hashCode() : 0);
return result;
}
}