org.deeplearning4j.nn.graph.ComputationGraph Maven / Gradle / Ivy
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.graph;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.util.*;
import org.nd4j.adapters.OutputAdapter;
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.api.*;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.util.ComputationGraphUtil;
import org.deeplearning4j.nn.graph.util.GraphIndices;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.graph.vertex.impl.FrozenVertex;
import org.deeplearning4j.nn.graph.vertex.impl.InputVertex;
import org.deeplearning4j.nn.graph.vertex.impl.LayerVertex;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.nd4j.common.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Triple;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.linalg.workspace.ND4JWorkspaceException;
import org.nd4j.linalg.workspace.WorkspaceUtils;
import org.nd4j.common.util.OneTimeLogger;
import org.nd4j.linalg.workspace.WorkspacesCloseable;
import java.io.*;
import java.util.*;
import java.util.concurrent.atomic.AtomicLong;
@Slf4j
public class ComputationGraph implements Serializable, Model, NeuralNetwork {
protected ComputationGraphConfiguration configuration;
protected boolean initCalled = false;
protected transient Solver solver; //Used to call optimizers during backprop
protected INDArray flattenedParams; //Params for all layers are a view/subset of this array
@Getter
protected transient INDArray flattenedGradients; //Gradients for all layers are a view/subset of this array
@Getter
@Setter
protected Gradient gradient;
protected double score;
@Setter
private boolean initDone = false;
@Getter
@Setter
protected boolean clearTbpttState = true; //Mainly for unit testing (should be enabled otherwise)
//Workspaces for CUDNN. Pass to LayerWorkspaceMgr for re-use in cudnn helpers
@Getter
protected transient Map helperWorkspaces = new HashMap<>();
private transient final AtomicLong occupiedBy = new AtomicLong(-1);
/**
* Workspace for working memory for a single layer: forward pass and backward pass
* Note that this is opened/closed once per op (activate/backpropGradient call)
*/
protected static final String WS_LAYER_WORKING_MEM = "WS_LAYER_WORKING_MEM";
/**
* Workspace for storing all layers' activations - used only to store activations (layer inputs) as part of backprop
* Not used for inference
*/
protected static final String WS_ALL_LAYERS_ACT = "WS_ALL_LAYERS_ACT";
/**
* Workspace for working memory in RNNs - opened and closed once per RNN time step
*/
protected static final String WS_RNN_LOOP_WORKING_MEM = "WS_RNN_LOOP_WORKING_MEM";
/**
* Workspace for output methods that use OutputAdapter
*/
protected static final String WS_OUTPUT_MEM = "WS_OUTPUT_MEM";
protected final WorkspaceConfiguration WS_LAYER_WORKING_MEM_CONFIG;
protected static final WorkspaceConfiguration WS_ALL_LAYERS_ACT_CONFIG = WorkspaceConfiguration.builder()
.initialSize(0)
.overallocationLimit(0.05)
.policyLearning(LearningPolicy.FIRST_LOOP)
.policyReset(ResetPolicy.BLOCK_LEFT)
.policySpill(SpillPolicy.REALLOCATE)
.policyAllocation(AllocationPolicy.OVERALLOCATE)
.build();
protected final WorkspaceConfiguration WS_LAYER_ACT_X_CONFIG;
protected static final WorkspaceConfiguration WS_RNN_LOOP_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder()
.initialSize(0).overallocationLimit(0.05).policyReset(ResetPolicy.BLOCK_LEFT)
.policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE)
.policyLearning(LearningPolicy.FIRST_LOOP).build();
protected transient ThreadLocal lastEtlTime = new ThreadLocal<>();
/**
* All GraphVertex objects in the network.
*/
protected GraphVertex[] vertices;
/**
* Map of vertices by name
*/
protected Map verticesMap;
/**
* Indexes of graph vertices, in topological order. The topological order defines the order in which forward pass
* (and hence also backward pass, which is the opposite to this) is conducted in the network.
*/
protected int[] topologicalOrder;
/**
* Topological sort and vertex index/name + name/index mapping
*/
protected GraphIndices graphIndices;
/**
* A list of layers. Each of these layers is present in a GraphVertex, but are here for easy reference.
* This array also defines the order in which the getLayer(int) method returns layers.
*/
protected Layer[] layers;
/**
* The number of input arrays to the network. Many networks only have 1 input; however, a ComputationGraph may
* have an arbitrary number (>=1) separate input arrays
*/
private int numInputArrays;
/**
* The number of output arrays to the network. Many networks only have 1 output; however, a ComputationGraph may
* have an arbitrary number (>=1) separate output arrays
*/
private int numOutputArrays;
//Current inputs, labels, input mask arrays and label mask arrays
private transient INDArray[] inputs;
private transient INDArray[] labels;
private transient INDArray[] inputMaskArrays;
private transient INDArray[] labelMaskArrays;
private transient int[] outputLayerIdxs;
private NeuralNetConfiguration defaultConfiguration;
private Collection trainingListeners = new ArrayList<>();
public ComputationGraph(ComputationGraphConfiguration configuration) {
this.configuration = configuration;
this.numInputArrays = configuration.getNetworkInputs().size();
this.numOutputArrays = configuration.getNetworkOutputs().size();
this.inputs = new INDArray[numInputArrays];
this.labels = new INDArray[numOutputArrays];
this.defaultConfiguration = configuration.getDefaultConfiguration();
//Working memory: should learn over course of: (a) full forward pass, and (b) full backward pass
//Working memory should be opened once per vertex, for each of forward and backward passes
int numWorkingMem = 2 * configuration.getVertices().size();
WS_LAYER_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder()
.initialSize(0)
.overallocationLimit(0.02)
.policyLearning(LearningPolicy.OVER_TIME)
.cyclesBeforeInitialization(numWorkingMem)
.policyReset(ResetPolicy.BLOCK_LEFT)
.policySpill(SpillPolicy.REALLOCATE)
.policyAllocation(AllocationPolicy.OVERALLOCATE)
.build();
//Activations memory: opened once per layer - for every second layer (preprocessors are within the loop).
//Technically we could set learning to numLayers / 2, but will set to numLayers for simplicity, and also to
// account for a backward pass
WS_LAYER_ACT_X_CONFIG = WorkspaceConfiguration.builder()
.initialSize(0)
.overallocationLimit(0.02)
.policyLearning(LearningPolicy.OVER_TIME)
.cyclesBeforeInitialization(configuration.getVertices().size())
.policyReset(ResetPolicy.BLOCK_LEFT)
.policySpill(SpillPolicy.REALLOCATE)
.policyAllocation(AllocationPolicy.OVERALLOCATE)
.build();
}
/**
* This method allows to set ETL field time, useful for performance tracking
*
* @param time
*/
public void setLastEtlTime(long time) {
lastEtlTime.set(time);
}
/**
* This method returns ETL time field value
*
* @return
*/
public long getLastEtlTime() {
Long time = lastEtlTime.get();
return time == null ? 0L : time;
}
/**
* This method sets specified CacheMode for all layers within network
*
* @param mode
*/
public void setCacheMode(CacheMode mode) {
if (mode == null)
mode = CacheMode.NONE;
for (Layer layer : layers) {
layer.setCacheMode(mode);
}
}
/**
* This method returns configuration of this ComputationGraph
*
* @return
*/
public ComputationGraphConfiguration getConfiguration() {
return configuration;
}
/**
* Returns the number of layers in the ComputationGraph
*/
public int getNumLayers() {
return (layers != null ? layers.length : 0);
}
/**
* Get the layer by the number of that layer, in range 0 to getNumLayers()-1
* NOTE: This is different from the internal GraphVertex index for the layer
*/
public Layer getLayer(int idx) {
return layers[idx];
}
/**
* Get all layers in the ComputationGraph
*/
public Layer[] getLayers() {
return layers;
}
/**
* Get a given layer by name.
*/
public Layer getLayer(String name) {
Preconditions.checkState(verticesMap.containsKey(name), "Layer with name %s does not exist in the network", name);
return verticesMap.get(name).getLayer();
}
/**
* Returns an array of all GraphVertex objects.
*/
public GraphVertex[] getVertices() {
return vertices;
}
/**
* Return a given GraphVertex by name, or null if no vertex with that name exists
*/
public GraphVertex getVertex(String name) {
return verticesMap.get(name);
}
/**
* The number of inputs to this network
*/
public int getNumInputArrays() {
return numInputArrays;
}
/**
* The number of output (arrays) for this network
*/
public int getNumOutputArrays() {
return numOutputArrays;
}
/**
* Set the specified input for the ComputationGraph
*/
public void setInput(int inputNum, INDArray input) {
if (inputs == null) {
//May be null after clear()
inputs = new INDArray[numInputArrays];
}
inputs[inputNum] = input;
}
/**
* Set all inputs for the ComputationGraph network
*/
public void setInputs(INDArray... inputs) {
if (inputs != null && inputs.length != this.numInputArrays) {
throw new IllegalArgumentException("Invalid input array: network has " + numInputArrays
+ " inputs, but array is of length " + inputs.length);
}
this.inputs = inputs;
}
/**
* Get the previously set input for the ComputationGraph
*/
public INDArray getInput(int inputNum) {
if (inputs == null)
return null;
return inputs[inputNum];
}
/**
* Get the previously set inputs for the ComputationGraph
*/
public INDArray[] getInputs() {
return inputs;
}
/**
* Get the previously set feature/input mask arrays for the ComputationGraph
*/
public INDArray[] getInputMaskArrays() {
return inputMaskArrays;
}
/**
* Get the previously set label/output mask arrays for the ComputationGraph
*/
public INDArray[] getLabelMaskArrays() {
return labelMaskArrays;
}
/**
* Set the specified label for the ComputationGraph
*/
public void setLabel(int labelNum, INDArray label) {
labels[labelNum] = label;
}
/**
* Set all labels for the ComputationGraph network
*/
public void setLabels(INDArray... labels) {
if (labels != null && labels.length != this.numOutputArrays) {
throw new IllegalArgumentException("Invalid output array: network has " + numOutputArrays
+ " outputs, but array is of length " + labels.length);
}
this.labels = labels;
}
/**
* This method allows you to specificy GradientsAccumulator instance to be used with this model
*
* PLEASE NOTE: Do not use this method unless you understand how to use GradientsAccumulator & updates sharing.
* PLEASE NOTE: Do not use this method on standalone model
*
* @param accumulator
*/
public void setGradientsAccumulator(GradientsAccumulator accumulator) {
if (!initCalled)
init();
solver.getOptimizer().setGradientsAccumulator(accumulator);
}
/**
* Initialize the ComputationGraph network
*/
public void init() {
init(null, false);
}
/**
* Initialize the ComputationGraph, optionally with an existing parameters array.
* If an existing parameters array is specified, it will be used (and the values will not be modified) in the network;
* if no parameters array is specified, parameters will be initialized randomly according to the network configuration.
*
* @param parameters Network parameter. May be null. If null: randomly initialize.
* @param cloneParametersArray Whether the parameter array (if any) should be cloned, or used directly
*/
public void init(INDArray parameters, boolean cloneParametersArray) {
if (initCalled)
return;
DataType netDtype = getConfiguration().getDataType();
if(parameters != null && parameters.dataType() != netDtype){
Preconditions.checkState(parameters.rank() == 2 && parameters.size(0) == 1, "Invalid parameters array: should be rank 2 with shape [1,numParams]. Got %ndShape", parameters);
if(cloneParametersArray){
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
parameters = parameters.castTo(netDtype);
}
} else {
throw new IllegalStateException("Error initializing network: Network datatype is set to " + netDtype
+ " but provided array has datatype " + parameters.dataType() + " with cloneParametersArray argument" +
" set to false. Cannot initialize net with specified datatype array if that array does not match network datatype");
}
}
if (configuration.getTrainingWorkspaceMode() == null)
configuration.setTrainingWorkspaceMode(WorkspaceMode.NONE);
if (configuration.getInferenceWorkspaceMode() == null)
configuration.setInferenceWorkspaceMode(WorkspaceMode.NONE);
if (configuration.getCacheMode() == null)
configuration.setCacheMode(CacheMode.NONE);
OneTimeLogger.info(log, "Starting ComputationGraph with WorkspaceModes set to [training: {}; inference: {}], cacheMode set to [{}]",
configuration.getTrainingWorkspaceMode(), configuration.getInferenceWorkspaceMode(), configuration.getCacheMode());
//First: build topological ordering, based on configuration. Used for forward pass, backprop and order of parameters/gradients
GraphIndices indices = calculateIndices();
topologicalOrder = indices.getTopologicalSortOrder();
//Initialization: create the GraphVertex objects, based on configuration structure
Map configVertexMap = configuration.getVertices();
//Names of all of the (data) inputs to the ComputationGraph
List networkInputNames = configuration.getNetworkInputs();
//Inputs for each layer and GraphNode:
Map> vertexInputs = configuration.getVertexInputs();
this.vertices = new GraphVertex[networkInputNames.size() + configuration.getVertices().size()];
//All names: inputs, layers and graph nodes (index to name map)
Map allNamesReverse = new HashMap<>();
//Create network input vertices:
int vertexNumber = 0;
for (String name : networkInputNames) {
GraphVertex gv = new InputVertex(this, name, vertexNumber, null, netDtype); //Output vertices: set later
allNamesReverse.put(name, vertexNumber);
vertices[vertexNumber++] = gv;
}
//Go through layers, and work out total number of parameters. Then allocate full parameters array
long numParams = 0;
long[] numParamsForVertex = new long[topologicalOrder.length];
int i = 0;
for (; i < configuration.getNetworkInputs().size(); i++) {
numParamsForVertex[i] = 0; //No parameters for input vertices
}
for(; i < topologicalOrder.length; i++) {
String name = indices.getIdxToName().get(i);
org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name);
n.setDataType(netDtype);
numParamsForVertex[i] = n.numParams(true);
if(numParamsForVertex[i] < 0)
throw new DL4JInvalidConfigException("Layer " + name + " had parameters < 0 " + numParamsForVertex[i]);
numParams += numParamsForVertex[i];
}
boolean initializeParams;
if (parameters != null) {
if (numParams > 0 && !parameters.isRowVectorOrScalar())
throw new IllegalArgumentException("Invalid parameters: should be a row vector");
if (parameters.length() != numParams)
throw new IllegalArgumentException("Invalid parameters: expected length " + numParams + ", got length "
+ parameters.length());
if (cloneParametersArray)
flattenedParams = parameters.dup();
else
flattenedParams = parameters;
initializeParams = false;
} else if(numParams > 0){
flattenedParams = Nd4j.create(netDtype, 1, numParams);
initializeParams = true;
} else {
flattenedParams = null;
initializeParams = false;
}
//Set RNG seed, for repeatability between initializations when set
if (initializeParams) {
Nd4j.getRandom().setSeed(conf().getSeed());
}
if(flattenedParams == null)
flattenedParams = Nd4j.zeros(DataType.FLOAT,0);
INDArray flattenedParamsReshape = flattenedParams.reshape(flattenedParams.length());
//Given the topological ordering: work out the subset of the parameters array used for each layer
// Then extract out for use when initializing the Layers
INDArray[] paramsViewForVertex = new INDArray[topologicalOrder.length];
long paramOffsetSoFar = 0;
i = 0;
for (int vertexIdx : topologicalOrder) {
long nParamsThisVertex = numParamsForVertex[vertexIdx];
if (nParamsThisVertex != 0) {
paramsViewForVertex[vertexIdx] = flattenedParamsReshape.get(
NDArrayIndex.interval(paramOffsetSoFar, paramOffsetSoFar + nParamsThisVertex));
}
i++;
paramOffsetSoFar += nParamsThisVertex;
}
int numLayers = 0;
List tempLayerList = new ArrayList<>();
defaultConfiguration.clearVariables();
List variables = defaultConfiguration.variables(false);
i = configuration.getNetworkInputs().size();
for(; i layerVariables = l.conf().variables();
if (layerVariables != null) {
for (String s : layerVariables) {
variables.add(gv.getVertexName() + "_" + s);
}
}
}
allNamesReverse.put(name, vertexNumber);
vertices[vertexNumber++] = gv;
}
layers = tempLayerList.toArray(new Layer[numLayers]);
//Create the lookup table, so we can find vertices easily by name
verticesMap = new HashMap<>();
for (GraphVertex gv : vertices) {
verticesMap.put(gv.getVertexName(), gv);
}
//Now: do another pass to set the input and output indices, for each vertex
// These indices are used during forward and backward passes
//To get output indices: need to essentially build the graph in reverse...
Map> verticesOutputTo = new HashMap<>(); //Key: vertex. Values: vertices that this node is an input for
for (GraphVertex gv : vertices) {
String vertexName = gv.getVertexName();
List vertexInputNames;
vertexInputNames = vertexInputs.get(vertexName);
if (vertexInputNames == null)
continue;
//Build reverse network structure:
for (String s : vertexInputNames) {
List list = verticesOutputTo.get(s);
if (list == null) {
list = new ArrayList<>();
verticesOutputTo.put(s, list);
}
list.add(vertexName); //Edge: s -> vertexName
}
}
for (GraphVertex gv : vertices) {
String vertexName = gv.getVertexName();
int vertexIndex = gv.getVertexIndex();
List vertexInputNames;
vertexInputNames = vertexInputs.get(vertexName);
if (vertexInputNames == null)
continue;
VertexIndices[] inputIndices = new VertexIndices[vertexInputNames.size()];
for (int j = 0; j < vertexInputNames.size(); j++) {
String inName = vertexInputNames.get(j);
int inputVertexIndex = allNamesReverse.get(inName);
//Here: we have x -> gv connection
//gv might have multiple inputs, not just x
//Need to know which input x is
int inputNumber = vertexInputs.get(vertexName).indexOf(inName);
if (inputNumber == -1)
throw new IllegalStateException("Could not find vertex " + vertexIndex + " in the list of inputs "
+ "for vertex " + gv.getVertexName() + "; error in graph structure?");
inputIndices[j] = new VertexIndices(inputVertexIndex, inputNumber);
}
gv.setInputVertices(inputIndices);
}
//Handle the outputs for this vertex
for (GraphVertex gv : vertices) {
String vertexName = gv.getVertexName();
List thisVertexOutputsTo = verticesOutputTo.get(vertexName);
if (thisVertexOutputsTo == null || thisVertexOutputsTo.isEmpty())
continue; //Output vertex
VertexIndices[] outputIndices = new VertexIndices[thisVertexOutputsTo.size()];
int j = 0;
for (String s : new HashSet<>(thisVertexOutputsTo)) {
//First, we have gv -> s
//Which input in s does gv connect to? s may in general have multiple inputs...
List nextVertexInputNames = vertexInputs.get(s);
for (int k = 0; k < nextVertexInputNames.size(); k++) {
if(vertexName.equals(nextVertexInputNames.get(k))){
int outputVertexIndex = allNamesReverse.get(s);
outputIndices[j++] = new VertexIndices(outputVertexIndex, k);
}
}
}
gv.setOutputVertices(outputIndices);
}
//Mark any output vertices as outputs:
for (String s : configuration.getNetworkOutputs()) {
GraphVertex gv = verticesMap.get(s);
gv.setOutputVertex(true);
}
// now we init solver & optimizer
if (solver == null) {
try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
solver.initOptimizer();
}
}
//Mark which layers can safely modify their input in-place. This is a performance optimization for
// dropout and similar operations.
// Safe when the input is: (a) it's not a graph input, and (b) isn't shared by any other layers/vertices
Map> seenAsInputTo = new HashMap<>();
for(Map.Entry> entry : configuration.getVertexInputs().entrySet()){
for(String s : entry.getValue() ){
if (!seenAsInputTo.containsKey(s)) {
seenAsInputTo.put(s, new ArrayList());
}
List seen = seenAsInputTo.get(s);
seen.add(entry.getKey());
}
}
for(Layer l : layers){
String layerName = l.conf().getLayer().getLayerName();
List inputs = configuration.getVertexInputs().get(layerName);
String in = inputs.get(0); //For now: layers should have exactly 1 input
if(configuration.getNetworkInputs().contains(in)){
//TODO When is it safe to NOT allow input modifucation? It's not always safe...
// For example dropout + iterating over List that is used for multiple epochs...
continue;
}
List seen = seenAsInputTo.get(in);
if(seen.size() == 1){
l.allowInputModification(true);
} else {
//For the count > 1 case, we can work out if it's the last one in the topological order... at which point,
// it should be safe to use
int thisIdx = indices.getNameToIdx().get(layerName);
int thisTopoPos = ArrayUtils.indexOf(indices.getTopologicalSortOrder(), thisIdx);
int maxTopoPosition = -1;
for(String s : seen){
int idx = indices.getNameToIdx().get(s);
int topoPos = ArrayUtils.indexOf(indices.getTopologicalSortOrder(), idx);
maxTopoPosition = Math.max(maxTopoPosition, topoPos);
}
if(thisTopoPos == maxTopoPosition){
//Last one in the topo sort... all other layers have already consumed this input by the time this layer's
// forward pass is done
l.allowInputModification(true);
} //Otherwise: keep default of false
}
}
synchronizeIterEpochCounts();
initCalled = true;
}
/**
* This method: initializes the flattened gradients array (used in backprop) and sets the appropriate subset in all layers.
* As a general rule, this shouldn't ever need to be called manually when doing training via fit(DataSet), fit(DataSetIterator)
* or fit(MultiDataSet) methods
*/
public void initGradientsView() {
try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
if (!initCalled)
init();
GraphIndices indices = calculateIndices();
//Go through layers, and work out total number of parameters. Then allocate full parameters array
long numParams = 0;
long[] numParamsForVertex = new long[topologicalOrder.length];
int i = 0;
for (; i < configuration.getNetworkInputs().size(); i++) {
numParamsForVertex[i] = 0; //No parameters for input vertices
}
Map configVertexMap = configuration.getVertices();
for (; i < topologicalOrder.length; i++) {
String name = indices.getIdxToName().get(i);
org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name);
numParamsForVertex[i] = n.numParams(true);
numParams += numParamsForVertex[i];
}
if(numParams > 0) {
flattenedGradients = Nd4j.create(flattenedParams.dataType(), 1, numParams);
}
if(flattenedGradients == null)
flattenedGradients = Nd4j.zeros(DataType.FLOAT,0);
INDArray flattenedGradientsReshape = flattenedGradients.reshape(flattenedGradients.length());
//Given the topological ordering: work out the subset of the gradient array used for each layer, and set it
long paramOffsetSoFar = 0;
i = 0;
for (int vertexIdx : topologicalOrder) {
long nParamsThisVertex = numParamsForVertex[vertexIdx];
if (nParamsThisVertex != 0) {
INDArray gradientView = flattenedGradientsReshape.get(
NDArrayIndex.interval(paramOffsetSoFar, paramOffsetSoFar + nParamsThisVertex));
vertices[vertexIdx].setBackpropGradientsViewArray(gradientView);
}
i++;
paramOffsetSoFar += nParamsThisVertex;
}
}
}
protected int[] getOutputLayerIndices(){
if(outputLayerIdxs == null) {
outputLayerIdxs = new int[numOutputArrays];
int i = 0;
for (String s : configuration.getNetworkOutputs()) {
outputLayerIdxs[i++] = verticesMap.get(s).getVertexIndex();
}
}
return outputLayerIdxs;
}
/**
* Perform layerwise pretraining for one epoch - see {@link #pretrain(DataSetIterator, int)}
*/
public void pretrain(DataSetIterator iter) {
pretrain(iter, 1);
}
/**
* Pretrain network with a single input and single output. DataSetIterators can only be used if the number of input
* arrays for the ComputationGraph is 1.
* This method performs layerwise pretraining on all pre-trainable layers in the network (VAEs, Autoencoders, etc), for the specified
* number of epochs each. For example, if numEpochs=3, then layer 0 will be fit for 3 epochs, followed by layer 1
* for 3 epochs, and so on.
* For networks with more than one input use {@link #pretrain(MultiDataSetIterator)}
*/
public void pretrain(DataSetIterator iter, int numEpochs) {
if (numInputArrays != 1) {
throw new UnsupportedOperationException(
"Cannot train ComputationGraph network with multiple inputs using a DataSetIterator");
}
pretrain(ComputationGraphUtil.toMultiDataSetIterator(iter), numEpochs);
}
/**
* Pretrain network with multiple inputs and/or outputs
*/
public void pretrain(MultiDataSetIterator iter) {
pretrain(iter, 1);
}
/**
* Pretrain network with multiple inputs and/or outputs
* This method performs layerwise pretraining on all pre-trainable layers in the network (VAEs, Autoencoders, etc), for the specified
* number of epochs each. For example, if numEpochs=3, then layer 0 will be fit for 3 epochs, followed by layer 1
* for 3 epochs, and so on.
* Non-pretrainable layers are ignored
*
* @param iter Training data
* @param numEpochs Number of epochs to fit each layer with
* @see #pretrainLayer(String, MultiDataSetIterator)
*/
public void pretrain(MultiDataSetIterator iter, int numEpochs) {
try{
pretrainHelper(iter, numEpochs);
} catch (OutOfMemoryError e){
CrashReportingUtil.writeMemoryCrashDump(this, e);
throw e;
}
}
private void pretrainHelper(MultiDataSetIterator iter, int numEpochs){
if (flattenedGradients == null) {
initGradientsView();
}
//Assume here that all layers are pretrainable layers
for (int i = 0; i < topologicalOrder.length; i++) {
if (!vertices[i].hasLayer())
continue;
if (vertices[i].getLayer() instanceof IOutputLayer)
continue; //Don't pretrain output layer
if (!vertices[i].getLayer().isPretrainLayer())
continue; //Skip layers that aren't pretrainable
pretrainLayerHelper(vertices[i].getVertexName(), iter, numEpochs);
}
}
/**
* Pretrain a specified layer with the given DataSetIterator
*
* @param layerName Layer name
* @param dataSetIterator Data
*/
public void pretrainLayer(String layerName, DataSetIterator dataSetIterator) {
if (numInputArrays != 1) {
throw new UnsupportedOperationException(
"Cannot train ComputationGraph network with multiple inputs using a DataSetIterator");
}
pretrainLayer(layerName, ComputationGraphUtil.toMultiDataSetIterator(dataSetIterator));
}
/**
* Pretrain a specified layer with the given MultiDataSetIterator
*
* @param layerName Layer name
* @param iter Training data
*/
public void pretrainLayer(String layerName, MultiDataSetIterator iter) {
try{
pretrainLayerHelper(layerName, iter, 1);
} catch (OutOfMemoryError e){
CrashReportingUtil.writeMemoryCrashDump(this, e);
throw e;
}
}
private void pretrainLayerHelper(String layerName, MultiDataSetIterator iter, int numEpochs){
if (flattenedGradients == null) {
initGradientsView();
}
if (!verticesMap.containsKey(layerName)) {
throw new IllegalStateException("Invalid vertex name: " + layerName + " - all vertex names: " +
verticesMap.keySet());
}
if (!verticesMap.get(layerName).hasLayer()) {
//No op
return;
}
GraphVertex toTrain = verticesMap.get(layerName);
int idx = toTrain.getVertexIndex();
LayerWorkspaceMgr workspaceMgr;
if(configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE){
workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
} else {
workspaceMgr = LayerWorkspaceMgr.builder()
.with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
.with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
//Use FF/BP working memory for updater also
.with(ArrayType.UPDATER_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.build();
}
workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);
if(!iter.hasNext() && iter.resetSupported())
iter.reset();
MultiDataSetIterator withAsync = iter.asyncSupported() ? new AsyncMultiDataSetIterator(iter) : iter;
while(withAsync.hasNext()) {
MultiDataSet mds = withAsync.next();
try(MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS)) {
//FF - note should be using TEST mode here for the layers that feed into the specified layer
Map activations = ffToLayerActivationsInWS(false, idx, new int[]{idx}, FwdPassType.STANDARD,
false, mds.getFeatures(), mds.getFeaturesMaskArrays(), mds.getLabelsMaskArrays(), true);
//Get input to the current layer
VertexIndices[] inputsToLayer = toTrain.getInputVertices();
for (VertexIndices vi : inputsToLayer) {
String inName = vertices[vi.getVertexIndex()].getVertexName();
INDArray act = activations.get(inName);
toTrain.setInput(vi.getVertexEdgeNumber(), act, workspaceMgr);
}
Layer layer = toTrain.getLayer();
layer.fit(layer.input(), workspaceMgr);
}
}
}
/**
* Fit the ComputationGraph using a DataSet.
* Note that this method can only be used with ComputationGraphs with 1 input and 1 output.
* For networks with more than one input or output, use {@link #fit(MultiDataSetIterator)}
*/
public void fit(DataSet dataSet) {
if (numInputArrays != 1 || numOutputArrays != 1)
throw new UnsupportedOperationException("Cannot train ComputationGraph network with "
+ " multiple inputs or outputs using a DataSet");
boolean hasMaskArrays = dataSet.hasMaskArrays();
if (hasMaskArrays) {
INDArray[] fMask = (dataSet.getFeaturesMaskArray() != null ? new INDArray[]{dataSet.getFeaturesMaskArray()}
: null);
INDArray[] lMask = (dataSet.getLabelsMaskArray() != null ? new INDArray[]{dataSet.getLabelsMaskArray()}
: null);
fit(new INDArray[]{dataSet.getFeatures()}, new INDArray[]{dataSet.getLabels()}, fMask, lMask);
} else {
fit(new INDArray[]{dataSet.getFeatures()}, new INDArray[]{dataSet.getLabels()});
}
if (hasMaskArrays)
clearLayerMaskArrays();
clearLayersStates();
}
/**
* Perform minibatch training on all minibatches in the DataSetIterator, for the specified number of epochs.
* Equvalent to calling {@link #fit(DataSetIterator)} numEpochs times in a loop
*
* @param iterator Training data (DataSetIterator). Iterator must support resetting
* @param numEpochs Number of training epochs, >= 1
*/
public void fit(@NonNull DataSetIterator iterator, int numEpochs){
Preconditions.checkArgument(numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", numEpochs);
Preconditions.checkArgument(numEpochs == 1 || iterator.resetSupported(), "Cannot perform multiple epochs training using" +
"iterator thas does not support resetting (iterator.resetSupported() returned false)");
for(int i=0; i
* Note that this method can only be used with ComputationGraphs with 1 input and 1 output
* Method doesn't do layerwise pretraining.
* For pretraining use method pretrain.. {@link #pretrain(DataSetIterator)}
* @param iterator Training data (DataSetIterator)
*/
public void fit(@NonNull DataSetIterator iterator) {
fit(new MultiDataSetIteratorAdapter(iterator));
}
/**
* Fit the ComputationGraph using a MultiDataSet
*/
public void fit(MultiDataSet multiDataSet) {
fit(multiDataSet.getFeatures(), multiDataSet.getLabels(), multiDataSet.getFeaturesMaskArrays(),
multiDataSet.getLabelsMaskArrays());
if (multiDataSet.hasMaskArrays())
clearLayerMaskArrays();
}
/**
* Perform minibatch training on all minibatches in the MultiDataSetIterator, for the specified number of epochs.
* Equvalent to calling {@link #fit(MultiDataSetIterator)} numEpochs times in a loop
*
* @param iterator Training data (DataSetIterator). Iterator must support resetting
* @param numEpochs Number of training epochs, >= 1
*/
public void fit(@NonNull MultiDataSetIterator iterator, int numEpochs){
Preconditions.checkArgument(numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", numEpochs);
Preconditions.checkArgument(numEpochs == 1 || iterator.resetSupported(), "Cannot perform multiple epochs training using" +
"iterator thas does not support resetting (iterator.resetSupported() returned false)");
for(int i=0; i
* For pretraining use method pretrain.. {@link #pretrain(MultiDataSetIterator)}
* @param multi Training data (MultiDataSetIterator)
*/
public synchronized void fit(MultiDataSetIterator multi) {
if (flattenedGradients == null) {
initGradientsView();
}
if(!multi.hasNext() && multi.resetSupported()){
multi.reset();
}
for (TrainingListener tl : trainingListeners) {
tl.onEpochStart(this);
}
boolean destructable = false;
MultiDataSetIterator multiDataSetIterator;
if (multi.asyncSupported()) {
multiDataSetIterator = new AsyncMultiDataSetIterator(multi, Math.max(Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), true);
destructable = true;
} else
multiDataSetIterator = multi;
long time1 = System.currentTimeMillis();
while(multiDataSetIterator.hasNext()){
MultiDataSet mds = multiDataSetIterator.next();
long time2 = System.currentTimeMillis();
lastEtlTime.set((time2 - time1));
fit(mds.getFeatures(),mds.getLabels(), mds.getFeaturesMaskArrays(), mds.getLabelsMaskArrays());
time1 = System.currentTimeMillis();
}
if (destructable)
((AsyncMultiDataSetIterator) multiDataSetIterator).shutdown();
for (TrainingListener tl : trainingListeners) {
tl.onEpochEnd(this);
}
incrementEpochCount();
}
/**
* Fit the ComputationGraph given arrays of inputs and labels.
*
* @param inputs The network inptus
* @param labels The labels
*/
public void fit(INDArray[] inputs, INDArray[] labels) {
fit(inputs, labels, null, null);
}
/**
* Fit the ComputationGraph using the specified inputs and labels (and mask arrays)
*
* @param inputs The network inputs (features)
* @param labels The network labels
* @param featureMaskArrays Mask arrays for inputs/features. Typically used for RNN training. May be null.
* @param labelMaskArrays Mas arrays for the labels/outputs. Typically used for RNN training. May be null.
*/
public void fit(INDArray[] inputs, INDArray[] labels, INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
try{
fitHelper(inputs, labels, featureMaskArrays, labelMaskArrays);
} catch (OutOfMemoryError e){
CrashReportingUtil.writeMemoryCrashDump(this, e);
throw e;
}
}
private synchronized void fitHelper(INDArray[] inputs, INDArray[] labels, INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
if (numParams() == 0) {
return; //Edge case: net with no params: fitting is a no-op
}
if (flattenedGradients == null) {
initGradientsView();
}
setInputs(inputs);
setLabels(labels);
setLayerMaskArrays(featureMaskArrays, labelMaskArrays);
update(TaskUtils.buildTask(inputs, labels));
LayerWorkspaceMgr workspaceMgr;
if(configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE){
workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
} else {
workspaceMgr = LayerWorkspaceMgr.builder()
.with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
.with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
//Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or FF/BP_WORKING_MEM
// as these should be closed by the time updaters are executed
//Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this
.with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.build();
}
workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);
if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
doTruncatedBPTT(inputs, labels, featureMaskArrays, labelMaskArrays, workspaceMgr);
} else {
if (solver == null) {
try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
}
}
//TODO: cache workspace
solver.optimize(workspaceMgr);
}
if (featureMaskArrays != null || labelMaskArrays != null) {
clearLayerMaskArrays();
}
clearLayersStates();
synchronizeIterEpochCounts();
}
/**
* Calculate a topological sort order for the vertices in the graph.
* Note that this is used for
* (a) working out what order to do forward pass,
* (b) what order to do backprop (i.e., reverse of this)
* (c) order to flatten parameters (and gradients)
*
* Specifically, gradients/params/forward pass are executed on vertex[topologicalSortOrder[i]], for i=0..nVertices-1
*/
public int[] topologicalSortOrder() {
return calculateIndices().getTopologicalSortOrder();
}
/**
* Calculate the indices needed for the network:
* (a) topological sort order
* (b) Map: vertex index -> vertex name
* (c) Map: vertex name -> vertex index
*
* @return Calculated indices
*/
public GraphIndices calculateIndices(){
if(graphIndices != null)
return graphIndices;
//Get cached topological sort order from config, if present
if(configuration.getTopologicalOrder() != null && configuration.getTopologicalOrderStr() != null){
int[] t = configuration.getTopologicalOrder();
List s = configuration.getTopologicalOrderStr();
Map m1 = new HashMap<>();
Map m2 = new HashMap<>();
for( int i=0; i nodeMap = configuration.getVertices();
List networkInputNames = configuration.getNetworkInputs();
int numVertices = networkInputNames.size() + configuration.getVertices().size();
int[] out = new int[numVertices];
int outCounter = 0;
//First: represent the graph more usefully as a Map>, where map represents edges i -> j
// key represents j, set is set of i (inputs) for vertices j
Map vertexNamesMap = new HashMap<>();
Map vertexNamesMap2 = new HashMap<>();
int i = 0;
for (String inputName : configuration.getNetworkInputs()) {
vertexNamesMap.put(i, inputName);
vertexNamesMap2.put(inputName, i);
i++;
}
for (Map.Entry entry : nodeMap.entrySet()) {
String name = entry.getKey();
vertexNamesMap.put(i, name);
vertexNamesMap2.put(name, i);
i++;
}
Map> inputEdges = new HashMap<>(); //key: vertex. Values: vertices that the key vertex receives input from
Map> outputEdges = new HashMap<>(); //key: vertex. Values: vertices that the key vertex outputs to
for (String s : configuration.getNetworkInputs()) {
int idx = vertexNamesMap2.get(s);
inputEdges.put(idx, null);
}
for (Map.Entry entry : nodeMap.entrySet()) {
String thisVertexName = entry.getKey();
int idx = vertexNamesMap2.get(thisVertexName);
List inputsToThisVertex = configuration.getVertexInputs().get(thisVertexName);
if (inputsToThisVertex == null || inputsToThisVertex.isEmpty()) {
inputEdges.put(idx, null);
continue;
}
Set inputSet = new HashSet<>();
for (String s : inputsToThisVertex) {
Integer inputIdx = vertexNamesMap2.get(s);
inputSet.add(inputIdx);
Set outputSetForInputIdx = outputEdges.get(inputIdx);
if (outputSetForInputIdx == null) {
outputSetForInputIdx = new HashSet<>();
outputEdges.put(inputIdx, outputSetForInputIdx);
}
outputSetForInputIdx.add(idx); //input vertex outputs to the current vertex
}
inputEdges.put(idx, inputSet);
}
//Now: do topological sort
//Set of all nodes with no incoming edges: (this would be: input vertices)
LinkedList noIncomingEdges = new LinkedList<>();
for (Map.Entry> entry : inputEdges.entrySet()) {
Set inputsFrom = entry.getValue();
if (inputsFrom == null || inputsFrom.isEmpty()) {
noIncomingEdges.add(entry.getKey());
}
}
while (!noIncomingEdges.isEmpty()) {
int next = noIncomingEdges.removeFirst();
out[outCounter++] = next; //Add to sorted list
Set vertexOutputsTo = outputEdges.get(next);
//Remove edges next -> vertexOuputsTo[...] from graph;
if (vertexOutputsTo != null) {
for (Integer v : vertexOutputsTo) {
Set set = inputEdges.get(v);
set.remove(next);
if (set.isEmpty()) {
noIncomingEdges.add(v); //No remaining edges for vertex i -> add to list for processing
}
}
}
}
//If any edges remain in the graph: graph has cycles:
for (Map.Entry> entry : inputEdges.entrySet()) {
Set set = entry.getValue();
if (set == null)
continue;
if (!set.isEmpty())
throw new IllegalStateException(
"Invalid configuration: cycle detected in graph. Cannot calculate topological ordering with graph cycle ("
+ "cycle includes vertex \"" + vertexNamesMap.get(entry.getKey())
+ "\")");
}
//Store: the topological sort order in the configuraation... this is to ensure that when the config is
// deserialized, it has exactly the same topological sort order on all platforms
List s = new ArrayList<>(out.length);
for( int idx : out){
s.add(vertexNamesMap.get(idx));
}
configuration.setTopologicalOrder(out);
configuration.setTopologicalOrderStr(s);
graphIndices = GraphIndices.builder()
.topologicalSortOrder(out)
.nameToIdx(vertexNamesMap2)
.idxToName(vertexNamesMap)
.build();
return graphIndices;
}
@Override
public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr){
computeGradientAndScore();
}
public void computeGradientAndScore() {
synchronizeIterEpochCounts();
LayerWorkspaceMgr workspaceMgr;
if(configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE){
workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
} else {
workspaceMgr = LayerWorkspaceMgr.builder()
.with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
.with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
//Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or FF/BP_WORKING_MEM
// as these should be closed by the time updaters are executed
//Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this
.with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.build();
}
workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);
boolean tbptt = configuration.getBackpropType() == BackpropType.TruncatedBPTT;
FwdPassType fwdType = (tbptt ? FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE : FwdPassType.STANDARD);
synchronizeIterEpochCounts();
//Calculate activations (which are stored in each layer, and used in backprop)
try(MemoryWorkspace wsAllActivations = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS)) {
Map activations = ffToLayerActivationsInWS(true, -1, getOutputLayerIndices(),
fwdType, tbptt, inputs, inputMaskArrays, labelMaskArrays, false);
if (!trainingListeners.isEmpty()) {
try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
for (TrainingListener tl : trainingListeners) {
tl.onForwardPass(this, activations);
}
}
}
calcBackpropGradients(false,false);
workspaceMgr.assertCurrentWorkspace(ArrayType.ACTIVATIONS, null);
//Score: sum of the scores for the various output layers...
double r = calcRegularizationScore(true);
score = 0.0;
int outNum = 0;
for (String s : configuration.getNetworkOutputs()) {
GraphVertex gv = verticesMap.get(s);
if(gv instanceof LayerVertex) {
//At this point: the input to the output layer might not be set on the layer itself - just the vertex
LayerVertex lv = (LayerVertex) gv;
if(!lv.isSetLayerInput()) {
lv.applyPreprocessorAndSetInput(workspaceMgr);
}
}
Layer vertexLayer = gv.getLayer();
if (vertexLayer instanceof FrozenLayerWithBackprop) {
vertexLayer = ((FrozenLayerWithBackprop) vertexLayer).getInsideLayer();
}
vertexLayer.setMaskArray((labelMaskArrays == null) ? null : labelMaskArrays[outNum]);
try(MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) {
score += ((IOutputLayer) vertexLayer).computeScore(r, true, workspaceMgr);
}
//Only want to add l1/l2 component once...
r = 0.0;
outNum++;
}
//Listeners
if (!trainingListeners.isEmpty()) {
try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
for (TrainingListener tl : trainingListeners) {
tl.onBackwardPass(this);
}
}
}
}
for(GraphVertex gv : vertices){
gv.clear();
}
}
/**
* Conduct forward pass using a single input array. Note that this method can only be used with ComputationGraphs
* with a single input array.
*
* @param input The input array
* @param layerTillIndex the layer to feed forward to
* @param train If true: do forward pass at training time
* @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
*/
public Map feedForward(INDArray input, int layerTillIndex,boolean train) {
if (numInputArrays != 1)
throw new UnsupportedOperationException("Cannot feedForward with single input for graph network with "
+ numInputArrays + " expected inputs");
setInput(0, input);
return feedForward(train,layerTillIndex);
}
/**
* Conduct forward pass using an array of inputs. This overload allows the forward pass to be conducted, optionally
* (not) clearing the layer input arrays.
* Note: when using clearInputs=false, there can be some performance and memory overhead: this is because the arrays are
* defined outside of workspaces (which are enabled by default) - otherwise, old/invalidated arrays could still be
* accessed after calling this method. Consequently: Don't use clearInputs=false unless you have a use case that
* requires them to remain after feed-forward has been completed
*
* @param input An array of ComputationGraph inputs
* @param layerTillIndex the index of the layer to feed forward to
* @param train If true: do forward pass at training time; false: do forward pass at test time
* @param clearInputs If true (default for other methods): clear the inputs of all layers after doing forward
* pass. False don't clear layer inputs.
* @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
*/
public Map feedForward(INDArray[] input, int layerTillIndex,boolean train, boolean clearInputs) {
setInputs(input);
try {
return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, layerTillIndex, null,
input, inputMaskArrays, labelMaskArrays, clearInputs);
} catch (OutOfMemoryError e){
CrashReportingUtil.writeMemoryCrashDump(this, e);
throw e;
}
}
/**
* Conduct forward pass using an array of inputs
*
* @param input An array of ComputationGraph inputs
* @param layerTillIndex the index of the layer to feed forward to
* @param train If true: do forward pass at training time; false: do forward pass at test time
* @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
*/
public Map feedForward(INDArray[] input, int layerTillIndex,boolean train) {
setInputs(input);
return feedForward(train, layerTillIndex);
}
/**
* Conduct forward pass using the stored inputs
*
* @param train If true: do forward pass at training time; false: do forward pass at test time
* @param layerTillIndex the index of the layer to feed forward to
* @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
*/
public Map feedForward(boolean train,int layerTillIndex) {
int graphVertexIndexOfLayer = layers[layerTillIndex].getIndex();
try{
return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, graphVertexIndexOfLayer,
null, inputs, inputMaskArrays, labelMaskArrays, true);
} catch (OutOfMemoryError e){
CrashReportingUtil.writeMemoryCrashDump(this, e);
throw e;
}
}
/**
* Conduct forward pass using a single input array. Note that this method can only be used with ComputationGraphs
* with a single input array.
*
* @param input The input array
* @param train If true: do forward pass at training time
* @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
*/
public Map feedForward(INDArray input, boolean train) {
if (numInputArrays != 1)
throw new UnsupportedOperationException("Cannot feedForward with single input for graph network with "
+ numInputArrays + " expected inputs");
setInput(0, input);
return feedForward(train);
}
/**
* Conduct forward pass using an array of inputs
*
* @param input An array of ComputationGraph inputs
* @param train If true: do forward pass at training time; false: do forward pass at test time
* @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
*/
public Map feedForward(INDArray[] input, boolean train) {
return feedForward(input, train, true);
}
/**
* Conduct forward pass using an array of inputs. This overload allows the forward pass to be conducted, optionally
* (not) clearing the layer input arrays.
* Note: this method should NOT be used with clearInputs = true, unless you know what you are doing. Specifically:
* when using clearInputs=false, in combination with workspaces, the layer input fields may leak outside of the
* workspaces in which they were defined - potentially causing a crash. See
* https://deeplearning4j.konduit.ai/config/config-memory/config-workspaces
* for more details
*
* @param input An array of ComputationGraph inputs
* @param train If true: do forward pass at training time; false: do forward pass at test time
* @param clearInputs If true (default for other methods): clear the inputs of all layers after doing forward
* pass. False don't clear layer inputs.
* @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
*/
public Map feedForward(INDArray[] input, boolean train, boolean clearInputs){
setInputs(input);
try {
return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, vertices.length - 1,
null, input, inputMaskArrays, labelMaskArrays, clearInputs);
} catch (OutOfMemoryError e){
CrashReportingUtil.writeMemoryCrashDump(this, e);
throw e;
}
}
/**
* Conduct forward pass using the stored inputs, at test time
*
* @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
*/
public Map feedForward() {
return feedForward(false);
}
/**
* Conduct forward pass using the stored inputs
*
* @param train If true: do forward pass at training time; false: do forward pass at test time
* @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
*/
public Map feedForward(boolean train) {
try {
return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, vertices.length - 1,
null, inputs, inputMaskArrays, labelMaskArrays, true);
} catch (OutOfMemoryError e){
CrashReportingUtil.writeMemoryCrashDump(this, e);
throw e;
}
}
/**
* @param train True: training time. False: test time
* @param excludeOutputLayers Should we exclude the output layers during forward pass? (usually: false)
* @param includeNonLayerVertexActivations Include non-layer vertices in the output may?
* @return Map of activations. Key: vertex name. Value: activations.
*/
public Map feedForward(boolean train, boolean excludeOutputLayers,
boolean includeNonLayerVertexActivations) {
int[] exclude = null;
if(excludeOutputLayers){
exclude = getOutputLayerIndices();
}
Map m = ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false,
vertices.length-1, exclude, inputs, inputMaskArrays, labelMaskArrays, true);
if(includeNonLayerVertexActivations){
return m;
} else {
//Include only layers - in previous versions, we've always included inputs too for this method...
Map out = new HashMap<>();
for(Map.Entry e : m.entrySet()){
GraphVertex v = verticesMap.get(e.getKey());
if(v instanceof LayerVertex || v instanceof InputVertex){
out.put(e.getKey(), e.getValue());
}
}
return out;
}
}
/**
* Return an array of network outputs (predictions) at test time, given the specified network inputs
* Network outputs are for output layers only.
*
* @param input Inputs to the network
* @return Output activations (order: same as defined in network configuration)
*/
public INDArray[] output(INDArray... input) {
return output(false, input, inputMaskArrays, labelMaskArrays);
}
/**
* A convenience method that returns a single INDArray, instead of an INDArray[].
* Useful for ComputationGraphs that have only a single output.
* Otherwise identical to {@link #output(INDArray...)}
*
* @param input Inputs to the network
* @return Output activations array
*/
public INDArray outputSingle(INDArray... input) {
return outputSingle(false, input);
}
/**
* Return an array of network outputs (predictions), given the specified network inputs
* Network outputs are for output layers only.
*
* @param train If true: do forward pass at training time; false: do forward pass at test time
* @param input Inputs to the network
* @return Output activations (order: same as defined in network configuration)
*/
public INDArray[] output(boolean train, INDArray... input) {
return output(train, (MemoryWorkspace)null, input);
}
/**
* Return an array of network outputs (predictions), given the specified network inputs
* Network outputs are for output layers only.
* If no memory workspace is provided, the output will be detached (not in any workspace).
* If a memory workspace is provided, the output activation array (i.e., the INDArray returned by this method)
* will be placed in the specified workspace. This workspace must be opened by the user before calling this method -
* and the user is responsible for (a) closing this workspace, and (b) ensuring the output array is not used out
* of scope (i.e., not used after closing the workspace to which it belongs - as this is likely to cause either
* an exception when used, or a crash).
*
* @param train If true: do forward pass at training time; false: do forward pass at test time
* @param outputWorkspace May be null. If not null: the workspace MUST be opened before calling this method.
* @param input Inputs to the network
* @return Output activations (order: same as defined in network configuration)
*/
public INDArray[] output(boolean train, MemoryWorkspace outputWorkspace, INDArray... input) {
return output(train, input, inputMaskArrays, labelMaskArrays, outputWorkspace);
}
/**
* Return an array of network outputs (predictions), given the specified network inputs
* Network outputs are for output layers only.
*
* @param train If true: forward pass for training mode. False: test mode
* @param input Input arrays to the netwonk
* @param inputMasks Optional input mask arrays (may be null)
* @return Network output activations
*/
public INDArray[] output(boolean train, @NonNull INDArray[] input, INDArray[] inputMasks){
return output(train, input, inputMasks, null);
}
/**
* Return an array of network outputs (predictions), given the specified network inputs
* Network outputs are for output layers only.
*
* @param train If true: forward pass for training mode. False: test mode
* @param input Input arrays to the netwonk
* @param inputMasks Optional input mask arrays (may be null)
* @param labelMasks Optional label mask arrays (may be null
* @return Network output activations
*/
public INDArray[] output(boolean train, @NonNull INDArray[] input, INDArray[] inputMasks, INDArray[] labelMasks) {
return output(train, input, inputMasks, labelMasks, null);
}
/**
* This method uses provided OutputAdapter to return custom object built from INDArray
*
* PLEASE NOTE: This method uses dedicated Workspace for output generation to avoid redundant allocations
*
* @param inputs Input arrays to the netwonk
* @param inputMasks Optional input mask arrays (may be null)
* @param labelMasks Optional label mask arrays (may be null
* @param outputAdapter OutputAdapter instance
* @param T extends Object
* @return T instance produced by OutputAdapter
*/
public synchronized T output(@NonNull INDArray[] inputs, INDArray[] inputMasks, INDArray[] labelMasks, @NonNull OutputAdapter outputAdapter) {
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM)) {
if (outputAdapter instanceof ModelAdapter)
return ((ModelAdapter) outputAdapter).apply(this, inputs, inputMasks, labelMasks);
else
return outputAdapter.apply(output(false, inputs, inputMasks, labelMasks, ws));
}
}
/**
* Return an array of network outputs (predictions), given the specified network inputs
* Network outputs are for output layers only.
* If no memory workspace is provided, the output will be detached (not in any workspace).
* If a memory workspace is provided, the output activation array (i.e., the INDArray returned by this method)
* will be placed in the specified workspace. This workspace must be opened by the user before calling this method -
* and the user is responsible for (a) closing this workspace, and (b) ensuring the output array is not used out
* of scope (i.e., not used after closing the workspace to which it belongs - as this is likely to cause either
* an exception when used, or a crash).
*
* @param train If true: forward pass for training mode. False: test mode
* @param input Input arrays to the netwonk
* @param inputMasks Optional input mask arrays (may be null)
* @param labelMasks Optional label mask arrays (may be null
* @param outputWorkspace May be null. If not null: the workspace MUST be opened before calling this method.
* @return Network output activations
*/
public synchronized INDArray[] output(boolean train, @NonNull INDArray[] input, INDArray[] inputMasks, INDArray[] labelMasks, MemoryWorkspace outputWorkspace){
try {
setLayerMaskArrays(inputMasks, labelMasks);
INDArray[] out = outputOfLayersDetached(train, FwdPassType.STANDARD, getOutputLayerIndices(), input, inputMasks, labelMasks, true, false, outputWorkspace);
clearLayerMaskArrays();
clearLayersStates();
return out;
} catch (OutOfMemoryError e){
CrashReportingUtil.writeMemoryCrashDump(this, e);
throw e;
}
}
/**
* A convenience method that returns a single INDArray, instead of an INDArray[].
* Useful for ComputationGraphs that have only a single output.
* Otherwise identical to {@link #output(boolean, INDArray...)}
*
* @param train If true: do forward pass at training time; false: do forward pass at test time
* @param input Inputs to the network
* @return Output activations array
*/
public INDArray outputSingle(boolean train, INDArray... input) {
return outputSingle(train, true, input);
}
/**
* Identical to {@link #outputSingle(boolean, boolean, INDArray...)} but has the option of not clearing the input
* arrays (useful when later backpropagating external errors). Most users should use {@link #outputSingle(boolean, INDArray...)}
* in preference to this method.
*/
public INDArray outputSingle(boolean train, boolean clearInputs, INDArray... input){
if (numOutputArrays != 1) {
throw new IllegalStateException(
"Cannot use outputSingle with ComputationGraph that does not have exactly 1 output. nOutputs: "
+ numOutputArrays);
}
return output(train, clearInputs, input)[0];
}
/**
* An output method for the network, with optional clearing of the layer inputs.
* Note: most users should use {@link #output(boolean, INDArray...)} or similar methods, unless they are doing
* non-standard operations (like providing the input arrays externally)
*
* @param train If true: output during training. False: output during testing. Affects some things such as
* dropout
* @param clearInputs If true: clear the input arrays for all layers. False: leave the input arrays as-is - which
* can be useful for "external errors" (no output layer) backprop use cases
* @param input Input to the network
* @return Output from the network
*/
public synchronized INDArray[] output(boolean train, boolean clearInputs, INDArray... input){
boolean detachedInputs = !clearInputs; //If !clearInputs, then inputs should be detached (otherwise: will be out of scope)
try {
return outputOfLayersDetached(train, FwdPassType.STANDARD, getOutputLayerIndices(), input, null, null, clearInputs, detachedInputs, null);
} catch (OutOfMemoryError e){
CrashReportingUtil.writeMemoryCrashDump(this, e);
throw e;
}
}
/**
* Generate the output for all examples/batches in the input iterator, and concatenate them into a single array
* per network output
*
* @param iterator Data to pass through the network
* @return output for all examples in the iterator
*/
public INDArray[] output(DataSetIterator iterator){
return output(new MultiDataSetIteratorAdapter(iterator));
}
/**
* Generate the output for all examples/batches in the input iterator, and concatenate them into a single array
* per network output
*
* @param iterator Data to pass through the network
* @return output for all examples in the iterator
*/
public INDArray[] output(MultiDataSetIterator iterator){
List outputs = new ArrayList<>();
while(iterator.hasNext()){
MultiDataSet next = iterator.next();
INDArray[] out = output(false, next.getFeatures(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
outputs.add(out);
}
INDArray[][] arr = outputs.toArray(new INDArray[outputs.size()][0]);
return DataSetUtil.mergeFeatures(arr, null).getFirst();
}
/**
* Generate the output for all examples/batches in the input iterator, and concatenate them into a single array.
* Can only be used with ComputationGraphs with 1 output
*
* @param iterator Data to pass through the network
* @return output for all examples in the iterator
*/
public INDArray outputSingle(DataSetIterator iterator){
Preconditions.checkArgument(numOutputArrays == 1, "Cannot use this method with nets that have more" +
" than 1 output array. This network has %s outputs", numOutputArrays);
return output(iterator)[0];
}
/**
* Generate the output for all examples/batches in the input iterator, and concatenate them into a single array.
* Can only be used with ComputationGraphs with 1 output
*
* @param iterator Data to pass through the network
* @return output for all examples in the iterator
*/
public INDArray outputSingle(MultiDataSetIterator iterator){
Preconditions.checkArgument(numOutputArrays == 1, "Cannot use this method with nets that have more" +
" than 1 output array. This network has %s outputs", numOutputArrays);
return output(iterator)[0];
}
/**
* Get the activations for the specific layers only
* @param layers Layers to get the specified activations for
* @param train If true: train mode. False: test (inference) mode
* @param features Features array
* @param featureMasks Feature masks array. May be null
* @return Activations of the selected layers, in the same order as the "layers" arg/list
*/
public INDArray[] output(List layers, boolean train, INDArray[] features, INDArray[] featureMasks){
Preconditions.checkState(layers != null && layers.size() > 0, "Layers must not be null: got later names %s", layers);
int[] layerNums = new int[layers.size()];
for( int i=0; i ffToLayerActivationsDetached(boolean train, @NonNull FwdPassType fwdPassType, boolean storeLastForTBPTT,
int layerIndex, int[] excludeIdxs, @NonNull INDArray[] features,
INDArray[] fMask, INDArray[] lMask, boolean clearLayers){
if(layerIndex < 0 || layerIndex >= topologicalOrder.length){
throw new IllegalArgumentException("Invalid layer index - index must be >= 0 and < " + topologicalOrder.length
+ ", got index " + layerIndex);
}
setInputs(features);
setLayerMaskArrays(fMask, lMask);
//Verify that no workspace is open externally
WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active before call to ffToLayerActivationsDetached", true);
LayerWorkspaceMgr workspaceMgr;
WorkspaceMode wsm = (train ? configuration.getTrainingWorkspaceMode() : configuration.getInferenceWorkspaceMode());
if (wsm == WorkspaceMode.NONE) {
workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
} else {
workspaceMgr = LayerWorkspaceMgr.builder()
.noWorkspaceFor(ArrayType.ACTIVATIONS)
.with(ArrayType.INPUT, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
.build();
if(features[0].isAttached()){
//Don't leverage out of async DataMultiSetIterator workspaces
workspaceMgr.setNoLeverageOverride(features[0].data().getParentWorkspace().getId());
}
}
workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);
Map activations = new HashMap<>();
//Add the inputs:
for( int i=0; i temp = ((MultiLayerNetwork) l).rnnActivateUsingStoredState(
current.getInputs()[0], train, storeLastForTBPTT);
out = temp.get(temp.size() - 1);
} else {
//non-recurrent layer
out = current.doForward(train, workspaceMgr);
}
} else {
out = current.doForward(train, workspaceMgr);
}
} else {
throw new IllegalArgumentException("Unsupported forward pass type for this method: " + fwdPassType);
}
validateArrayWorkspaces(workspaceMgr, out, ArrayType.ACTIVATIONS, vName, false, "Feed forward (inference)");
}
activations.put(current.getVertexName(), out);
if(inputsTo != null) { //May be null for output vertices (which don't feed into any other vertices)
for (VertexIndices v : inputsTo) {
//Note that we don't have to do anything special here: the activations are always detached in
// this method
int inputToIndex = v.getVertexIndex();
int vIdxEdge = v.getVertexEdgeNumber();
vertices[inputToIndex].setInput(vIdxEdge, out, workspaceMgr);
}
}
if(clearLayers) {
current.clear();
}
}
if(traceLog){
log.trace("Completed forward pass: {} (\"{}\") - {}", i, vName, current.getClass().getSimpleName());
}
}
return activations;
}
/**
* Feed-forward through the network - if workspaces are used, all returned activations will be present in workspace
* WS_ALL_LAYERS_ACT.
* Note: if using workspaces for training, requires that WS_ALL_LAYERS_ACT is open externally.
* If using NO workspaces, requires that no external workspace is open
*
* @param train Training mode (true) or test/inference mode (false)
* @param layerIndex Index (inclusive) to stop forward pass at. For all layers, use -1
* @param excludeIdxs Layers (vertices) to exclude from forward pass. These layers will be skipped, and hence
* are usually output layers or at the end of the network. May be null.
* @param fwdPassType Type of forward pass to perform (STANDARD or RNN_ACTIVATE_WITH_STORED_STATE only)
* @param storeLastForTBPTT ONLY used if fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE
* @param input Input feature arrays
* @param fMask Feature mask arrays. May be null.
* @param lMask Label mask array. May be null.
* @param clearInputs Whether the layer inputs should be cleared
* @return Map of activations (including the input), in workspace WS_ALL_LAYERS_ACT if workspaces are used (detached
* otherwise)
*/
protected synchronized Map ffToLayerActivationsInWS(boolean train, int layerIndex, int[] excludeIdxs,
FwdPassType fwdPassType, boolean storeLastForTBPTT,
INDArray[] input, INDArray[] fMask, INDArray[] lMask, boolean clearInputs) {
if(layerIndex != -1 && (layerIndex < 0 || layerIndex >= topologicalOrder.length)){
throw new IllegalArgumentException("Invalid input index - index must be >= 0 and < " + topologicalOrder.length
+ ", got index " + layerIndex);
}
setInputs(input);
setLayerMaskArrays(fMask, lMask);
LayerWorkspaceMgr workspaceMgr;
WorkspaceMode wsm = (train ? configuration.getTrainingWorkspaceMode() : configuration.getInferenceWorkspaceMode());
if(wsm == WorkspaceMode.NONE){
//Verify that no workspace is open externally
WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active in ffToLayerActivationsDetached", true);
workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
} else {
WorkspaceUtils.assertOpenAndActive(WS_ALL_LAYERS_ACT, "ffToLayerActivationsInWs method requires workspace WS_ALL_LAYERS_ACT to be open");
workspaceMgr = LayerWorkspaceMgr.builder()
.with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
.build();
if(input[0].isAttached()){
//Don't leverage out of async DataMultiSetIterator workspaces
workspaceMgr.setNoLeverageOverride(input[0].data().getParentWorkspace().getId());
}
if(configuration.getCacheMode() != CacheMode.NONE){
//For now: store cache mode activations in activations workspace
workspaceMgr.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG);
}
}
workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);
boolean traceLog = log.isTraceEnabled();
Map activations = new HashMap<>();
//Do forward pass according to the topological ordering of the network
int stopIndex;
if (layerIndex > 0) {
stopIndex = ArrayUtils.indexOf(topologicalOrder, layerIndex);
} else {
stopIndex = topologicalOrder.length -1;
}
for (int i = 0; i <= stopIndex; i++) {
GraphVertex current = vertices[topologicalOrder[i]];
String vName = current.getVertexName();
int vIdx = current.getVertexIndex();
if(traceLog){
log.trace("About forward pass: {} (\"{}\") - {}", i, vName, current.getClass().getSimpleName());
}
if(excludeIdxs != null && ArrayUtils.contains(excludeIdxs, vIdx)){
continue;
}
try(MemoryWorkspace wsFFWorking = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)){
VertexIndices[] inputsTo = current.getOutputVertices();
INDArray out;
if(current.isInputVertex()){
out = inputs[vIdx];
} else {
if(fwdPassType == FwdPassType.STANDARD){
out = current.doForward(train, workspaceMgr);
} else if(fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE) {
if (current.hasLayer()) {
Layer l = current.getLayer();
if (l instanceof RecurrentLayer) {
out = ((RecurrentLayer) l).rnnActivateUsingStoredState(current.getInputs()[0], train,
storeLastForTBPTT, workspaceMgr);
} else if(l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer && ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying() instanceof RecurrentLayer) {
RecurrentLayer rl = (RecurrentLayer) ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying();
out = rl.rnnActivateUsingStoredState(current.getInputs()[0], train,storeLastForTBPTT, workspaceMgr);
} else if (l instanceof MultiLayerNetwork) {
List temp = ((MultiLayerNetwork) l).rnnActivateUsingStoredState(
current.getInputs()[0], train, storeLastForTBPTT);
out = temp.get(temp.size() - 1);
} else {
//non-recurrent layer
out = current.doForward(train, workspaceMgr);
}
} else {
out = current.doForward(train, workspaceMgr);
}
} else {
throw new IllegalStateException("FwdPassType not supported for this method: " + fwdPassType);
}
validateArrayWorkspaces(workspaceMgr, out, ArrayType.ACTIVATIONS, vName, false, "Feed forward (inference)");
}
activations.put(current.getVertexName(), out);
if(inputsTo != null) {
//Can be null for output layers
for (VertexIndices v : inputsTo) {
//Note that we don't have to do anything special here: the activations are always detached in
// this method
int inputToIndex = v.getVertexIndex();
int vIdxEdge = v.getVertexEdgeNumber();
vertices[inputToIndex].setInput(vIdxEdge, out, workspaceMgr);
}
}
if(clearInputs) {
current.clear();
}
}
if(traceLog){
log.trace("Completed forward pass: {} (\"{}\") - {}", i, vName, current.getClass().getSimpleName());
}
}
return activations;
}
/**
* Provide the output of the specified layers, detached from any workspace. This is most commonly used at inference/test
* time, and is more memory efficient than {@link #ffToLayerActivationsDetached(boolean, FwdPassType, boolean, int, int[], INDArray[], INDArray[], INDArray[], boolean)}
* and {@link #ffToLayerActivationsInWS(boolean, int, int[], FwdPassType, boolean, INDArray[], INDArray[], INDArray[], boolean)}.
* This method clears all layer inputs.
*
* NOTE: in general, no workspaces should be activated externally for this method!
* This method handles the workspace activation as required
*
* @param train Training mode (true) or test/inference mode (false)
* @param fwdPassType Type of forward pass to perform (STANDARD or RNN_TIMESTEP only)
* @param layerIndexes Indexes of the layers to get the activations for
* @param features Input features for the network
* @param fMask Input/feature mask array. May be null.
* @param lMasks Labels mask array. May be null
* @param clearLayerInputs If true: the layer input fields will be cleared
* @param detachedInputs If true: the layer input fields will be detached. Usually used for external errors cases
* @param outputWorkspace Optional - if provided, outputs should be placed in this workspace. NOTE: this workspace
* must be open
* @return Output of the specified layers, detached from any workspace
*/
protected INDArray[] outputOfLayersDetached(boolean train, @NonNull FwdPassType fwdPassType, @NonNull int[] layerIndexes, @NonNull INDArray[] features,
INDArray[] fMask, INDArray[] lMasks, boolean clearLayerInputs, boolean detachedInputs, MemoryWorkspace outputWorkspace){
if(features.length != numInputArrays){
throw new IllegalArgumentException("Invalid number of input arrays: network has " + numInputArrays
+ " inputs, got " + features.length + " input arrays");
}
for( int i = 0; i < layerIndexes.length; i++) {
if(layerIndexes[i] < 0 || layerIndexes[i] >= topologicalOrder.length) {
throw new IllegalArgumentException("Invalid input index - index must be >= 0 and < " + topologicalOrder.length
+ ", got index " + layerIndexes[i]);
}
}
setInputs(features);
setLayerMaskArrays(fMask, lMasks);
MemoryWorkspace outputPrevious = null;
if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) {
//Verify that no workspace is open externally
WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active before call to outputOfLayersDetached");
} else {
Preconditions.checkState(outputWorkspace.isScopeActive(), "Workspace \"" + outputWorkspace.getId() +
"\" was provided for the network/layer outputs. When provided, this workspace must be opened before " +
"calling the output method; furthermore, closing the workspace is the responsibility of the user");
outputPrevious = outputWorkspace.getParentWorkspace();
}
//First: for each vertex, determine the highest index of the vertex that consumes it's output
//Then: for each vertex, determine the forward pass step that each vertex's output has been fully consumed on
//In other words, if vertex X -> Y and X -> Z, and topological sort order is X,a,Y,b,Z,
//Then we know X's output activations have been fully consumed by step index 4 in the topological sort
//thus vertexOutputsFullyConsumedByStep[X.index] = IndexOf(topologicalSort, Z.index)
//Position in array: index of vertex. Value at position: the step (in topological order) that the activations
// have been consumed by
//Put another way: this is the step that it's safe to deallocate the layer's activations by closing the
// corresponding workspace
int[] vertexOutputsFullyConsumedByStep = new int[topologicalOrder.length];
for(GraphVertex gv : vertices) {
int idx = gv.getVertexIndex();
int maxStepOfOutputTo = -1;
VertexIndices[] outputsTo = gv.getOutputVertices();
if(outputsTo != null) {
//May be null for final/output layers
for (VertexIndices vi : outputsTo) {
int posInTopoSort = ArrayUtils.indexOf(topologicalOrder, vi.getVertexIndex());
if (posInTopoSort == -1) {
throw new IllegalStateException("Did not find vertex " + vi.getVertexIndex() + " in topological sort array");
}
maxStepOfOutputTo = Math.max(maxStepOfOutputTo, posInTopoSort);
}
} else {
maxStepOfOutputTo = topologicalOrder.length-1;
}
vertexOutputsFullyConsumedByStep[idx] = maxStepOfOutputTo;
}
//Do forward pass according to the topological ordering of the network
INDArray[] outputs = new INDArray[layerIndexes.length];
int stopIndex = -1;
for( int i = 0; i < layerIndexes.length; i++) {
stopIndex = Math.max(stopIndex, ArrayUtils.indexOf(topologicalOrder, layerIndexes[i]));
}
List allWorkspaceManagers = new ArrayList<>();
List freeWorkspaceManagers = new ArrayList<>(); //Basically used as a stack
Map openActivationsWorkspaces = new IdentityHashMap<>();
WorkspaceMode wsm = (train ? configuration.getTrainingWorkspaceMode() : configuration.getInferenceWorkspaceMode());
boolean noWS = wsm == WorkspaceMode.NONE;
LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null;
List[] closeAtEndIteraton = (List[])new List[topologicalOrder.length];
MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
Throwable t = null;
try {
for (int i = 0; i <= stopIndex; i++) {
GraphVertex current = vertices[topologicalOrder[i]];
GraphVertex prev = i > 0 ? vertices[topologicalOrder[i - 1]] : null;
String vName = current.getVertexName();
int vIdx = current.getVertexIndex();
//First: determine what workspace manager we should use for forward pass in this vertex
LayerWorkspaceMgr workspaceMgr;
if (noWS) {
workspaceMgr = allNone;
} else {
//First: is there a free forward pass workspace we can use?
if (freeWorkspaceManagers.size() > 0) {
workspaceMgr = freeWorkspaceManagers.remove(freeWorkspaceManagers.size() - 1);
} else {
//No existing free workspace managers for forward pass - create a new one...
String wsName = "WS_LAYER_ACT_" + allWorkspaceManagers.size();
workspaceMgr = LayerWorkspaceMgr.builder()
.with(ArrayType.INPUT, wsName, WS_LAYER_ACT_X_CONFIG)
.with(ArrayType.ACTIVATIONS, wsName, WS_LAYER_ACT_X_CONFIG)
.with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
.build();
if (detachedInputs) {
//Sometimes (like: external errors use cases) we don't want the activations/inputs to be
// in a workspace
workspaceMgr.setScopedOutFor(ArrayType.INPUT);
workspaceMgr.setScopedOutFor(ArrayType.ACTIVATIONS);
} else {
//Don't leverage out of async MultiDataSetIterator workspaces
if (features[0].isAttached()) {
workspaceMgr.setNoLeverageOverride(features[0].data().getParentWorkspace().getId());
}
}
allWorkspaceManagers.add(workspaceMgr);
}
}
workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);
//Is this one of the layers/vertices that we want the output for?
boolean isRequiredOutput = false;
String origWSAct = null;
WorkspaceConfiguration origWSActConf = null;
if (ArrayUtils.contains(layerIndexes, vIdx)) {
isRequiredOutput = true;
if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) {
//Place activations in user-specified workspace
origWSAct = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
origWSActConf = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
workspaceMgr.setWorkspace(ArrayType.ACTIVATIONS, outputWorkspace.getId(), outputWorkspace.getWorkspaceConfiguration());
} else {
//Standard case
if (!workspaceMgr.isScopedOut(ArrayType.ACTIVATIONS)) {
//Activations/output to return: don't want this in any workspace
origWSAct = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
origWSActConf = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
workspaceMgr.setScopedOutFor(ArrayType.ACTIVATIONS);
}
}
}
//Open the relevant workspace for the activations.
//Note that this will be closed only once the current vertex's activations have been consumed
MemoryWorkspace wsActivations = null;
if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || !isRequiredOutput) { //Open WS if (a) no external/output WS (if present, it's already open), or (b) not being placed in external/output WS
wsActivations = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS);
openActivationsWorkspaces.put(wsActivations, workspaceMgr);
}
//Note that because we're opening activation workspaces not in any defined order (i.e., workspace
// use isn't simply nested), we'll manually override the previous workspace setting. Otherwise, when we
// close these workspaces, the "current" workspace may be set to the incorrect one
if (wsActivations != null)
wsActivations.setPreviousWorkspace(initialWorkspace);
int closeableAt = vertexOutputsFullyConsumedByStep[vIdx];
if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || (wsActivations != null && !outputWorkspace.getId().equals(wsActivations.getId()))) {
if (closeAtEndIteraton[closeableAt] == null) {
closeAtEndIteraton[closeableAt] = new ArrayList<>();
}
closeAtEndIteraton[closeableAt].add(wsActivations);
}
try (MemoryWorkspace wsFFWorking = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) {
VertexIndices[] inputsTo = current.getOutputVertices();
INDArray out = null;
if (current.isInputVertex()) {
out = features[vIdx];
} else {
if (fwdPassType == FwdPassType.STANDARD) {
//Standard feed-forward case
if(i > 0 && current.hasLayer() && prev.hasLayer() &&
ConvolutionUtils.layerHasConvolutionLayout(prev.getLayer().conf().getLayer())
&& ConvolutionUtils.layerHasConvolutionLayout(current.getLayer().conf().getLayer())) {
/**
* Not QUITE the proper fix, but getting close.
* Able to detect this happens mid graph and do something about it.
* Need to play with output sizes a bit to make sure we put the right parameters in there to get
* correct behavior.
*/
CNN2DFormat preLayerFormat = ConvolutionUtils.getFormatForLayer(prev.getLayer().conf().getLayer());
CNN2DFormat currLayerFormat = ConvolutionUtils.getFormatForLayer(current.getLayer().conf().getLayer());
if(preLayerFormat != currLayerFormat) {
int inputIdx = -1;
for(int inputVertex = 0; inputVertex < current.getInputVertices().length; inputVertex++) {
if(current.getInputVertices()[inputVertex].getVertexIndex() == prev.getVertexIndex()) {
inputIdx = inputVertex;
}
}
//NHWC case
if(preLayerFormat == CNN2DFormat.NCHW) {
current.setInput(inputIdx,current.getInputs()[inputIdx].permute(0,3,1,2),workspaceMgr);
}
//NCHW case
else if(preLayerFormat == CNN2DFormat.NHWC) {
current.setInput(inputIdx,current.getInputs()[inputIdx].permute(0,2,3,1),workspaceMgr);
}
else
throw new IllegalStateException("No CNN2DDataFormat type found for previous layer!");
out = current.doForward(train, workspaceMgr);
}
else
out = current.doForward(train, workspaceMgr);
} else if(i > 0 && current.hasLayer() && prev.hasLayer() &&
Convolution1DUtils.hasRnnDataFormat(prev.getLayer().conf().getLayer())
&& Convolution1DUtils.hasRnnDataFormat(current.getLayer().conf().getLayer())) {
RNNFormat preLayerFormat = Convolution1DUtils.getRnnFormatFromLayer(prev.getLayer().conf().getLayer());
RNNFormat currLayerFormat = Convolution1DUtils.getRnnFormatFromLayer(current.getLayer().conf().getLayer());
int inputIdx = -1;
for(int inputVertex = 0; inputVertex < current.getInputVertices().length; inputVertex++) {
if(current.getInputVertices()[inputVertex].getVertexIndex() == prev.getVertexIndex()) {
inputIdx = inputVertex;
}
}
//permute for next layer
if(preLayerFormat != currLayerFormat) {
current.setInput(inputIdx,current.getInputs()[inputIdx].permute(0,2,1),workspaceMgr);
}
out = current.doForward(train, workspaceMgr);
} else {
out = current.doForward(train, workspaceMgr);
}
} else if (fwdPassType == FwdPassType.RNN_TIMESTEP) {
if (current.hasLayer()) {
//Layer
INDArray input = current.getInputs()[0];
Layer l = current.getLayer();
if (l instanceof RecurrentLayer) {
out = ((RecurrentLayer) l).rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr);
} else if (l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer && ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying() instanceof RecurrentLayer) {
RecurrentLayer rl = ((RecurrentLayer) ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying());
out = rl.rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr);
} else if (l instanceof MultiLayerNetwork) {
out = ((MultiLayerNetwork) l).rnnTimeStep(reshapeTimeStepInput(input));
} else {
//non-recurrent layer
out = current.doForward(train, workspaceMgr);
}
} else {
//GraphNode
out = current.doForward(train, workspaceMgr);
}
} else {
throw new IllegalArgumentException("Unsupported forward pass type for this method: " + fwdPassType);
}
validateArrayWorkspaces(workspaceMgr, out, ArrayType.ACTIVATIONS, vName, false, "Feed forward (inference)");
}
if (inputsTo != null) { //Output vertices may not input to any other vertices
for (VertexIndices v : inputsTo) {
//Note that we don't have to do anything special here: the activations are always detached in
// this method
int inputToIndex = v.getVertexIndex();
int vIdxEdge = v.getVertexEdgeNumber();
vertices[inputToIndex].setInput(vIdxEdge, out, workspaceMgr);
}
}
if (clearLayerInputs) {
current.clear();
}
if (isRequiredOutput) {
outputs[ArrayUtils.indexOf(layerIndexes, vIdx)] = out;
if (origWSAct != null) {
//Reset the configuration, as we may reuse this workspace manager...
workspaceMgr.setWorkspace(ArrayType.ACTIVATIONS, origWSAct, origWSActConf);
}
}
}
//Close any activations workspaces that we no longer require
//Note that activations workspaces can be closed only once the corresponding output activations have
// been fully consumed
if (closeAtEndIteraton[i] != null) {
for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
wsAct.close();
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
freeWorkspaceManagers.add(canNowReuse);
}
}
}
} catch (Throwable t2){
t = t2;
} finally {
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
//Though if stopIndex < numLayers, some might still be open
for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){
while (ws.isScopeActive()) {
//Edge case here: seems that scoping out can increase the tagScope of the current WS
//and if we hit an exception during forward pass, we aren't guaranteed to call close a sufficient
// number of times to actually close it, in all cases
try{
ws.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
}
}
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
if(t != null){
if(t instanceof RuntimeException){
throw ((RuntimeException)t);
}
throw new RuntimeException("Error during neural network forward pass", t);
}
if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) {
WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active at the end of outputOfLayerDetached");
} else {
Preconditions.checkState(outputWorkspace.isScopeActive(), "Expected output workspace to still be open" +
"at end of outputOfLayerDetached, but ");
outputWorkspace.setPreviousWorkspace(outputPrevious);
}
}
return outputs;
}
private INDArray reshapeTimeStepInput(INDArray input) {
if (input.rank() == 2) { // dynamically reshape to 3D input with one time-step.
long[] inShape = input.shape();
input = input.reshape(inShape[0], inShape[1], 1);
}
return input;
}
/**
* Calculate the gradient of the network with respect to some external errors.
* Note that this is typically used for things like reinforcement learning, not typical networks that include
* an OutputLayer or RnnOutputLayer
*
* @param epsilons Epsilons (errors) at the output. Same order with which the output layers are defined in configuration setOutputs(String...)
* @return Gradient for the network
*/
public Gradient backpropGradient(INDArray... epsilons) {
if (epsilons == null || epsilons.length != numOutputArrays)
throw new IllegalArgumentException(
"Invalid input: must have epsilons length equal to number of output arrays");
try {
calcBackpropGradients(true, configuration.getBackpropType() == BackpropType.TruncatedBPTT, epsilons);
return gradient;
} catch (OutOfMemoryError e){
CrashReportingUtil.writeMemoryCrashDump(this, e);
throw e;
}
}
/**
* Do backprop (gradient calculation)
*
* @param truncatedBPTT false: normal backprop. true: calculate gradients using truncated BPTT for RNN layers
* @param externalEpsilons null usually (for typical supervised learning). If not null (and length > 0) then assume that
* the user has provided some errors externally, as they would do for example in reinforcement
* learning situations.
*/
protected void calcBackpropGradients(boolean clearLayers, boolean truncatedBPTT, INDArray... externalEpsilons) {
if (flattenedGradients == null) {
initGradientsView();
}
/*
Design for workspaces use in backprop for ComputationGraph is similar to MultiLayerNetwork and shares some
features with outputOfLayersDetached
Specifically:
1. We assume forward pass has already been done, and hence layer input fields are set (with all arrays/activations in
workspace WS_ALL_LAYERS_ACT if appropriate)
2. We use a set of small workspaces to contain the activation gradients for a single layer
These are opened once per layer, and are closed only once the corresponding activation gradients have been
consumed by all layers
*/
if(externalEpsilons == null || externalEpsilons.length == 0 && configuration.getTrainingWorkspaceMode() != WorkspaceMode.NONE){
WorkspaceUtils.assertOpenAndActive(WS_ALL_LAYERS_ACT, "Expected workspace WS_ALL_LAYERS_ACT to be active and open" +
" in calcBackpropGradients when workspace mode is not set to NONE");
}
//Validate the network configuration for external errors - no output layers
if(externalEpsilons != null && externalEpsilons.length > 0){
List outputLayers = configuration.getNetworkOutputs();
for(String s : outputLayers ){
GraphVertex gv = getVertex(s);
if(gv instanceof LayerVertex && ((LayerVertex)gv).getLayer() instanceof IOutputLayer){
throw new IllegalStateException("Cannot perform backprop with external errors in conjunction with an output layer:" +
" output layers cannot use external errors for backprop. Layer name: " + s);
}
}
}
//Position in array: index of vertex. Value at position: the step (in topological order) that the activation
// gradients of the specified vertex have been consumed by
//Put another way: this is the step that it's safe to deallocate the layer's activation gradients by closing the
// corresponding workspace
//TODO we can probably cache this...
int[] vertexActGradsFullyConsumedByStep = new int[topologicalOrder.length];
for(GraphVertex gv : vertices){
int idx = gv.getVertexIndex();
int minStepOfInputFrom = Integer.MAX_VALUE;
VertexIndices[] inputsFrom = gv.getInputVertices();
if(inputsFrom != null) {
//inputsFrom may be null for input vertex
for (VertexIndices vi : inputsFrom) {
int posInTopoSort = ArrayUtils.indexOf(topologicalOrder, vi.getVertexIndex());
if (posInTopoSort == -1) {
throw new IllegalStateException("Did not find vertex " + vi.getVertexIndex() + " in topological sort array");
}
minStepOfInputFrom = Math.min(minStepOfInputFrom, posInTopoSort);
}
}
if(minStepOfInputFrom == Integer.MAX_VALUE){
//Input vertex, etc
vertexActGradsFullyConsumedByStep[idx] = 0;
} else {
vertexActGradsFullyConsumedByStep[idx] = minStepOfInputFrom;
}
}
boolean noWS = configuration.getInferenceWorkspaceMode() == WorkspaceMode.NONE;
LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null;
List allWorkspaceManagers = new ArrayList<>();
List freeWorkspaceManagers = new ArrayList<>(); //Basically used as a stack
Map openActivationsWorkspaces = new IdentityHashMap<>();
List[] closeAtEndIteraton = (List[])new List[topologicalOrder.length];
//Do backprop, in reverse topological order
LinkedList> gradients = new LinkedList<>();
boolean[] setVertexEpsilon = new boolean[topologicalOrder.length]; //If true: already set epsilon for this vertex; later epsilons should be *added* to the existing one, not set
MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
boolean traceLog = log.isTraceEnabled();
Throwable t = null;
try {
for (int i = topologicalOrder.length - 1; i >= 0; i--) {
boolean hitFrozen = false;
GraphVertex current = vertices[topologicalOrder[i]];
int vIdx = current.getVertexIndex();
String vertexName = current.getVertexName();
if (traceLog) {
log.trace("About backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
}
//FIXME: make the frozen vertex feature extraction more flexible
if (current.hasLayer() && current.getLayer() instanceof FrozenLayer || current instanceof FrozenVertex) {
hitFrozen = true;
}
if (current.isInputVertex() || hitFrozen) {
//Close any activation gradient workspaces that we no longer require
//Note that activation gradient workspaces can be closed only once the corresponding activations
// gradients have been fully consumed
if (closeAtEndIteraton[i] != null) {
for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
wsAct.close();
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
freeWorkspaceManagers.add(canNowReuse);
}
}
closeAtEndIteraton[i] = null;
continue;
}
//First: determine what workspace manager we should use for the activation gradients from this vertex
LayerWorkspaceMgr workspaceMgr;
if (noWS) {
workspaceMgr = allNone;
} else {
//First: is there a free activation gradient workspace we can use?
if (freeWorkspaceManagers.size() > 0) {
workspaceMgr = freeWorkspaceManagers.remove(freeWorkspaceManagers.size() - 1);
} else {
//No existing free workspace managers for forward pass - create a new one...
String wsName = "WS_LAYER_ACT_" + allWorkspaceManagers.size();
workspaceMgr = LayerWorkspaceMgr.builder()
.with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.with(ArrayType.ACTIVATION_GRAD, wsName, WS_LAYER_ACT_X_CONFIG)
.with(ArrayType.ACTIVATIONS, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) //For forward pass in the context of BP
.with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
.with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
.build();
allWorkspaceManagers.add(workspaceMgr);
}
}
workspaceMgr.setHelperWorkspacePointers(helperWorkspaces);
if (current.isOutputVertex()) {
//Two reasons for a vertex to be an output vertex:
//(a) it's an output layer (i.e., instanceof IOutputLayer), or
//(b) it's a normal layer, but it has been marked as an output layer for use in external errors - for reinforcement learning, for example
int thisOutputNumber = configuration.getNetworkOutputs().indexOf(current.getVertexName());
Layer currentLayer = current.getLayer();
if (currentLayer instanceof FrozenLayerWithBackprop) {
currentLayer = ((FrozenLayerWithBackprop) currentLayer).getInsideLayer();
}
if (currentLayer instanceof IOutputLayer) {
IOutputLayer outputLayer = (IOutputLayer) currentLayer;
INDArray currLabels = labels[thisOutputNumber];
outputLayer.setLabels(currLabels);
} else {
if ((externalEpsilons == null || externalEpsilons.length == 0)
&& labels[thisOutputNumber] != null) {
throw new DL4JException("Layer \"" + current.getVertexName() + "\" of type "
+ current.getLayer().getClass().getSimpleName()
+ " is set as network output "
+ "(but isn't an IOutputLayer). Only IOutputLayer layers can be fit via backprop with"
+ " a labels array. ");
}
current.setEpsilon(externalEpsilons[thisOutputNumber]);
setVertexEpsilon[topologicalOrder[i]] = true;
}
}
//Actually execute backprop for the specified vertex
//First: Open the relevant workspace for the activations.
//Note that this will be closed only once the current vertex's activations have been consumed
MemoryWorkspace wsActivationGrads = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATION_GRAD);
openActivationsWorkspaces.put(wsActivationGrads, workspaceMgr);
//Note that because we're opening activation gradient workspaces not in any defined order (i.e., workspace
// use isn't simply nested), we'll manually override the previous workspace setting. Otherwise, when we
// close these workspaces, the "current" workspace may be set to the incorrect one
wsActivationGrads.setPreviousWorkspace(initialWorkspace);
int closeableAt = vertexActGradsFullyConsumedByStep[vIdx];
if (closeableAt >= 0) {
if (closeAtEndIteraton[closeableAt] == null) {
closeAtEndIteraton[closeableAt] = new ArrayList<>();
}
closeAtEndIteraton[closeableAt].add(wsActivationGrads);
}
Pair pair;
INDArray[] epsilons;
try (MemoryWorkspace wsWorkingMem = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)) {
pair = current.doBackward(truncatedBPTT, workspaceMgr);
epsilons = pair.getSecond();
//Validate workspace location for the activation gradients:
//validateArrayWorkspaces(LayerWorkspaceMgr mgr, INDArray array, ArrayType arrayType, String vertexName, boolean isInputVertex, String op){
for (INDArray epsilon : epsilons) {
if (epsilon != null) {
//May be null for EmbeddingLayer, etc
validateArrayWorkspaces(workspaceMgr, epsilon, ArrayType.ACTIVATION_GRAD, vertexName, false, "Backprop");
}
}
}
//Inputs to the current GraphVertex:
VertexIndices[] inputVertices = current.getInputVertices();
//Set epsilons for the vertices that provide inputs to this vertex:
if (inputVertices != null) {
int j = 0;
for (VertexIndices v : inputVertices) {
GraphVertex gv = vertices[v.getVertexIndex()];
if (setVertexEpsilon[gv.getVertexIndex()]) {
//This vertex: must output to multiple vertices... we want to add the epsilons here
INDArray currentEps = gv.getEpsilon();
if(currentEps == null){
//Edge case: this can be null for dual embedding layer case - in -> e1, in -> e2
gv.setEpsilon(currentEps);
} else {
gv.setEpsilon(currentEps.addi(epsilons[j++])); //TODO is this always safe?
}
} else {
gv.setEpsilon(epsilons[j++]);
}
setVertexEpsilon[gv.getVertexIndex()] = true;
}
}
if (pair.getFirst() != null) {
Gradient g = pair.getFirst();
Map map = g.gradientForVariable();
LinkedList> tempList = new LinkedList<>();
for (Map.Entry entry : map.entrySet()) {
String origName = entry.getKey();
String newName = current.getVertexName() + "_" + origName;
tempList.addFirst(new Triple<>(newName, entry.getValue(),
g.flatteningOrderForVariable(origName)));
}
for (Triple triple : tempList)
gradients.addFirst(triple);
}
//Close any activation gradient workspaces that we no longer require
//Note that activation gradient workspaces can be closed only once the corresponding activations
// gradients have been fully consumed
if (closeAtEndIteraton[i] != null) {
for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
wsAct.close();
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
freeWorkspaceManagers.add(canNowReuse);
}
closeAtEndIteraton[i] = null;
}
if (traceLog) {
log.trace("Completed backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
}
}
} catch (Throwable t2){
t = t2;
} finally {
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){
try{
ws.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
}
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
if(t != null){
if(t instanceof RuntimeException){
throw ((RuntimeException)t);
}
throw new RuntimeException("Error during neural network backpropagation calculation", t);
}
}
//Now, add the gradients in the order we need them in for flattening (same as params order)
Gradient gradient = new DefaultGradient(flattenedGradients);
for (Triple tr : gradients) {
gradient.setGradientFor(tr.getFirst(), tr.getSecond(), tr.getThird());
}
this.gradient = gradient;
if(truncatedBPTT && clearTbpttState){
rnnClearPreviousState();
}
//Clear inputs and epsilons:
if(clearLayers) {
for (GraphVertex gv : vertices) {
gv.clear();
}
}
}
@Override
public ComputationGraph clone() {
ComputationGraph cg = new ComputationGraph(configuration.clone());
cg.init(params().dup(), false);
if (solver != null) {
//If solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however
ComputationGraphUpdater u = this.getUpdater();
INDArray updaterState = u.getStateViewArray();
if (updaterState != null) {
cg.getUpdater().setStateViewArray(updaterState.dup());
}
}
cg.trainingListeners = this.trainingListeners;
for (int i = 0; i < topologicalOrder.length; i++) {
if (!vertices[topologicalOrder[i]].hasLayer())
continue;
String layerName = vertices[topologicalOrder[i]].getVertexName();
if (getLayer(layerName) instanceof FrozenLayer) {
cg.getVertex(layerName).setLayerAsFrozen();
}
}
return cg;
}
public double calcRegularizationScore(boolean backpropParamsOnly){
double scoreSum = 0.0;
for (int i = 0; i < layers.length; i++) {
scoreSum += layers[i].calcRegularizationScore(backpropParamsOnly);
}
return scoreSum;
}
/**
* Set the trainingListeners for the ComputationGraph (and all layers in the network)
*/
public void setListeners(Collection listeners) {
if (layers == null)
init();
for (Layer l : layers) {
l.setListeners(listeners);
}
if (solver != null) {
solver.setListeners(listeners);
}
this.trainingListeners.clear();
if (listeners != null) {
this.trainingListeners.addAll(listeners);
}
}
/**
* Set the trainingListeners for the ComputationGraph (and all layers in the network)
*/
public void setListeners(TrainingListener... listeners) {
List list = new ArrayList<>();
//Check: user might have done setListeners(null) thinking this would clear the current listeners.
//This results in an TrainingListener[1] with a single null value -> results in a NPE later
if (listeners != null && listeners.length > 0) {
for (TrainingListener i : listeners) {
if (i != null)
list.add(i);
}
}
setListeners(list);
}
/**
* This method ADDS additional TrainingListener to existing listeners
*
* @param listeners Listeners to add
*/
@Override
public void addListeners(TrainingListener... listeners) {
if (this.trainingListeners == null) {
setListeners(listeners);
return;
} else {
List newListeners = new ArrayList<>(this.trainingListeners); //To avoid immutable list issues
Collections.addAll(newListeners, listeners);
setListeners(newListeners);
}
if (solver != null) {
solver.setListeners(this.trainingListeners);
}
}
/**
* Get the trainingListeners for the ComputationGraph
*/
public Collection getListeners() {
return trainingListeners;
}
/**
* Get the ComputationGraphUpdater for the network. Creates one on demand, if required
*/
public ComputationGraphUpdater getUpdater() {
return getUpdater(true);
}
/**
* Get the ComputationGraphUpdater for this network
* @param initializeIfAbsent If true: create the updater if one is absent. False: return null if absent.
* @return Updater
*/
public ComputationGraphUpdater getUpdater(boolean initializeIfAbsent){
if (solver == null && initializeIfAbsent) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
solver.getOptimizer().setUpdaterComputationGraph(new ComputationGraphUpdater(this));
}
if(solver != null) {
return solver.getOptimizer().getComputationGraphUpdater(initializeIfAbsent);
}
return null;
}
/**
* Set the computationGraphUpdater for the network
*/
public void setUpdater(ComputationGraphUpdater updater) {
if (solver == null) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
}
solver.getOptimizer().setUpdaterComputationGraph(updater);
}
/**
* Get the specified output layer, by index. The index of the output
* layer may be 0 to {@link #getNumOutputArrays()}-1
*/
public Layer getOutputLayer(int outputLayerIdx) {
if (outputLayerIdx >= numOutputArrays)
throw new IllegalArgumentException("Invalid index: cannot get output layer " + outputLayerIdx
+ ", total number of network outputs = " + numOutputArrays);
return getLayer(configuration.getNetworkOutputs().get(outputLayerIdx));
}
/**
* @deprecated To be removed. Use {@link #params()}
*/
@Deprecated
public INDArray params(boolean backwardOnly) {
return params();
}
/**
* Sets the input and labels and returns a score for the prediction with respect to the true labels
* This is equivalent to {@link #score(DataSet, boolean)} with training==true.
* NOTE: this version of the score function can only be used with ComputationGraph networks that have
* a single input and a single output.
*
* @param dataSet the data to score
* @return the score for the given input,label pairs
* @see #score(DataSet, boolean)
*/
public double score(DataSet dataSet) {
return score(dataSet, false);
}
/**
* Sets the input and labels and returns a score for the prediction with respect to the true labels
* NOTE: this version of the score function can only be used with ComputationGraph networks that have
* a single input and a single output. Use {@link #score(MultiDataSet, boolean)} for multiple input/output networks
*
* @param dataSet the data to score
* @param training whether score is being calculated at training time (true) or test time (false)
* @return the score for the given input,label pairs
* @see #score(DataSet, boolean)
*/
public double score(DataSet dataSet, boolean training) {
if (numInputArrays != 1 || numOutputArrays != 1)
throw new UnsupportedOperationException("Cannot score ComputationGraph network with "
+ " DataSet: network does not have 1 input and 1 output arrays");
return score(ComputationGraphUtil.toMultiDataSet(dataSet), training);
}
/**
* Score the network given the MultiDataSet, at test time
*/
public double score(MultiDataSet dataSet) {
return score(dataSet, false);
}
/**
* Sets the input and labels and returns a score for the prediction with respect to the true labels
*
* @param dataSet the data to score
* @param training whether score is being calculated at training time (true) or test time (false)
* @return the score for the given input,label pairs
*/
public double score(MultiDataSet dataSet, boolean training) {
try{
return scoreHelper(dataSet, training);
} catch (OutOfMemoryError e){
CrashReportingUtil.writeMemoryCrashDump(this, e);
throw e;
}
}
private double scoreHelper(MultiDataSet dataSet, boolean training){
LayerWorkspaceMgr mgr;
WorkspaceMode wsm = (training ? configuration.getTrainingWorkspaceMode() : configuration.getInferenceWorkspaceMode());
if(wsm == WorkspaceMode.NONE){
mgr = LayerWorkspaceMgr.noWorkspaces();
} else {
mgr = LayerWorkspaceMgr.builder()
.noWorkspaceFor(ArrayType.ACTIVATIONS)
.noWorkspaceFor(ArrayType.INPUT)
.with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
.build();
}
mgr.setHelperWorkspacePointers(helperWorkspaces);
boolean hasMaskArrays = dataSet.hasMaskArrays();
if (hasMaskArrays) {
setLayerMaskArrays(dataSet.getFeaturesMaskArrays(), dataSet.getLabelsMaskArrays());
}
double score = 0.0;
setInputs(dataSet.getFeatures());
//TODO Can possibly optimize this, in terms of memory use/workspaces
ffToLayerActivationsDetached(training, FwdPassType.STANDARD, false, vertices.length-1,
getOutputLayerIndices(), dataSet.getFeatures(), dataSet.getFeaturesMaskArrays(),dataSet.getLabelsMaskArrays(), false);
//Need to feed forward, but not the output layers
try(WorkspacesCloseable ws = mgr.notifyScopeEntered(ArrayType.ACTIVATIONS,ArrayType.FF_WORKING_MEM,ArrayType.RNN_FF_LOOP_WORKING_MEM)){
INDArray[] labels = dataSet.getLabels();
setLabels(labels);
//Score: sum of the scores for the various output layers...
double r = calcRegularizationScore(true);
int i = 0;
for (String s : configuration.getNetworkOutputs()) {
GraphVertex gv = verticesMap.get(s);
Layer outLayer = gv.getLayer();
if (outLayer == null || !(outLayer instanceof IOutputLayer)) {
log.warn("Cannot calculate score: vertex \"" + s + "\" is not an output layer");
return 0.0;
}
IOutputLayer ol = (IOutputLayer) outLayer;
ol.setLabels(labels[i++]);
score += ((LayerVertex) gv).computeScore(r, training, mgr);
//Only want to add l1/l2 once...
r = 0.0;
}
}
clearLayersStates(); //Clean up layer inputs/mask arrays - may be invalidated by workspace
return score;
}
/**
* Calculate the score for each example in a DataSet individually. Unlike {@link #score(DataSet)} and {@link #score(DataSet, boolean)}
* this method does not average/sum over examples. This method allows for examples to be scored individually (at test time only), which
* may be useful for example for autoencoder architectures and the like.
* Each row of the output (assuming addRegularizationTerms == true) is equivalent to calling score(DataSet) with a single example.
*
* @param data The data to score
* @param addRegularizationTerms If true: add l1/l2 regularization terms (if any) to the score. If false: don't add regularization terms
* @return An INDArray (column vector) of size input.numRows(); the ith entry is the score (loss value) of the ith example
*/
public INDArray scoreExamples(DataSet data, boolean addRegularizationTerms) {
if (numInputArrays != 1 || numOutputArrays != 1)
throw new UnsupportedOperationException("Cannot score ComputationGraph network with "
+ " DataSet: network does not have 1 input and 1 output arrays");
return scoreExamples(ComputationGraphUtil.toMultiDataSet(data), addRegularizationTerms);
}
/**
* Calculate the score for each example in a DataSet individually. Unlike {@link #score(MultiDataSet)} and {@link #score(MultiDataSet, boolean)}
* this method does not average/sum over examples. This method allows for examples to be scored individually (at test time only), which
* may be useful for example for autoencoder architectures and the like.
* Each row of the output (assuming addRegularizationTerms == true) is equivalent to calling score(MultiDataSet) with a single example.
*
* @param dataSet The data to score
* @param addRegularizationTerms If true: add l1/l2 regularization terms (if any) to the score. If false: don't add regularization terms
* @return An INDArray (column vector) of size input.numRows(); the ith entry is the score (loss value) of the ith example
*/
public INDArray scoreExamples(MultiDataSet dataSet, boolean addRegularizationTerms) {
try{
return scoreExamplesHelper(dataSet, addRegularizationTerms);
} catch (OutOfMemoryError e){
CrashReportingUtil.writeMemoryCrashDump(this, e);
throw e;
}
}
private INDArray scoreExamplesHelper(MultiDataSet dataSet, boolean addRegularizationTerms){
LayerWorkspaceMgr mgr;
if(configuration.getInferenceWorkspaceMode() == WorkspaceMode.NONE){
mgr = LayerWorkspaceMgr.noWorkspaces();
} else {
mgr = LayerWorkspaceMgr.builder()
.with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG)
.with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG)
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
.build();
}
mgr.setHelperWorkspacePointers(helperWorkspaces);
boolean hasMaskArrays = dataSet.hasMaskArrays();
if (hasMaskArrays) {
setLayerMaskArrays(dataSet.getFeaturesMaskArrays(), dataSet.getLabelsMaskArrays());
}
INDArray out = null;
setInputs(dataSet.getFeatures());
//Need to feed forward, but not the output layers
try(MemoryWorkspace ws = mgr.notifyScopeEntered(ArrayType.ACTIVATIONS)) {
//TODO maybe optimize? We only need *some* of the activations in the WS...
ffToLayerActivationsInWS(false, vertices.length - 1, getOutputLayerIndices(), FwdPassType.STANDARD, false,
dataSet.getFeatures(), dataSet.getFeaturesMaskArrays(), dataSet.getLabelsMaskArrays(), false);
INDArray[] labels = dataSet.getLabels();
setLabels(labels);
double r = (addRegularizationTerms ? calcRegularizationScore(true) : 0.0);
int i = 0;
for (String s : configuration.getNetworkOutputs()) {
GraphVertex gv = verticesMap.get(s);
Layer outLayer = gv.getLayer();
if (outLayer == null || !(outLayer instanceof IOutputLayer)) {
throw new UnsupportedOperationException(
"Cannot calculate score: vertex \"" + s + "\" is not an output layer");
}
IOutputLayer ol = (IOutputLayer) outLayer;
ol.setLabels(labels[i++]);
INDArray scoreCurrLayer;
try(MemoryWorkspace wsFF = mgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) {
scoreCurrLayer =((LayerVertex) gv).computeScoreForExamples(r, mgr);
}
if (out == null)
out = scoreCurrLayer.detach();
else
out.addi(scoreCurrLayer);
//Only want to add l1/l2 once...
r = 0.0;
}
}
if (dataSet.hasMaskArrays())
clearLayerMaskArrays();
clearLayersStates();
return out;
}
//------------------------------------------------------
//Model methods:
@Override
public void fit() {
fit(inputs, labels, inputMaskArrays, labelMaskArrays);
}
@Override
public void update(INDArray gradient, String paramType) {
throw new UnsupportedOperationException("Not implemented");
}
@Override
public void update(Gradient gradient) {
if (gradient.gradient().length() != numParams(true))
throw new IllegalArgumentException("Invalid input: expect gradients array of length " + numParams(true));
for (Map.Entry entry : gradient.gradientForVariable().entrySet()) {
String key = entry.getKey();
INDArray val = entry.getValue();
int idx = key.lastIndexOf('_');
if (idx == -1)
throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\"");
String layerName = key.substring(0, idx);
String paramType = key.split("_")[1];
// Update graph gradient
this.gradient.gradientForVariable().put(key, val);
// Update layer params
getLayer(layerName).update(val, paramType);
}
// Update layerwise gradient view
setBackpropGradientsViewArray(gradient.gradient());
}
private void update(Task task) {
if (!initDone) {
initDone = true;
Heartbeat heartbeat = Heartbeat.getInstance();
task = ModelSerializer.taskByModel(this);
Environment env = EnvironmentUtils.buildEnvironment();
heartbeat.reportEvent(Event.STANDALONE, env, task);
}
}
@Override
public double score() {
return score;
}
public void setScore(double score) {
this.score = score;
}
@Override
public INDArray params() {
if(flattenedParams == null)
return Nd4j.zeros(DataType.FLOAT,0);
if(flattenedParams.rank() > 1 && !flattenedParams.wasClosed())
return flattenedParams.reshape(flattenedParams.length());
return flattenedParams;
}
@Override
public INDArray updaterState() {
return getUpdater() != null ? getUpdater().getUpdaterStateViewArray() : null;
}
@Override
public long numParams() {
return numParams(true);
}
@Override
public long numParams(boolean backwards) {
int nParams = 0;
for (Layer layer : layers) {
nParams += layer.numParams(backwards);
}
return nParams;
}
@Override
public void setParams(INDArray params) {
if (params == flattenedParams)
return; //No op
if (this.flattenedParams != null && this.flattenedParams.length() == params.length()) {
this.flattenedParams.assign(params);
return;
}
INDArray paramsViewReshape = params.reshape(params.length());
int idx = 0;
for (int i = 0; i < topologicalOrder.length; i++) {
if (!vertices[topologicalOrder[i]].hasLayer())
continue;
Layer layer = vertices[topologicalOrder[i]].getLayer();
long range = layer.numParams();
if (range <= 0)
continue; //Some layers: no parameters (subsampling etc)
INDArray get = paramsViewReshape.get(NDArrayIndex.interval(idx, range + idx));
layer.setParams(get);
idx += range;
}
}
@Override
public void setParamsViewArray(INDArray gradient) {
throw new UnsupportedOperationException("Not supported");
}
@Override
public INDArray getGradientsViewArray() {
return flattenedGradients;
}
@Override
public void setBackpropGradientsViewArray(INDArray gradient) {
INDArray gradientReshape = gradient.reshape(gradient.length());
int paramsSoFar = 0;
for (int i = 0; i < topologicalOrder.length; i++) {
if (!vertices[topologicalOrder[i]].hasLayer())
continue;
Layer layer = vertices[topologicalOrder[i]].getLayer();
long range = layer.numParams();
if (range <= 0)
continue; //Some layers: no parameters (subsampling etc)
layer.setBackpropGradientsViewArray(gradientReshape.get(
NDArrayIndex.interval(paramsSoFar, paramsSoFar + range)));
paramsSoFar += range;
}
}
@Override
public void fit(INDArray data, LayerWorkspaceMgr workspaceMgr){
throw new UnsupportedOperationException("Cannot pretrain ComputationGraph with single INDArray");
}
@Override
public Gradient gradient() {
return gradient;
}
@Override
public Pair gradientAndScore() {
return new Pair<>(gradient(), score());
}
@Override
public int batchSize() {
//In 99+% of cases, the input and labels dimension 0 size should be identical
//The only real exceptions: space to batch, and batch to space layers
//In those cases, we should base it on the labels size, as this impacts gradient calculation
return labels == null || labels[0] == null ? (int) inputs[0].size(0) : (int)labels[0].size(0);
}
@Override
public NeuralNetConfiguration conf() {
return defaultConfiguration;
}
@Override
public void setConf(NeuralNetConfiguration conf) {
throw new UnsupportedOperationException();
}
@Override
public INDArray input() {
if (numInputArrays == 1)
return (inputs != null ? inputs[0] : null);
else
throw new UnsupportedOperationException(
"Cannot return single input: ComputationGraph has multiple inputs");
}
@Override
public ConvexOptimizer getOptimizer() {
return solver.getOptimizer();
}
@Override
public INDArray getParam(String paramName) {
// throw new UnsupportedOperationException("Not implemented");
int idx = paramName.lastIndexOf('_');
if (idx == -1)
throw new IllegalStateException("Invalid param key: not have layer separator: \"" + paramName + "\"");
String layerName = paramName.substring(0, idx);
String paramType = paramName.substring(idx + 1);
return getLayer(layerName).getParam(paramType);
}
@Override
public Map paramTable() {
return paramTable(false);
}
public Map paramTable(boolean backpropParamsOnly) {
//Get all parameters from all layers/vertices
Map allParams = new LinkedHashMap<>();
for(GraphVertex gv : vertices){
Map paramMap = gv.paramTable(backpropParamsOnly);
for (Map.Entry entry : paramMap.entrySet()) {
String newKey = gv.getVertexName() + "_" + entry.getKey();
allParams.put(newKey, entry.getValue());
}
}
return allParams;
}
@Override
public void setParamTable(@NonNull Map paramTable) {
Map m = paramTable();
Preconditions.checkArgument(paramTable.keySet().equals(m.keySet()), "Cannot set param table: parameter set keys are not equal");
Map current = paramTable();
//Check shapes before doing partial assigment to avoid leaving net in incorrect state
for(String s : current.keySet()){
INDArray arrCurrent = current.get(s);
INDArray arrNew = paramTable.get(s);
val shapeCurrent = arrCurrent.shape();
val shapeNew = arrNew.shape();
Preconditions.checkState(Arrays.equals(shapeCurrent, shapeNew), "Cannot set parameters: shape array for " +
"parameter \"%s\" does not match existing shape: parameter shape = %s, new param shape = %s", s, shapeCurrent, arrNew);
}
for(String s : current.keySet()) {
INDArray arrCurrent = current.get(s);
INDArray arrNew = paramTable.get(s);
arrCurrent.assign(arrNew);
}
}
@Override
public void setParam(String key, INDArray val) {
// throw new UnsupportedOperationException("Not implemented");
int idx = key.lastIndexOf('_');
if (idx == -1)
throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\"");
String layerName = key.substring(0, idx);
String paramType = key.substring(idx + 1);
getLayer(layerName).setParam(paramType, val);
}
@Override
public void clear() {
inputs = null;
labels = null;
inputMaskArrays = null;
labelMaskArrays = null;
}
@Override
public void applyConstraints(int iteration, int epoch) {
for(Layer l : layers){
l.applyConstraints(iteration, epoch);
}
}
//------------------------------------------------------------------------------
//RNN-specific functionality
/**
* If this ComputationGraph contains one or more RNN layers: conduct forward pass (prediction)
* but using previous stored state for any RNN layers. The activations for the final step are
* also stored in the RNN layers for use next time rnnTimeStep() is called.
* This method can be used to generate output one or more steps at a time instead of always having to do
* forward pass from t=0. Example uses are for streaming data, and for generating samples from network output
* one step at a time (where samples are then fed back into the network as input)
* If no previous state is present in RNN layers (i.e., initially or after calling rnnClearPreviousState()),
* the default initialization (usually 0) is used.
* Supports mini-batch (i.e., multiple predictions/forward pass in parallel) as well as for single examples.
*
* @param inputs Input to network. May be for one or multiple time steps. For single time step:
* input has shape [miniBatchSize,inputSize] or [miniBatchSize,inputSize,1]. miniBatchSize=1 for single example.
* For multiple time steps: [miniBatchSize,inputSize,inputTimeSeriesLength]
* @return Output activations. If output is RNN layer (such as RnnOutputLayer): if all inputs have shape [miniBatchSize,inputSize]
* i.e., is 2d, then outputs have shape [miniBatchSize,outputSize] (i.e., also 2d) instead of [miniBatchSize,outputSize,1].
* Otherwise output is 3d [miniBatchSize,outputSize,inputTimeSeriesLength] when using RnnOutputLayer (or unmodified otherwise).
*/
public INDArray[] rnnTimeStep(INDArray... inputs) {
return rnnTimeStepHelper(null, inputs);
}
/**
* See {@link #rnnTimeStep(INDArray...)} for details.
* If no memory workspace is provided, the output will be detached (not in any workspace).
* If a memory workspace is provided, the output activation array (i.e., the INDArray returned by this method)
* will be placed in the specified workspace. This workspace must be opened by the user before calling this method -
* and the user is responsible for (a) closing this workspace, and (b) ensuring the output array is not used out
* of scope (i.e., not used after closing the workspace to which it belongs - as this is likely to cause either
* an exception when used, or a crash).
*
* @param inputs Input activations
* @param outputWorkspace Output workspace. May be null
* @return The output/activations from the network (either detached or in the specified workspace if provided)
*/
public INDArray[] rnnTimeStep(MemoryWorkspace outputWorkspace, INDArray... inputs){
try{
return rnnTimeStepHelper(outputWorkspace, inputs);
} catch (OutOfMemoryError e){
CrashReportingUtil.writeMemoryCrashDump(this, e);
throw e;
}
}
private INDArray[] rnnTimeStepHelper(MemoryWorkspace outputWs, INDArray... inputs){
boolean inputIs2d = true;
for (INDArray i : inputs) {
if (i.rank() != 2) {
inputIs2d = false;
break;
}
}
INDArray[] outputs = outputOfLayersDetached(false, FwdPassType.RNN_TIMESTEP, getOutputLayerIndices(), inputs, null, null, true, false, outputWs);
//As per MultiLayerNetwork.rnnTimeStep(): if inputs are all 2d, then outputs are all 2d
if (inputIs2d) {
for (int i = 0; i < outputs.length; i++) {
if (outputs[i].rank() == 3 && outputs[i].size(2) == 1) {
//Return 2d output with shape [miniBatchSize,nOut]
// instead of 3d output with shape [miniBatchSize,nOut,1]
outputs[i] = outputs[i].tensorAlongDimension(0, 1, 0);
}
}
}
this.inputs = null;
return outputs;
}
/**
* Get the state of the RNN layer, as used in {@link #rnnTimeStep(INDArray...)}.
*
* @param layer Number/index of the layer.
* @return Hidden state, or null if layer is not an RNN layer
*/
public Map rnnGetPreviousState(int layer) {
return rnnGetPreviousState(layers[layer].conf().getLayer().getLayerName());
}
/**
* Get the state of the RNN layer, as used in {@link #rnnTimeStep(INDArray...)}.
*
* @param layerName name of the layer
* @return Hidden state, or null if layer is not an RNN layer
*/
public Map rnnGetPreviousState(String layerName) {
Layer l = verticesMap.get(layerName).getLayer();
if(l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer){
l = ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying();
}
if (l == null || !(l instanceof RecurrentLayer))
return null;
return ((RecurrentLayer) l).rnnGetPreviousState();
}
/**
* Get a map of states for ALL RNN layers, as used in {@link #rnnTimeStep(INDArray...)}.
* Layers that are not RNN layers will not have an entry in the returned map
*
* @return Map of states (keyed by layer name) or null if layer is not an RNN layer
* @see #rnnSetPreviousStates(Map)
*/
public Map> rnnGetPreviousStates() {
Map> states = new HashMap<>();
for (Layer l : layers) {
if(l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer){
l = ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying();
}
if (l instanceof RecurrentLayer) {
states.put(l.conf().getLayer().getLayerName(), ((RecurrentLayer) l).rnnGetPreviousState());
}
}
return states;
}
/**
* Set the state of the RNN layer, for use in {@link #rnnTimeStep(INDArray...)}
*
* @param layer The number/index of the layer.
* @param state The state to set the specified layer to
*/
public void rnnSetPreviousState(int layer, Map state) {
rnnSetPreviousState(layers[layer].conf().getLayer().getLayerName(), state);
}
/**
* Set the state of the RNN layer, for use in {@link #rnnTimeStep(INDArray...)}
*
* @param layerName The name of the layer.
* @param state The state to set the specified layer to
*/
public void rnnSetPreviousState(String layerName, Map state) {
Layer l = verticesMap.get(layerName).getLayer();
if(l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer){
l = ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying();
}
if (l == null || !(l instanceof RecurrentLayer)) {
throw new UnsupportedOperationException(
"Layer \"" + layerName + "\" is not a recurrent layer. Cannot set state");
}
((RecurrentLayer) l).rnnSetPreviousState(state);
}
/**
* Set the states for all RNN layers, for use in {@link #rnnTimeStep(INDArray...)}
*
* @param previousStates The previous time step states for all layers (key: layer name. Value: layer states)
* @see #rnnGetPreviousStates()
*/
public void rnnSetPreviousStates(Map> previousStates) {
for (Map.Entry> entry : previousStates.entrySet()) {
rnnSetPreviousState(entry.getKey(), entry.getValue());
}
}
/**
* Clear the previous state of the RNN layers (if any), used in {@link #rnnTimeStep(INDArray...)}
*/
public void rnnClearPreviousState() {
if (layers == null)
return;
for (Layer layer : layers) {
if (layer instanceof RecurrentLayer)
((RecurrentLayer) layer).rnnClearPreviousState();
else if (layer instanceof MultiLayerNetwork) {
((MultiLayerNetwork) layer).rnnClearPreviousState();
}
}
}
/**
* Fit the network using truncated BPTT
*/
protected void doTruncatedBPTT(INDArray[] inputs, INDArray[] labels, INDArray[] featureMasks,
INDArray[] labelMasks, LayerWorkspaceMgr workspaceMgr) {
if (flattenedGradients == null) {
initGradientsView();
}
//Approach used here to implement truncated BPTT: if input is 3d, split it. Otherwise: input is unmodified
long timeSeriesLength = -1;
for (INDArray in : inputs) {
if (in.rank() != 3)
continue;
if (timeSeriesLength == -1)
timeSeriesLength = in.size(2);
else if (timeSeriesLength != in.size(2)) {
log.warn("Cannot do TBPTT with time series of different lengths");
return;
}
}
for (INDArray out : labels) {
if (out.rank() != 3)
continue;
if (timeSeriesLength == -1)
timeSeriesLength = out.size(2);
else if (timeSeriesLength != out.size(2)) {
log.warn("Cannot do TBPTT with time series of different lengths");
return;
}
}
long fwdLen = configuration.getTbpttFwdLength();
long nSubsets = timeSeriesLength / fwdLen;
if (timeSeriesLength % fwdLen != 0)
nSubsets++;
rnnClearPreviousState();
for (int i = 0; i < nSubsets; i++) {
long startTimeIdx = i * fwdLen;
long endTimeIdx = startTimeIdx + fwdLen;
if (endTimeIdx > timeSeriesLength)
endTimeIdx = timeSeriesLength;
if (startTimeIdx > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
List list = getSubsetsForTbptt((int) startTimeIdx, endTimeIdx, inputs, labels, featureMasks, labelMasks);
setInputs(list.get(0));
setLabels(list.get(1));
setLayerMaskArrays(list.get(2), list.get(3));
if (solver == null) {
try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this)
.build();
}
}
solver.optimize(workspaceMgr);
//Finally, update the state of the RNN layers:
rnnUpdateStateWithTBPTTState();
}
if(clearTbpttState) {
rnnClearPreviousState();
}
clearLayerMaskArrays();
}
private List getSubsetsForTbptt(int startTimeIdx, long endTimeIdx, INDArray[] inputs, INDArray[] labels,
INDArray[] featureMasks, INDArray[] labelMasks){
INDArray[] newInputs = new INDArray[inputs.length];
INDArray[] newLabels = new INDArray[labels.length];
INDArray[] newFeatureMasks = (featureMasks != null ? new INDArray[featureMasks.length] : null);
INDArray[] newLabelMasks = (labelMasks != null ? new INDArray[labelMasks.length] : null);
for (int j = 0; j < inputs.length; j++) {
if (inputs[j].rank() != 3)
newInputs[j] = inputs[j];
else {
newInputs[j] = inputs[j].get(NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
}
}
for (int j = 0; j < labels.length; j++) {
if (labels[j].rank() != 3)
newLabels[j] = labels[j];
else {
newLabels[j] = labels[j].get(NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
}
}
if (featureMasks != null) {
for (int j = 0; j < featureMasks.length; j++) {
if (featureMasks[j] == null)
continue;
newFeatureMasks[j] = featureMasks[j].get(NDArrayIndex.all(),
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
}
}
if (labelMasks != null) {
for (int j = 0; j < labelMasks.length; j++) {
if (labelMasks[j] == null)
continue;
newLabelMasks[j] = labelMasks[j].get(NDArrayIndex.all(),
NDArrayIndex.interval(startTimeIdx, endTimeIdx));
}
}
return Arrays.asList(newInputs, newLabels, newFeatureMasks, newLabelMasks);
}
/**
* Similar to rnnTimeStep and feedForward() methods. Difference here is that this method:
* (a) like rnnTimeStep does forward pass using stored state for RNN layers, and
* (b) unlike rnnTimeStep does not modify the RNN layer state
* Therefore multiple calls to this method with the same input should have the same output.
* Typically used during training only. Use rnnTimeStep for prediction/forward pass at test time.
*
* @param inputs Input to network
* @param training Whether training or not
* @param storeLastForTBPTT set to true if used as part of truncated BPTT training
* @return Activations for each layer (including input, as per feedforward() etc)
*/
public Map rnnActivateUsingStoredState(INDArray[] inputs, boolean training,
boolean storeLastForTBPTT) {
return ffToLayerActivationsDetached(training, FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE, storeLastForTBPTT, vertices.length-1,
null, inputs, inputMaskArrays, labelMaskArrays, true);
}
/**
* Set the mask arrays for features and labels. Mask arrays are typically used in situations such as one-to-many
* and many-to-one learning with recurrent neural networks, as well as for supporting time series of varying lengths
* within the same minibatch.
* For example, with RNN data sets with input of shape [miniBatchSize,nIn,timeSeriesLength] and outputs of shape
* [miniBatchSize,nOut,timeSeriesLength], the features and mask arrays will have shape [miniBatchSize,timeSeriesLength]
* and contain values 0 or 1 at each element (to specify whether a given input/example is present - or merely padding -
* at a given time step).
* NOTE: This method is not usually used directly. Instead, the various feedForward and fit methods handle setting
* of masking internally.
*
* @param featureMaskArrays Mask array for features (input)
* @param labelMaskArrays Mask array for labels (output)
* @see #clearLayerMaskArrays()
*/
public void setLayerMaskArrays(INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
this.clearLayerMaskArrays();
this.inputMaskArrays = featureMaskArrays;
this.labelMaskArrays = labelMaskArrays;
if (featureMaskArrays != null) {
if (featureMaskArrays.length != numInputArrays) {
throw new IllegalArgumentException("Invalid number of feature mask arrays");
}
long minibatchSize = -1;
for (INDArray i : featureMaskArrays) {
if (i != null) {
minibatchSize = i.size(0);
}
}
//Here: need to do forward pass through the network according to the topological ordering of the network
Map> map = new HashMap<>();
for (int i = 0; i < topologicalOrder.length; i++) {
GraphVertex current = vertices[topologicalOrder[i]];
if (current.isInputVertex()) {
INDArray fMask = featureMaskArrays[current.getVertexIndex()];
map.put(current.getVertexIndex(), new Pair<>(fMask, MaskState.Active));
} else {
VertexIndices[] inputVertices = current.getInputVertices();
//Now: work out the mask arrays to feed forward...
INDArray[] inputMasks = null; //new INDArray[inputVertices.length];
MaskState maskState = null;
for (int j = 0; j < inputVertices.length; j++) {
Pair p = map.get(inputVertices[j].getVertexIndex());
if (p != null) {
if (inputMasks == null) {
inputMasks = new INDArray[inputVertices.length];
}
inputMasks[j] = p.getFirst();
if (maskState == null || maskState == MaskState.Passthrough) {
maskState = p.getSecond();
}
}
}
if (minibatchSize > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
Pair outPair =
current.feedForwardMaskArrays(inputMasks, maskState, (int)minibatchSize);
map.put(topologicalOrder[i], outPair);
}
}
}
if (labelMaskArrays != null) {
if (labelMaskArrays.length != numOutputArrays) {
throw new IllegalArgumentException("Invalid number of label mask arrays");
}
for (int i = 0; i < labelMaskArrays.length; i++) {
if (labelMaskArrays[i] == null) {
// This output doesn't have a mask, we can skip it.
continue;
}
String outputName = configuration.getNetworkOutputs().get(i);
GraphVertex v = verticesMap.get(outputName);
Layer ol = v.getLayer();
ol.setMaskArray(labelMaskArrays[i]);
}
}
}
/**
* Remove the mask arrays from all layers.
* See {@link #setLayerMaskArrays(INDArray[], INDArray[])} for details on mask arrays.
*/
public void clearLayerMaskArrays() {
for (Layer layer : layers) {
layer.setMaskArray(null);
}
this.inputMaskArrays = null;
this.labelMaskArrays = null;
}
/**
* Update the internal state of RNN layers after a truncated BPTT fit call
*/
protected void rnnUpdateStateWithTBPTTState() {
for (int i = 0; i < layers.length; i++) {
if (layers[i] instanceof RecurrentLayer) {
RecurrentLayer l = ((RecurrentLayer) layers[i]);
l.rnnSetPreviousState(l.rnnGetTBPTTState());
} else if (layers[i] instanceof MultiLayerNetwork) {
((MultiLayerNetwork) layers[i]).updateRnnStateWithTBPTTState();
}
}
}
/**
* Evaluate the network (classification performance - single output ComputationGraphs only)
*
* @param iterator Iterator to evaluate on
* @return Evaluation object; results of evaluation on all examples in the data set
*/
public T evaluate(DataSetIterator iterator) {
return (T)evaluate(iterator, (List)null);
}
/**
* Evaluate the network (classification performance - single output ComputationGraphs only)
*
* @param iterator Iterator to evaluate on
* @return Evaluation object; results of evaluation on all examples in the data set
*/
public T evaluate(MultiDataSetIterator iterator) {
return evaluate(iterator, (List)null);
}
/**
* Evaluate the network on the provided data set (single output ComputationGraphs only). Used for evaluating
* the performance of classifiers
*
* @param iterator Data to undertake evaluation on
* @return Evaluation object, summarizing the results of the evaluation on the provided DataSetIterator
*/
public T evaluate(DataSetIterator iterator, List labelsList) {
return evaluate(iterator, labelsList, 1);
}
/**
* Evaluate the network on the provided data set (single output ComputationGraphs only). Used for evaluating
* the performance of classifiers
*
* @param iterator Data to undertake evaluation on
* @return Evaluation object, summarizing the results of the evaluation on the provided DataSetIterator
*/
public T evaluate(MultiDataSetIterator iterator, List labelsList) {
return evaluate(iterator, labelsList, 1);
}
/**
* Evaluate the network (for classification) on the provided data set, with top N accuracy in addition to standard accuracy.
* For 'standard' accuracy evaluation only, use topN = 1
*
* @param iterator Iterator (data) to evaluate on
* @param labelsList List of labels. May be null.
* @param topN N value for top N accuracy evaluation
* @return Evaluation object, summarizing the results of the evaluation on the provided DataSetIterator
*/
public T evaluate(DataSetIterator iterator, List labelsList, int topN) {
if (labelsList == null)
labelsList = iterator.getLabels();
Layer outputLayer = getOutputLayer(0);
if(getConfiguration().isValidateOutputLayerConfig()){
OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), Evaluation.class);
}
return (T)doEvaluation(iterator, new org.deeplearning4j.eval.Evaluation(labelsList, topN))[0];
}
/**
* Evaluate the network (for classification) on the provided data set, with top N accuracy in addition to standard accuracy.
* For 'standard' accuracy evaluation only, use topN = 1
*
* @param iterator Iterator (data) to evaluate on
* @param labelsList List of labels. May be null.
* @param topN N value for top N accuracy evaluation
* @return Evaluation object, summarizing the results of the evaluation on the provided DataSetIterator
*/
public T evaluate(MultiDataSetIterator iterator, List labelsList, int topN) {
Layer outputLayer = getOutputLayer(0);
if(getConfiguration().isValidateOutputLayerConfig()){
OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), Evaluation.class);
}
return (T)doEvaluation(iterator, new org.deeplearning4j.eval.Evaluation(labelsList, topN))[0];
}
/**
* Evaluate the (single output layer only) network for regression performance
*
* @param iterator Data to evaluate on
* @return Regression evaluation
*/
public T evaluateRegression(DataSetIterator iterator) {
return evaluateRegression(iterator, null);
}
/**
* Evaluate the (single output layer only) network for regression performance
*
* @param iterator Data to evaluate on
* @return Regression evaluation
*/
public T evaluateRegression(MultiDataSetIterator iterator) {
return evaluateRegression(iterator, null);
}
/**
* Evaluate the (single output layer only) network for regression performance
*
* @param iterator Data to evaluate on
* @param columnNames Column names for the regression evaluation. May be null.
* @return Regression evaluation
*/
public T evaluateRegression(DataSetIterator iterator, List columnNames) {
return (T)doEvaluation(iterator, new org.deeplearning4j.eval.RegressionEvaluation(columnNames))[0];
}
/**
* Evaluate the (single output layer only) network for regression performance
*
* @param iterator Data to evaluate on
* @return Regression evaluation
*/
public T evaluateRegression(MultiDataSetIterator iterator, List columnNames) {
return (T)doEvaluation(iterator, new org.deeplearning4j.eval.RegressionEvaluation(columnNames))[0];
}
/**
* @deprecated To be removed - use {@link #evaluateROC(DataSetIterator, int)} to enforce selection of appropriate ROC/threshold configuration
*/
@Deprecated
public T evaluateROC(DataSetIterator iterator) {
return evaluateROC(iterator, 0);
}
/**
* Evaluate the network (must be a binary classifier) on the specified data, using the {@link ROC} class
*
* @param iterator Data to evaluate on
* @param rocThresholdSteps Number of threshold steps to use with {@link ROC}
* @return ROC evaluation on the given dataset
*/
public T evaluateROC(DataSetIterator iterator, int rocThresholdSteps) {
Layer outputLayer = getOutputLayer(0);
if(getConfiguration().isValidateOutputLayerConfig()){
OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROC.class);
}
return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROC(rocThresholdSteps))[0];
}
/**
* @deprecated To be removed - use {@link #evaluateROC(DataSetIterator, int)} to enforce selection of appropriate ROC/threshold configuration
*/
@Deprecated
public T evaluateROC(MultiDataSetIterator iterator) {
return evaluateROC(iterator, 0);
}
/**
* Evaluate the network (must be a binary classifier) on the specified data, using the {@link ROC} class
*
* @param iterator Data to evaluate on
* @param rocThresholdSteps Number of threshold steps to use with {@link ROC}
* @return ROC evaluation on the given dataset
*/
public T evaluateROC(MultiDataSetIterator iterator, int rocThresholdSteps) {
Layer outputLayer = getOutputLayer(0);
if(getConfiguration().isValidateOutputLayerConfig()){
OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROC.class);
}
return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROC(rocThresholdSteps))[0];
}
/**
* @deprecated To be removed - use {@link #evaluateROCMultiClass(DataSetIterator, int)} to enforce selection of appropriate ROC/threshold configuration
*/
@Deprecated
public T evaluateROCMultiClass(DataSetIterator iterator) {
return evaluateROCMultiClass(iterator, 0);
}
/**
* Evaluate the network on the specified data, using the {@link ROCMultiClass} class
*
* @param iterator Data to evaluate on
* @param rocThresholdSteps Number of threshold steps to use with {@link ROCMultiClass}
* @return Multi-class ROC evaluation on the given dataset
*/
public T evaluateROCMultiClass(DataSetIterator iterator, int rocThresholdSteps) {
Layer outputLayer = getOutputLayer(0);
if(getConfiguration().isValidateOutputLayerConfig()){
OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROCMultiClass.class);
}
return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROCMultiClass(rocThresholdSteps))[0];
}
/**
* Evaluate the network on the specified data, using the {@link ROCMultiClass} class
*
* @param iterator Data to evaluate on
* @param rocThresholdSteps Number of threshold steps to use with {@link ROCMultiClass}
* @return Multi-class ROC evaluation on the given dataset
*/
public T evaluateROCMultiClass(MultiDataSetIterator iterator, int rocThresholdSteps) {
Layer outputLayer = getOutputLayer(0);
if(getConfiguration().isValidateOutputLayerConfig()){
OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROCMultiClass.class);
}
return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROCMultiClass(rocThresholdSteps))[0];
}
/**
* Perform evaluation on the given data (DataSetIterator) with the given {@link IEvaluation} instance
*
* @param iterator Test data to evaluate on
* @param evaluations IEvaluation instances
* @param Type of the IEvaluation instance
* @return The input IEvaluation instance, after performing evaluation on the test data
*/
public T[] doEvaluation(DataSetIterator iterator, T... evaluations) {
return doEvaluation(new MultiDataSetIteratorAdapter(iterator), evaluations);
}
/**
* Perform evaluation on the given data (MultiDataSetIterator) with the given {@link IEvaluation} instance
*
* @param iterator Test data to evaluate on
* @param evaluations IEvaluation insntance
* @param Type of the IEvaluation instance
* @return The input IEvaluation instance, after performing evaluation on the test data
*/
public T[] doEvaluation(MultiDataSetIterator iterator, T... evaluations) {
try{
return doEvaluationHelper(iterator, evaluations);
} catch (OutOfMemoryError e){
CrashReportingUtil.writeMemoryCrashDump(this, e);
throw e;
}
}
/**
* Perform evaluation for networks with multiple outputs.
*
* @param iterator Data to evaluate
* @param evaluations Evaluation instances. Key: the network output number (0 to numOutputs-1). Value: the IEvaluation
* instances to perform evaluation with, for that output only. Note that not every output needs to
* have an IEvaluation[] defined.
* @return The same evaluation map, after performing evaluation
*/
public Map evaluate(DataSetIterator iterator, Map evaluations){
return evaluate(new MultiDataSetIteratorAdapter(iterator), evaluations);
}
/**
* Perform evaluation for networks with multiple outputs.
*
* @param iterator Data to evaluate
* @param evaluations Evaluation instances. Key: the network output number (0 to numOutputs-1). Value: the IEvaluation
* instances to perform evaluation with, for that output only. Note that not every output needs to
* have an IEvaluation[] defined.
* @return The same evaluation map, after performing evaluation
*/
public Map evaluate(MultiDataSetIterator iterator, Map evaluations){
try{
return doEvaluationHelper(iterator, evaluations);
} catch (OutOfMemoryError e){
CrashReportingUtil.writeMemoryCrashDump(this, e);
throw e;
}
}
@SuppressWarnings("unchecked")
@SafeVarargs
private final T[] doEvaluationHelper(MultiDataSetIterator iterator, T... evaluations) {
Map map = Collections.singletonMap(0, (IEvaluation[])evaluations);
return (T[])doEvaluationHelper(iterator, map).get(0);
}
private Map doEvaluationHelper(MultiDataSetIterator iterator, Map evaluations){
if (layers == null || !(getOutputLayer(0) instanceof IOutputLayer)) {
throw new IllegalStateException("Cannot evaluate network with no output layer");
}
WorkspaceUtils.assertNoWorkspacesOpen("Expected no external workspaces open at start of evaluation (doEvaluationHelper)");
if (iterator.resetSupported() && !iterator.hasNext())
iterator.reset();
MultiDataSetIterator iter =
iterator.asyncSupported() ? new AsyncMultiDataSetIterator(iterator, 2, true) : iterator;
WorkspaceMode cMode = configuration.getTrainingWorkspaceMode();
configuration.setTrainingWorkspaceMode(configuration.getInferenceWorkspaceMode());
boolean useRnnSegments = (configuration.getBackpropType() == BackpropType.TruncatedBPTT);
MemoryWorkspace outputWs;
if(getConfiguration().getInferenceWorkspaceMode() == WorkspaceMode.ENABLED){
outputWs = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM);
} else {
outputWs = new DummyWorkspace();
}
while (iter.hasNext()) {
MultiDataSet next = iter.next();
if (next.getFeatures() == null || next.getLabels() == null)
continue;
if (!useRnnSegments) {
//Standard/non-RNN case
//Assuming single output here
INDArray[] features = next.getFeatures();
INDArray[] featuresMasks = next.getFeaturesMaskArrays();
INDArray[] labels = next.getLabels();
INDArray[] labelMasks = next.getLabelsMaskArrays();
List meta = next.getExampleMetaData();
try (MemoryWorkspace ws = outputWs.notifyScopeEntered()) {
INDArray[] out = outputOfLayersDetached(false, FwdPassType.STANDARD, getOutputLayerIndices(), features, featuresMasks, labelMasks, true, false, ws);
for (Integer i : evaluations.keySet()) {
Preconditions.checkState(i >= 0 && i = 0 && i < getNumOutputArrays(), "Invalid output index: indices for outputs " +
"must be between 0 and %s inclusive - found index %s", numOutputArrays, (int) i);
INDArray currOut = out[i];
INDArray currLabel = labels[i];
try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
for (IEvaluation evaluation : evalsThisOutput)
evaluation.eval(currLabel, currOut, next.getLabelsMaskArray(i), meta);
}
}
}
} else {
rnnClearPreviousState();
int fwdLen = configuration.getTbpttFwdLength();
long tsLength = -1;
long nF = next.getFeatures().length;
for (int i = 0; i < nF; i++) {
if (next.getFeatures(i).rank() == 3) {
tsLength = next.getFeatures(i).size(2);
}
}
if (tsLength < 0) {
throw new IllegalStateException("Invalid configuration: detected TBPTT backprop type without" +
" time series features");
}
long nSubsets = tsLength / fwdLen;
if (tsLength % fwdLen != 0)
nSubsets++; //Example: 100 fwdLen with timeSeriesLength=120 -> want 2 subsets (1 of size 100, 1 of size 20)
for (int i = 0; i < nSubsets; i++) {
int startTimeIdx = i * fwdLen;
long endTimeIdx = Math.min(startTimeIdx + fwdLen, tsLength);
List subset = getSubsetsForTbptt(startTimeIdx, endTimeIdx, next.getFeatures(),
next.getLabels(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
setLayerMaskArrays(subset.get(2), subset.get(3));
try (MemoryWorkspace ws = outputWs.notifyScopeEntered()) {
INDArray[] outSub = rnnTimeStep(ws, subset.get(0));
for (Integer idx : evaluations.keySet()) {
IEvaluation[] evalsThisOutput = evaluations.get(idx);
if (evalsThisOutput == null)
continue;
INDArray labelSub = (subset.get(1) == null ? null : subset.get(1)[idx]);
INDArray maskSub = subset.get(3) == null ? null : subset.get(3)[idx];
INDArray currOut = outSub[idx];
try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
for (IEvaluation evaluation : evalsThisOutput)
evaluation.eval(labelSub, currOut, maskSub);
}
}
}
}
rnnClearPreviousState();
}
//Clear inputs, masks etc. Important to avoid leaking invalidated/out of scope arrays between iterations
clearLayersStates();
}
if (iterator.asyncSupported())
((AsyncMultiDataSetIterator) iter).shutdown();
configuration.setTrainingWorkspaceMode(cMode);
return (Map) evaluations;
}
/**
* String detailing the architecture of the computation graph.
* Vertices are printed in a topological sort order.
* Columns are Vertex Names with layer/vertex type, nIn, nOut, Total number of parameters and the Shapes of the parameters
* And the inputs to the vertex
* Will also give information about frozen layers/vertices, if any.
*
* @return Summary as a string
* @see #memoryInfo(int, InputType...)
*/
public String summary() {
return summary((InputType[])null);
}
/**
* String detailing the architecture of the computation graph.
* Will also display activation size when given an input type.
* Vertices are printed in a topological sort order.
* Columns are Vertex Names with layer/vertex type, nIn, nOut, Total number of parameters and the Shapes of the parameters
* And the inputs to the vertex
* Will also give information about frozen layers/vertices, if any.
*
* @return Summary as a string
* @see #memoryInfo(int, InputType...)
*/
public String summary(InputType... inputTypes) {
StringBuilder ret = new StringBuilder();
ret.append("\n");
int frozenParams = 0;
Map vertexOutputs = new HashMap<>(); //vertex name and output types
int currLayerIdx = -1;
List lines = new ArrayList<>();
if(inputTypes == null){
lines.add(new String[]{"VertexName (VertexType)", "nIn,nOut", "TotalParams", "ParamsShape", "Vertex Inputs"});
} else {
lines.add(new String[]{"VertexName (VertexType)", "nIn,nOut", "TotalParams", "ParamsShape", "Vertex Inputs", "InputShape", "OutputShape"});
}
int[] maxLength = new int[inputTypes == null || inputTypes.length == 0 ? 5 : 7];
String[] header = lines.get(0);
for( int i=0; i inputTypeList = new ArrayList<>();
if (currentVertex.hasLayer()) {
Layer currentLayer = ((LayerVertex) currentVertex).getLayer();
classNameArr = currentLayer.getClass().getName().split("\\.");
className = classNameArr[classNameArr.length - 1];
paramCount = String.format("%,d", currentLayer.numParams());
//layer with params
if (currentLayer.numParams() > 0) {
paramShape = "";
if (currentLayer instanceof BidirectionalLayer) { // Bidirectional layer is not an FFL
BidirectionalLayer bi = (BidirectionalLayer) currentLayer;
in = String.valueOf(((Bidirectional)bi.conf().getLayer()).getNIn());
out = String.valueOf(((Bidirectional)bi.conf().getLayer()).getNOut());
} else {
try {
in = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNIn());
out = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNOut());
}
catch (Exception e) { // Some layers, like PReLU, are just BaseLayers (but have parameters)
}
}
List paraNames = currentLayer.conf().variables();
for (String aP : paraNames) {
String paramS = ArrayUtils.toString(currentLayer.paramTable().get(aP).shape());
paramShape += aP + ":" + paramS + ", ";
}
paramShape = paramShape.subSequence(0, paramShape.lastIndexOf(",")).toString();
}
//frozen layer
if (currentLayer instanceof FrozenLayer) {
frozenParams += currentLayer.numParams();
classNameArr = ((FrozenLayer) currentLayer).getInsideLayer().getClass().getName().split("\\.");
className = "Frozen " + classNameArr[classNameArr.length - 1];
}
if (inputTypes != null) {
//get input type
String inputVertexName = vertices[currentVertex.getInputVertices()[0].getVertexIndex()].getVertexName();
InputType currentInType = vertexOutputs.get(inputVertexName);
inShape = currentInType.toString();
inputTypeList.add(currentInType);
InputPreProcessor layerVertexPreProcesor = ((org.deeplearning4j.nn.conf.graph.LayerVertex)configuration.getVertices().get(currentVertexName)).getPreProcessor();
if (layerVertexPreProcesor != null) {
inShape += "-->" + layerVertexPreProcesor.getOutputType(currentInType);
}
}
currLayerIdx++;
} else {
//get input type
if (inputTypes != null) {
VertexIndices[] inputVertices = currentVertex.getInputVertices();
if (inputVertices != null) {
for (int i = 0; i < inputVertices.length; i++) {
GraphVertex thisInputVertex = vertices[inputVertices[i].getVertexIndex()];
inputTypeList.add(vertexOutputs.get(thisInputVertex.getVertexName()));
}
}
}
}
if (inputTypes != null) {
InputType currentVertexOutputType = configuration.getVertices().get(currentVertexName).getOutputType(currLayerIdx, inputTypeList.toArray(new InputType[inputTypeList.size()]));
outShape = currentVertexOutputType.toString();
vertexOutputs.put(currentVertexName, currentVertexOutputType);
}
}
//Add on to summary string
String[] line;
if (inputTypes == null) {
line = new String[]{currentVertexName + " (" + className + ")", in + "," + out, paramCount, paramShape, connections};
} else {
line = new String[]{currentVertexName + " (" + className + ")", in + "," + out, paramCount, paramShape, connections, inShape, outShape};
}
for( int i = 0; i < line.length; i++) {
maxLength[i] = Math.max(maxLength[i], line[i] == null ? 0 : line[i].length());
}
lines.add(line);
}
StringBuilder sbFormat = new StringBuilder();
int totalLength = 0;
int pos = 0;
for(int length : maxLength){
int currLength;
if(pos++ == maxLength.length-1){
currLength = length;
} else {
currLength = length+3;
}
sbFormat.append("%-").append(currLength).append("s");
totalLength += currLength;
}
sbFormat.append("\n");
String format = sbFormat.toString();
ret.append(StringUtils.repeat("=", totalLength))
.append("\n");
boolean first = true;
for(String[] line : lines){
String formatted = String.format(format, (Object[])line);
ret.append(formatted);
if(first){
ret.append(StringUtils.repeat("=", totalLength)).append("\n");
first = false;
}
}
ret.append(StringUtils.repeat("-", totalLength))
.append(String.format("\n%30s %,d", "Total Parameters: ", params().length()))
.append(String.format("\n%30s %,d", "Trainable Parameters: ", params().length() - frozenParams))
.append(String.format("\n%30s %,d", "Frozen Parameters: ", frozenParams))
.append("\n")
.append(StringUtils.repeat("=", totalLength))
.append("\n");
return ret.toString();
}
/**
* Generate information regarding memory use for the network, for the given input types and minibatch size.
* Note that when using workspaces or CuDNN, the network should be trained for some iterations so that the memory
* workspaces have time to initialize. Without this, the memory requirements during training may be underestimated.
*
* Note also that this is the same information that is generated during an OOM crash when training or performing
* inference.
*
* @param minibatch Minibatch size to estimate memory for
* @param inputTypes Input types to the network
* @return A String with information about network memory use information
*/
public String memoryInfo(int minibatch, InputType... inputTypes){
return CrashReportingUtil.generateMemoryStatus(this, minibatch, inputTypes);
}
/**
* This method just makes sure there's no state preserved within layers
*/
public void clearLayersStates() {
for (Layer layer : layers) {
layer.clear();
layer.clearNoiseWeightParams();
}
for (GraphVertex vertex : vertices) {
vertex.clearVertex();
}
}
/**
* Increment the epoch count (in the underlying {@link ComputationGraphConfiguration} by 1).
* Note that this is done automatically when using iterator-based fitting methods, such as
* {@link #fit(DataSetIterator)} or {@link #fit(MultiDataSet)}. However, when using non-iterator fit methods
* (DataSet, MultiDataSet, INDArrays etc), the network has no way to know when one epoch ends and another starts.
* In such situations, this method can be used to increment the epoch counter.
* Note that the epoch counter is used for situations such as some learning rate schedules, and the like.
*
* The current epoch count can be obtained using {@code ComputationGraph.getConfiguration().getEpochCount()}
*/
public void incrementEpochCount(){
configuration.setEpochCount(configuration.getEpochCount() + 1);
synchronizeIterEpochCounts();
}
protected void synchronizeIterEpochCounts(){
//TODO: this is necessrry for some schedules - but the redundant values are a little ugly...
int currIter = getConfiguration().getIterationCount();
int currEpoch = getConfiguration().getEpochCount();
for(Layer l : layers){
l.setIterationCount(currIter);
l.setEpochCount(currEpoch);
}
}
/**
* Returns the number of iterations (parameter updates) that the ComputationGraph has done
* @return Number of iterations
*/
public int getIterationCount(){
return configuration.getIterationCount();
}
/**
* Returns the number of epochs that the ComputationGraph has done.
* Note that the epoch count is incremented only when {@link #fit(DataSetIterator)}, {@link #fit(MultiDataSetIterator)},
* {@link #fit(DataSetIterator, int)} or {@link #fit(MultiDataSetIterator, int)} are used.
* The epoch count can also be manually incremented using {@link #incrementEpochCount()}
* @return Number of epochs
*/
public int getEpochCount(){
return configuration.getEpochCount();
}
/**
* Save the ComputationGraph to a file. Restore using {@link #load(File, boolean)}.
* Note that this saves the updater (i.e., the state array for momentum/Adam/rmsprop etc), which is desirable
* if further training will be undertaken.
*
* @param f File to save the network to
* @see ModelSerializer ModelSerializer for more details (and saving/loading via streams)
* @see #save(File, boolean)
*/
public void save( File f ) throws IOException {
save(f, true);
}
/**
* Save the ComputationGraph to a file. Restore using {@link #load(File, boolean)}.
*
* @param f File to save the network to
* @param saveUpdater If true: save the updater (i.e., the state array for momentum/Adam/rmsprop etc), which should
* usually be saved if further training is required
* @see ModelSerializer ModelSerializer for more details (and saving/loading via streams)
* @see #save(File, boolean)
*/
public void save(File f, boolean saveUpdater) throws IOException{
ModelSerializer.writeModel(this, f, saveUpdater);
}
/**
* Restore a ComputationGraph to a file, saved using {@link #save(File)} or {@link ModelSerializer}
* @param f File to load the network from
* @param loadUpdater If true: load the updater if it is available (i.e., the state array for momentum/Adam/rmsprop
* etc) - use false if no further training is required, or true if further training
* will be undertaken
* @see ModelSerializer ModelSerializer for more details (and saving/loading via streams)
*/
public static ComputationGraph load(File f, boolean loadUpdater) throws IOException {
return ModelSerializer.restoreComputationGraph(f, loadUpdater);
}
/**
* Return a copy of the network with the parameters and activations set to use the specified (floating point) data type.
* If the existing datatype is the same as the requested dataype, the original network will be returned unchanged.
* Only floating point datatypes (DOUBLE, FLOAT, HALF) may be used.
*
* @param dataType Datatype to convert the network to
* @return The network, set to use the specified datatype for the parameters and activations
*/
public ComputationGraph convertDataType(@NonNull DataType dataType){
Preconditions.checkState(dataType.isFPType(), "Invalid DataType: %s. Can only convert network to a floating point type", dataType);
if(dataType == params().dataType()){
return this;
}
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
INDArray newParams = params().castTo(dataType);
String jsonConfig = getConfiguration().toJson();
ComputationGraphConfiguration newConf = ComputationGraphConfiguration.fromJson(jsonConfig);
newConf.setDataType(dataType);
ComputationGraph newNet = new ComputationGraph(newConf);
newNet.init(newParams, false);
Updater u = getUpdater(false);
if(u != null && u.getStateViewArray() != null){
INDArray oldUpdaterState = u.getStateViewArray();
newNet.getUpdater(true).getStateViewArray().assign(oldUpdaterState);
}
return newNet;
}
}
/**
* Set the learning rate for all layers in the network to the specified value. Note that if any learning rate
* schedules are currently present, these will be removed in favor of the new (fixed) learning rate.
*
* Note: This method not free from a performance point of view: a proper learning rate schedule
* should be used in preference to calling this method at every iteration.
*
* @param newLr New learning rate for all layers
* @see #setLearningRate(ISchedule)
* @see #setLearningRate(String, double)
*/
public void setLearningRate(double newLr) {
NetworkUtils.setLearningRate(this, newLr);
}
/**
* Set the learning rate schedule for all layers in the network to the specified schedule.
* This schedule will replace any/all existing schedules, and also any fixed learning rate values.
* Note that the iteration/epoch counts will not be reset. Use {@link ComputationGraphConfiguration#setIterationCount(int)}
* and {@link ComputationGraphConfiguration#setEpochCount(int)} if this is required
*
* @param newLr New learning rate schedule for all layers
* @see #setLearningRate(ISchedule)
* @see #setLearningRate(String, double)
*/
public void setLearningRate(ISchedule newLr) {
NetworkUtils.setLearningRate(this, newLr);
}
/**
* Set the learning rate for a single layer in the network to the specified value. Note that if any learning rate
* schedules are currently present, these will be removed in favor of the new (fixed) learning rate.
*
* Note: This method not free from a performance point of view: a proper learning rate schedule
* should be used in preference to calling this method at every iteration. Note also that
* {@link #setLearningRate(double)} should also be used in preference, when all layers need to be set to a new LR
*
* @param layerName Name of the layer to set the LR for
* @param newLr New learning rate for a single layer
* @see #setLearningRate(ISchedule)
* @see #setLearningRate(String, double)
*/
public void setLearningRate(String layerName, double newLr) {
NetworkUtils.setLearningRate(this, layerName, newLr);
}
/**
* Set the learning rate schedule for a single layer in the network to the specified value.
* Note also that {@link #setLearningRate(ISchedule)} should also be used in preference, when all layers need
* to be set to a new LR schedule.
* This schedule will replace any/all existing schedules, and also any fixed learning rate values.
* Note also that the iteration/epoch counts will not be reset. Use {@link ComputationGraphConfiguration#setIterationCount(int)}
* and {@link ComputationGraphConfiguration#setEpochCount(int)} if this is required
*
* @param layerName Name of the layer to set the LR schedule for
* @param newLr New learning rate for a single layer
* @see #setLearningRate(ISchedule)
* @see #setLearningRate(String, double)
*/
public void setLearningRate(String layerName, ISchedule newLr) {
NetworkUtils.setLearningRate(this, layerName, newLr);
}
/**
* Get the current learning rate, for the specified layer, from the network.
* Note: If the layer has no learning rate (no parameters, or an updater without a learning rate) then null is returned
* @param layerName Layer name
* @return Learning rate for the specified layer, or null
*/
public Double getLearningRate(String layerName){
return NetworkUtils.getLearningRate(this, layerName);
}
/**
* Return the layer size (number of units) for the specified layer.
* Note that the meaning of the "layer size" can depend on the type of layer. For example:
* - DenseLayer, OutputLayer, recurrent layers: number of units (nOut configuration option)
* - ConvolutionLayer: the channels (number of channels)
* - Subsampling layers, global pooling layers, etc: size of 0 is always returned
*
* @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive
* @return Size of the layer
*/
public long layerSize(int layer) {
if (layer < 0 || layer > layers.length) {
throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and "
+ (layers.length - 1) + " inclusive");
}
return layerSize(layers[layer].conf().getLayer().getLayerName());
}
/**
* Return the input size (number of inputs) for the specified layer.
* Note that the meaning of the "input size" can depend on the type of layer. For example:
* - DenseLayer, OutputLayer, etc: the feature vector size (nIn configuration option)
* - Recurrent layers: the feature vector size per time step (nIn configuration option)
* - ConvolutionLayer: the channels (number of channels)
* - Subsampling layers, global pooling layers, etc: size of 0 is always returned
*
* @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive
* @return Size of the layer
*/
public long layerInputSize(int layer) {
if (layer < 0 || layer > layers.length) {
throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and "
+ (layers.length - 1) + " inclusive");
}
return layerInputSize(layers[layer].conf().getLayer().getLayerName());
}
/**
* Return the layer size (number of units) for the specified layer.
* Note that the meaning of the "layer size" can depend on the type of layer. For example:
* - DenseLayer, OutputLayer, recurrent layers: number of units (nOut configuration option)
* - ConvolutionLayer: the channels (number of channels)
* - Subsampling layers, global pooling layers, etc: size of 0 is always returned
*
* @param layerName Name of the layer to get the size of
* @return Size of the layer
*/
public long layerSize(String layerName) {
Layer l = getLayer(layerName);
if(l == null){
throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists");
}
org.deeplearning4j.nn.conf.layers.Layer conf = l.conf().getLayer();
if (conf == null || !(conf instanceof FeedForwardLayer)) {
return 0;
}
FeedForwardLayer ffl = (FeedForwardLayer) conf;
return ffl.getNOut();
}
/**
* Return the input size (number of inputs) for the specified layer.
* Note that the meaning of the "input size" can depend on the type of layer. For example:
* - DenseLayer, OutputLayer, etc: the feature vector size (nIn configuration option)
* - Recurrent layers: the feature vector size per time step (nIn configuration option)
* - ConvolutionLayer: the channels (number of channels)
* - Subsampling layers, global pooling layers, etc: size of 0 is always returned
*
* @param layerName Name of the layer to get the size of
* @return Size of the layer
*/
public long layerInputSize(String layerName) {
Layer l = getLayer(layerName);
if(l == null){
throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists");
}
org.deeplearning4j.nn.conf.layers.Layer conf = l.conf().getLayer();
if (conf == null || !(conf instanceof FeedForwardLayer)) {
return 0;
}
FeedForwardLayer ffl = (FeedForwardLayer) conf;
return ffl.getNIn();
}
/**
* Indicates whether some other object is "equal to" this one.
*
* The {@code equals} method implements an equivalence relation
* on non-null object references:
*
* - It is reflexive: for any non-null reference value
* {@code x}, {@code x.equals(x)} should return
* {@code true}.
*
- It is symmetric: for any non-null reference values
* {@code x} and {@code y}, {@code x.equals(y)}
* should return {@code true} if and only if
* {@code y.equals(x)} returns {@code true}.
*
- It is transitive: for any non-null reference values
* {@code x}, {@code y}, and {@code z}, if
* {@code x.equals(y)} returns {@code true} and
* {@code y.equals(z)} returns {@code true}, then
* {@code x.equals(z)} should return {@code true}.
*
- It is consistent: for any non-null reference values
* {@code x} and {@code y}, multiple invocations of
* {@code x.equals(y)} consistently return {@code true}
* or consistently return {@code false}, provided no
* information used in {@code equals} comparisons on the
* objects is modified.
*
- For any non-null reference value {@code x},
* {@code x.equals(null)} should return {@code false}.
*
*
* The {@code equals} method for class {@code Object} implements
* the most discriminating possible equivalence relation on objects;
* that is, for any non-null reference values {@code x} and
* {@code y}, this method returns {@code true} if and only
* if {@code x} and {@code y} refer to the same object
* ({@code x == y} has the value {@code true}).
*
* Note that it is generally necessary to override the {@code hashCode}
* method whenever this method is overridden, so as to maintain the
* general contract for the {@code hashCode} method, which states
* that equal objects must have equal hash codes.
*
* @param obj the reference object with which to compare.
* @return {@code true} if this object is the same as the obj
* argument; {@code false} otherwise.
* @see #hashCode()
* @see HashMap
*/
@Override
public boolean equals(Object obj) {
if (obj == null)
return false;
if (obj instanceof ComputationGraph) {
ComputationGraph network = (ComputationGraph) obj;
boolean paramsEquals = network.params().equals(params());
boolean confEquals = getConfiguration().equals(network.getConfiguration());
boolean updaterEquals = getUpdater().equals(network.getUpdater());
return paramsEquals && confEquals && updaterEquals;
}
return false;
}
private void writeObject(ObjectOutputStream oos) throws IOException {
ModelSerializer.writeModel(this, oos, true);
}
private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException {
val cg = ModelSerializer.restoreComputationGraph(ois, true);
this.defaultConfiguration = cg.defaultConfiguration.clone();
this.configuration = cg.configuration.clone();
this.init();
this.flattenedParams.assign(cg.flattenedParams);
if (cg.getUpdater() != null && cg.getUpdater(false).getStateViewArray() != null)
this.getUpdater(true).getStateViewArray().assign(cg.getUpdater(false).getStateViewArray());
}
/**
* Close the network and deallocate all native memory, including: parameters, gradients, updater memory and workspaces
* Note that the network should not be used again for any purpose after it has been closed
*/
@Override
public void close(){
//Close the INDArray and dealloc
if(flattenedParams.closeable())
flattenedParams.close();
if(flattenedGradients != null && flattenedGradients.closeable())
flattenedGradients.close();
Updater u = getUpdater(false);
if(u != null && u.getStateViewArray() != null) {
INDArray state = u.getStateViewArray();
if(state.closeable())
state.close();
}
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
System.gc();
}
}