Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.conf;
import lombok.*;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.OutputLayerUtil;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.io.Serializable;
import java.util.*;
@Data
@EqualsAndHashCode(exclude = {"trainingWorkspaceMode", "inferenceWorkspaceMode", "cacheMode", "topologicalOrder", "topologicalOrderStr"})
@AllArgsConstructor(access = AccessLevel.PRIVATE)
@NoArgsConstructor
public class ComputationGraphConfiguration implements Serializable, Cloneable {
private static Logger log = LoggerFactory.getLogger(ComputationGraphConfiguration.class);
protected Map vertices = new LinkedHashMap<>();
protected Map> vertexInputs = new LinkedHashMap<>();
@Getter
@Setter
protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;
@Getter
@Setter
protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;
@Getter
@Setter
protected CacheMode cacheMode;
@Getter
@Setter
protected DataType dataType = DataType.FLOAT; //Default to float for 1.0.0-beta3 and earlier nets
protected boolean validateOutputLayerConfig = true; //Default for 1.0.0-beta3 and earlier nets
/**
* List of inputs to the network, by name
*/
protected List networkInputs;
/**
* List of network outputs, by name
*/
protected List networkOutputs;
protected BackpropType backpropType = BackpropType.Standard;
protected int tbpttFwdLength = 20;
protected int tbpttBackLength = 20;
protected NeuralNetConfiguration defaultConfiguration;
//Counter for the number of parameter updates so far
// This is important for learning rate schedules, for example, and is stored here to ensure it is persisted
// for Spark and model serialization
protected int iterationCount = 0;
//Counter for the number of epochs completed so far. Used for per-epoch schedules
protected int epochCount = 0;
protected int[] topologicalOrder;
protected List topologicalOrderStr;
/**
* @return YAML representation of configuration
*/
public String toYaml() {
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
synchronized (mapper) {
try {
return mapper.writeValueAsString(this);
} catch (org.nd4j.shade.jackson.core.JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}
/**
* Create a neural net configuration from YAML
*
* @param json the neural net configuration from YAML
* @return {@link ComputationGraphConfiguration}
*/
public static ComputationGraphConfiguration fromYaml(String json) {
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
try {
return mapper.readValue(json, ComputationGraphConfiguration.class);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* @return JSON representation of computation graph configuration
*/
public String toJson() {
//As per MultiLayerConfiguration.toJson()
ObjectMapper mapper = NeuralNetConfiguration.mapper();
synchronized (mapper) {
//JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
//when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243
try {
return mapper.writeValueAsString(this);
} catch (org.nd4j.shade.jackson.core.JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}
/**
* Create a computation graph configuration from json
*
* @param json the neural net configuration from json
* @return {@link ComputationGraphConfiguration}
*/
public static ComputationGraphConfiguration fromJson(String json) {
//As per MultiLayerConfiguration.fromJson()
ObjectMapper mapper = NeuralNetConfiguration.mapper();
ComputationGraphConfiguration conf;
try {
conf = mapper.readValue(json, ComputationGraphConfiguration.class);
} catch (InvalidTypeIdException e){
if(e.getMessage().contains("@class")){
try{
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
return JsonMappers.getLegacyMapper().readValue(json, ComputationGraphConfiguration.class);
} catch (InvalidTypeIdException e2){
//Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.Layer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..."
//1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work
String msg = e2.getMessage();
if(msg != null && msg.contains("Could not resolve type id")){
throw new RuntimeException("Error deserializing ComputationGraphConfiguration - configuration may have a custom " +
"layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom" +
" layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J", e);
}
throw new RuntimeException(e2);
} catch (IOException e2){
throw new RuntimeException(e2);
}
}
throw new RuntimeException(e);
} catch (Exception e) {
//Check if this exception came from legacy deserializer...
String msg = e.getMessage();
if(msg != null && msg.contains("legacy")){
throw new RuntimeException("Error deserializing ComputationGraphConfiguration - configuration may have a custom " +
"layer, vertex or preprocessor, in pre version 1.0.0-alpha JSON format. These layers can be " +
"deserialized by first registering them with NeuralNetConfiguration.registerLegacyCustomClassesForJSON(Class...)", e);
}
throw new RuntimeException(e);
}
//To maintain backward compatibility after activation function refactoring (configs generated with v0.7.1 or earlier)
// Previously: enumeration used for activation functions. Now: use classes
int layerCount = 0;
Map vertexMap = conf.getVertices();
JsonNode vertices = null;
for (Map.Entry entry : vertexMap.entrySet()) {
if (!(entry.getValue() instanceof LayerVertex)) {
continue;
}
LayerVertex lv = (LayerVertex) entry.getValue();
if (lv.getLayerConf() != null && lv.getLayerConf().getLayer() != null) {
Layer layer = lv.getLayerConf().getLayer();
if (layer instanceof BaseLayer && ((BaseLayer) layer).getActivationFn() == null) {
String layerName = layer.getLayerName();
try {
if (vertices == null) {
JsonNode jsonNode = mapper.readTree(json);
vertices = jsonNode.get("vertices");
}
JsonNode vertexNode = vertices.get(layerName);
JsonNode layerVertexNode = vertexNode.get("LayerVertex");
if (layerVertexNode == null || !layerVertexNode.has("layerConf")
|| !layerVertexNode.get("layerConf").has("layer")) {
continue;
}
JsonNode layerWrapperNode = layerVertexNode.get("layerConf").get("layer");
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
continue;
}
JsonNode layerNode = layerWrapperNode.elements().next();
JsonNode activationFunction = layerNode.get("activationFunction"); //Should only have 1 element: "dense", "output", etc
if (activationFunction != null) {
IActivation ia = Activation.fromString(activationFunction.asText()).getActivationFunction();
((BaseLayer) layer).setActivationFn(ia);
}
} catch (IOException e) {
log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON",
e);
}
}
handleLegacyWeightInitFromJson(json, layer, mapper, vertices);
}
}
return conf;
}
/**
* Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied from handling of {@link Activation}
* above.
* @return True if all is well and layer iteration shall continue. False else-wise.
*/
private static void handleLegacyWeightInitFromJson(String json, Layer layer, ObjectMapper mapper, JsonNode vertices) {
if (layer instanceof BaseLayer && ((BaseLayer) layer).getWeightInitFn() == null) {
String layerName = layer.getLayerName();
try {
if (vertices == null) {
JsonNode jsonNode = mapper.readTree(json);
vertices = jsonNode.get("vertices");
}
JsonNode vertexNode = vertices.get(layerName);
JsonNode layerVertexNode = vertexNode.get("LayerVertex");
if (layerVertexNode == null || !layerVertexNode.has("layerConf")
|| !layerVertexNode.get("layerConf").has("layer")) {
return;
}
JsonNode layerWrapperNode = layerVertexNode.get("layerConf").get("layer");
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
return;
}
JsonNode layerNode = layerWrapperNode.elements().next();
JsonNode weightInit = layerNode.get("weightInit"); //Should only have 1 element: "dense", "output", etc
JsonNode distribution = layerNode.get("dist");
Distribution dist = null;
if(distribution != null) {
dist = mapper.treeToValue(distribution, Distribution.class);
}
if (weightInit != null) {
final IWeightInit wi = WeightInit.valueOf(weightInit.asText()).getWeightInitFunction(dist);
((BaseLayer) layer).setWeightInitFn(wi);
}
} catch (IOException e) {
log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON",
e);
}
}
}
@Override
public String toString() {
return toJson();
}
@Override
public ComputationGraphConfiguration clone() {
ComputationGraphConfiguration conf = new ComputationGraphConfiguration();
conf.vertices = new LinkedHashMap<>();
for (Map.Entry entry : this.vertices.entrySet()) {
conf.vertices.put(entry.getKey(), entry.getValue().clone());
}
conf.vertexInputs = new LinkedHashMap<>();
for (Map.Entry> entry : this.vertexInputs.entrySet()) {
conf.vertexInputs.put(entry.getKey(), new ArrayList<>(entry.getValue()));
}
conf.networkInputs = new ArrayList<>();
conf.networkInputs = new ArrayList<>(this.networkInputs);
conf.networkOutputs = new ArrayList<>(this.networkOutputs);
conf.backpropType = backpropType;
conf.tbpttFwdLength = tbpttFwdLength;
conf.tbpttBackLength = tbpttBackLength;
conf.defaultConfiguration = defaultConfiguration.clone();
conf.trainingWorkspaceMode = trainingWorkspaceMode;
conf.inferenceWorkspaceMode = inferenceWorkspaceMode;
conf.cacheMode = this.cacheMode;
conf.defaultConfiguration.cacheMode = this.cacheMode;
conf.validateOutputLayerConfig = this.validateOutputLayerConfig;
conf.dataType = this.dataType;
return conf;
}
/**
* Check the configuration, make sure it is valid
*
* @throws IllegalStateException if configuration is not valid
*/
public void validate() {
validate(false, false);
}
/**
* Check the configuration, make sure it is valid
*
* @param allowDisconnected If true: don't throw an exception on vertices that are 'disconnected'. A disconnected
* vertex is one that is not an output, and doesn't connect to any other vertices. i.e.,
* it's output activations don't go anywhere
* @throws IllegalStateException if configuration is not valid
*/
public void validate(boolean allowDisconnected, boolean allowNoOutput){
if (networkInputs == null || networkInputs.isEmpty()) {
throw new IllegalStateException( "Invalid configuration: network has no inputs. " +
"Use .addInputs(String...) to label (and give an ordering to) the network inputs");
}
if ((networkOutputs == null || networkOutputs.isEmpty()) && !allowNoOutput) {
throw new IllegalStateException("Invalid configuration: network has no outputs. " +
"Use .setOutput(String...) to specify (and give an ordering to) the output vertices, " +
"or use allowNoOutputs(true) to disable this check");
}
//Check uniqueness of names for inputs, layers, GraphNodes
for (String s : networkInputs) {
if (vertices.containsKey(s)) {
throw new IllegalStateException("Invalid configuration: name \"" + s
+ "\" is present in both network inputs and graph vertices/layers");
}
}
//Check: each layer & node has at least one input
for (Map.Entry> e : vertexInputs.entrySet()) {
String nodeName = e.getKey();
if (e.getValue() == null || e.getValue().isEmpty()) {
throw new IllegalStateException("Invalid configuration: vertex \"" + nodeName + "\" has no inputs");
}
for (String inputName : e.getValue()) {
if (!vertices.containsKey(inputName) && !networkInputs.contains(inputName)) {
throw new IllegalStateException("Invalid configuration: Vertex \"" + nodeName + "\" has input \""
+ inputName + "\" that does not exist");
}
}
}
//Check output names:
if(networkOutputs != null) {
for (String s : networkOutputs) {
if (!vertices.containsKey(s)) {
throw new IllegalStateException(
"Invalid configuration: Output name \"" + s + "\" is not a valid vertex");
}
}
}
//Check that there aren't any disconnected vertices
if(!allowDisconnected){
//A vertex is considered disconnected if it is (a) not an output vertex, and (b) isn't used an as input
// to another layer
Set seenAsInput = new HashSet<>();
seenAsInput.addAll(networkOutputs);
for(Map.Entry> e : vertexInputs.entrySet()){
seenAsInput.addAll(e.getValue());
}
Set disconnected = new HashSet<>();
disconnected.addAll(networkInputs);
disconnected.addAll(vertices.keySet());
disconnected.removeAll(seenAsInput);
if(!disconnected.isEmpty() && !allowNoOutput){ //If allowing no output: by definition we have disconnected vertices
throw new IllegalStateException("Invalid configuration: disconnected vertices found - " + disconnected
+ ". Disconnected vertices are those that do not connect to either another vertex, and are also"
+ " not a network output. This vertex can be set as an output using setOutputs(String...). "
+ "To disable this error (i.e., allow network configurations with" +
" disconnected vertices) use GraphBuilder.allowDisconnected(true)");
}
}
//Check for no graph cycles: done in ComputationGraph.init()
}
/**
* Add preprocessors automatically, given the specified types of inputs for the network. Inputs are specified using the
* {@link InputType} class, in the same order in which the inputs were defined in the original configuration.
* For example, in a network with two inputs: a convolutional input (28x28x1 images) and feed forward inputs, use
* {@code .addPreProcessors(InputType.convolutional(28,28,1),InputType.feedForward())}.
* For the CNN->Dense and CNN->RNN transitions, the nIns on the Dense/RNN layers will also be added automatically.
* NOTE: This method will be called automatically when using the
* {@link GraphBuilder#setInputTypes(InputType...)} functionality.
* See that method for details.
*/
public void addPreProcessors(InputType... inputTypes) {
getLayerActivationTypes(true, inputTypes);
}
/**
* Add preprocessors automatically, given the specified types of inputs for the network. Inputs are specified using the
* {@link InputType} class, in the same order in which the inputs were defined in the original configuration.
* For example, in a network with two inputs: a convolutional input (28x28x1 images) and feed forward inputs, use
* {@code .addPreProcessors(InputType.convolutional(28,28,1),InputType.feedForward())}.
* For the CNN->Dense and CNN->RNN transitions, the nIns on the Dense/RNN layers will also be added automatically.
* NOTE: This method will be called automatically when using the
* {@link GraphBuilder#setInputTypes(InputType...)} functionality.
* See that method for details.
* @param forceOverrideInputs whether to forcibly over ride inputs or not
* when setting up pre processing
* @param inputTypes the input types to set
*/
public void addPreProcessors(boolean addPreprocIfNecessary,boolean forceOverrideInputs,InputType... inputTypes) {
getLayerActivationTypes(addPreprocIfNecessary,forceOverrideInputs, inputTypes);
}
/**
* Add preprocessors automatically, given the specified types of inputs for the network. Inputs are specified using the
* {@link InputType} class, in the same order in which the inputs were defined in the original configuration.
* For example, in a network with two inputs: a convolutional input (28x28x1 images) and feed forward inputs, use
* {@code .addPreProcessors(InputType.convolutional(28,28,1),InputType.feedForward())}.
* For the CNN->Dense and CNN->RNN transitions, the nIns on the Dense/RNN layers will also be added automatically.
* NOTE: This method will be called automatically when using the
* {@link GraphBuilder#setInputTypes(InputType...)} functionality.
* See that method for details.
* @param forceOverrideInputs whether to forcibly over ride inputs or not
* when setting up pre processing
* @param inputTypes the input types to set
*/
public void addPreProcessors(boolean forceOverrideInputs,InputType... inputTypes) {
getLayerActivationTypes(true,forceOverrideInputs, inputTypes);
}
/**
* For the given input shape/type for the network, return a map of activation sizes for each layer and vertex
* in the graph. Note that this method will automatically add preprocessors if required, to handle (for example)
* the transition between CNN and dense layers.
* @param inputTypes Input types for the network
* @return A map of activation types for the graph (key: vertex name. value: type of activations out of that vertex)
*/
public Map getLayerActivationTypes(InputType... inputTypes) {
return getLayerActivationTypes(true, inputTypes);
}
/**
* For the given input shape/type for the network, return a map of activation sizes for each layer and vertex
* in the graph. Note that this method can also add preprocessors if required (to handle transitions between some
* layer types such as convolutional -> dense, for example)
* @param addPreprocIfNecessary If true: add any required preprocessors, in the process of calculating the layer
* activation sizes
* @param overrideInputs whether to forcibly over ride inputs when
* setting inputs
* @param inputTypes Input types for the network
* @return A map of activation types for the graph (key: vertex name. value: type of activations out of that vertex)
*/
public Map getLayerActivationTypes(boolean addPreprocIfNecessary,boolean overrideInputs, InputType... inputTypes) {
if (inputTypes == null || inputTypes.length != networkInputs.size()) {
throw new IllegalArgumentException(
"Invalid number of InputTypes: cannot add preprocessors if number of InputType "
+ "objects differs from number of network inputs");
}
//Now: need to do essentially a forward pass through the network, to work out what type of preprocessors to add
//To do this: need to know what the output types are for each GraphVertex.
//Do topological sort
List topologicalOrdering = topologicalOrdering();
//Now, given the topological sort: do equivalent of forward pass
Map vertexOutputs = new LinkedHashMap<>();
int currLayerIdx = -1;
for (String s : topologicalOrdering) {
int inputIdx = networkInputs.indexOf(s);
if (inputIdx != -1) {
vertexOutputs.put(s, inputTypes[inputIdx]);
continue;
}
GraphVertex gv = vertices.get(s);
List inputTypeList = new ArrayList<>();
if (gv instanceof LayerVertex) {
//Add preprocessor, if necessary:
String in = vertexInputs.get(s).get(0);
InputType layerInput = vertexOutputs.get(in);
inputTypeList.add(layerInput);
LayerVertex lv = (LayerVertex) gv;
Layer l = lv.getLayerConf().getLayer();
//Preprocessors - add if necessary
if (lv.getPreProcessor() == null) {
//But don't override preprocessors that are manually defined; if none has been defined,
//add the appropriate preprocessor for this input type/layer combination
InputPreProcessor preproc = l.getPreProcessorForInputType(layerInput);
lv.setPreProcessor(preproc);
}
//Set nIn value for layer (if not already set)
InputType afterPreproc = layerInput;
if (lv.getPreProcessor() != null && addPreprocIfNecessary) {
InputPreProcessor ip = lv.getPreProcessor();
afterPreproc = ip.getOutputType(layerInput);
}
l.setNIn(afterPreproc, overrideInputs);
currLayerIdx++;
} else {
List inputs = vertexInputs.get(s);
if (inputs != null) {
for (String inputVertexName : inputs) {
inputTypeList.add(vertexOutputs.get(inputVertexName));
}
}
}
InputType outputFromVertex =
gv.getOutputType(currLayerIdx, inputTypeList.toArray(new InputType[inputTypeList.size()]));
vertexOutputs.put(s, outputFromVertex);
}
return vertexOutputs;
}
/**
* For the given input shape/type for the network, return a map of activation sizes for each layer and vertex
* in the graph. Note that this method can also add preprocessors if required (to handle transitions between some
* layer types such as convolutional -> dense, for example)
* @param addPreprocIfNecessary If true: add any required preprocessors, in the process of calculating the layer
* activation sizes
* @param inputTypes Input types for the network
* @return A map of activation types for the graph (key: vertex name. value: type of activations out of that vertex)
*/
public Map getLayerActivationTypes(boolean addPreprocIfNecessary, InputType... inputTypes) {
return getLayerActivationTypes(addPreprocIfNecessary,true,inputTypes);
}
private Map> verticesOutputTo() {
Map> verticesOutputTo = new HashMap<>(); //Key: vertex. Values: vertices that this node is an input for
for (Map.Entry entry : vertices.entrySet()) {
String vertexName = entry.getKey();
List vertexInputNames;
vertexInputNames = vertexInputs.get(vertexName);
if (vertexInputNames == null)
continue;
//Build reverse network structure:
for (String s : vertexInputNames) {
List list = verticesOutputTo.get(s);
if (list == null) {
list = new ArrayList<>();
verticesOutputTo.put(s, list);
}
list.add(vertexName); //Edge: s -> vertexName
}
}
return verticesOutputTo;
}
private List topologicalOrdering() {
//First step: build network in reverse order (i.e., define map of a -> list(b) instead of list(a) -> b)
Map> verticesOutputTo = verticesOutputTo();
LinkedList noIncomingEdges = new LinkedList<>(networkInputs); //Set of all nodes with no incoming edges
List topologicalOrdering = new ArrayList<>();
Map> inputEdges = new HashMap<>();
for (Map.Entry> entry : vertexInputs.entrySet()) {
inputEdges.put(entry.getKey(), new HashSet<>(entry.getValue()));
}
while (!noIncomingEdges.isEmpty()) {
String next = noIncomingEdges.removeFirst();
topologicalOrdering.add(next);
//Remove edges next -> vertexOuputsTo[...] from graph;
List nextEdges = verticesOutputTo.get(next);
if (nextEdges != null && !nextEdges.isEmpty()) {
for (String s : nextEdges) {
Set set = inputEdges.get(s);
set.remove(next);
if (set.isEmpty()) {
noIncomingEdges.add(s); //No remaining edges for vertex i -> add to list for processing
}
}
}
}
//If any edges remain in the graph: graph has cycles:
for (Map.Entry> entry : inputEdges.entrySet()) {
Set set = entry.getValue();
if (set == null)
continue;
if (!set.isEmpty())
throw new IllegalStateException(
"Invalid configuration: cycle detected in graph. Cannot calculate topological ordering with graph cycle ("
+ "cycle includes vertex \"" + entry.getKey() + "\")");
}
return topologicalOrdering;
}
/**
* Get a {@link MemoryReport} for the given computation graph configuration. This is used to estimate the
* memory requirements for the given network configuration and input
*
* @param inputTypes Input types for the network
* @return Memory report for the network
*/
public NetworkMemoryReport getMemoryReport(InputType... inputTypes) {
Map memoryReportMap = new LinkedHashMap<>();
List topologicalOrdering = topologicalOrdering();
Map vertexOutputs = new HashMap<>();
int currLayerIdx = -1;
for (String s : topologicalOrdering) {
int inputIdx = networkInputs.indexOf(s);
if (inputIdx != -1) {
vertexOutputs.put(s, inputTypes[inputIdx]);
continue;
}
GraphVertex gv = vertices.get(s);
List inputTypeList = new ArrayList<>();
if (gv instanceof LayerVertex) {
//Add preprocessor, if necessary:
String in = vertexInputs.get(s).get(0);
InputType layerInput = vertexOutputs.get(in);
inputTypeList.add(layerInput);
currLayerIdx++;
} else {
List inputs = vertexInputs.get(s);
if (inputs != null) {
for (String inputVertexName : inputs) {
inputTypeList.add(vertexOutputs.get(inputVertexName));
}
}
}
InputType outputFromVertex =
gv.getOutputType(currLayerIdx, inputTypeList.toArray(new InputType[inputTypeList.size()]));
vertexOutputs.put(s, outputFromVertex);
MemoryReport mr = gv.getMemoryReport(inputTypeList.toArray(new InputType[inputTypeList.size()]));
memoryReportMap.put(s, mr);
}
return new NetworkMemoryReport(memoryReportMap, ComputationGraphConfiguration.class, "ComputationGraph",
inputTypes);
}
@Data
public static class GraphBuilder {
private static final int DEFAULT_TBPTT_LENGTH = 20;
protected Map vertices = new LinkedHashMap<>();
/**
* Key: graph node. Values: input to that node
*/
protected Map> vertexInputs = new LinkedHashMap<>();
protected List networkInputs = new ArrayList<>();
protected List networkInputTypes = new ArrayList<>();
protected List networkOutputs = new ArrayList<>();
protected BackpropType backpropType = BackpropType.Standard;
protected int tbpttFwdLength = DEFAULT_TBPTT_LENGTH;
protected int tbpttBackLength = DEFAULT_TBPTT_LENGTH;
protected Map inputPreProcessors = new LinkedHashMap<>();
protected NeuralNetConfiguration.Builder globalConfiguration;
protected boolean allowDisconnected = false;
protected boolean allowNoOutput = false;
protected boolean validateOutputConfig = true;
protected boolean validateTbpttConfig = true;
protected String lastAdded = null;
public GraphBuilder(NeuralNetConfiguration.Builder globalConfiguration) {
this.globalConfiguration = globalConfiguration;
}
public GraphBuilder(ComputationGraphConfiguration newConf, NeuralNetConfiguration.Builder globalConfiguration) {
ComputationGraphConfiguration clonedConf = newConf.clone();
this.vertices = clonedConf.getVertices();
this.vertexInputs = clonedConf.getVertexInputs();
this.networkInputs = clonedConf.getNetworkInputs();
this.networkOutputs = clonedConf.getNetworkOutputs();
this.backpropType = clonedConf.getBackpropType();
this.tbpttFwdLength = clonedConf.getTbpttFwdLength();
this.tbpttBackLength = clonedConf.getTbpttBackLength();
this.globalConfiguration = globalConfiguration;
//this.getGlobalConfiguration().setSeed(clonedConf.getDefaultConfiguration().getSeed());
}
/**
* Specify the processors for a given layer
* These are used at each layer for doing things like normalization and shaping of input.
* Note: preprocessors can also be defined using the {@link #addLayer(String, Layer, InputPreProcessor, String...)} method.
*
* @param layer the name of the layer that this preprocessor will be used with
* @param processor the preprocessor to use for the specified layer
*/
public GraphBuilder inputPreProcessor(String layer, InputPreProcessor processor) {
inputPreProcessors.put(layer, processor);
return this;
}
/**
* The type of backprop. Default setting is used for most networks (MLP, CNN etc),
* but optionally truncated BPTT can be used for training recurrent neural networks.
* If using TruncatedBPTT make sure you set both tBPTTForwardLength() and tBPTTBackwardLength()
*
* @param type Type of backprop. Default: BackpropType.Standard
*/
public GraphBuilder backpropType(BackpropType type) {
this.backpropType = type;
return this;
}
/**
* When doing truncated BPTT: how many steps of forward pass should we do
* before doing (truncated) backprop?
* Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
* Typically tBPTTForwardLength parameter is same as the tBPTTBackwardLength parameter,
* but may be larger than it in some circumstances (but never smaller)
* Ideally your training data time series length should be divisible by this
* This is the k1 parameter on pg23 of
* http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf
*
* @param forwardLength Forward length > 0, >= backwardLength
*/
public GraphBuilder tBPTTForwardLength(int forwardLength) {
this.tbpttFwdLength = forwardLength;
return this;
}
/**
* When doing truncated BPTT: how many steps of backward should we do?
* Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
* This is the k2 parameter on pg23 of
* http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf
*
* @param backwardLength <= forwardLength
*/
public GraphBuilder tBPTTBackwardLength(int backwardLength) {
this.tbpttBackLength = backwardLength;
return this;
}
/**
* When doing truncated backpropagation through time (tBPTT): how many steps should we do?
* Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
* See: http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf
*
* @param tbpttLength length > 0
*/
public GraphBuilder tBPTTLength(int tbpttLength){
tBPTTForwardLength(tbpttLength);
return tBPTTBackwardLength(tbpttLength);
}
/**
* Add a layer, with no {@link InputPreProcessor}, with the specified name and specified inputs.
*
* @param layerName Name/label of the layer to add
* @param layer The layer configuration
* @param layerInputs Inputs to this layer. Inputs may be other layers, GraphVertex objects,
* on a combination of the two.
* @see #addLayer(String, Layer, InputPreProcessor, String...)
*/
public GraphBuilder addLayer(String layerName, Layer layer, String... layerInputs) {
return addLayer(layerName, layer, null, layerInputs);
}
/**
* Add a layer, with no {@link InputPreProcessor}, with the specified name
* and input from the last added layer/vertex.
*
* @param layerName Name/label of the layer to add
* @param layer The layer configuration
* @see #addLayer(String, Layer, InputPreProcessor, String...)
*/
public GraphBuilder appendLayer(String layerName, Layer layer) {
return appendLayer(layerName, layer, null);
}
/**
* Add a layer, with no {@link InputPreProcessor}, with the specified name and specified inputs.
*
* @param layerName Name/label of the layer to add
* @param layer The layer configuration
* @param layerInputs Inputs to this layer. Inputs may be other layers, GraphVertex objects,
* on a combination of the two.
* @see #addLayer(String, Layer, InputPreProcessor, String...)
*/
public GraphBuilder layer(int layerName, Layer layer, String... layerInputs) {
return addLayer(String.valueOf(layerName), layer, null, layerInputs);
}
/**
* Add a layer, with no {@link InputPreProcessor}, with the specified name and specified inputs.
*
* @param layerName Name/label of the layer to add
* @param layer The layer configuration
* @param layerInputs Inputs to this layer. Inputs may be other layers, GraphVertex objects,
* on a combination of the two.
* @see #addLayer(String, Layer, InputPreProcessor, String...)
*/
public GraphBuilder layer(String layerName, Layer layer, String... layerInputs) {
return addLayer(layerName, layer, null, layerInputs);
}
/**
* Add a layer and an {@link InputPreProcessor}, with the specified name and specified inputs.
*
* @param layerName Name/label of the layer to add
* @param layer The layer configuration
* @param preProcessor The InputPreProcessor to use with this layer.
* @param layerInputs Inputs to this layer. Inputs may be other layers, GraphVertex objects,
* on a combination of the two.
*/
public GraphBuilder addLayer(String layerName, Layer layer, InputPreProcessor preProcessor,
String... layerInputs) {
NeuralNetConfiguration.Builder builder = globalConfiguration.clone();
builder.layer(layer);
addVertex(layerName, new LayerVertex(builder.build(), preProcessor), layerInputs);
layer.setLayerName(layerName);
return this;
}
/**
* Add a layer and an {@link InputPreProcessor}, with the specified name
* and input from the last added layer/vertex.
*
* @param layerName Name/label of the layer to add
* @param layer The layer configuration
* @param preProcessor The InputPreProcessor to use with this layer.
*/
public GraphBuilder appendLayer(String layerName, Layer layer, InputPreProcessor preProcessor) {
if(lastAdded == null){
throw new IllegalStateException("Can not use appendLayer with no previous layers");
}
addLayer(layerName, layer, preProcessor, lastAdded);
return this;
}
/**
* Add a layer and an {@link InputPreProcessor}, with the specified name and specified inputs.
*
* @param layerName Name/label of the layer to add
* @param layer The layer configuration
* @param preProcessor The InputPreProcessor to use with this layer.
* @param layerInputs Inputs to this layer. Inputs may be other layers, GraphVertex objects,
* on a combination of the two.
*/
public GraphBuilder layer(String layerName, Layer layer, InputPreProcessor preProcessor,
String... layerInputs) {
return addLayer(layerName, layer, preProcessor, layerInputs);
}
/**
* Intended for use with the transfer learning API. Users discouraged from employing it directly.
* Removes the specified vertex from the vertices list, it's connections and associated preprocessor
* If the vertex removed is an output vertex it will also be removed from the list of outputs
* @param vertexName Name of the vertex to remove
*/
public GraphBuilder removeVertex(String vertexName) {
removeVertex(vertexName, true);
return this;
}
/**
* Intended for use with the transfer learning API. Users discouraged from employing it directly.
* Removes the specified vertex from the vertices list,
* Removes it's connections (associated preprocessor and if an output also removes it from list of outputs) if "removeConnections" is specified as true
* Specifying as false can leave the graph in an invalid state with references to vertices that donot exist unless a new vertex is added back in with the same name
* @param removeConnections Specify true to remove connections
* @param vertexName Name of the vertex to remove
*/
public GraphBuilder removeVertex(String vertexName, boolean removeConnections) {
vertices.remove(vertexName);
vertexInputs.remove(vertexName);
if (networkInputs.contains(vertexName)) {
networkInputs.remove(vertexName);
}
if (removeConnections) {
if (networkOutputs.contains(vertexName)) {
networkOutputs.remove(vertexName);
}
Map> newVertexInputs = new LinkedHashMap<>();
for (Map.Entry> entry : this.vertexInputs.entrySet()) {
List inputs = entry.getValue();
if (inputs.contains(vertexName)) {
//Some lists are not modifiable. So we'll make a new copy, minus the one to be removed
List newList = new ArrayList<>(inputs.size()-1);
for(String s : inputs){
if(!vertexName.equals(s)){
newList.add(s);
}
}
newVertexInputs.put(entry.getKey(), newList);
} else {
newVertexInputs.put(entry.getKey(), entry.getValue());
}
}
this.vertexInputs = newVertexInputs;
if (inputPreProcessors.containsKey(vertexName)) {
inputPreProcessors.remove(vertexName);
}
}
return this;
}
/**
* Specify the inputs to the network, and their associated labels.
*
* @param inputNames The names of the inputs. This also defines their order
*/
public GraphBuilder addInputs(String... inputNames) {
Collections.addAll(networkInputs, inputNames);
lastAdded = networkInputs.get(networkInputs.size() - 1);
return this;
}
/**
* Specify the inputs to the network, and their associated labels.
*
* @param inputNames The names of the inputs. This also defines their order
*/
public GraphBuilder addInputs(Collection inputNames) {
networkInputs.addAll(inputNames);
lastAdded = networkInputs.get(networkInputs.size() - 1);
return this;
}
/**Specify the types of inputs to the network, so that:
* (a) preprocessors can be automatically added, and
* (b) the nIns (input size) for each layer can be automatically calculated and set
* The order here is the same order as .addInputs(). Thus, if you do .addInputs("a","b") and .setInputTypes(InputType.feedForward(),
* InputType.convolutional(28,28,1)) then the input labelled "a" is a feed forward input, whereas the input labelled "b" in a CNN
* input, with 28x28x1 images as input.
* Note: Using setInputTypes is not always necessary, but can be especially helpful for example with CNNs such that
* the calculations on input/output sizes (width, height, channels, etc) don't need to be done manually.
* Note 2: If a preprocessor is manually added for a given layer, it will not be overridden by the automatic
* addition of preprocessors.
* Note 3: If a layer has an nIn set manually, this will not be overridden
*/
public GraphBuilder setInputTypes(InputType... inputTypes) {
if (inputTypes != null && inputTypes.length > 0) {
if (networkInputs.size() > 0 && //If no network inputs have been set here - can't valid number of input types here...
networkInputTypes.size() + inputTypes.length != networkInputs.size()) {
throw new IllegalArgumentException(
"Invalid number of InputTypes: " +
"existing inputTypes ("+networkInputTypes.size()+") + additional inputTypes ("+inputTypes.length+")" +
" != number of network inputs ("+networkInputs.size()+")");
}
Collections.addAll(networkInputTypes, inputTypes);
}
return this;
}
/**
* Set the network output labels. These should be the names of the OutputLayer instances in the network
*
* @param outputNames The names of the output layers. This also defines their order.
*/
public GraphBuilder setOutputs(String... outputNames) {
networkOutputs.clear();
Collections.addAll(networkOutputs, outputNames);
return this;
}
/**
* Add a {@link GraphVertex} to the network configuration. A GraphVertex defines forward and backward pass methods,
* and can contain a {@link LayerVertex}, a {@link org.deeplearning4j.nn.conf.graph.ElementWiseVertex} to do element-wise
* addition/subtraction, a {@link MergeVertex} to combine/concatenate the activations out of multiple layers or vertices,
* a {@link org.deeplearning4j.nn.conf.graph.SubsetVertex} to select a subset of the activations out of another layer/GraphVertex.
* Custom GraphVertex objects (that extend the abstract {@link GraphVertex} class) may also be used.
*
* @param vertexName The name of the GraphVertex to add
* @param vertex The GraphVertex to add
* @param vertexInputs The inputs/activations to this GraphVertex.
*/
public GraphBuilder addVertex(String vertexName, GraphVertex vertex, String... vertexInputs) {
Preconditions.checkState(!vertices.containsKey(vertexName), "Cannot add vertex: a vertex with name \"%s\" already exists", vertexName);
vertices.put(vertexName, vertex);
//Automatically insert a MergeNode if this vertex can only take 1 input (layer vertices, etc)
if (vertex.maxVertexInputs() == 1 && vertexInputs != null && vertexInputs.length > 1) {
String mergeName = vertexName + "-merge";
addVertex(mergeName, new MergeVertex(), vertexInputs);
this.vertexInputs.put(vertexName, Collections.singletonList(mergeName));
} else if (vertexInputs != null) {
this.vertexInputs.put(vertexName, Arrays.asList(vertexInputs));
}
this.lastAdded = vertexName;
return this;
}
/**
* Add a {@link GraphVertex} to the network configuration, with input from the last added vertex/layer. A GraphVertex defines forward and backward pass methods,
* and can contain a {@link LayerVertex}, a {@link org.deeplearning4j.nn.conf.graph.ElementWiseVertex} to do element-wise
* addition/subtraction, a {@link MergeVertex} to combine/concatenate the activations out of multiple layers or vertices,
* a {@link org.deeplearning4j.nn.conf.graph.SubsetVertex} to select a subset of the activations out of another layer/GraphVertex.
* Custom GraphVertex objects (that extend the abstract {@link GraphVertex} class) may also be used.
*
* @param vertexName The name of the GraphVertex to add
* @param vertex The GraphVertex to add
*/
public GraphBuilder appendVertex(String vertexName, GraphVertex vertex) {
if(lastAdded == null){
throw new IllegalStateException("Can not use appendLayer with no previous layers");
}
addVertex(vertexName, vertex, lastAdded);
return this;
}
/**
* Used only during validation after building.
* If true: don't throw an exception on configurations containing vertices that are 'disconnected'. A disconnected
* vertex is one that is not an output, and doesn't connect to any other vertices. i.e., it's output activations
* don't go anywhere. Most users can (and should) leave this as the default value of false.
*
* @param allowDisconnected Whether to allow disconnected vertices, during validation
*/
public GraphBuilder allowDisconnected(boolean allowDisconnected){
this.allowDisconnected = allowDisconnected;
return this;
}
/**
* Used only during validation after building.
* If true: don't throw an exception on configurations without any outputs. This is enabled by default
* to avoid creating invalid graphs, but can be disabled if required.
* Most users can (and should) leave this as the default value of false.
*
* @param allowNoOutput Whether to allow no outputs, during validation
*/
public GraphBuilder allowNoOutput(boolean allowNoOutput){
this.allowNoOutput = allowNoOutput;
return this;
}
/**
* Enabled by default. If enabled, the output layer configuration will be validated, to throw an exception on
* likely invalid outputs - such as softmax + nOut=1, or LossMCXENT + Tanh.
* If disabled (false) no output layer validation will be performed.
* Disabling this validation is not recommended, as the configurations that fail validation usually will
* not be able to learn correctly. However, the option to disable this validation is provided for advanced users
* when creating non-standard architectures.
*
* @param validate If true: validate output layer configuration. False: don't validate
*/
public GraphBuilder validateOutputLayerConfig(boolean validate) {
this.validateOutputConfig = validate;
return this;
}
/**
* Enabled by default. If enabled, an exception will be throw when using the (invalid) combination of truncated
* backpropagation through time (TBPTT) with either a GlobalPoolingLayer or LastTimeStepLayer.
* It is possible to disable this validation to allow what is almost certainly an invalid configuration to be used,
* however this is not recommended.
*
* @param validate Whether TBPTT validation should be performed
*/
public GraphBuilder validateTbpttConfig(boolean validate){
this.validateTbpttConfig = validate;
return this;
}
/**
* For the (perhaps partially constructed) network configuration, return a map of activation sizes for each
* layer and vertex in the graph.
* Note 1: The network configuration may be incomplete, but the inputs have been added to the layer already.
* Note 2: To use this method, the network input types must have been set using {@link #setInputTypes(InputType...)}
* first
* @return A map of activation types for the graph (key: vertex name. value: type of activations out of that vertex)
*/
public Map getLayerActivationTypes() {
Preconditions.checkArgument(networkInputs != null && networkInputs.size() > 0,
"Cannot calculate activation types if no inputs have been set (use addInputs(String...))");
Preconditions.checkArgument(networkInputTypes != null && networkInputTypes.size() == networkInputs.size(),
"Cannot calculate layer activation types if network if network input types have not" +
"been set (use ");
//Instantiate temporary ComputationGraphConfiguration and calculate output shapes
ComputationGraphConfiguration conf;
try{
conf = buildConfig();
} catch (Exception e){
throw new RuntimeException("Error calculating activation types for layers: error occured when constructing " +
"temporary ComputationGraphConfiguration)", e);
}
try{
conf.validate(true, true);
} catch (Exception e){
throw new RuntimeException("Error calculating activation types for layers: validation of temporary" +
" ComputationGraphConfiguration failed", e);
}
return conf.getLayerActivationTypes(true, networkInputTypes.toArray(new InputType[networkInputTypes.size()]));
}
private ComputationGraphConfiguration buildConfig(){
//Validate BackpropType setting
if((tbpttBackLength != DEFAULT_TBPTT_LENGTH || tbpttFwdLength != DEFAULT_TBPTT_LENGTH) && backpropType != BackpropType.TruncatedBPTT){
log.warn("Truncated backpropagation through time lengths have been configured with values " + tbpttFwdLength
+ " and " + tbpttBackLength + " but backprop type is set to " + backpropType + ". TBPTT configuration" +
" settings will only take effect if backprop type is set to BackpropType.TruncatedBPTT");
}
ComputationGraphConfiguration conf = new ComputationGraphConfiguration();
conf.backpropType = backpropType;
conf.tbpttBackLength = tbpttBackLength;
conf.tbpttFwdLength = tbpttFwdLength;
conf.networkInputs = networkInputs;
conf.networkOutputs = networkOutputs;
conf.vertices = this.vertices;
conf.vertexInputs = this.vertexInputs;
conf.trainingWorkspaceMode = globalConfiguration.trainingWorkspaceMode;
conf.inferenceWorkspaceMode = globalConfiguration.inferenceWorkspaceMode;
conf.cacheMode = globalConfiguration.cacheMode;
conf.validateOutputLayerConfig = validateOutputConfig;
conf.dataType = globalConfiguration.dataType;
conf.defaultConfiguration = globalConfiguration.build();
//Add preprocessors that were defined separately to the Layers to which they belong
for (Map.Entry entry : inputPreProcessors.entrySet()) {
GraphVertex gv = vertices.get(entry.getKey());
if (gv instanceof LayerVertex) {
LayerVertex lv = (LayerVertex) gv;
lv.setPreProcessor(entry.getValue());
} else {
throw new IllegalStateException(
"Invalid configuration: InputPreProcessor defined for GraphVertex \""
+ entry.getKey() + "\", but this vertex is not a LayerVertex");
}
}
for (Map.Entry gv : vertices.entrySet()) {
if (gv.getValue() instanceof LayerVertex) {
LayerVertex lv = (LayerVertex) gv.getValue();
Layer l = lv.getLayerConf().getLayer();
}
if (gv.getValue() instanceof SameDiffVertex)
((SameDiffVertex) gv.getValue()).applyGlobalConfig(globalConfiguration);
}
return conf;
}
/**
* Create the ComputationGraphConfiguration from the Builder pattern
*/
public ComputationGraphConfiguration build() {
ComputationGraphConfiguration conf = buildConfig();
conf.validate(allowDisconnected, allowNoOutput); //throws exception for invalid configuration
//Automatically add preprocessors, set nIns for CNN->dense transitions, etc
if (!networkInputTypes.isEmpty()) {
conf.addPreProcessors(networkInputTypes.toArray(new InputType[networkInputs.size()]));
}
if(validateOutputConfig) {
//Validate output layer configurations...
for (Map.Entry e : conf.getVertices().entrySet()) {
if (e.getValue() instanceof LayerVertex) {
Layer l = ((LayerVertex) e.getValue()).getLayerConf().getLayer();
OutputLayerUtil.validateOutputLayer(e.getKey(), l); //No-op for non output/loss layers
}
}
}
if(backpropType == BackpropType.TruncatedBPTT && validateTbpttConfig) {
//Check for invalid combination - tbptt plus LastTimeStepLayer or
for(Map.Entry e : vertices.entrySet()){
GraphVertex gv = e.getValue();
Layer l = (gv instanceof LayerVertex ? ((LayerVertex)gv).getLayerConf().getLayer() : null);
if(gv instanceof LastTimeStepVertex || (l != null && (l instanceof LastTimeStep || l instanceof GlobalPoolingLayer))){
String s = (l == null ? gv.getClass().getName() : l.getClass().getName());
String n = e.getKey();
throw new IllegalStateException("Invalid network configuration detected: Truncated backpropagation through time (TBPTT)" +
" cannot be used with layer \"" + n + "\" of type " + s + ": TBPTT is incompatible with this layer type (which is designed " +
"to process entire sequences at once, and does support the type of sequence segments that TPBTT uses).\n" +
"This check can be disabled using validateTbpttConfig(false) but this is not recommended.");
}
}
}
return conf;
}
}
}