All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.conf.ComputationGraphConfiguration Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *
 *  * Copyright 2016 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */
package org.deeplearning4j.nn.conf;

import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.introspect.AnnotatedClass;
import org.nd4j.shade.jackson.databind.jsontype.NamedType;
import lombok.*;
import org.apache.commons.lang3.ClassUtils;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.reflections.Reflections;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.Serializable;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * ComputationGraphConfiguration is a configuration object for neural networks with arbitrary connection structure.
 * It is analogous to {@link MultiLayerConfiguration}, but allows considerably greater flexibility for the network
 * architecture.
* Specifically, the network architecture is a directed acyclic graph, where each vertex in the graph is a {@link GraphVertex}, * which may for example be a layer or a vertex/object that defines arbitrary forward and backward pass functionality.
* Note that the ComputationGraph may have an arbitrary number of inputs (multiple independent inputs, possibly of different * types), and an arbitrary number of outputs (for example, multiple {@link OutputLayer} instances. * Typical usage:
* {@code ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()....graphBuilder()...build();} * * @author Alex Black */ @Data @EqualsAndHashCode @AllArgsConstructor(access = AccessLevel.PRIVATE) @NoArgsConstructor public class ComputationGraphConfiguration implements Serializable, Cloneable { private static Logger log = LoggerFactory.getLogger(ComputationGraphConfiguration.class); private static final AtomicBoolean defaultChangeWarningPrinted = new AtomicBoolean(false); protected Map vertices = new LinkedHashMap<>(); protected Map> vertexInputs = new LinkedHashMap<>(); /** * List of inputs to the network, by name */ protected List networkInputs; /** * List of network outputs, by name */ protected List networkOutputs; protected boolean pretrain = false; protected boolean backprop = true; protected BackpropType backpropType = BackpropType.Standard; protected int tbpttFwdLength = 20; protected int tbpttBackLength = 20; protected NeuralNetConfiguration defaultConfiguration; //Counter for the number of parameter updates so far // This is important for learning rate schedules, for example, and is stored here to ensure it is persisted // for Spark and model serialization protected int iterationCount = 0; /** * @return JSON representation of configuration */ public String toYaml() { ObjectMapper mapper = NeuralNetConfiguration.mapperYaml(); try { return mapper.writeValueAsString(this); } catch (org.nd4j.shade.jackson.core.JsonProcessingException e) { throw new RuntimeException(e); } } /** * Create a neural net configuration from json * * @param json the neural net configuration from json * @return {@link ComputationGraphConfiguration} */ public static ComputationGraphConfiguration fromYaml(String json) { ObjectMapper mapper = NeuralNetConfiguration.mapperYaml(); try { return mapper.readValue(json, ComputationGraphConfiguration.class); } catch (IOException e) { throw new RuntimeException(e); } } /** * @return JSON representation of computation graph configuration */ public String toJson() { //As per MultiLayerConfiguration.toJson() ObjectMapper mapper = NeuralNetConfiguration.mapper(); try { return mapper.writeValueAsString(this); } catch (org.nd4j.shade.jackson.core.JsonProcessingException e) { throw new RuntimeException(e); } } /** * Create a computation graph configuration from json * * @param json the neural net configuration from json * @return {@link ComputationGraphConfiguration} */ public static ComputationGraphConfiguration fromJson(String json) { //As per MultiLayerConfiguration.fromJson() ObjectMapper mapper = NeuralNetConfiguration.mapper(); ComputationGraphConfiguration conf; try { conf = mapper.readValue(json, ComputationGraphConfiguration.class); } catch (IOException e) { throw new RuntimeException(e); } //To maintain backward compatibility after activation function refactoring (configs generated with v0.7.1 or earlier) // Previously: enumeration used for activation functions. Now: use classes int layerCount = 0; Map vertexMap = conf.getVertices(); JsonNode vertices = null; for(Map.Entry entry : vertexMap.entrySet()){ if(!(entry.getValue() instanceof LayerVertex)){ continue; } LayerVertex lv = (LayerVertex)entry.getValue(); if(lv.getLayerConf() != null && lv.getLayerConf().getLayer() != null){ Layer layer = lv.getLayerConf().getLayer(); if(layer.getActivationFn() == null){ String layerName = layer.getLayerName(); try{ if(vertices == null){ JsonNode jsonNode = mapper.readTree(json); vertices = jsonNode.get("vertices"); } JsonNode vertexNode = vertices.get(layerName); JsonNode layerVertexNode = vertexNode.get("LayerVertex"); if(layerVertexNode == null || !layerVertexNode.has("layerConf") || !layerVertexNode.get("layerConf").has("layer")){ continue; } JsonNode layerWrapperNode = layerVertexNode.get("layerConf").get("layer"); if(layerWrapperNode == null || layerWrapperNode.size() != 1){ continue; } JsonNode layerNode = layerWrapperNode.elements().next(); JsonNode activationFunction = layerNode.get("activationFunction"); //Should only have 1 element: "dense", "output", etc if(activationFunction != null){ IActivation ia = Activation.fromString(activationFunction.asText()).getActivationFunction(); layer.setActivationFn(ia); } } catch (IOException e ){ log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON",e); } } } } return conf; } @Override public String toString() { return toJson(); } @Override public ComputationGraphConfiguration clone() { ComputationGraphConfiguration conf = new ComputationGraphConfiguration(); conf.vertices = new LinkedHashMap<>(); for (Map.Entry entry : this.vertices.entrySet()) { conf.vertices.put(entry.getKey(), entry.getValue().clone()); } conf.vertexInputs = new LinkedHashMap<>(); for (Map.Entry> entry : this.vertexInputs.entrySet()) { conf.vertexInputs.put(entry.getKey(), new ArrayList<>(entry.getValue())); } conf.networkInputs = new ArrayList<>(this.networkInputs); conf.networkOutputs = new ArrayList<>(this.networkOutputs); conf.pretrain = pretrain; conf.backprop = backprop; conf.backpropType = backpropType; conf.tbpttFwdLength = tbpttFwdLength; conf.tbpttBackLength = tbpttBackLength; conf.defaultConfiguration = defaultConfiguration.clone(); return conf; } /** * Check the configuration, make sure it is valid * * @throws IllegalStateException if configuration is not valid */ public void validate() { if (networkInputs == null || networkInputs.size() < 1) { throw new IllegalStateException("Invalid configuration: network has no inputs. Use .addInputs(String...) to label (and give an ordering to) the network inputs"); } if (networkOutputs == null || networkOutputs.size() < 1) { throw new IllegalStateException("Invalid configuration: network has no outputs. Use .setOutput(String...) to specify (and give an ordering to) the output vertices"); } //Check uniqueness of names for inputs, layers, GraphNodes for (String s : networkInputs) { if (vertices.containsKey(s)) { throw new IllegalStateException("Invalid configuration: name \"" + s + "\" is present in both network inputs and graph vertices/layers"); } } //Check: each layer & node has at least one input for (Map.Entry> e : vertexInputs.entrySet()) { String nodeName = e.getKey(); if (e.getValue() == null || e.getValue().isEmpty()) { throw new IllegalStateException("Invalid configuration: vertex \"" + nodeName + "\" has no inputs"); } for (String inputName : e.getValue()) { if (!vertices.containsKey(inputName) && !networkInputs.contains(inputName)) { throw new IllegalStateException("Invalid configuration: Vertex \"" + nodeName + "\" has input \"" + inputName + "\" that does not exist"); } } } //Check output names: for (String s : networkOutputs) { if (!vertices.containsKey(s)) { throw new IllegalStateException("Invalid configuration: Output name \"" + s + "\" is not a valid vertex"); } } boolean warned = false; if (!pretrain && !defaultChangeWarningPrinted.get()) { log.warn("Warning: new network default sets pretrain to false."); warned = true; } if(backprop && !defaultChangeWarningPrinted.get()) { log.warn("Warning: new network default sets backprop to true."); warned = true; } if(warned){ defaultChangeWarningPrinted.set(true); } //Check for no graph cycles: done in ComputationGraph.init() } /** * Add preprocessors automatically, given the specified types of inputs for the network. Inputs are specified using the * {@link InputType} class, in the same order in which the inputs were defined in the original configuration.
* For example, in a network with two inputs: a convolutional input (28x28x1 images) and feed forward inputs, use * {@code .addPreProcessors(InputType.convolutional(1,28,28),InputType.feedForward())}.
* For the CNN->Dense and CNN->RNN transitions, the nIns on the Dense/RNN layers will also be added automatically. * NOTE: This method will be called automatically when using the * {@link GraphBuilder#setInputTypes(InputType...)} functionality. * See that method for details. */ public void addPreProcessors(InputType... inputTypes) { if (inputTypes == null || inputTypes.length != networkInputs.size()) { throw new IllegalArgumentException("Invalid number of InputTypes: cannot add preprocessors if number of InputType " + "objects differs from number of network inputs"); } //Now: need to do essentially a forward pass through the network, to work out what type of preprocessors to add //To do this: need to know what the output types are for each GraphVertex. //First step: build network in reverse order (i.e., define map of a -> list(b) instead of list(a) -> b) Map> verticesOutputTo = new HashMap<>(); //Key: vertex. Values: vertices that this node is an input for for (Map.Entry entry : vertices.entrySet()) { String vertexName = entry.getKey(); List vertexInputNames; vertexInputNames = vertexInputs.get(vertexName); if (vertexInputNames == null) continue; //Build reverse network structure: for (String s : vertexInputNames) { List list = verticesOutputTo.get(s); if (list == null) { list = new ArrayList<>(); verticesOutputTo.put(s, list); } list.add(vertexName); //Edge: s -> vertexName } } //Now: do topological sort LinkedList noIncomingEdges = new LinkedList<>(networkInputs); //Set of all nodes with no incoming edges List topologicalOrdering = new ArrayList<>(); Map> inputEdges = new HashMap<>(); for (Map.Entry> entry : vertexInputs.entrySet()) { inputEdges.put(entry.getKey(), new HashSet<>(entry.getValue())); } while (!noIncomingEdges.isEmpty()) { String next = noIncomingEdges.removeFirst(); topologicalOrdering.add(next); //Remove edges next -> vertexOuputsTo[...] from graph; List nextEdges = verticesOutputTo.get(next); if (nextEdges != null && !nextEdges.isEmpty()) { for (String s : nextEdges) { Set set = inputEdges.get(s); set.remove(next); if (set.isEmpty()) { noIncomingEdges.add(s); //No remaining edges for vertex i -> add to list for processing } } } } //If any edges remain in the graph: graph has cycles: for (Map.Entry> entry : inputEdges.entrySet()) { Set set = entry.getValue(); if (set == null) continue; if (!set.isEmpty()) throw new IllegalStateException("Invalid configuration: cycle detected in graph. Cannot calculate topological ordering with graph cycle (" + "cycle includes vertex \"" + entry.getKey() + "\")"); } //Now, given the topological sort: do equivalent of forward pass Map vertexOutputs = new HashMap<>(); int currLayerIdx = -1; for (String s : topologicalOrdering) { int inputIdx = networkInputs.indexOf(s); if (inputIdx != -1) { vertexOutputs.put(s, inputTypes[inputIdx]); continue; } GraphVertex gv = vertices.get(s); List inputTypeList = new ArrayList<>(); if (gv instanceof LayerVertex) { //Add preprocessor, if necessary: String in = vertexInputs.get(s).get(0); InputType layerInput = vertexOutputs.get(in); inputTypeList.add(layerInput); LayerVertex lv = (LayerVertex) gv; Layer l = lv.getLayerConf().getLayer(); //Preprocessors - add if necessary if (lv.getPreProcessor() == null){ //But don't override preprocessors that are manually defined; if none has been defined, //add the appropriate preprocessor for this input type/layer combination InputPreProcessor preproc = l.getPreProcessorForInputType(layerInput); lv.setPreProcessor(preproc); } //Set nIn value for layer (if not already set) InputType afterPreproc = layerInput; if(lv.getPreProcessor() != null){ InputPreProcessor ip = lv.getPreProcessor(); afterPreproc = ip.getOutputType(layerInput); } l.setNIn(afterPreproc, false); currLayerIdx++; } else { List inputs = vertexInputs.get(s); if (inputs != null) { for (String inputVertexName : inputs) { inputTypeList.add(vertexOutputs.get(inputVertexName)); } } } InputType outputFromVertex = gv.getOutputType(currLayerIdx, inputTypeList.toArray(new InputType[inputTypeList.size()])); vertexOutputs.put(s, outputFromVertex); } } @Data public static class GraphBuilder { protected Map vertices = new LinkedHashMap<>(); /** * Key: graph node. Values: input to that node */ protected Map> vertexInputs = new LinkedHashMap<>(); protected List networkInputs = new ArrayList<>(); protected List networkInputTypes = new ArrayList<>(); protected List networkOutputs = new ArrayList<>(); protected boolean pretrain = false; protected boolean backprop = true; protected BackpropType backpropType = BackpropType.Standard; protected int tbpttFwdLength = 20; protected int tbpttBackLength = 20; protected Map inputPreProcessors = new LinkedHashMap<>(); protected NeuralNetConfiguration.Builder globalConfiguration; public GraphBuilder(NeuralNetConfiguration.Builder globalConfiguration) { this.globalConfiguration = globalConfiguration; } /** * @deprecated As of 0.6.0 */ @Deprecated public GraphBuilder redistributeParams(boolean redistributeParams) { return this; } /** * Specify the processors for a given layer * These are used at each layer for doing things like normalization and shaping of input.
* Note: preprocessors can also be defined using the {@link #addLayer(String, Layer, InputPreProcessor, String...)} method. * * @param layer the name of the layer that this preprocessor will be used with * @param processor the preprocessor to use for the specified layer */ public GraphBuilder inputPreProcessor(String layer, InputPreProcessor processor) { inputPreProcessors.put(layer, processor); return this; } /** * Whether to do back prop (standard supervised learning) or not * * @param backprop whether to do back prop or not */ public GraphBuilder backprop(boolean backprop) { this.backprop = backprop; return this; } /** * Whether to do layerwise pre training or not * * @param pretrain whether to do pre train or not */ public GraphBuilder pretrain(boolean pretrain) { this.pretrain = pretrain; return this; } /** * The type of backprop. Default setting is used for most networks (MLP, CNN etc), * but optionally truncated BPTT can be used for training recurrent neural networks. * If using TruncatedBPTT make sure you set both tBPTTForwardLength() and tBPTTBackwardLength() * * @param type Type of backprop. Default: BackpropType.Standard */ public GraphBuilder backpropType(BackpropType type) { this.backpropType = type; return this; } /** * When doing truncated BPTT: how many steps of forward pass should we do * before doing (truncated) backprop?
* Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
* Typically tBPTTForwardLength parameter is same as the tBPTTBackwardLength parameter, * but may be larger than it in some circumstances (but never smaller)
* Ideally your training data time series length should be divisible by this * This is the k1 parameter on pg23 of * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf * * @param forwardLength Forward length > 0, >= backwardLength */ public GraphBuilder tBPTTForwardLength(int forwardLength) { this.tbpttFwdLength = forwardLength; return this; } /** * When doing truncated BPTT: how many steps of backward should we do?
* Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
* This is the k2 parameter on pg23 of * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf * * @param backwardLength <= forwardLength */ public GraphBuilder tBPTTBackwardLength(int backwardLength) { this.tbpttBackLength = backwardLength; return this; } /** * Add a layer, with no {@link InputPreProcessor}, with the specified name and specified inputs. * * @param layerName Name/label of the layer to add * @param layer The layer configuration * @param layerInputs Inputs to this layer (must be 1 or more). Inputs may be other layers, GraphVertex objects, * on a combination of the two. * @see #addLayer(String, Layer, InputPreProcessor, String...) */ public GraphBuilder addLayer(String layerName, Layer layer, String... layerInputs) { return addLayer(layerName, layer, null, layerInputs); } /** * Add a layer and an {@link InputPreProcessor}, with the specified name and specified inputs. * * @param layerName Name/label of the layer to add * @param layer The layer configuration * @param preProcessor The InputPreProcessor to use with this layer. * @param layerInputs Inputs to this layer (must be 1 or more). Inputs may be other layers, GraphVertex objects, * on a combination of the two. */ public GraphBuilder addLayer(String layerName, Layer layer, InputPreProcessor preProcessor, String... layerInputs) { NeuralNetConfiguration.Builder builder = globalConfiguration.clone(); builder.layer(layer); vertices.put(layerName, new LayerVertex(builder.build(), preProcessor)); //Automatically insert a MergeNode if layerInputs.length > 1 //Layers can only have 1 input if (layerInputs != null && layerInputs.length > 1) { String mergeName = layerName + "-merge"; addVertex(mergeName, new MergeVertex(), layerInputs); this.vertexInputs.put(layerName, Collections.singletonList(mergeName)); } else if (layerInputs != null) { this.vertexInputs.put(layerName, Arrays.asList(layerInputs)); } layer.setLayerName(layerName); return this; } /** * Specify the inputs to the network, and their associated labels. * * @param inputNames The names of the inputs. This also defines their order */ public GraphBuilder addInputs(String... inputNames) { Collections.addAll(networkInputs, inputNames); return this; } /**Specify the types of inputs to the network, so that:
* (a) preprocessors can be automatically added, and
* (b) the nIns (input size) for each layer can be automatically calculated and set
* The order here is the same order as .addInputs(). Thus, if you do .addInputs("a","b") and .setInputTypes(InputType.feedForward(), * InputType.convolutional(1,28,28)) then the input labelled "a" is a feed forward input, whereas the input labelled "b" in a CNN * input, with 28x28x1 images as input.
* Note: Using setInputTypes is not always necessary, but can be especially helpful for example with CNNs such that * the calculations on input/ouput sizes (width, height, depth, etc) don't need to be done manually.
* Note 2: If a preprocessor is manually added for a given layer, it will not be overridden by the automatic * addition of preprocessors. * Note 3: If a layer has an nIn set manually, this will not be overridden */ public GraphBuilder setInputTypes(InputType... inputTypes) { if(inputTypes != null && inputTypes.length > 0) Collections.addAll(networkInputTypes, inputTypes); return this; } /** * Set the network output labels. These should be the names of the OutputLayer instances in the network * * @param outputNames The names of the output layers. This also defines their order. */ public GraphBuilder setOutputs(String... outputNames) { Collections.addAll(networkOutputs, outputNames); return this; } /** * Add a {@link GraphVertex} to the network configuration. A GraphVertex defines forward and backward pass methods, * and can contain a {@link LayerVertex}, a {@link org.deeplearning4j.nn.conf.graph.ElementWiseVertex} to do element-wise * addition/subtraction, a {@link MergeVertex} to combine/concatenate the activations out of multiple layers or vertices, * a {@link org.deeplearning4j.nn.conf.graph.SubsetVertex} to select a subset of the activations out of another layer/GraphVertex.
* Custom GraphVertex objects (that extend the abstract {@link GraphVertex} class) may also be used. * * @param vertexName The name of the GraphVertex to add * @param vertex The GraphVertex to add * @param vertexInputs The inputs/activations to this GraphVertex */ public GraphBuilder addVertex(String vertexName, GraphVertex vertex, String... vertexInputs) { vertices.put(vertexName, vertex); this.vertexInputs.put(vertexName, Arrays.asList(vertexInputs)); return this; } /** * Create the ComputationGraphConfiguration from the Builder pattern */ public ComputationGraphConfiguration build() { ComputationGraphConfiguration conf = new ComputationGraphConfiguration(); conf.backprop = backprop; conf.pretrain = pretrain; conf.backpropType = backpropType; conf.tbpttBackLength = tbpttBackLength; conf.tbpttFwdLength = tbpttFwdLength; conf.networkInputs = networkInputs; conf.networkOutputs = networkOutputs; conf.vertices = this.vertices; conf.vertexInputs = this.vertexInputs; conf.defaultConfiguration = globalConfiguration.build(); conf.getDefaultConfiguration().setPretrain(pretrain); //Add preprocessors that were defined separately to the Layers to which they belong for (Map.Entry entry : inputPreProcessors.entrySet()) { GraphVertex gv = vertices.get(entry.getKey()); if (gv instanceof LayerVertex) { LayerVertex lv = (LayerVertex) gv; lv.setPreProcessor(entry.getValue()); } else { throw new IllegalStateException("Invalid configuration: InputPreProcessor defined for GraphVertex \"" + entry.getKey() + "\", but this vertex is not a LayerVertex"); } } for (Map.Entry gv : vertices.entrySet()) { if (gv.getValue() instanceof LayerVertex) { LayerVertex lv = (LayerVertex) gv.getValue(); Layer l = lv.getLayerConf().getLayer(); if (l instanceof BasePretrainNetwork) lv.getLayerConf().setPretrain(pretrain); } } conf.validate(); //throws exception for invalid configuration //Automatically add preprocessors, set nIns for CNN->dense transitions, etc if (!networkInputTypes.isEmpty()) { conf.addPreProcessors(networkInputTypes.toArray(new InputType[networkInputs.size()])); } return conf; } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy