org.deeplearning4j.nn.conf.ComputationGraphConfiguration Maven / Gradle / Ivy
/*
*
* * Copyright 2016 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.nn.conf;
import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.introspect.AnnotatedClass;
import org.nd4j.shade.jackson.databind.jsontype.NamedType;
import lombok.*;
import org.apache.commons.lang3.ClassUtils;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.reflections.Reflections;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.io.Serializable;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* ComputationGraphConfiguration is a configuration object for neural networks with arbitrary connection structure.
* It is analogous to {@link MultiLayerConfiguration}, but allows considerably greater flexibility for the network
* architecture.
* Specifically, the network architecture is a directed acyclic graph, where each vertex in the graph is a {@link GraphVertex},
* which may for example be a layer or a vertex/object that defines arbitrary forward and backward pass functionality.
* Note that the ComputationGraph may have an arbitrary number of inputs (multiple independent inputs, possibly of different
* types), and an arbitrary number of outputs (for example, multiple {@link OutputLayer} instances.
* Typical usage:
* {@code ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()....graphBuilder()...build();}
*
* @author Alex Black
*/
@Data
@EqualsAndHashCode
@AllArgsConstructor(access = AccessLevel.PRIVATE)
@NoArgsConstructor
public class ComputationGraphConfiguration implements Serializable, Cloneable {
private static Logger log = LoggerFactory.getLogger(ComputationGraphConfiguration.class);
private static final AtomicBoolean defaultChangeWarningPrinted = new AtomicBoolean(false);
protected Map vertices = new LinkedHashMap<>();
protected Map> vertexInputs = new LinkedHashMap<>();
/**
* List of inputs to the network, by name
*/
protected List networkInputs;
/**
* List of network outputs, by name
*/
protected List networkOutputs;
protected boolean pretrain = false;
protected boolean backprop = true;
protected BackpropType backpropType = BackpropType.Standard;
protected int tbpttFwdLength = 20;
protected int tbpttBackLength = 20;
protected NeuralNetConfiguration defaultConfiguration;
//Counter for the number of parameter updates so far
// This is important for learning rate schedules, for example, and is stored here to ensure it is persisted
// for Spark and model serialization
protected int iterationCount = 0;
/**
* @return JSON representation of configuration
*/
public String toYaml() {
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
try {
return mapper.writeValueAsString(this);
} catch (org.nd4j.shade.jackson.core.JsonProcessingException e) {
throw new RuntimeException(e);
}
}
/**
* Create a neural net configuration from json
*
* @param json the neural net configuration from json
* @return {@link ComputationGraphConfiguration}
*/
public static ComputationGraphConfiguration fromYaml(String json) {
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
try {
return mapper.readValue(json, ComputationGraphConfiguration.class);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* @return JSON representation of computation graph configuration
*/
public String toJson() {
//As per MultiLayerConfiguration.toJson()
ObjectMapper mapper = NeuralNetConfiguration.mapper();
try {
return mapper.writeValueAsString(this);
} catch (org.nd4j.shade.jackson.core.JsonProcessingException e) {
throw new RuntimeException(e);
}
}
/**
* Create a computation graph configuration from json
*
* @param json the neural net configuration from json
* @return {@link ComputationGraphConfiguration}
*/
public static ComputationGraphConfiguration fromJson(String json) {
//As per MultiLayerConfiguration.fromJson()
ObjectMapper mapper = NeuralNetConfiguration.mapper();
ComputationGraphConfiguration conf;
try {
conf = mapper.readValue(json, ComputationGraphConfiguration.class);
} catch (IOException e) {
throw new RuntimeException(e);
}
//To maintain backward compatibility after activation function refactoring (configs generated with v0.7.1 or earlier)
// Previously: enumeration used for activation functions. Now: use classes
int layerCount = 0;
Map vertexMap = conf.getVertices();
JsonNode vertices = null;
for(Map.Entry entry : vertexMap.entrySet()){
if(!(entry.getValue() instanceof LayerVertex)){
continue;
}
LayerVertex lv = (LayerVertex)entry.getValue();
if(lv.getLayerConf() != null && lv.getLayerConf().getLayer() != null){
Layer layer = lv.getLayerConf().getLayer();
if(layer.getActivationFn() == null){
String layerName = layer.getLayerName();
try{
if(vertices == null){
JsonNode jsonNode = mapper.readTree(json);
vertices = jsonNode.get("vertices");
}
JsonNode vertexNode = vertices.get(layerName);
JsonNode layerVertexNode = vertexNode.get("LayerVertex");
if(layerVertexNode == null || !layerVertexNode.has("layerConf") || !layerVertexNode.get("layerConf").has("layer")){
continue;
}
JsonNode layerWrapperNode = layerVertexNode.get("layerConf").get("layer");
if(layerWrapperNode == null || layerWrapperNode.size() != 1){
continue;
}
JsonNode layerNode = layerWrapperNode.elements().next();
JsonNode activationFunction = layerNode.get("activationFunction"); //Should only have 1 element: "dense", "output", etc
if(activationFunction != null){
IActivation ia = Activation.fromString(activationFunction.asText()).getActivationFunction();
layer.setActivationFn(ia);
}
} catch (IOException e ){
log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON",e);
}
}
}
}
return conf;
}
@Override
public String toString() {
return toJson();
}
@Override
public ComputationGraphConfiguration clone() {
ComputationGraphConfiguration conf = new ComputationGraphConfiguration();
conf.vertices = new LinkedHashMap<>();
for (Map.Entry entry : this.vertices.entrySet()) {
conf.vertices.put(entry.getKey(), entry.getValue().clone());
}
conf.vertexInputs = new LinkedHashMap<>();
for (Map.Entry> entry : this.vertexInputs.entrySet()) {
conf.vertexInputs.put(entry.getKey(), new ArrayList<>(entry.getValue()));
}
conf.networkInputs = new ArrayList<>(this.networkInputs);
conf.networkOutputs = new ArrayList<>(this.networkOutputs);
conf.pretrain = pretrain;
conf.backprop = backprop;
conf.backpropType = backpropType;
conf.tbpttFwdLength = tbpttFwdLength;
conf.tbpttBackLength = tbpttBackLength;
conf.defaultConfiguration = defaultConfiguration.clone();
return conf;
}
/**
* Check the configuration, make sure it is valid
*
* @throws IllegalStateException if configuration is not valid
*/
public void validate() {
if (networkInputs == null || networkInputs.size() < 1) {
throw new IllegalStateException("Invalid configuration: network has no inputs. Use .addInputs(String...) to label (and give an ordering to) the network inputs");
}
if (networkOutputs == null || networkOutputs.size() < 1) {
throw new IllegalStateException("Invalid configuration: network has no outputs. Use .setOutput(String...) to specify (and give an ordering to) the output vertices");
}
//Check uniqueness of names for inputs, layers, GraphNodes
for (String s : networkInputs) {
if (vertices.containsKey(s)) {
throw new IllegalStateException("Invalid configuration: name \"" + s + "\" is present in both network inputs and graph vertices/layers");
}
}
//Check: each layer & node has at least one input
for (Map.Entry> e : vertexInputs.entrySet()) {
String nodeName = e.getKey();
if (e.getValue() == null || e.getValue().isEmpty()) {
throw new IllegalStateException("Invalid configuration: vertex \"" + nodeName + "\" has no inputs");
}
for (String inputName : e.getValue()) {
if (!vertices.containsKey(inputName) && !networkInputs.contains(inputName)) {
throw new IllegalStateException("Invalid configuration: Vertex \"" + nodeName + "\" has input \"" +
inputName + "\" that does not exist");
}
}
}
//Check output names:
for (String s : networkOutputs) {
if (!vertices.containsKey(s)) {
throw new IllegalStateException("Invalid configuration: Output name \"" + s + "\" is not a valid vertex");
}
}
boolean warned = false;
if (!pretrain && !defaultChangeWarningPrinted.get()) {
log.warn("Warning: new network default sets pretrain to false.");
warned = true;
}
if(backprop && !defaultChangeWarningPrinted.get()) {
log.warn("Warning: new network default sets backprop to true.");
warned = true;
}
if(warned){
defaultChangeWarningPrinted.set(true);
}
//Check for no graph cycles: done in ComputationGraph.init()
}
/**
* Add preprocessors automatically, given the specified types of inputs for the network. Inputs are specified using the
* {@link InputType} class, in the same order in which the inputs were defined in the original configuration.
* For example, in a network with two inputs: a convolutional input (28x28x1 images) and feed forward inputs, use
* {@code .addPreProcessors(InputType.convolutional(1,28,28),InputType.feedForward())}.
* For the CNN->Dense and CNN->RNN transitions, the nIns on the Dense/RNN layers will also be added automatically.
* NOTE: This method will be called automatically when using the
* {@link GraphBuilder#setInputTypes(InputType...)} functionality.
* See that method for details.
*/
public void addPreProcessors(InputType... inputTypes) {
if (inputTypes == null || inputTypes.length != networkInputs.size()) {
throw new IllegalArgumentException("Invalid number of InputTypes: cannot add preprocessors if number of InputType "
+ "objects differs from number of network inputs");
}
//Now: need to do essentially a forward pass through the network, to work out what type of preprocessors to add
//To do this: need to know what the output types are for each GraphVertex.
//First step: build network in reverse order (i.e., define map of a -> list(b) instead of list(a) -> b)
Map> verticesOutputTo = new HashMap<>(); //Key: vertex. Values: vertices that this node is an input for
for (Map.Entry entry : vertices.entrySet()) {
String vertexName = entry.getKey();
List vertexInputNames;
vertexInputNames = vertexInputs.get(vertexName);
if (vertexInputNames == null) continue;
//Build reverse network structure:
for (String s : vertexInputNames) {
List list = verticesOutputTo.get(s);
if (list == null) {
list = new ArrayList<>();
verticesOutputTo.put(s, list);
}
list.add(vertexName); //Edge: s -> vertexName
}
}
//Now: do topological sort
LinkedList noIncomingEdges = new LinkedList<>(networkInputs); //Set of all nodes with no incoming edges
List topologicalOrdering = new ArrayList<>();
Map> inputEdges = new HashMap<>();
for (Map.Entry> entry : vertexInputs.entrySet()) {
inputEdges.put(entry.getKey(), new HashSet<>(entry.getValue()));
}
while (!noIncomingEdges.isEmpty()) {
String next = noIncomingEdges.removeFirst();
topologicalOrdering.add(next);
//Remove edges next -> vertexOuputsTo[...] from graph;
List nextEdges = verticesOutputTo.get(next);
if (nextEdges != null && !nextEdges.isEmpty()) {
for (String s : nextEdges) {
Set set = inputEdges.get(s);
set.remove(next);
if (set.isEmpty()) {
noIncomingEdges.add(s); //No remaining edges for vertex i -> add to list for processing
}
}
}
}
//If any edges remain in the graph: graph has cycles:
for (Map.Entry> entry : inputEdges.entrySet()) {
Set set = entry.getValue();
if (set == null) continue;
if (!set.isEmpty())
throw new IllegalStateException("Invalid configuration: cycle detected in graph. Cannot calculate topological ordering with graph cycle ("
+ "cycle includes vertex \"" + entry.getKey() + "\")");
}
//Now, given the topological sort: do equivalent of forward pass
Map vertexOutputs = new HashMap<>();
int currLayerIdx = -1;
for (String s : topologicalOrdering) {
int inputIdx = networkInputs.indexOf(s);
if (inputIdx != -1) {
vertexOutputs.put(s, inputTypes[inputIdx]);
continue;
}
GraphVertex gv = vertices.get(s);
List inputTypeList = new ArrayList<>();
if (gv instanceof LayerVertex) {
//Add preprocessor, if necessary:
String in = vertexInputs.get(s).get(0);
InputType layerInput = vertexOutputs.get(in);
inputTypeList.add(layerInput);
LayerVertex lv = (LayerVertex) gv;
Layer l = lv.getLayerConf().getLayer();
//Preprocessors - add if necessary
if (lv.getPreProcessor() == null){
//But don't override preprocessors that are manually defined; if none has been defined,
//add the appropriate preprocessor for this input type/layer combination
InputPreProcessor preproc = l.getPreProcessorForInputType(layerInput);
lv.setPreProcessor(preproc);
}
//Set nIn value for layer (if not already set)
InputType afterPreproc = layerInput;
if(lv.getPreProcessor() != null){
InputPreProcessor ip = lv.getPreProcessor();
afterPreproc = ip.getOutputType(layerInput);
}
l.setNIn(afterPreproc, false);
currLayerIdx++;
} else {
List inputs = vertexInputs.get(s);
if (inputs != null) {
for (String inputVertexName : inputs) {
inputTypeList.add(vertexOutputs.get(inputVertexName));
}
}
}
InputType outputFromVertex = gv.getOutputType(currLayerIdx, inputTypeList.toArray(new InputType[inputTypeList.size()]));
vertexOutputs.put(s, outputFromVertex);
}
}
@Data
public static class GraphBuilder {
protected Map vertices = new LinkedHashMap<>();
/**
* Key: graph node. Values: input to that node
*/
protected Map> vertexInputs = new LinkedHashMap<>();
protected List networkInputs = new ArrayList<>();
protected List networkInputTypes = new ArrayList<>();
protected List networkOutputs = new ArrayList<>();
protected boolean pretrain = false;
protected boolean backprop = true;
protected BackpropType backpropType = BackpropType.Standard;
protected int tbpttFwdLength = 20;
protected int tbpttBackLength = 20;
protected Map inputPreProcessors = new LinkedHashMap<>();
protected NeuralNetConfiguration.Builder globalConfiguration;
public GraphBuilder(NeuralNetConfiguration.Builder globalConfiguration) {
this.globalConfiguration = globalConfiguration;
}
/**
* @deprecated As of 0.6.0
*/
@Deprecated
public GraphBuilder redistributeParams(boolean redistributeParams) {
return this;
}
/**
* Specify the processors for a given layer
* These are used at each layer for doing things like normalization and shaping of input.
* Note: preprocessors can also be defined using the {@link #addLayer(String, Layer, InputPreProcessor, String...)} method.
*
* @param layer the name of the layer that this preprocessor will be used with
* @param processor the preprocessor to use for the specified layer
*/
public GraphBuilder inputPreProcessor(String layer, InputPreProcessor processor) {
inputPreProcessors.put(layer, processor);
return this;
}
/**
* Whether to do back prop (standard supervised learning) or not
*
* @param backprop whether to do back prop or not
*/
public GraphBuilder backprop(boolean backprop) {
this.backprop = backprop;
return this;
}
/**
* Whether to do layerwise pre training or not
*
* @param pretrain whether to do pre train or not
*/
public GraphBuilder pretrain(boolean pretrain) {
this.pretrain = pretrain;
return this;
}
/**
* The type of backprop. Default setting is used for most networks (MLP, CNN etc),
* but optionally truncated BPTT can be used for training recurrent neural networks.
* If using TruncatedBPTT make sure you set both tBPTTForwardLength() and tBPTTBackwardLength()
*
* @param type Type of backprop. Default: BackpropType.Standard
*/
public GraphBuilder backpropType(BackpropType type) {
this.backpropType = type;
return this;
}
/**
* When doing truncated BPTT: how many steps of forward pass should we do
* before doing (truncated) backprop?
* Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
* Typically tBPTTForwardLength parameter is same as the tBPTTBackwardLength parameter,
* but may be larger than it in some circumstances (but never smaller)
* Ideally your training data time series length should be divisible by this
* This is the k1 parameter on pg23 of
* http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf
*
* @param forwardLength Forward length > 0, >= backwardLength
*/
public GraphBuilder tBPTTForwardLength(int forwardLength) {
this.tbpttFwdLength = forwardLength;
return this;
}
/**
* When doing truncated BPTT: how many steps of backward should we do?
* Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
* This is the k2 parameter on pg23 of
* http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf
*
* @param backwardLength <= forwardLength
*/
public GraphBuilder tBPTTBackwardLength(int backwardLength) {
this.tbpttBackLength = backwardLength;
return this;
}
/**
* Add a layer, with no {@link InputPreProcessor}, with the specified name and specified inputs.
*
* @param layerName Name/label of the layer to add
* @param layer The layer configuration
* @param layerInputs Inputs to this layer (must be 1 or more). Inputs may be other layers, GraphVertex objects,
* on a combination of the two.
* @see #addLayer(String, Layer, InputPreProcessor, String...)
*/
public GraphBuilder addLayer(String layerName, Layer layer, String... layerInputs) {
return addLayer(layerName, layer, null, layerInputs);
}
/**
* Add a layer and an {@link InputPreProcessor}, with the specified name and specified inputs.
*
* @param layerName Name/label of the layer to add
* @param layer The layer configuration
* @param preProcessor The InputPreProcessor to use with this layer.
* @param layerInputs Inputs to this layer (must be 1 or more). Inputs may be other layers, GraphVertex objects,
* on a combination of the two.
*/
public GraphBuilder addLayer(String layerName, Layer layer, InputPreProcessor preProcessor, String... layerInputs) {
NeuralNetConfiguration.Builder builder = globalConfiguration.clone();
builder.layer(layer);
vertices.put(layerName, new LayerVertex(builder.build(), preProcessor));
//Automatically insert a MergeNode if layerInputs.length > 1
//Layers can only have 1 input
if (layerInputs != null && layerInputs.length > 1) {
String mergeName = layerName + "-merge";
addVertex(mergeName, new MergeVertex(), layerInputs);
this.vertexInputs.put(layerName, Collections.singletonList(mergeName));
} else if (layerInputs != null) {
this.vertexInputs.put(layerName, Arrays.asList(layerInputs));
}
layer.setLayerName(layerName);
return this;
}
/**
* Specify the inputs to the network, and their associated labels.
*
* @param inputNames The names of the inputs. This also defines their order
*/
public GraphBuilder addInputs(String... inputNames) {
Collections.addAll(networkInputs, inputNames);
return this;
}
/**Specify the types of inputs to the network, so that:
* (a) preprocessors can be automatically added, and
* (b) the nIns (input size) for each layer can be automatically calculated and set
* The order here is the same order as .addInputs(). Thus, if you do .addInputs("a","b") and .setInputTypes(InputType.feedForward(),
* InputType.convolutional(1,28,28)) then the input labelled "a" is a feed forward input, whereas the input labelled "b" in a CNN
* input, with 28x28x1 images as input.
* Note: Using setInputTypes is not always necessary, but can be especially helpful for example with CNNs such that
* the calculations on input/ouput sizes (width, height, depth, etc) don't need to be done manually.
* Note 2: If a preprocessor is manually added for a given layer, it will not be overridden by the automatic
* addition of preprocessors.
* Note 3: If a layer has an nIn set manually, this will not be overridden
*/
public GraphBuilder setInputTypes(InputType... inputTypes) {
if(inputTypes != null && inputTypes.length > 0) Collections.addAll(networkInputTypes, inputTypes);
return this;
}
/**
* Set the network output labels. These should be the names of the OutputLayer instances in the network
*
* @param outputNames The names of the output layers. This also defines their order.
*/
public GraphBuilder setOutputs(String... outputNames) {
Collections.addAll(networkOutputs, outputNames);
return this;
}
/**
* Add a {@link GraphVertex} to the network configuration. A GraphVertex defines forward and backward pass methods,
* and can contain a {@link LayerVertex}, a {@link org.deeplearning4j.nn.conf.graph.ElementWiseVertex} to do element-wise
* addition/subtraction, a {@link MergeVertex} to combine/concatenate the activations out of multiple layers or vertices,
* a {@link org.deeplearning4j.nn.conf.graph.SubsetVertex} to select a subset of the activations out of another layer/GraphVertex.
* Custom GraphVertex objects (that extend the abstract {@link GraphVertex} class) may also be used.
*
* @param vertexName The name of the GraphVertex to add
* @param vertex The GraphVertex to add
* @param vertexInputs The inputs/activations to this GraphVertex
*/
public GraphBuilder addVertex(String vertexName, GraphVertex vertex, String... vertexInputs) {
vertices.put(vertexName, vertex);
this.vertexInputs.put(vertexName, Arrays.asList(vertexInputs));
return this;
}
/**
* Create the ComputationGraphConfiguration from the Builder pattern
*/
public ComputationGraphConfiguration build() {
ComputationGraphConfiguration conf = new ComputationGraphConfiguration();
conf.backprop = backprop;
conf.pretrain = pretrain;
conf.backpropType = backpropType;
conf.tbpttBackLength = tbpttBackLength;
conf.tbpttFwdLength = tbpttFwdLength;
conf.networkInputs = networkInputs;
conf.networkOutputs = networkOutputs;
conf.vertices = this.vertices;
conf.vertexInputs = this.vertexInputs;
conf.defaultConfiguration = globalConfiguration.build();
conf.getDefaultConfiguration().setPretrain(pretrain);
//Add preprocessors that were defined separately to the Layers to which they belong
for (Map.Entry entry : inputPreProcessors.entrySet()) {
GraphVertex gv = vertices.get(entry.getKey());
if (gv instanceof LayerVertex) {
LayerVertex lv = (LayerVertex) gv;
lv.setPreProcessor(entry.getValue());
} else {
throw new IllegalStateException("Invalid configuration: InputPreProcessor defined for GraphVertex \"" + entry.getKey()
+ "\", but this vertex is not a LayerVertex");
}
}
for (Map.Entry gv : vertices.entrySet()) {
if (gv.getValue() instanceof LayerVertex) {
LayerVertex lv = (LayerVertex) gv.getValue();
Layer l = lv.getLayerConf().getLayer();
if (l instanceof BasePretrainNetwork) lv.getLayerConf().setPretrain(pretrain);
}
}
conf.validate(); //throws exception for invalid configuration
//Automatically add preprocessors, set nIns for CNN->dense transitions, etc
if (!networkInputTypes.isEmpty()) {
conf.addPreProcessors(networkInputTypes.toArray(new InputType[networkInputs.size()]));
}
return conf;
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy