All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.nn.graph.ComputationGraph Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*-
 *
 *  * Copyright 2016 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.nn.graph;

import lombok.Getter;
import lombok.Setter;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
import org.deeplearning4j.eval.*;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.util.ComputationGraphUtil;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.graph.vertex.impl.InputVertex;
import org.deeplearning4j.nn.graph.vertex.impl.LayerVertex;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.OneTimeLogger;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.*;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.memory.abstracts.DummyWorkspace;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * A ComputationGraph network is a neural network with arbitrary (directed acyclic graph) connection structure.
 * A ComputationGraph may also have an arbitrary number of inputs and outputs.
 *
 * @author Alex Black
 */
public class ComputationGraph implements Serializable, Model, NeuralNetwork {

    private static final Logger log = LoggerFactory.getLogger(ComputationGraph.class);

    protected ComputationGraphConfiguration configuration;
    protected boolean initCalled = false;
    protected transient Solver solver; //Used to call optimizers during backprop
    protected INDArray flattenedParams; //Params for all layers are a view/subset of this array
    @Getter
    protected transient INDArray flattenedGradients; //Gradients for all layers are a view/subset of this array
    protected Gradient gradient;
    protected double score;
    @Setter
    private boolean initDone = false;

    public final static String workspaceCache = "LOOP_CACHE";
    public final static String workspaceExternal = "LOOP_EXTERNAL";
    public final static String workspaceFeedForward = "LOOP_FF";
    public final static String workspacePretrain = "LOOP_PTR";
    public final static String workspaceTBPTT = "LOOP_TBPTT";
    public final static String workspaceLSTM = "LOOP_LSTM";

    public final static WorkspaceConfiguration workspaceConfigurationFeedForward = WorkspaceConfiguration.builder()
                    .initialSize(0).overallocationLimit(0.2).policyReset(ResetPolicy.BLOCK_LEFT)
                    .policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE)
                    .policyLearning(LearningPolicy.OVER_TIME).build();

    public final static WorkspaceConfiguration workspaceConfigurationTBPTT = WorkspaceConfiguration.builder()
                    .initialSize(0).overallocationLimit(0.2).policyReset(ResetPolicy.BLOCK_LEFT)
                    .policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE)
                    .policyLearning(LearningPolicy.OVER_TIME).build();

    public final static WorkspaceConfiguration workspaceConfigurationLSTM = WorkspaceConfiguration.builder()
                    .initialSize(0).overallocationLimit(0.2).policyReset(ResetPolicy.BLOCK_LEFT)
                    .policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE)
                    .policyLearning(LearningPolicy.FIRST_LOOP).build();

    public final static WorkspaceConfiguration workspaceConfigurationExternal = WorkspaceConfiguration.builder()
                    .overallocationLimit(0.2).policyReset(ResetPolicy.BLOCK_LEFT).cyclesBeforeInitialization(3)
                    .policySpill(SpillPolicy.REALLOCATE).policyLearning(LearningPolicy.OVER_TIME).build();

    public final static WorkspaceConfiguration workspaceConfigurationCache = WorkspaceConfiguration.builder()
                    .overallocationLimit(0.2).policyReset(ResetPolicy.BLOCK_LEFT).cyclesBeforeInitialization(3)
                    .policyMirroring(MirroringPolicy.FULL).policySpill(SpillPolicy.REALLOCATE)
                    .policyLearning(LearningPolicy.OVER_TIME).build();

    protected ThreadLocal lastEtlTime = new ThreadLocal<>();

    /**
     * All GraphVertex objects in the network.
     */
    protected GraphVertex[] vertices;
    /**
     * Map of vertices by name
     */
    protected Map verticesMap;
    /**
     * Indexes of graph vertices, in topological order. The topological order defines the order in which forward pass
     * (and hence also backward pass, which is the opposite to this) is conducted in the network.
     */
    protected int[] topologicalOrder;
    /**
     * A list of layers. Each of these layers is present in a GraphVertex, but are here for easy reference.
     * This array also defines the order in which the getLayer(int) method returns layers.
     */
    protected Layer[] layers;

    /**
     * The number of input arrays to the network. Many networks only have 1 input; however, a ComputationGraph may
     * have an arbitrary number (>=1) separate input arrays
     */
    private int numInputArrays;
    /**
     * The number of output arrays to the network. Many networks only have 1 output; however, a ComputationGraph may
     * have an arbitrary number (>=1) separate output arrays
     */
    private int numOutputArrays;

    //Current inputs, labels, input mask arrays and label mask arrays
    private transient INDArray[] inputs;
    private transient INDArray[] labels;
    private transient INDArray[] inputMaskArrays;
    private transient INDArray[] labelMaskArrays;

    private NeuralNetConfiguration defaultConfiguration;
    private Collection listeners = new ArrayList<>();
    private Collection trainingListeners = new ArrayList<>();


    public ComputationGraph(ComputationGraphConfiguration configuration) {
        this.configuration = configuration;
        this.numInputArrays = configuration.getNetworkInputs().size();
        this.numOutputArrays = configuration.getNetworkOutputs().size();
        this.inputs = new INDArray[numInputArrays];
        this.labels = new INDArray[numOutputArrays];
        this.defaultConfiguration = configuration.getDefaultConfiguration();
    }

    /**
     * This method allows to set ETL field time, useful for performance tracking
     * @param time
     */
    public void setLastEtlTime(long time) {
        lastEtlTime.set(time);
    }

    /**
     * This method returns ETL time field value
     * @return
     */
    public long getLastEtlTime() {
        Long time = lastEtlTime.get();
        return time == null ? 0L : time;
    }

    /**
     * This method sets specified CacheMode for all layers within network
     *
     * @param mode
     */
    public void setCacheMode(CacheMode mode) {
        if (mode == null)
            mode = CacheMode.NONE;

        for (Layer layer : layers) {
            layer.setCacheMode(mode);
        }
    }

    /**
     * This method returns configuration of this ComputationGraph
     * @return
     */
    public ComputationGraphConfiguration getConfiguration() {
        return configuration;
    }

    /**
     * Returns the number of layers in the ComputationGraph
     */
    public int getNumLayers() {
        return (layers != null ? layers.length : 0);
    }

    /**
     * Get the layer by the number of that layer, in range 0 to getNumLayers()-1
     * NOTE: This is different from the internal GraphVertex index for the layer
     */
    public Layer getLayer(int idx) {
        return layers[idx];
    }

    /**
     * Get all layers in the ComputationGraph
     */
    public Layer[] getLayers() {
        return layers;
    }

    /**
     * Get a given layer by name.
     */
    public Layer getLayer(String name) {
        return verticesMap.get(name).getLayer(); //TODO checks
    }

    /**
     * Returns an array of all GraphVertex objects.
     */
    public GraphVertex[] getVertices() {
        return vertices;
    }

    /**
     * Return a given GraphVertex by name, or null if no vertex with that name exists
     */
    public GraphVertex getVertex(String name) {
        return verticesMap.get(name);
    }

    /**
     * The number of inputs to this network
     */
    public int getNumInputArrays() {
        return numInputArrays;
    }

    /**
     * The number of output (arrays) for this network
     */
    public int getNumOutputArrays() {
        return numOutputArrays;
    }

    /**
     * Set the specified input for the ComputationGraph
     */
    public void setInput(int inputNum, INDArray input) {
        if (inputs == null) {
            //May be null after clear()
            inputs = new INDArray[numInputArrays];
        }
        inputs[inputNum] = input;
    }

    /**
     * Set all inputs for the ComputationGraph network
     */
    public void setInputs(INDArray... inputs) {
        if (inputs != null && inputs.length != this.numInputArrays) {
            throw new IllegalArgumentException("Invalid input array: network has " + numInputArrays
                            + " inputs, but array is of length " + inputs.length);
        }
        this.inputs = inputs;
    }

    /**
     * Get the previously set input for the ComputationGraph
     */
    public INDArray getInput(int inputNum) {
        if (inputs == null)
            return null;
        return inputs[inputNum];
    }

    /**
     * Get the previously set inputs for the ComputationGraph
     */
    public INDArray[] getInputs() {
        return inputs;
    }

    /**
     * Get the previously set feature/input mask arrays for the ComputationGraph
     */
    public INDArray[] getInputMaskArrays() {
        return inputMaskArrays;
    }

    /**
     * Get the previously set label/output mask arrays for the ComputationGraph
     */
    public INDArray[] getLabelMaskArrays() {
        return labelMaskArrays;
    }

    /**
     * Set the specified label for the ComputationGraph
     */
    public void setLabel(int labelNum, INDArray label) {
        labels[labelNum] = label;
    }

    /**
     * Set all labels for the ComputationGraph network
     */
    public void setLabels(INDArray... labels) {
        if (labels != null && labels.length != this.numOutputArrays) {
            throw new IllegalArgumentException("Invalid output array: network has " + numOutputArrays
                            + " outputs, but array is of length " + labels.length);
        }
        this.labels = labels;
    }

    /**
     * This method allows you to specificy GradientsAccumulator instance to be used with this model
     *
     * PLEASE NOTE: Do not use this method unless you understand how to use GradientsAccumulator & updates sharing.
     * PLEASE NOTE: Do not use this method on standalone model
     *
     * @param accumulator
     */
    public void setGradientsAccumulator(GradientsAccumulator accumulator) {
        if (!initCalled)
            init();

        solver.getOptimizer().setGradientsAccumulator(accumulator);
    }

    /**
     * Initialize the ComputationGraph network
     */
    public void init() {
        init(null, false);
    }

    /**
     * Initialize the ComputationGraph, optionally with an existing parameters array.
     * If an existing parameters array is specified, it will be used (and the values will not be modified) in the network;
     * if no parameters array is specified, parameters will be initialized randomly according to the network configuration.
     *
     * @param parameters           Network parameter. May be null. If null: randomly initialize.
     * @param cloneParametersArray Whether the parameter array (if any) should be cloned, or used directly
     */
    public void init(INDArray parameters, boolean cloneParametersArray) {
        if (initCalled)
            return;

        OneTimeLogger.info(log, "Starting ComputationGraph with WorkspaceModes set to [training: {}; inference: {}]",
                        configuration.getTrainingWorkspaceMode(), configuration.getInferenceWorkspaceMode());

        if (configuration.getCacheMode() == CacheMode.HOST) {
            workspaceConfigurationCache.setPolicyMirroring(MirroringPolicy.HOST_ONLY);
        }

        //First: build topological ordering, based on configuration. Used for forward pass, backprop and order of parameters/gradients
        topologicalOrder = topologicalSortOrder();

        //Initialization: create the GraphVertex objects, based on configuration structure
        Map configVertexMap = configuration.getVertices();

        //Names of all of the (data) inputs to the ComputationGraph
        List networkInputNames = configuration.getNetworkInputs();

        //Inputs for each layer and GraphNode:
        Map> vertexInputs = configuration.getVertexInputs();
        this.vertices = new GraphVertex[networkInputNames.size() + configuration.getVertices().size()];

        //All names: inputs, layers and graph nodes (index to name map)
        Map allNamesReverse = new HashMap<>();

        //Create network input vertices:
        int vertexNumber = 0;
        for (String name : networkInputNames) {
            GraphVertex gv = new InputVertex(this, name, vertexNumber, null); //Output vertices: set later
            allNamesReverse.put(name, vertexNumber);
            vertices[vertexNumber++] = gv;
        }

        //Go through layers, and work out total number of parameters. Then allocate full parameters array
        int numParams = 0;
        int[] numParamsForVertex = new int[topologicalOrder.length];
        int i = 0;
        for (; i < configuration.getNetworkInputs().size(); i++) {
            numParamsForVertex[i] = 0; //No parameters for input vertices
        }
        for (Map.Entry nodeEntry : configVertexMap.entrySet()) {
            org.deeplearning4j.nn.conf.graph.GraphVertex n = nodeEntry.getValue();
            numParamsForVertex[i] = n.numParams(true);
            numParams += numParamsForVertex[i];
            i++;
        }

        boolean initializeParams;
        if (parameters != null) {
            if (!parameters.isRowVector())
                throw new IllegalArgumentException("Invalid parameters: should be a row vector");
            if (parameters.length() != numParams)
                throw new IllegalArgumentException("Invalid parameters: expected length " + numParams + ", got length "
                                + parameters.length());

            if (cloneParametersArray)
                flattenedParams = parameters.dup();
            else
                flattenedParams = parameters;

            initializeParams = false;
        } else {
            flattenedParams = Nd4j.create(1, numParams);
            initializeParams = true;
        }

        //Set RNG seed, for repeatability between initializations when set
        if (initializeParams) {
            Nd4j.getRandom().setSeed(conf().getSeed());
        }

        //Given the topological ordering: work out the subset of the parameters array used for each layer
        // Then extract out for use when initializing the Layers
        INDArray[] paramsViewForVertex = new INDArray[topologicalOrder.length];
        int paramOffsetSoFar = 0;
        i = 0;
        for (int vertexIdx : topologicalOrder) {
            int nParamsThisVertex = numParamsForVertex[vertexIdx];
            if (nParamsThisVertex != 0) {
                paramsViewForVertex[vertexIdx] = flattenedParams.get(NDArrayIndex.point(0),
                                NDArrayIndex.interval(paramOffsetSoFar, paramOffsetSoFar + nParamsThisVertex));
            }
            i++;
            paramOffsetSoFar += nParamsThisVertex;
        }


        int numLayers = 0;
        List tempLayerList = new ArrayList<>();
        defaultConfiguration.clearVariables();
        List variables = defaultConfiguration.variables(false);
        for (Map.Entry nodeEntry : configVertexMap.entrySet()) {
            org.deeplearning4j.nn.conf.graph.GraphVertex n = nodeEntry.getValue();
            String name = nodeEntry.getKey();
            GraphVertex gv = n.instantiate(this, name, vertexNumber, paramsViewForVertex[vertexNumber],
                            initializeParams);

            if (gv.hasLayer()) {
                numLayers++;
                Layer l = gv.getLayer();
                tempLayerList.add(l);
                List layerVariables = l.conf().variables();
                if (layerVariables != null) {
                    for (String s : layerVariables) {
                        variables.add(gv.getVertexName() + "_" + s);
                    }
                }
            }

            allNamesReverse.put(name, vertexNumber);
            vertices[vertexNumber++] = gv;
        }
        layers = tempLayerList.toArray(new Layer[numLayers]);


        //Create the lookup table, so we can find vertices easily by name
        verticesMap = new HashMap<>();
        for (GraphVertex gv : vertices) {
            verticesMap.put(gv.getVertexName(), gv);
        }

        //Now: do another pass to set the input and output indices, for each vertex
        // These indices are used during forward and backward passes
        //To get output indices: need to essentially build the graph in reverse...
        Map> verticesOutputTo = new HashMap<>(); //Key: vertex. Values: vertices that this node is an input for
        for (GraphVertex gv : vertices) {
            String vertexName = gv.getVertexName();
            List vertexInputNames;
            vertexInputNames = vertexInputs.get(vertexName);

            if (vertexInputNames == null)
                continue;

            //Build reverse network structure:
            for (String s : vertexInputNames) {
                List list = verticesOutputTo.get(s);
                if (list == null) {
                    list = new ArrayList<>();
                    verticesOutputTo.put(s, list);
                }
                list.add(vertexName); //Edge: s -> vertexName
            }
        }


        for (GraphVertex gv : vertices) {
            String vertexName = gv.getVertexName();
            int vertexIndex = gv.getVertexIndex();
            List vertexInputNames;
            vertexInputNames = vertexInputs.get(vertexName);

            if (vertexInputNames == null)
                continue;

            VertexIndices[] inputIndices = new VertexIndices[vertexInputNames.size()];
            for (int j = 0; j < vertexInputNames.size(); j++) {
                String inName = vertexInputNames.get(j);
                int inputVertexIndex = allNamesReverse.get(inName);

                //Output of vertex 'inputVertexIndex' is the jth input to the current vertex
                //For input indices, we need to know which output connection of vertex 'inputVertexIndex' this represents
                GraphVertex inputVertex = vertices[inputVertexIndex];
                //First: get the outputs of the input vertex...
                List inputVertexOutputsTo = verticesOutputTo.get(inName);
                int outputNumberOfInput = inputVertexOutputsTo.indexOf(vertexName);


                if (outputNumberOfInput == -1)
                    throw new IllegalStateException("Could not find vertex " + vertexIndex + " in the list of outputs "
                                    + "for vertex " + inputVertex + "; error in graph structure?");
                //Overall here: the 'outputNumberOfInput'th output of vertex 'inputVertexIndex' is the jth input to the current vertex

                inputIndices[j] = new VertexIndices(inputVertexIndex, outputNumberOfInput);
            }

            gv.setInputVertices(inputIndices);
        }

        //Handle the outputs for this vertex
        for (GraphVertex gv : vertices) {
            String vertexName = gv.getVertexName();

            List thisVertexOutputsTo = verticesOutputTo.get(vertexName);

            if (thisVertexOutputsTo == null || thisVertexOutputsTo.isEmpty())
                continue; //Output vertex
            VertexIndices[] outputIndices = new VertexIndices[thisVertexOutputsTo.size()];
            int j = 0;
            for (String s : thisVertexOutputsTo) {
                //First, we have gv -> s
                //Which input in s does gv connect to? s may in general have multiple inputs...
                List nextVertexInputNames = vertexInputs.get(s);

                int outputVertexInputNumber = nextVertexInputNames.indexOf(vertexName);

                int outputVertexIndex = allNamesReverse.get(s);
                outputIndices[j++] = new VertexIndices(outputVertexIndex, outputVertexInputNumber);
            }
            gv.setOutputVertices(outputIndices);
        }

        // now we init solver & optimizer
        if (solver == null) {
            try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
                solver.initOptimizer();
            }
        }

        initCalled = true;
    }

    /**
     * This method: initializes the flattened gradients array (used in backprop) and sets the appropriate subset in all layers.
     * As a general rule, this shouldn't ever need to be called manually when doing training via fit(DataSet), fit(DataSetIterator)
     * or fit(MultiDataSet) methods
     */
    public void initGradientsView() {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
            if (!initCalled)
                init();

            //Go through layers, and work out total number of parameters. Then allocate full parameters array
            int numParams = 0;
            int[] numParamsForVertex = new int[topologicalOrder.length];
            int i = 0;
            for (; i < configuration.getNetworkInputs().size(); i++) {
                numParamsForVertex[i] = 0; //No parameters for input vertices
            }
            Map configVertexMap = configuration.getVertices();
            for (Map.Entry nodeEntry : configVertexMap
                            .entrySet()) {
                org.deeplearning4j.nn.conf.graph.GraphVertex n = nodeEntry.getValue();
                numParamsForVertex[i] = n.numParams(true);
                numParams += numParamsForVertex[i];
                i++;
            }
            flattenedGradients = Nd4j.create(1, numParams);

            //Given the topological ordering: work out the subset of the gradient array used for each layer, and set it
            int paramOffsetSoFar = 0;
            i = 0;
            for (int vertexIdx : topologicalOrder) {
                int nParamsThisVertex = numParamsForVertex[vertexIdx];
                if (nParamsThisVertex != 0) {
                    INDArray gradientView = flattenedGradients.get(NDArrayIndex.point(0),
                                    NDArrayIndex.interval(paramOffsetSoFar, paramOffsetSoFar + nParamsThisVertex));
                    vertices[vertexIdx].setBackpropGradientsViewArray(gradientView);
                }
                i++;
                paramOffsetSoFar += nParamsThisVertex;
            }
        }
    }

    /**
     * Pretrain network with a single input and single output. DataSetIterators can only be used if the number of input
     * arrays for the ComputationGraph is 1.
     * For networks with more than one input use {@link #pretrain(MultiDataSetIterator)}
     */
    public void pretrain(DataSetIterator iter) {
        if (numInputArrays != 1) {
            throw new UnsupportedOperationException(
                            "Cannot train ComputationGraph network with  multiple inputs using a DataSetIterator");
        }

        pretrain(ComputationGraphUtil.toMultiDataSetIterator(iter));
    }

    /**
     * Pretrain network with multiple inputs and/or outputs
     */
    public void pretrain(MultiDataSetIterator iter) {
        if (!configuration.isPretrain())
            return;
        if (flattenedGradients == null) {
            initGradientsView();
        }

        //Assume here that all layers are pretrainable layers
        for (int i = 0; i < topologicalOrder.length; i++) {
            if (!vertices[i].hasLayer())
                continue;
            if (vertices[i].getLayer() instanceof IOutputLayer)
                continue; //Don't pretrain output layer
            if (!vertices[i].getLayer().isPretrainLayer())
                continue; //Skip layers that aren't pretrainable

            pretrainLayer(vertices[i].getVertexName(), iter);
        }
    }

    /**
     * Pretrain a specified layer with the given DataSetIterator
     *
     * @param layerName       Layer name
     * @param dataSetIterator Data
     */
    public void pretrainLayer(String layerName, DataSetIterator dataSetIterator) {
        if (numInputArrays != 1) {
            throw new UnsupportedOperationException(
                            "Cannot train ComputationGraph network with  multiple inputs using a DataSetIterator");
        }

        pretrainLayer(layerName, ComputationGraphUtil.toMultiDataSetIterator(dataSetIterator));
    }

    /**
     * Pretrain a specified layer with the given MultiDataSetIterator
     *
     * @param layerName       Layer name
     * @param iter Training data
     */
    public void pretrainLayer(String layerName, MultiDataSetIterator iter) {
        if (!configuration.isPretrain())
            return;
        if (flattenedGradients == null) {
            initGradientsView();
        }

        if (!verticesMap.containsKey(layerName)) {
            throw new IllegalStateException("Invalid vertex name: " + layerName);
        }
        if (!verticesMap.get(layerName).hasLayer()) {
            //No op
            return;
        }

        int layerIndex = verticesMap.get(layerName).getVertexIndex();

        //Need to do partial forward pass. Simply folowing the topological ordering won't be efficient, as we might
        // end up doing forward pass on layers we don't need to.
        //However, we can start with the topological order, and prune out any layers we don't need to do

        LinkedList partialTopoSort = new LinkedList<>();
        Set seenSoFar = new HashSet<>();
        partialTopoSort.add(topologicalOrder[layerIndex]);
        seenSoFar.add(topologicalOrder[layerIndex]);
        for (int j = layerIndex - 1; j >= 0; j--) {
            //Do we need to do forward pass on this GraphVertex?
            //If it is input to any other layer we need, then yes. Otherwise: no
            VertexIndices[] outputsTo = vertices[topologicalOrder[j]].getOutputVertices();
            boolean needed = false;
            for (VertexIndices vi : outputsTo) {
                if (seenSoFar.contains(vi.getVertexIndex())) {
                    needed = true;
                    break;
                }
            }
            if (needed) {
                partialTopoSort.addFirst(topologicalOrder[j]);
                seenSoFar.add(topologicalOrder[j]);
            }
        }

        int[] fwdPassOrder = new int[partialTopoSort.size()];
        int k = 0;
        for (Integer g : partialTopoSort)
            fwdPassOrder[k++] = g;

        GraphVertex gv = vertices[fwdPassOrder[fwdPassOrder.length - 1]];
        Layer layer = gv.getLayer();

        if (!iter.hasNext() && iter.resetSupported()) {
            iter.reset();
        }

        MemoryWorkspace workspace =
                configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace()
                        : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
                        ComputationGraph.workspaceConfigurationExternal, ComputationGraph.workspaceExternal);
        MemoryWorkspace cache = configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace()
                : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceConfigurationCache, ComputationGraph.workspaceCache);

        MemoryWorkspace wsFF = configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE
                ? new DummyWorkspace()
                : configuration.getTrainingWorkspaceMode() == WorkspaceMode.SINGLE
                ? Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceExternal)
                : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
                workspaceConfigurationFeedForward, workspaceFeedForward);

        MemoryWorkspace wsPTR = configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE
                ? new DummyWorkspace()
                : configuration.getTrainingWorkspaceMode() == WorkspaceMode.SINGLE
                ? Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceExternal)
                : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
                workspaceConfigurationFeedForward, workspacePretrain);

        while (iter.hasNext()) {
            MultiDataSet multiDataSet = iter.next();

            try (MemoryWorkspace wsCache = cache.notifyScopeEntered()) {
                try (MemoryWorkspace ws = workspace.notifyScopeEntered()) {
                    try (MemoryWorkspace wP = wsPTR.notifyScopeEntered()) {

                        setInputs(multiDataSet.getFeatures());

                        for (int j = 0; j < fwdPassOrder.length - 1; j++) {
                            try (MemoryWorkspace wF = wsFF.notifyScopeEntered()) {
                                GraphVertex current = vertices[fwdPassOrder[j]];
                                if (current.isInputVertex()) {
                                    VertexIndices[] inputsTo = current.getOutputVertices();
                                    INDArray input = inputs[current.getVertexIndex()];

                                    for (VertexIndices v : inputsTo) {
                                        int vIdx = v.getVertexIndex();
                                        int vIdxInputNum = v.getVertexEdgeNumber();
                                        //This input: the 'vIdxInputNum'th input to vertex 'vIdx'
                                        vertices[vIdx].setInput(vIdxInputNum, input.dup().leverageTo(workspacePretrain)); //TODO When to dup?
                                    }

                                } else {
                                    //Do forward pass:
                                    INDArray out = current.doForward(true);

                                    //Now, set the inputs for the next vertices:
                                    VertexIndices[] outputsTo = current.getOutputVertices();
                                    if (outputsTo != null) {
                                        for (VertexIndices v : outputsTo) {
                                            int vIdx = v.getVertexIndex();
                                            int inputNum = v.getVertexEdgeNumber();
                                            //This (jth) connection from the output: is the 'inputNum'th input to vertex 'vIdx'
                                            vertices[vIdx].setInput(inputNum, out);
                                        }
                                    }
                                }
                            }
                        }
                        //At this point: have done all of the required forward pass stuff. Can now pretrain layer on current input

                        layer.fit(gv.getInputs()[0]);
                        layer.conf().setPretrain(false);
                    }
                }
            }
        }
    }

    /**
     * Fit the ComputationGraph using a DataSet.
     * Note that this method can only be used with ComputationGraphs with 1 input and 1 output.
     * For networks with more than one input or output, use {@link #fit(MultiDataSetIterator)}
     */
    public void fit(DataSet dataSet) {
        if (numInputArrays != 1 || numOutputArrays != 1)
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with "
                            + " multiple inputs or outputs using a DataSet");

        boolean hasMaskArrays = dataSet.hasMaskArrays();
        if (hasMaskArrays) {
            INDArray[] fMask = (dataSet.getFeaturesMaskArray() != null ? new INDArray[] {dataSet.getFeaturesMaskArray()}
                            : null);
            INDArray[] lMask = (dataSet.getLabelsMaskArray() != null ? new INDArray[] {dataSet.getLabelsMaskArray()}
                            : null);
            fit(new INDArray[] {dataSet.getFeatures()}, new INDArray[] {dataSet.getLabels()}, fMask, lMask);
        } else {
            fit(new INDArray[] {dataSet.getFeatures()}, new INDArray[] {dataSet.getLabels()});
        }

        if (hasMaskArrays)
            clearLayerMaskArrays();

        clearLayersStates();
    }

    /**
     * Fit the ComputationGraph using a DataSetIterator.
     * Note that this method can only be used with ComputationGraphs with 1 input and 1 output
     */
    public void fit(DataSetIterator iterator) {
        if (flattenedGradients == null) {
            initGradientsView();
        }
        if (numInputArrays != 1 || numOutputArrays != 1)
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with "
                            + " multiple inputs or outputs using a DataSetIterator");

        boolean destructable = false;

        DataSetIterator dataSetIterator;
        // we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
        if (iterator.asyncSupported()) {
            dataSetIterator = new AsyncDataSetIterator(iterator,
                            Math.min(Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2),
                            configuration.getTrainingWorkspaceMode() != WorkspaceMode.NONE);
            destructable = true;
        } else
            dataSetIterator = iterator;

        if (trainingListeners.size() > 0) {
            for (TrainingListener tl : trainingListeners) {
                tl.onEpochStart(this);
            }
        }

        if (configuration.isPretrain()) {
            pretrain(dataSetIterator);
        }

        MemoryWorkspace workspace =
                        configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace()
                                        : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
                                                        workspaceConfigurationExternal, workspaceExternal);
        MemoryWorkspace cache = configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace()
                        : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationCache,
                                        workspaceCache);

        if (configuration.isBackprop()) {
            update(TaskUtils.buildTask(dataSetIterator));

            long time1 = System.currentTimeMillis();
            while (dataSetIterator.hasNext()) {
                DataSet next = dataSetIterator.next();
                long time2 = System.currentTimeMillis();

                lastEtlTime.set((time2 - time1));

                if (next.getFeatures() == null || next.getLabels() == null)
                    break;


                //migrate(next);

                boolean hasMaskArrays = next.hasMaskArrays();
                if (hasMaskArrays) {
                    INDArray[] fMask = (next.getFeaturesMaskArray() != null
                                    ? new INDArray[] {next.getFeaturesMaskArray()} : null);
                    INDArray[] lMask = (next.getLabelsMaskArray() != null ? new INDArray[] {next.getLabelsMaskArray()}
                                    : null);
                    setLayerMaskArrays(fMask, lMask);
                }

                if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                    doTruncatedBPTT(new INDArray[] {next.getFeatures()}, new INDArray[] {next.getLabels()},
                                    (hasMaskArrays ? new INDArray[] {next.getFeaturesMaskArray()} : null),
                                    (hasMaskArrays ? new INDArray[] {next.getLabelsMaskArray()} : null));
                } else {
                    setInput(0, next.getFeatures());
                    setLabel(0, next.getLabels());
                    if (solver == null) {
                        try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                            solver = new Solver.Builder().configure(defaultConfiguration) //TODO; don't like this
                                            .listeners(listeners).model(this).build();
                        }
                    }

                    try (MemoryWorkspace wsCache = cache.notifyScopeEntered()) {
                        try (MemoryWorkspace ws = workspace.notifyScopeEntered()) {
                            solver.optimize();
                        }
                    }
                }

                if (hasMaskArrays) {
                    clearLayerMaskArrays();
                }

                time1 = System.currentTimeMillis();
            }

            Nd4j.getMemoryManager().invokeGcOccasionally();
        }


        if (trainingListeners.size() > 0) {
            for (TrainingListener tl : trainingListeners) {
                tl.onEpochEnd(this);
            }
        }

        clearLayersStates();

        if (destructable)
            ((AsyncDataSetIterator) dataSetIterator).shutdown();
    }

    /**
     * Fit the ComputationGraph using a MultiDataSet
     */
    public void fit(MultiDataSet multiDataSet) {
        fit(multiDataSet.getFeatures(), multiDataSet.getLabels(), multiDataSet.getFeaturesMaskArrays(),
                        multiDataSet.getLabelsMaskArrays());
        if (multiDataSet.hasMaskArrays())
            clearLayerMaskArrays();
    }

    /**
     * Fit the ComputationGraph using a MultiDataSetIterator
     */
    public void fit(MultiDataSetIterator multi) {
        if (flattenedGradients == null) {
            initGradientsView();
        }

        boolean destructable = false;

        MultiDataSetIterator multiDataSetIterator;
        if (multi.asyncSupported()) {
            multiDataSetIterator = new AsyncMultiDataSetIterator(multi,
                            Math.max(Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2),
                            configuration.getTrainingWorkspaceMode() != WorkspaceMode.NONE);
            destructable = true;
        } else
            multiDataSetIterator = multi;

        if (configuration.isPretrain()) {
            pretrain(multiDataSetIterator);
        }


        MemoryWorkspace workspace =
                        configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace()
                                        : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
                                                        workspaceConfigurationExternal, workspaceExternal);

        MemoryWorkspace cache = configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace()
                        : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationCache,
                                        workspaceCache);

        if (configuration.isBackprop()) {

            long time1 = System.currentTimeMillis();
            while (multiDataSetIterator.hasNext()) {
                MultiDataSet next = multiDataSetIterator.next();
                long time2 = System.currentTimeMillis();

                lastEtlTime.set((time2 - time1));

                if (next.getFeatures() == null || next.getLabels() == null)
                    break;


                //migrate(next);

                if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                    doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArrays(),
                                    next.getLabelsMaskArrays());
                } else {
                    boolean hasMaskArrays = next.hasMaskArrays();
                    if (hasMaskArrays) {
                        setLayerMaskArrays(next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
                    }

                    setInputs(next.getFeatures());
                    setLabels(next.getLabels());
                    if (solver == null) {
                        try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                            solver = new Solver.Builder().configure(defaultConfiguration).listeners(listeners)
                                            .model(this).build();
                        }
                    }

                    try (MemoryWorkspace wsCache = cache.notifyScopeEntered()) {
                        try (MemoryWorkspace ws = workspace.notifyScopeEntered()) {
                            solver.optimize();
                        }
                    }

                    if (hasMaskArrays) {
                        clearLayerMaskArrays();
                    }
                }

                Nd4j.getMemoryManager().invokeGcOccasionally();
                time1 = System.currentTimeMillis();
            }
        }

        clearLayersStates();

        if (destructable)
            ((AsyncMultiDataSetIterator) multiDataSetIterator).shutdown();
    }

    protected void migrate(MultiDataSet ds) {
        if (ds.getFeatures() != null)
            for (int i = 0; i < ds.getFeatures().length; i++)
                if (ds.getFeatures()[i] != null && ds.getFeatures()[i].isAttached())
                    ds.getFeatures()[i] = ds.getFeatures()[i].migrate();

        if (ds.getFeaturesMaskArrays() != null)
            for (int i = 0; i < ds.getFeaturesMaskArrays().length; i++)
                if (ds.getFeaturesMaskArrays()[i] != null && ds.getFeaturesMaskArrays()[i].isAttached())
                    ds.getFeaturesMaskArrays()[i] = ds.getFeaturesMaskArrays()[i].migrate();

        if (ds.getLabels() != null)
            for (int i = 0; i < ds.getLabels().length; i++)
                if (ds.getLabels()[i] != null && ds.getLabels()[i].isAttached())
                    ds.getLabels()[i] = ds.getLabels()[i].migrate();

        if (ds.getLabelsMaskArrays() != null)
            for (int i = 0; i < ds.getLabelsMaskArrays().length; i++)
                if (ds.getLabelsMaskArrays()[i] != null && ds.getLabelsMaskArrays()[i].isAttached())
                    ds.getLabelsMaskArrays()[i] = ds.getLabelsMaskArrays()[i].migrate();

    }

    protected void migrate(DataSet ds) {
        if (ds.getFeatures() != null && ds.getFeatures().isAttached())
            ds.setFeatures(ds.getFeatures().migrate());

        if (ds.getLabels() != null && ds.getLabels().isAttached())
            ds.setLabels(ds.getLabels().migrate());

        if (ds.getFeaturesMaskArray() != null && ds.getFeaturesMaskArray().isAttached())
            ds.setFeaturesMaskArray(ds.getFeaturesMaskArray().migrate());

        if (ds.getLabelsMaskArray() != null && ds.getLabelsMaskArray().isAttached())
            ds.setLabelsMaskArray(ds.getLabelsMaskArray().migrate());
    }

    /**
     * Fit the ComputationGraph given arrays of inputs and labels.
     *
     * @param inputs The network inptus
     * @param labels The labels
     */
    public void fit(INDArray[] inputs, INDArray[] labels) {
        fit(inputs, labels, null, null);
    }

    /**
     * Fit the ComputationGraph using the specified inputs and labels (and mask arrays)
     *
     * @param inputs            The network inputs (features)
     * @param labels            The network labels
     * @param featureMaskArrays Mask arrays for inputs/features. Typically used for RNN training. May be null.
     * @param labelMaskArrays   Mas arrays for the labels/outputs. Typically used for RNN training. May be null.
     */
    public void fit(INDArray[] inputs, INDArray[] labels, INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
        if (flattenedGradients == null) {
            initGradientsView();
        }

        setInputs(inputs);
        setLabels(labels);
        setLayerMaskArrays(featureMaskArrays, labelMaskArrays);
        update(TaskUtils.buildTask(inputs, labels));

        MemoryWorkspace workspace =
                configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace()
                        : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
                        workspaceConfigurationExternal, workspaceExternal);
        MemoryWorkspace cache = configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace()
                : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationCache,
                workspaceCache);

        if (configuration.isPretrain()) {
            MultiDataSetIterator iter = new SingletonMultiDataSetIterator(new org.nd4j.linalg.dataset.MultiDataSet(inputs, labels, featureMaskArrays, labelMaskArrays));


            pretrain(iter);
        }

        if (configuration.isBackprop()) {
            if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                doTruncatedBPTT(inputs, labels, featureMaskArrays, labelMaskArrays);
            } else {
                if (solver == null) {
                    try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                        solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
                    }
                }

                try (MemoryWorkspace wsCache = cache.notifyScopeEntered()) {
                    try (MemoryWorkspace ws = workspace.notifyScopeEntered()) {
                        solver.optimize();
                    }
                }
            }
        }

        if (featureMaskArrays != null || labelMaskArrays != null) {
            clearLayerMaskArrays();
        }

        clearLayersStates();
    }

    /**
     * Calculate a topological sort order for the vertices in the graph.
     * Note that this is used for
     * (a) working out what order to do forward pass,
     * (b) what order to do backprop (i.e., reverse of this)
     * (c) order to flatten parameters (and gradients)
     *
     * Specifically, gradients/params/forward pass are executed on vertex[topologicalSortOrder[i]], for i=0..nVertices-1
     */
    public int[] topologicalSortOrder() {
        if (topologicalOrder != null)
            return topologicalOrder;

        //https://en.wikipedia.org/wiki/Topological_sorting#Kahn.27s_algorithm
        Map nodeMap = configuration.getVertices();
        List networkInputNames = configuration.getNetworkInputs();
        int numVertices = networkInputNames.size() + configuration.getVertices().size();
        int[] out = new int[numVertices];
        int outCounter = 0;

        //First: represent the graph more usefully as a Map>, where map represents edges i -> j
        // key represents j, set is set of i (inputs) for vertices j
        Map vertexNamesMap = new HashMap<>();
        Map vertexNamesMap2 = new HashMap<>();
        int i = 0;
        for (String inputName : configuration.getNetworkInputs()) {
            vertexNamesMap.put(i, inputName);
            vertexNamesMap2.put(inputName, i);
            i++;
        }
        for (Map.Entry entry : nodeMap.entrySet()) {
            String name = entry.getKey();
            vertexNamesMap.put(i, name);
            vertexNamesMap2.put(name, i);
            i++;
        }

        Map> inputEdges = new HashMap<>(); //key: vertex. Values: vertices that the key vertex receives input from
        Map> outputEdges = new HashMap<>(); //key: vertex. Values: vertices that the key vertex outputs to

        for (String s : configuration.getNetworkInputs()) {
            int idx = vertexNamesMap2.get(s);
            inputEdges.put(idx, null);
        }

        for (Map.Entry entry : nodeMap.entrySet()) {
            String thisVertexName = entry.getKey();
            int idx = vertexNamesMap2.get(thisVertexName);
            List inputsToThisVertex = configuration.getVertexInputs().get(thisVertexName);

            if (inputsToThisVertex == null || inputsToThisVertex.isEmpty()) {
                inputEdges.put(idx, null);
                continue;
            }

            Set inputSet = new HashSet<>();
            for (String s : inputsToThisVertex) {
                Integer inputIdx = vertexNamesMap2.get(s);
                if (inputIdx == null) {
                    System.out.println();
                }
                inputSet.add(inputIdx);
                Set outputSetForInputIdx = outputEdges.get(inputIdx);
                if (outputSetForInputIdx == null) {
                    outputSetForInputIdx = new HashSet<>();
                    outputEdges.put(inputIdx, outputSetForInputIdx);
                }
                outputSetForInputIdx.add(idx); //input vertex outputs to the current vertex
            }

            inputEdges.put(idx, inputSet);
        }

        //Now: do topological sort
        //Set of all nodes with no incoming edges: (this would be: input vertices)
        LinkedList noIncomingEdges = new LinkedList<>();
        for (Map.Entry> entry : inputEdges.entrySet()) {
            Set inputsFrom = entry.getValue();
            if (inputsFrom == null || inputsFrom.isEmpty()) {
                noIncomingEdges.add(entry.getKey());
            }
        }

        while (!noIncomingEdges.isEmpty()) {
            int next = noIncomingEdges.removeFirst();
            out[outCounter++] = next; //Add to sorted list

            Set vertexOutputsTo = outputEdges.get(next);

            //Remove edges next -> vertexOuputsTo[...] from graph;
            if (vertexOutputsTo != null) {
                for (Integer v : vertexOutputsTo) {
                    Set set = inputEdges.get(v);
                    set.remove(next);
                    if (set.isEmpty()) {
                        noIncomingEdges.add(v); //No remaining edges for vertex i -> add to list for processing
                    }
                }
            }
        }

        //If any edges remain in the graph: graph has cycles:
        for (Map.Entry> entry : inputEdges.entrySet()) {
            Set set = entry.getValue();
            if (set == null)
                continue;
            if (!set.isEmpty())
                throw new IllegalStateException(
                                "Invalid configuration: cycle detected in graph. Cannot calculate topological ordering with graph cycle ("
                                                + "cycle includes vertex \"" + vertexNamesMap.get(entry.getKey())
                                                + "\")");
        }

        return out;
    }

    @Override
    public void computeGradientAndScore() {
        //Calculate activations (which are stored in each layer, and used in backprop)
        if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
            Map activations = rnnActivateUsingStoredState(inputs, true, true);
            if (trainingListeners.size() > 0) {
                try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                    for (TrainingListener tl : trainingListeners) {
                        tl.onForwardPass(this, activations);
                    }
                }
            }
            calcBackpropGradients(true);
        } else {
            Map activations = feedForward(true, true, false, false);
            if (trainingListeners.size() > 0) {
                try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                    for (TrainingListener tl : trainingListeners) {
                        tl.onForwardPass(this, activations);
                    }
                }
            }
            calcBackpropGradients(false);
        }

        //Score: sum of the scores for the various output layers...
        double l1 = calcL1();
        double l2 = calcL2();

        score = 0.0;
        for (String s : configuration.getNetworkOutputs()) {
            GraphVertex gv = verticesMap.get(s);

            score += ((IOutputLayer) gv.getLayer()).computeScore(l1, l2, true);

            //Only want to add l1/l2 once...
            l1 = 0.0;
            l2 = 0.0;
        }

        //Listeners
        if (trainingListeners.size() > 0) {
            try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                for (TrainingListener tl : trainingListeners) {
                    tl.onBackwardPass(this);
                }
            }
        }
    }

    /**
     * Conduct forward pass using a single input array. Note that this method can only be used with ComputationGraphs
     * with a single input array.
     *
     * @param input The input array
     * @param train If true: do forward pass at training time
     * @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
     */
    public Map feedForward(INDArray input, boolean train) {
        if (numInputArrays != 1)
            throw new UnsupportedOperationException("Cannot feedForward with single input for graph network with "
                            + numInputArrays + " expected inputs");
        setInput(0, input);
        return feedForward(train);
    }

    /**
     * Conduct forward pass using an array of inputs
     *
     * @param input An array of ComputationGraph inputs
     * @param train If true: do forward pass at training time; false: do forward pass at test time
     * @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
     */
    public Map feedForward(INDArray[] input, boolean train) {
        if (numInputArrays != input.length)
            throw new UnsupportedOperationException("Cannot feedForward with " + input.length
                            + " inputs for graph network with " + numInputArrays + " expected inputs");
        for (int i = 0; i < input.length; i++)
            setInput(i, input[i]);
        return feedForward(train);
    }

    /**
     * Conduct forward pass using the stored inputs, at test time
     *
     * @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
     */
    public Map feedForward() {
        return feedForward(false);
    }

    /**
     * Conduct forward pass using the stored inputs
     *
     * @param train If true: do forward pass at training time; false: do forward pass at test time
     * @return A map of activations for each layer (not each GraphVertex). Keys = layer name, values = layer activations
     */
    public Map feedForward(boolean train) {
        return feedForward(train, false, false, true);
    }

    public Map feedForward(boolean train, boolean excludeOutputLayers) {
        return feedForward(train, excludeOutputLayers, false, true);
    }

    /**
     * @param train                            True: training time. False: test time
     * @param excludeOutputLayers              Should we exclude the output layers during forward pass? (usually: false)
     * @param includeNonLayerVertexActivations Include non-layer vertices in the output may?
     * @return Map of activations. Key: vertex name. Value: activations.
     */
    public Map feedForward(boolean train, boolean excludeOutputLayers,
                    boolean includeNonLayerVertexActivations) {
        return feedForward(train, excludeOutputLayers, includeNonLayerVertexActivations, true);
    }

    /**
     * PLEASE NEVER USE THIS METHOD IF YOU"RE NOT SURE WHAT YOU'll GET
     *
     * @param train
     * @param excludeOutputLayers
     * @param includeNonLayerVertexActivations
     * @param publicApi
     * @return
     */
    protected Map feedForward(boolean train, boolean excludeOutputLayers,
                    boolean includeNonLayerVertexActivations, boolean publicApi) {
        Map layerActivations = new HashMap<>();

        MemoryWorkspace workspace = configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE
                        ? new DummyWorkspace()
                        : configuration.getTrainingWorkspaceMode() == WorkspaceMode.SINGLE
                                        ? Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceExternal)
                                        : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
                                                        workspaceConfigurationFeedForward, workspaceFeedForward);

        //Do forward pass according to the topological ordering of the network
        for (int i = 0; i < topologicalOrder.length; i++) {
            GraphVertex current = vertices[topologicalOrder[i]];
            try (MemoryWorkspace ws = workspace.notifyScopeEntered()) {

                if (current.isInputVertex()) {
                    VertexIndices[] inputsTo = current.getOutputVertices();
                    // pushing out copy to parent workspace
                    INDArray input = inputs[current.getVertexIndex()].leverageTo(workspaceExternal);


                    layerActivations.put(current.getVertexName(), input);

                    for (VertexIndices v : inputsTo) {
                        int vIdx = v.getVertexIndex();
                        int vIdxInputNum = v.getVertexEdgeNumber();
                        //This input: the 'vIdxInputNum'th input to vertex 'vIdx'
                        // we're pushing input copies to outer workspace
                        // FIXME: do we REALLY need this dup()?
                        if (Nd4j.getWorkspaceManager().checkIfWorkspaceExists(workspaceExternal)
                                        && Nd4j.getMemoryManager().getCurrentWorkspace() != Nd4j.getWorkspaceManager()
                                                        .getWorkspaceForCurrentThread(
                                                                        ComputationGraph.workspaceExternal)) {
                            try (MemoryWorkspace wsB = Nd4j.getWorkspaceManager()
                                            .getWorkspaceForCurrentThread(workspaceExternal).notifyScopeBorrowed()) {
                                // FIXME: we don't really want detach here
                                vertices[vIdx].setInput(vIdxInputNum, input);
                            }
                        } else {
                            vertices[vIdx].setInput(vIdxInputNum, input);
                        }
                    }

                } else {
                    //Do forward pass:
                    if (excludeOutputLayers && current.isOutputVertex() && current.hasLayer()
                                    && current.getLayer() instanceof IOutputLayer) {
                        //When doing backprop (i.e., excludeOutputLayers = false), we don't need to do full forward pass through output layers too
                        // we only need to ensure the input to the output layers is set properly
                        continue;
                    }
                    // once again, pushing stuff out of this workspace
                    INDArray out;
                    if (publicApi) {
                        out = current.doForward(train).detach();
                    } else {
                        out = current.doForward(train).leverageTo(workspaceExternal);
                    }

                    if (includeNonLayerVertexActivations || current.hasLayer()) {
                        layerActivations.put(current.getVertexName(), out);
                    }

                    //Now, set the inputs for the next vertices:
                    VertexIndices[] outputsTo = current.getOutputVertices();
                    if (outputsTo != null) {
                        for (VertexIndices v : outputsTo) {
                            int vIdx = v.getVertexIndex();
                            int inputNum = v.getVertexEdgeNumber();
                            //This (jth) connection from the output: is the 'inputNum'th input to vertex 'vIdx'
                            if (Nd4j.getWorkspaceManager().checkIfWorkspaceExists(workspaceExternal)
                                            && Nd4j.getMemoryManager().getCurrentWorkspace() != Nd4j
                                                            .getWorkspaceManager().getWorkspaceForCurrentThread(
                                                                            ComputationGraph.workspaceExternal)) {
                                try (MemoryWorkspace wsB = Nd4j.getWorkspaceManager()
                                                .getWorkspaceForCurrentThread(workspaceExternal)
                                                .notifyScopeBorrowed()) {
                                    // FIXME: we don't really want detach here.
                                    vertices[vIdx].setInput(inputNum, out);
                                }
                            } else {
                                vertices[vIdx].setInput(inputNum, out);
                            }
                        }
                    }
                }
            }
        }

        if (!train)
            if (configuration.getTrainingWorkspaceMode() == WorkspaceMode.SEPARATE)
                Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceFeedForward).initializeWorkspace();

        return layerActivations;
    }

    /**
     * Return an array of network outputs (predictions) at test time, given the specified network inputs
     * Network outputs are for output layers only.
     *
     * @param input Inputs to the network
     * @return Output activations (order: same as defined in network configuration)
     */
    public INDArray[] output(INDArray... input) {
        return output(false, input);
    }

    /**
     * A convenience method that returns a single INDArray, instead of an INDArray[].
     * Useful for ComputationGraphs that have only a single output.
     * Otherwise identical to {@link #output(INDArray...)}
     *
     * @param input Inputs to the network
     * @return Output activations array
     */
    public INDArray outputSingle(INDArray... input) {
        return outputSingle(false, input);
    }

    /**
     * Return an array of network outputs (predictions), given the specified network inputs
     * Network outputs are for output layers only.
     *
     * @param train If true: do forward pass at training time; false: do forward pass at test time
     * @param input Inputs to the network
     * @return Output activations (order: same as defined in network configuration)
     */
    public INDArray[] output(boolean train, INDArray... input) {
        WorkspaceMode cMode = configuration.getTrainingWorkspaceMode();
        configuration.setTrainingWorkspaceMode(configuration.getInferenceWorkspaceMode());
        MemoryWorkspace workspace =
                        configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace()
                                        : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
                                                        workspaceConfigurationExternal, workspaceExternal);

        try (MemoryWorkspace wsE = workspace.notifyScopeEntered()) {
            INDArray[] tmp = silentOutput(train, input);
            for (int x = 0; x < tmp.length; x++)
                tmp[x] = tmp[x].detach();

            configuration.setTrainingWorkspaceMode(cMode);
            return tmp;
        }
    }

    protected INDArray[] silentOutput(boolean train, INDArray... input) {
        setInputs(input);
        Map activations = feedForward(false, false, false, false);
        INDArray[] outputs = new INDArray[numOutputArrays];
        int i = 0;
        for (String s : configuration.getNetworkOutputs()) {
            outputs[i++] = activations.get(s);
        }
        return outputs;
    }

    /**
     * A convenience method that returns a single INDArray, instead of an INDArray[].
     * Useful for ComputationGraphs that have only a single output.
     * Otherwise identical to {@link #output(boolean, INDArray...)}
     *
     * @param train If true: do forward pass at training time; false: do forward pass at test time
     * @param input Inputs to the network
     * @return Output activations array
     */
    public INDArray outputSingle(boolean train, INDArray... input) {
        if (numOutputArrays != 1) {
            throw new IllegalStateException(
                            "Cannot use outputSingle with ComputationGraph that does not have exactly 1 output. nOutputs: "
                                            + numOutputArrays);
        }
        return output(train, input)[0];
    }

    /**
     * Calculate the gradient of the network with respect to some external errors.
     * Note that this is typically used for things like reinforcement learning, not typical networks that include
     * an OutputLayer or RnnOutputLayer
     *
     * @param epsilons Epsilons (errors) at the output. Same order with which the output layers are defined in configuration setOutputs(String...)
     * @return Gradient for the network
     */
    public Gradient backpropGradient(INDArray... epsilons) {
        if (epsilons == null || epsilons.length != numOutputArrays)
            throw new IllegalArgumentException(
                            "Invalid input: must have epsilons length equal to number of output arrays");


        calcBackpropGradients(configuration.getBackpropType() == BackpropType.TruncatedBPTT, epsilons);
        return gradient;
    }

    /**
     * Do backprop (gradient calculation)
     *
     * @param truncatedBPTT    false: normal backprop. true: calculate gradients using truncated BPTT for RNN layers
     * @param externalEpsilons null usually (for typical supervised learning). If not null (and length > 0) then assume that
     *                         the user has provided some errors externally, as they would do for example in reinforcement
     *                         learning situations.
     */
    protected void calcBackpropGradients(boolean truncatedBPTT, INDArray... externalEpsilons) {
        if (flattenedGradients == null) {
            initGradientsView();
        }


        MemoryWorkspace workspace =
                        configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace()
                                        : configuration.getTrainingWorkspaceMode() == WorkspaceMode.SINGLE
                                                        ? Nd4j.getWorkspaceManager()
                                                                        .getWorkspaceForCurrentThread(workspaceExternal)
                                                        //: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(wsConf, workspaceBackProp);
                                                        : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
                                                                        workspaceConfigurationFeedForward,
                                                                        workspaceFeedForward);


        LinkedList> gradients = new LinkedList<>();

        //Do backprop according to the reverse of the topological ordering of the network
        boolean[] setVertexEpsilon = new boolean[topologicalOrder.length]; //If true: already set epsilon for this vertex; later epsilons should be *added* to the existing one, not set
        for (int i = topologicalOrder.length - 1; i >= 0; i--) {
            try (MemoryWorkspace ws = workspace.notifyScopeEntered()) {
                GraphVertex current = vertices[topologicalOrder[i]];

                if (current.isInputVertex())
                    continue; //No op
                //FIXME: make the frozen vertex feature extraction more flexible
                if (current.hasLayer() && current.getLayer() instanceof FrozenLayer)
                    break;

                if (current.isOutputVertex()) {
                    //Two reasons for a vertex to be an output vertex:
                    //(a) it's an output layer (i.e., instanceof IOutputLayer), or
                    //(b) it's a normal layer, but it has been marked as an output layer for use in external errors - for reinforcement learning, for example

                    int thisOutputNumber = configuration.getNetworkOutputs().indexOf(current.getVertexName());
                    if (current.getLayer() instanceof IOutputLayer) {
                        IOutputLayer outputLayer = (IOutputLayer) current.getLayer();

                        INDArray currLabels = labels[thisOutputNumber];
                        outputLayer.setLabels(currLabels);
                    } else {
                        if ((externalEpsilons == null || externalEpsilons.length == 0)
                                        && labels[thisOutputNumber] != null) {
                            throw new DL4JException("Layer \"" + current.getVertexName() + "\" of type "
                                            + current.getLayer().getClass().getSimpleName()
                                            + " is set as network output "
                                            + "(but isn't an IOutputLayer). Only IOutputLayer layers can be fit via backprop with"
                                            + " a labels array. ");
                        }
                        current.setEpsilon(externalEpsilons[thisOutputNumber]);
                        setVertexEpsilon[topologicalOrder[i]] = true;
                    }
                }

                Pair pair = current.doBackward(truncatedBPTT);
                INDArray[] epsilons = pair.getSecond();

                for (int x = 0; x < epsilons.length; x++) {
                    if (epsilons[x] == null) {
                        continue;
                    }

                    epsilons[x] = epsilons[x].leverageTo(workspaceExternal);
                }

                //Inputs to the current GraphVertex:
                VertexIndices[] inputVertices = current.getInputVertices();

                //Set epsilons for the vertices that provide inputs to this vertex:
                if (inputVertices != null) {
                    int j = 0;
                    for (VertexIndices v : inputVertices) {
                        GraphVertex gv = vertices[v.getVertexIndex()];
                        if (setVertexEpsilon[gv.getVertexIndex()]) {
                            //This vertex: must output to multiple vertices... we want to add the epsilons here
                            INDArray currentEps = gv.getEpsilon().leverageTo(workspaceExternal);
                            if (configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE) {
                                gv.setEpsilon(currentEps.add(epsilons[j++])); //TODO: in some circumstances, it may be safe  to do in-place add (but not always)
                            } else {
                                try (MemoryWorkspace wsB = Nd4j.getWorkspaceManager()
                                                .getWorkspaceForCurrentThread(workspaceExternal)
                                                .notifyScopeBorrowed()) {
                                    //try (MemoryWorkspace wsB = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                                    gv.setEpsilon(currentEps.add(epsilons[j++]));
                                }
                            }
                        } else {
                            gv.setEpsilon(epsilons[j++]);
                        }
                        setVertexEpsilon[gv.getVertexIndex()] = true;

                    }
                }

                if (pair.getFirst() != null) {
                    Gradient g = pair.getFirst();
                    Map map = g.gradientForVariable();
                    LinkedList> tempList = new LinkedList<>();
                    for (Map.Entry entry : map.entrySet()) {
                        String origName = entry.getKey();
                        String newName = current.getVertexName() + "_" + origName;
                        tempList.addFirst(new Triple<>(newName, entry.getValue(),
                                        g.flatteningOrderForVariable(origName)));
                    }
                    for (Triple t : tempList)
                        gradients.addFirst(t);
                }
            }
        }

        //Now, add the gradients in the order we need them in for flattening (same as params order)
        Gradient gradient = new DefaultGradient(flattenedGradients);
        for (Triple t : gradients) {
            gradient.setGradientFor(t.getFirst(), t.getSecond(), t.getThird());
        }

        if (configuration.getTrainingWorkspaceMode() == WorkspaceMode.SEPARATE)
            Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceFeedForward).initializeWorkspace();

        this.gradient = gradient;
    }

    @Override
    public ComputationGraph clone() {
        ComputationGraph cg = new ComputationGraph(configuration.clone());
        cg.init(params().dup(), false);
        if (solver != null) {
            //If  solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however
            ComputationGraphUpdater u = this.getUpdater();
            INDArray updaterState = u.getStateViewArray();
            if (updaterState != null) {
                cg.getUpdater().setStateViewArray(updaterState.dup());
            }
        }
        cg.listeners = this.listeners;
        for (int i = 0; i < topologicalOrder.length; i++) {
            if (!vertices[topologicalOrder[i]].hasLayer())
                continue;
            String layerName = vertices[topologicalOrder[i]].getVertexName();
            if (getLayer(layerName) instanceof FrozenLayer) {
                cg.getVertex(layerName).setLayerAsFrozen();
            }
        }
        return cg;
    }

    /**
     * Calculate the L2 regularization term for all layers in the entire network. This is the sum of the L2 terms
     * for each layer individually
     */
    public double calcL2() {
        double l2 = 0.0;
        for (Layer l : layers) {
            l2 += l.calcL2(true);
        }
        return l2;
    }

    /**
     * Calculate the L1 regularization term for all layers in the entire network. This is the sum of the L1 terms
     * for each layer individually
     */
    public double calcL1() {
        double l1 = 0.0;
        for (Layer l : layers) {
            l1 += l.calcL1(true);
        }
        return l1;
    }

    /**
     * Set the IterationListeners for the ComputationGraph (and all layers in the network)
     */
    public void setListeners(Collection listeners) {
        this.listeners = listeners;
        if (layers == null)
            init();

        for (Layer l : layers) {
            l.setListeners(listeners);
        }

        if (solver != null) {
            solver.setListeners(listeners);
        }

        this.trainingListeners.clear();
        if (listeners != null) {
            for (IterationListener il : listeners) {
                if (il instanceof TrainingListener) {
                    this.trainingListeners.add((TrainingListener) il);
                }
            }
        }
    }

    /**
     * Set the IterationListeners for the ComputationGraph (and all layers in the network)
     */
    public void setListeners(IterationListener... listeners) {
        List list = new ArrayList<>();
        //Check: user might have done setListeners(null) thinking this would clear the current listeners.
        //This results in an IterationListener[1] with a single null value -> results in a NPE later
        if (listeners != null && listeners.length > 0) {
            for (IterationListener i : listeners) {
                if (i != null)
                    list.add(i);
            }
        }
        setListeners(list);
    }

    /**
     * This method ADDS additional IterationListener to existing listeners
     *
     * @param listener
     */
    @Override
    public void addListeners(IterationListener... listeners) {
        if (this.listeners == null) {
            setListeners(listeners);
            return;
        }

        for (IterationListener listener : listeners) {
            this.listeners.add(listener);
            if (listener instanceof TrainingListener) {
                this.trainingListeners.add((TrainingListener) listener);
            }
        }

        if (solver != null) {
            solver.setListeners(this.listeners);
        }
    }

    /**
     * Get the IterationListeners for the ComputationGraph
     */
    public Collection getListeners() {
        return listeners;
    }

    /**
     * Get the ComputationGraphUpdater for the network
     */
    public ComputationGraphUpdater getUpdater() {
        if (solver == null) {
            solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
            solver.getOptimizer().setUpdaterComputationGraph(new ComputationGraphUpdater(this));
        }
        return solver.getOptimizer().getComputationGraphUpdater();
    }

    /**
     * Set the computationGraphUpdater for the network
     */
    public void setUpdater(ComputationGraphUpdater updater) {
        if (solver == null) {
            solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
        }
        solver.getOptimizer().setUpdaterComputationGraph(updater);
    }

    /**
     * Get the specified output layer, by index. The index of the output
     * layer may be 0 to {@link #getNumOutputArrays()}-1
     */
    public Layer getOutputLayer(int outputLayerIdx) {
        if (outputLayerIdx >= numOutputArrays)
            throw new IllegalArgumentException("Invalid index: cannot get output layer " + outputLayerIdx
                            + ", total number of network outputs = " + numOutputArrays);
        return getLayer(configuration.getNetworkOutputs().get(outputLayerIdx));
    }

    /**
     * Get the parameters for the ComputationGraph
     *
     * @param backwardOnly If true: backprop parameters only (i.e., no visible layer biases used in layerwise pretraining layers)
     */
    public INDArray params(boolean backwardOnly) {
        if (backwardOnly)
            return flattenedParams;

        List list = new ArrayList<>(layers.length);
        for (int i = 0; i < topologicalOrder.length; i++) {
            if (!vertices[topologicalOrder[i]].hasLayer())
                continue;

            Layer l = vertices[topologicalOrder[i]].getLayer();
            INDArray layerParams = l.params();
            if (layerParams != null)
                list.add(layerParams); //may be null: subsampling etc layers
        }

        return Nd4j.toFlattened('f', list);
    }

    /**
     * Sets the input and labels and returns a score for the prediction with respect to the true labels
* This is equivalent to {@link #score(DataSet, boolean)} with training==true.
* NOTE: this version of the score function can only be used with ComputationGraph networks that have * a single input and a single output. * * @param dataSet the data to score * @return the score for the given input,label pairs * @see #score(DataSet, boolean) */ public double score(DataSet dataSet) { return score(dataSet, false); } /** * Sets the input and labels and returns a score for the prediction with respect to the true labels
* NOTE: this version of the score function can only be used with ComputationGraph networks that have * a single input and a single output. Use {@link #score(MultiDataSet, boolean)} for multiple input/output networks * * @param dataSet the data to score * @param training whether score is being calculated at training time (true) or test time (false) * @return the score for the given input,label pairs * @see #score(DataSet, boolean) */ public double score(DataSet dataSet, boolean training) { if (numInputArrays != 1 || numOutputArrays != 1) throw new UnsupportedOperationException("Cannot score ComputationGraph network with " + " DataSet: network does not have 1 input and 1 output arrays"); return score(ComputationGraphUtil.toMultiDataSet(dataSet), training); } /** * Score the network given the MultiDataSet, at test time */ public double score(MultiDataSet dataSet) { return score(dataSet, false); } /** * Sets the input and labels and returns a score for the prediction with respect to the true labels
* * @param dataSet the data to score * @param training whether score is being calculated at training time (true) or test time (false) * @return the score for the given input,label pairs */ public double score(MultiDataSet dataSet, boolean training) { boolean hasMaskArrays = dataSet.hasMaskArrays(); if (hasMaskArrays) { setLayerMaskArrays(dataSet.getFeaturesMaskArrays(), dataSet.getLabelsMaskArrays()); } double score = 0.0; MemoryWorkspace workspace = configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( workspaceConfigurationExternal, workspaceExternal); try (MemoryWorkspace ws = workspace.notifyScopeEntered()) { feedForward(dataSet.getFeatures(), training); INDArray[] labels = dataSet.getLabels(); setLabels(labels); //Score: sum of the scores for the various output layers... double l1 = calcL1(); double l2 = calcL2(); int i = 0; for (String s : configuration.getNetworkOutputs()) { Layer outLayer = verticesMap.get(s).getLayer(); if (outLayer == null || !(outLayer instanceof IOutputLayer)) { log.warn("Cannot calculate score: vertex \"" + s + "\" is not an output layer"); return 0.0; } IOutputLayer ol = (IOutputLayer) outLayer; ol.setLabels(labels[i++]); score += ol.computeScore(l1, l2, training); //Only want to add l1/l2 once... l1 = 0.0; l2 = 0.0; } } if (hasMaskArrays) clearLayerMaskArrays(); return score; } /** * Calculate the score for each example in a DataSet individually. Unlike {@link #score(DataSet)} and {@link #score(DataSet, boolean)} * this method does not average/sum over examples. This method allows for examples to be scored individually (at test time only), which * may be useful for example for autoencoder architectures and the like.
* Each row of the output (assuming addRegularizationTerms == true) is equivalent to calling score(DataSet) with a single example. * * @param data The data to score * @param addRegularizationTerms If true: add l1/l2 regularization terms (if any) to the score. If false: don't add regularization terms * @return An INDArray (column vector) of size input.numRows(); the ith entry is the score (loss value) of the ith example */ public INDArray scoreExamples(DataSet data, boolean addRegularizationTerms) { if (numInputArrays != 1 || numOutputArrays != 1) throw new UnsupportedOperationException("Cannot score ComputationGraph network with " + " DataSet: network does not have 1 input and 1 output arrays"); return scoreExamples(ComputationGraphUtil.toMultiDataSet(data), addRegularizationTerms); } /** * Calculate the score for each example in a DataSet individually. Unlike {@link #score(MultiDataSet)} and {@link #score(MultiDataSet, boolean)} * this method does not average/sum over examples. This method allows for examples to be scored individually (at test time only), which * may be useful for example for autoencoder architectures and the like.
* Each row of the output (assuming addRegularizationTerms == true) is equivalent to calling score(MultiDataSet) with a single example. * * @param data The data to score * @param addRegularizationTerms If true: add l1/l2 regularization terms (if any) to the score. If false: don't add regularization terms * @return An INDArray (column vector) of size input.numRows(); the ith entry is the score (loss value) of the ith example */ public INDArray scoreExamples(MultiDataSet data, boolean addRegularizationTerms) { boolean hasMaskArray = data.hasMaskArrays(); if (hasMaskArray) setLayerMaskArrays(data.getFeaturesMaskArrays(), data.getLabelsMaskArrays()); feedForward(data.getFeatures(), false); setLabels(data.getLabels()); INDArray out = null; double l1 = (addRegularizationTerms ? calcL1() : 0.0); double l2 = (addRegularizationTerms ? calcL2() : 0.0); int i = 0; for (String s : configuration.getNetworkOutputs()) { Layer outLayer = verticesMap.get(s).getLayer(); if (outLayer == null || !(outLayer instanceof IOutputLayer)) { throw new UnsupportedOperationException( "Cannot calculate score: vertex \"" + s + "\" is not an output layer"); } IOutputLayer ol = (IOutputLayer) outLayer; ol.setLabels(labels[i++]); INDArray scoreCurrLayer = ol.computeScoreForExamples(l1, l2); if (out == null) out = scoreCurrLayer; else out.addi(scoreCurrLayer); //Only want to add l1/l2 once... l1 = 0.0; l2 = 0.0; } if (hasMaskArray) clearLayerMaskArrays(); return out; } //------------------------------------------------------ //Model methods: @Override public void fit() { fit(inputs, labels, inputMaskArrays, labelMaskArrays); } @Override public void update(INDArray gradient, String paramType) { throw new UnsupportedOperationException("Not implemented"); } @Override public void update(Gradient gradient) { if (gradient.gradient().length() != numParams(true)) throw new IllegalArgumentException("Invalid input: expect gradients array of length " + numParams(true)); for (Map.Entry entry : gradient.gradientForVariable().entrySet()) { String key = entry.getKey(); INDArray val = entry.getValue(); int idx = key.indexOf('_'); if (idx == -1) throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\""); String layerName = key.substring(0, idx); String paramType = key.split("_")[1]; // Update graph gradient this.gradient.gradientForVariable().put(key, val); // Update layer params getLayer(layerName).update(val, paramType); } // Update layerwise gradient view setBackpropGradientsViewArray(gradient.gradient()); } private void update(Task task) { if (!initDone) { initDone = true; Heartbeat heartbeat = Heartbeat.getInstance(); task = ModelSerializer.taskByModel(this); Environment env = EnvironmentUtils.buildEnvironment(); heartbeat.reportEvent(Event.STANDALONE, env, task); } } @Override public double score() { return score; } public void setScore(double score) { this.score = score; } @Override public void accumulateScore(double accum) { throw new UnsupportedOperationException("Not implemented"); } @Override public INDArray params() { return params(true); } @Override public INDArray updaterState() { return getUpdater() != null ? getUpdater().getUpdaterStateViewArray() : null; } @Override public int numParams() { return numParams(true); } @Override public int numParams(boolean backwards) { int nParams = 0; for (Layer layer : layers) { nParams += layer.numParams(backwards); } return nParams; } @Override public void setParams(INDArray params) { if (params == flattenedParams) return; //No op if (this.flattenedParams != null && this.flattenedParams.length() == params.length()) { this.flattenedParams.assign(params); return; } int idx = 0; for (int i = 0; i < topologicalOrder.length; i++) { if (!vertices[topologicalOrder[i]].hasLayer()) continue; Layer layer = vertices[topologicalOrder[i]].getLayer(); int range = layer.numParams(); if (range <= 0) continue; //Some layers: no parameters (subsampling etc) INDArray get = params.get(NDArrayIndex.point(0), NDArrayIndex.interval(idx, range + idx)); layer.setParams(get); idx += range; } } @Override public void setParamsViewArray(INDArray gradient) { throw new RuntimeException("Not yet implemented"); } @Override public INDArray getGradientsViewArray() { return flattenedGradients; } @Override public void setBackpropGradientsViewArray(INDArray gradient) { int paramsSoFar = 0; for (int i = 0; i < topologicalOrder.length; i++) { if (!vertices[topologicalOrder[i]].hasLayer()) continue; Layer layer = vertices[topologicalOrder[i]].getLayer(); int range = layer.numParams(); if (range <= 0) continue; //Some layers: no parameters (subsampling etc) layer.setBackpropGradientsViewArray(gradient.get(NDArrayIndex.point(0), NDArrayIndex.interval(paramsSoFar, paramsSoFar + range))); paramsSoFar += range; } } @Override public void applyLearningRateScoreDecay() { throw new UnsupportedOperationException("Not implemented"); } @Override public void fit(INDArray data) { throw new UnsupportedOperationException("Cannot pretrain ComputationGraph with single INDArray"); } @Override public void iterate(INDArray input) { throw new UnsupportedOperationException("Not implemented"); } @Override public Gradient gradient() { return gradient; } @Override public Pair gradientAndScore() { return new Pair<>(gradient(), score()); } @Override public int batchSize() { return inputs[0].size(0); } @Override public NeuralNetConfiguration conf() { return defaultConfiguration; } @Override public void setConf(NeuralNetConfiguration conf) { throw new UnsupportedOperationException(); } @Override public INDArray input() { if (numInputArrays == 1) return (inputs != null ? inputs[0] : null); else throw new UnsupportedOperationException( "Cannot return single input: ComputationGraph has multiple inputs"); } @Override public void validateInput() { } @Override public ConvexOptimizer getOptimizer() { return solver.getOptimizer(); } @Override public INDArray getParam(String paramName) { // throw new UnsupportedOperationException("Not implemented"); int idx = paramName.indexOf('_'); if (idx == -1) throw new IllegalStateException("Invalid param key: not have layer separator: \"" + paramName + "\""); String layerName = paramName.substring(0, idx); String paramType = paramName.substring(idx + 1); return getLayer(layerName).getParam(paramType); } @Override public void initParams() { throw new UnsupportedOperationException("Not implemented"); } @Override public Map paramTable() { return paramTable(false); } public Map paramTable(boolean backpropParamsOnly) { //Get all parameters from all layers Map allParams = new LinkedHashMap<>(); for (Layer layer : layers) { Map paramMap = layer.paramTable(backpropParamsOnly); for (Map.Entry entry : paramMap.entrySet()) { String newKey = layer.conf().getLayer().getLayerName() + "_" + entry.getKey(); allParams.put(newKey, entry.getValue()); } } return allParams; } @Override public void setParamTable(Map paramTable) { throw new UnsupportedOperationException("Not implemented"); } @Override public void setParam(String key, INDArray val) { // throw new UnsupportedOperationException("Not implemented"); int idx = key.indexOf('_'); if (idx == -1) throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\""); String layerName = key.substring(0, idx); String paramType = key.substring(idx + 1); getLayer(layerName).setParam(paramType, val); } @Override public void clear() { inputs = null; labels = null; inputMaskArrays = null; labelMaskArrays = null; } //------------------------------------------------------------------------------ //RNN-specific functionality /** * If this ComputationGraph contains one or more RNN layers: conduct forward pass (prediction) * but using previous stored state for any RNN layers. The activations for the final step are * also stored in the RNN layers for use next time rnnTimeStep() is called.
* This method can be used to generate output one or more steps at a time instead of always having to do * forward pass from t=0. Example uses are for streaming data, and for generating samples from network output * one step at a time (where samples are then fed back into the network as input)
* If no previous state is present in RNN layers (i.e., initially or after calling rnnClearPreviousState()), * the default initialization (usually 0) is used.
* Supports mini-batch (i.e., multiple predictions/forward pass in parallel) as well as for single examples.
* * @param inputs Input to network. May be for one or multiple time steps. For single time step: * input has shape [miniBatchSize,inputSize] or [miniBatchSize,inputSize,1]. miniBatchSize=1 for single example.
* For multiple time steps: [miniBatchSize,inputSize,inputTimeSeriesLength] * @return Output activations. If output is RNN layer (such as RnnOutputLayer): if all inputs have shape [miniBatchSize,inputSize] * i.e., is 2d, then outputs have shape [miniBatchSize,outputSize] (i.e., also 2d) instead of [miniBatchSize,outputSize,1].
* Otherwise output is 3d [miniBatchSize,outputSize,inputTimeSeriesLength] when using RnnOutputLayer (or unmodified otherwise). */ public INDArray[] rnnTimeStep(INDArray... inputs) { this.inputs = inputs; //Idea: if 2d in, want 2d out boolean inputIs2d = true; for (INDArray i : inputs) { if (i.rank() != 2) { inputIs2d = false; break; } } INDArray[] outputs = new INDArray[this.numOutputArrays]; //Based on: feedForward() for (int currVertexIdx : topologicalOrder) { GraphVertex current = vertices[currVertexIdx]; if (current.isInputVertex()) { VertexIndices[] inputsTo = current.getOutputVertices(); INDArray input = inputs[current.getVertexIndex()]; for (VertexIndices v : inputsTo) { int vIdx = v.getVertexIndex(); int vIdxInputNum = v.getVertexEdgeNumber(); //This input: the 'vIdxInputNum'th input to vertex 'vIdx' vertices[vIdx].setInput(vIdxInputNum, input.dup()); //TODO When to dup? } } else { INDArray out; if (current.hasLayer()) { //Layer Layer l = current.getLayer(); if (l instanceof RecurrentLayer) { out = ((RecurrentLayer) l).rnnTimeStep(current.getInputs()[0]); } else if (l instanceof MultiLayerNetwork) { out = ((MultiLayerNetwork) l).rnnTimeStep(current.getInputs()[0]); } else { //non-recurrent layer out = current.doForward(false); } } else { //GraphNode out = current.doForward(false); } if (current.isOutputVertex()) { //Get the index of this output vertex... int idx = configuration.getNetworkOutputs().indexOf(current.getVertexName()); outputs[idx] = out; } //Now, set the inputs for the next vertices: VertexIndices[] outputsTo = current.getOutputVertices(); if (outputsTo != null) { for (VertexIndices v : outputsTo) { int vIdx = v.getVertexIndex(); int inputNum = v.getVertexEdgeNumber(); //This (jth) connection from the output: is the 'inputNum'th input to vertex 'vIdx' vertices[vIdx].setInput(inputNum, out); } } } } //As per MultiLayerNetwork.rnnTimeStep(): if inputs are all 2d, then outputs are all 2d if (inputIs2d) { for (int i = 0; i < outputs.length; i++) { if (outputs[i].rank() == 3 && outputs[i].size(2) == 1) { //Return 2d output with shape [miniBatchSize,nOut] // instead of 3d output with shape [miniBatchSize,nOut,1] outputs[i] = outputs[i].tensorAlongDimension(0, 1, 0); } } } this.inputs = null; return outputs; } /** * Get the state of the RNN layer, as used in {@link #rnnTimeStep(INDArray...)}. * * @param layer Number/index of the layer. * @return Hidden state, or null if layer is not an RNN layer */ public Map rnnGetPreviousState(int layer) { return rnnGetPreviousState(layers[layer].conf().getLayer().getLayerName()); } /** * Get the state of the RNN layer, as used in {@link #rnnTimeStep(INDArray...)}. * * @param layerName name of the layer * @return Hidden state, or null if layer is not an RNN layer */ public Map rnnGetPreviousState(String layerName) { Layer l = verticesMap.get(layerName).getLayer(); if (l == null || !(l instanceof RecurrentLayer)) return null; return ((RecurrentLayer) l).rnnGetPreviousState(); } /** * Get a map of states for ALL RNN layers, as used in {@link #rnnTimeStep(INDArray...)}. * Layers that are not RNN layers will not have an entry in the returned map * * @return Map of states (keyed by layer name) or null if layer is not an RNN layer * @see #rnnSetPreviousStates(Map) */ public Map> rnnGetPreviousStates() { Map> states = new HashMap<>(); for (Layer l : layers) { if (l instanceof RecurrentLayer) { states.put(l.conf().getLayer().getLayerName(), ((RecurrentLayer) l).rnnGetPreviousState()); } } return states; } /** * Set the state of the RNN layer, for use in {@link #rnnTimeStep(INDArray...)} * * @param layer The number/index of the layer. * @param state The state to set the specified layer to */ public void rnnSetPreviousState(int layer, Map state) { rnnSetPreviousState(layers[layer].conf().getLayer().getLayerName(), state); } /** * Set the state of the RNN layer, for use in {@link #rnnTimeStep(INDArray...)} * * @param layerName The name of the layer. * @param state The state to set the specified layer to */ public void rnnSetPreviousState(String layerName, Map state) { Layer l = verticesMap.get(layerName).getLayer(); if (l == null || !(l instanceof RecurrentLayer)) { throw new UnsupportedOperationException( "Layer \"" + layerName + "\" is not a recurrent layer. Cannot set state"); } ((RecurrentLayer) l).rnnSetPreviousState(state); } /** * Set the states for all RNN layers, for use in {@link #rnnTimeStep(INDArray...)} * * @param previousStates The previous time step states for all layers (key: layer name. Value: layer states) * @see #rnnGetPreviousStates() */ public void rnnSetPreviousStates(Map> previousStates) { for (Map.Entry> entry : previousStates.entrySet()) { rnnSetPreviousState(entry.getKey(), entry.getValue()); } } /** * Clear the previous state of the RNN layers (if any), used in {@link #rnnTimeStep(INDArray...)} */ public void rnnClearPreviousState() { if (layers == null) return; for (Layer layer : layers) { if (layer instanceof RecurrentLayer) ((RecurrentLayer) layer).rnnClearPreviousState(); else if (layer instanceof MultiLayerNetwork) { ((MultiLayerNetwork) layer).rnnClearPreviousState(); } } } /** * Fit the network using truncated BPTT */ protected void doTruncatedBPTT(INDArray[] inputs, INDArray[] labels, INDArray[] featureMasks, INDArray[] labelMasks) { if (flattenedGradients == null) { initGradientsView(); } //Approach used here to implement truncated BPTT: if input is 3d, split it. Otherwise: input is unmodified int timeSeriesLength = -1; for (INDArray in : inputs) { if (in.rank() != 3) continue; if (timeSeriesLength == -1) timeSeriesLength = in.size(2); else if (timeSeriesLength != in.size(2)) { log.warn("Cannot do TBPTT with time series of different lengths"); return; } } for (INDArray out : labels) { if (out.rank() != 3) continue; if (timeSeriesLength == -1) timeSeriesLength = out.size(2); else if (timeSeriesLength != out.size(2)) { log.warn("Cannot do TBPTT with time series of different lengths"); return; } } int fwdLen = configuration.getTbpttFwdLength(); int nSubsets = timeSeriesLength / fwdLen; if (timeSeriesLength % fwdLen != 0) nSubsets++; rnnClearPreviousState(); INDArray[] newInputs = new INDArray[inputs.length]; INDArray[] newLabels = new INDArray[labels.length]; INDArray[] newFeatureMasks = (featureMasks != null ? new INDArray[featureMasks.length] : null); INDArray[] newLabelMasks = (labelMasks != null ? new INDArray[labelMasks.length] : null); workspaceConfigurationExternal.setCyclesBeforeInitialization(0); workspaceConfigurationExternal.setPolicyLearning(LearningPolicy.OVER_TIME); MemoryWorkspace workspaceT = configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( workspaceConfigurationTBPTT, workspaceTBPTT); MemoryWorkspace workspace = configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( workspaceConfigurationExternal, workspaceExternal); try (MemoryWorkspace wsT = workspaceT.notifyScopeEntered()) { for (int i = 0; i < nSubsets; i++) { try (MemoryWorkspace wsE = workspace.notifyScopeEntered()) { int startTimeIdx = i * fwdLen; int endTimeIdx = startTimeIdx + fwdLen; if (endTimeIdx > timeSeriesLength) endTimeIdx = timeSeriesLength; for (int j = 0; j < inputs.length; j++) { if (inputs[j].rank() != 3) newInputs[j] = inputs[j]; else { newInputs[j] = inputs[j].get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx)); } } for (int j = 0; j < labels.length; j++) { if (labels[j].rank() != 3) newLabels[j] = labels[j]; else { newLabels[j] = labels[j].get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx)); } } if (featureMasks != null) { for (int j = 0; j < featureMasks.length; j++) { if (featureMasks[j] == null) continue; newFeatureMasks[j] = featureMasks[j].get(NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx)); } } if (labelMasks != null) { for (int j = 0; j < labelMasks.length; j++) { if (labelMasks[j] == null) continue; newLabelMasks[j] = labelMasks[j].get(NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx)); } } setInputs(newInputs); setLabels(newLabels); setLayerMaskArrays(newFeatureMasks, newLabelMasks); if (solver == null) { try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this) .build(); } } solver.optimize(); //Finally, update the state of the RNN layers: rnnUpdateStateWithTBPTTState(); } } } if (configuration.getTrainingWorkspaceMode() != WorkspaceMode.NONE) { workspace.initializeWorkspace(); workspaceT.initializeWorkspace(); } rnnClearPreviousState(); if (featureMasks != null || labelMasks != null) { clearLayerMaskArrays(); } } /** * Similar to rnnTimeStep and feedForward() methods. Difference here is that this method:
* (a) like rnnTimeStep does forward pass using stored state for RNN layers, and
* (b) unlike rnnTimeStep does not modify the RNN layer state
* Therefore multiple calls to this method with the same input should have the same output.
* Typically used during training only. Use rnnTimeStep for prediction/forward pass at test time. * * @param inputs Input to network * @param training Whether training or not * @param storeLastForTBPTT set to true if used as part of truncated BPTT training * @return Activations for each layer (including input, as per feedforward() etc) */ public Map rnnActivateUsingStoredState(INDArray[] inputs, boolean training, boolean storeLastForTBPTT) { Map layerActivations = new HashMap<>(); //Do forward pass according to the topological ordering of the network for (int currVertexIdx : topologicalOrder) { GraphVertex current = vertices[currVertexIdx]; if (current.isInputVertex()) { VertexIndices[] inputsTo = current.getOutputVertices(); INDArray input = inputs[current.getVertexIndex()]; layerActivations.put(current.getVertexName(), input); for (VertexIndices v : inputsTo) { int vIdx = v.getVertexIndex(); int vIdxInputNum = v.getVertexEdgeNumber(); //This input: the 'vIdxInputNum'th input to vertex 'vIdx' vertices[vIdx].setInput(vIdxInputNum, input.dup()); //TODO When to dup? } } else { INDArray out; if (current.hasLayer()) { Layer l = current.getLayer(); if (l instanceof RecurrentLayer) { out = ((RecurrentLayer) l).rnnActivateUsingStoredState(current.getInputs()[0], training, storeLastForTBPTT); } else if (l instanceof MultiLayerNetwork) { List temp = ((MultiLayerNetwork) l).rnnActivateUsingStoredState( current.getInputs()[0], training, storeLastForTBPTT); out = temp.get(temp.size() - 1); } else { //non-recurrent layer out = current.doForward(training); } layerActivations.put(current.getVertexName(), out); } else { out = current.doForward(training); } //Now, set the inputs for the next vertices: VertexIndices[] outputsTo = current.getOutputVertices(); if (outputsTo != null) { for (VertexIndices v : outputsTo) { int vIdx = v.getVertexIndex(); int inputNum = v.getVertexEdgeNumber(); //This (jth) connection from the output: is the 'inputNum'th input to vertex 'vIdx' vertices[vIdx].setInput(inputNum, out); } } } } return layerActivations; } /** * Set the mask arrays for features and labels. Mask arrays are typically used in situations such as one-to-many * and many-to-one learning with recurrent neural networks, as well as for supporting time series of varying lengths * within the same minibatch.
* For example, with RNN data sets with input of shape [miniBatchSize,nIn,timeSeriesLength] and outputs of shape * [miniBatchSize,nOut,timeSeriesLength], the features and mask arrays will have shape [miniBatchSize,timeSeriesLength] * and contain values 0 or 1 at each element (to specify whether a given input/example is present - or merely padding - * at a given time step).
* NOTE: This method is not usually used directly. Instead, the various feedForward and fit methods handle setting * of masking internally. * * @param featureMaskArrays Mask array for features (input) * @param labelMaskArrays Mask array for labels (output) * @see #clearLayerMaskArrays() */ public void setLayerMaskArrays(INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) { this.clearLayerMaskArrays(); this.inputMaskArrays = featureMaskArrays; this.labelMaskArrays = labelMaskArrays; if (featureMaskArrays != null) { if (featureMaskArrays.length != numInputArrays) { throw new IllegalArgumentException("Invalid number of feature mask arrays"); } int minibatchSize = -1; for (INDArray i : featureMaskArrays) { if (i != null) { minibatchSize = i.size(0); } } //Here: need to do forward pass through the network according to the topological ordering of the network Map> map = new HashMap<>(); for (int i = 0; i < topologicalOrder.length; i++) { GraphVertex current = vertices[topologicalOrder[i]]; if (current.isInputVertex()) { INDArray fMask = featureMaskArrays[current.getVertexIndex()]; map.put(current.getVertexIndex(), new Pair<>(fMask, MaskState.Active)); } else { VertexIndices[] inputVertices = current.getInputVertices(); //Now: work out the mask arrays to feed forward... INDArray[] inputMasks = null; //new INDArray[inputVertices.length]; MaskState maskState = null; for (int j = 0; j < inputVertices.length; j++) { Pair p = map.get(inputVertices[j].getVertexIndex()); if (p != null) { if (inputMasks == null) { inputMasks = new INDArray[inputVertices.length]; } inputMasks[j] = p.getFirst(); if (maskState == null || maskState == MaskState.Passthrough) { maskState = p.getSecond(); } } } Pair outPair = current.feedForwardMaskArrays(inputMasks, maskState, minibatchSize); map.put(topologicalOrder[i], outPair); } } } if (labelMaskArrays != null) { if (labelMaskArrays.length != numOutputArrays) { throw new IllegalArgumentException("Invalid number of label mask arrays"); } for (int i = 0; i < labelMaskArrays.length; i++) { if (labelMaskArrays[i] == null) { // This output doesn't have a mask, we can skip it. continue; } String outputName = configuration.getNetworkOutputs().get(i); GraphVertex v = verticesMap.get(outputName); Layer ol = v.getLayer(); ol.setMaskArray(labelMaskArrays[i]); } } } /** * Remove the mask arrays from all layers.
* See {@link #setLayerMaskArrays(INDArray[], INDArray[])} for details on mask arrays. */ public void clearLayerMaskArrays() { for (Layer layer : layers) { layer.setMaskArray(null); } this.inputMaskArrays = null; this.labelMaskArrays = null; } /** * Update the internal state of RNN layers after a truncated BPTT fit call */ protected void rnnUpdateStateWithTBPTTState() { for (int i = 0; i < layers.length; i++) { if (layers[i] instanceof RecurrentLayer) { RecurrentLayer l = ((RecurrentLayer) layers[i]); l.rnnSetPreviousState(l.rnnGetTBPTTState()); } else if (layers[i] instanceof MultiLayerNetwork) { ((MultiLayerNetwork) layers[i]).updateRnnStateWithTBPTTState(); } } } /** * Evaluate the network (classification performance - single output ComputationGraphs only) * * @param iterator Iterator to evaluate on * @return Evaluation object; results of evaluation on all examples in the data set */ public Evaluation evaluate(DataSetIterator iterator) { return evaluate(iterator, null); } /** * Evaluate the network (classification performance - single output ComputationGraphs only) * * @param iterator Iterator to evaluate on * @return Evaluation object; results of evaluation on all examples in the data set */ public Evaluation evaluate(MultiDataSetIterator iterator) { return evaluate(iterator, null); } /** * Evaluate the network on the provided data set (single output ComputationGraphs only). Used for evaluating * the performance of classifiers * * @param iterator Data to undertake evaluation on * @return Evaluation object, summarizing the results of the evaluation on the provided DataSetIterator */ public Evaluation evaluate(DataSetIterator iterator, List labelsList) { return evaluate(iterator, labelsList, 1); } /** * Evaluate the network on the provided data set (single output ComputationGraphs only). Used for evaluating * the performance of classifiers * * @param iterator Data to undertake evaluation on * @return Evaluation object, summarizing the results of the evaluation on the provided DataSetIterator */ public Evaluation evaluate(MultiDataSetIterator iterator, List labelsList) { return evaluate(iterator, labelsList, 1); } /** * Evaluate the network (for classification) on the provided data set, with top N accuracy in addition to standard accuracy. * For 'standard' accuracy evaluation only, use topN = 1 * * @param iterator Iterator (data) to evaluate on * @param labelsList List of labels. May be null. * @param topN N value for top N accuracy evaluation * @return Evaluation object, summarizing the results of the evaluation on the provided DataSetIterator */ public Evaluation evaluate(DataSetIterator iterator, List labelsList, int topN) { if (labelsList == null) labelsList = iterator.getLabels(); return doEvaluation(iterator, new Evaluation(labelsList, topN))[0]; } /** * Evaluate the network (for classification) on the provided data set, with top N accuracy in addition to standard accuracy. * For 'standard' accuracy evaluation only, use topN = 1 * * @param iterator Iterator (data) to evaluate on * @param labelsList List of labels. May be null. * @param topN N value for top N accuracy evaluation * @return Evaluation object, summarizing the results of the evaluation on the provided DataSetIterator */ public Evaluation evaluate(MultiDataSetIterator iterator, List labelsList, int topN) { return doEvaluation(iterator, new Evaluation(labelsList, topN))[0]; } /** * Evaluate the (single output layer only) network for regression performance * @param iterator Data to evaluate on * @return Regression evaluation */ public RegressionEvaluation evaluateRegression(DataSetIterator iterator) { return evaluateRegression(iterator, null); } /** * Evaluate the (single output layer only) network for regression performance * @param iterator Data to evaluate on * @return Regression evaluation */ public RegressionEvaluation evaluateRegression(MultiDataSetIterator iterator) { return evaluateRegression(iterator, null); } /** * Evaluate the (single output layer only) network for regression performance * @param iterator Data to evaluate on * @param columnNames Column names for the regression evaluation. May be null. * @return Regression evaluation */ public RegressionEvaluation evaluateRegression(DataSetIterator iterator, List columnNames) { return doEvaluation(iterator, new RegressionEvaluation(columnNames))[0]; } /** * Evaluate the (single output layer only) network for regression performance * @param iterator Data to evaluate on * @return Regression evaluation */ public RegressionEvaluation evaluateRegression(MultiDataSetIterator iterator, List columnNames) { return doEvaluation(iterator, new RegressionEvaluation(columnNames))[0]; } /** * Evaluate the network (must be a binary classifier) on the specified data, using the {@link ROC} class * * @param iterator Data to evaluate on * @param rocThresholdSteps Number of threshold steps to use with {@link ROC} * @return ROC evaluation on the given dataset */ public ROC evaluateROC(DataSetIterator iterator, int rocThresholdSteps) { return doEvaluation(iterator, new ROC(rocThresholdSteps))[0]; } /** * Evaluate the network (must be a binary classifier) on the specified data, using the {@link ROC} class * * @param iterator Data to evaluate on * @param rocThresholdSteps Number of threshold steps to use with {@link ROC} * @return ROC evaluation on the given dataset */ public ROC evaluateROC(MultiDataSetIterator iterator, int rocThresholdSteps) { return doEvaluation(iterator, new ROC(rocThresholdSteps))[0]; } /** * Evaluate the network on the specified data, using the {@link ROCMultiClass} class * * @param iterator Data to evaluate on * @param rocThresholdSteps Number of threshold steps to use with {@link ROCMultiClass} * @return Multi-class ROC evaluation on the given dataset */ public ROCMultiClass evaluateROCMultiClass(DataSetIterator iterator, int rocThresholdSteps) { return doEvaluation(iterator, new ROCMultiClass(rocThresholdSteps))[0]; } /** * Evaluate the network on the specified data, using the {@link ROCMultiClass} class * * @param iterator Data to evaluate on * @param rocThresholdSteps Number of threshold steps to use with {@link ROCMultiClass} * @return Multi-class ROC evaluation on the given dataset */ public ROCMultiClass evaluateROCMultiClass(MultiDataSetIterator iterator, int rocThresholdSteps) { return doEvaluation(iterator, new ROCMultiClass(rocThresholdSteps))[0]; } /** * Perform evaluation on the given data (DataSetIterator) with the given {@link IEvaluation} instance * * @param iterator Test data to evaluate on * @param evaluation IEvaluation insntance * @param Type of the IEvaluation instance * @return The input IEvaluation instance, after performing evaluation on the test data */ public T[] doEvaluation(DataSetIterator iterator, T... evaluations) { if (layers == null || !(getOutputLayer(0) instanceof IOutputLayer)) { throw new IllegalStateException("Cannot evaluate network with no output layer"); } if (getNumOutputArrays() != 1) { throw new IllegalStateException("Cannot evaluate a model with > 1 output arrays from a DataSetIterator"); } if (iterator.resetSupported() && !iterator.hasNext()) iterator.reset(); DataSetIterator iter = iterator.asyncSupported() ? new AsyncDataSetIterator(iterator, 2, true) : iterator; WorkspaceMode cMode = configuration.getTrainingWorkspaceMode(); configuration.setTrainingWorkspaceMode(configuration.getInferenceWorkspaceMode()); MemoryWorkspace workspace = configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( workspaceConfigurationExternal, workspaceExternal); while (iter.hasNext()) { DataSet next = iter.next(); if (next.getFeatures() == null || next.getLabels() == null) break; try (MemoryWorkspace wsB = workspace.notifyScopeEntered()) { //Assuming single output here INDArray features = next.getFeatures(); INDArray featuresMask = next.getFeaturesMaskArray(); INDArray labels = next.getLabels(); INDArray labelMask = next.getLabelsMaskArray(); setLayerMaskArrays(featuresMask == null ? null : new INDArray[] {featuresMask}, labelMask == null ? null : new INDArray[] {labelMask}); INDArray[] out = silentOutput(false, features); for (T evaluation : evaluations) evaluation.eval(labels, out[0], labelMask); } clearLayerMaskArrays(); } if (iterator.asyncSupported()) ((AsyncDataSetIterator) iter).shutdown(); configuration.setTrainingWorkspaceMode(cMode); return evaluations; } /** * Perform evaluation on the given data (MultiDataSetIterator) with the given {@link IEvaluation} instance * * @param iterator Test data to evaluate on * @param evaluations IEvaluation insntance * @param Type of the IEvaluation instance * @return The input IEvaluation instance, after performing evaluation on the test data */ public T[] doEvaluation(MultiDataSetIterator iterator, T... evaluations) { if (layers == null || !(getOutputLayer(0) instanceof IOutputLayer)) { throw new IllegalStateException("Cannot evaluate network with no output layer"); } if (getNumOutputArrays() != 1) { throw new IllegalStateException("Cannot evaluate a model using this method with > 1 output arrays"); } if (iterator.resetSupported() && !iterator.hasNext()) iterator.reset(); MultiDataSetIterator iter = iterator.asyncSupported() ? new AsyncMultiDataSetIterator(iterator, 2, true) : iterator; WorkspaceMode cMode = configuration.getTrainingWorkspaceMode(); configuration.setTrainingWorkspaceMode(configuration.getInferenceWorkspaceMode()); MemoryWorkspace workspace = configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( workspaceConfigurationExternal, workspaceExternal); while (iter.hasNext()) { MultiDataSet next = iter.next(); if (next.getFeatures() == null || next.getLabels() == null) break; try (MemoryWorkspace wsB = workspace.notifyScopeEntered()) { //Assuming single output here INDArray[] features = next.getFeatures(); INDArray[] featuresMasks = next.getFeaturesMaskArrays(); INDArray labels = next.getLabels(0); INDArray[] labelMasks = next.getLabelsMaskArrays(); INDArray labelMask = next.getLabelsMaskArray(0); setLayerMaskArrays(featuresMasks, labelMasks); INDArray[] out = silentOutput(false, features); try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { for (T evaluation : evaluations) evaluation.eval(labels, out[0], labelMask); } } clearLayerMaskArrays(); } if (iterator.asyncSupported()) ((AsyncMultiDataSetIterator) iter).shutdown(); configuration.setTrainingWorkspaceMode(cMode); return evaluations; } /** * String detailing the architecture of the computation graph. * Vertices are printed in a topological sort order. * Columns are Vertex Names with layer/vertex type, nIn, nOut, Total number of parameters and the Shapes of the parameters * And the inputs to the vertex * Will also give information about frozen layers/vertices, if any. * @return Summary as a string */ public String summary() { String ret = "\n"; ret += StringUtils.repeat("=", 140); ret += "\n"; ret += String.format("%-40s%-15s%-15s%-30s %s\n", "VertexName (VertexType)", "nIn,nOut", "TotalParams", "ParamsShape", "Vertex Inputs"); ret += StringUtils.repeat("=", 140); ret += "\n"; int frozenParams = 0; for (int currVertexIdx : topologicalOrder) { GraphVertex current = vertices[currVertexIdx]; String name = current.getVertexName(); String[] classNameArr = current.getClass().toString().split("\\."); String className = classNameArr[classNameArr.length - 1]; String connections = "-"; if (!current.isInputVertex()) { connections = configuration.getVertexInputs().get(name).toString(); } String paramCount = "-"; String in = "-"; String out = "-"; String paramShape = "-"; if (current.hasLayer()) { Layer currentLayer = ((LayerVertex) current).getLayer(); classNameArr = currentLayer.getClass().getName().split("\\."); className = classNameArr[classNameArr.length - 1]; paramCount = String.valueOf(currentLayer.numParams()); if (currentLayer.numParams() > 0) { paramShape = ""; in = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNIn()); out = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNOut()); Set paraNames = currentLayer.conf().getLearningRateByParam().keySet(); for (String aP : paraNames) { String paramS = ArrayUtils.toString(currentLayer.paramTable().get(aP).shape()); paramShape += aP + ":" + paramS + ", "; } paramShape = paramShape.subSequence(0, paramShape.lastIndexOf(",")).toString(); } if (currentLayer instanceof FrozenLayer) { frozenParams += currentLayer.numParams(); classNameArr = ((FrozenLayer) currentLayer).getInsideLayer().getClass().getName().split("\\."); className = "Frozen " + classNameArr[classNameArr.length - 1]; } } ret += String.format("%-40s%-15s%-15s%-30s %s", name + " (" + className + ")", in + "," + out, paramCount, paramShape, connections); ret += "\n"; } ret += StringUtils.repeat("-", 140); ret += String.format("\n%30s %d", "Total Parameters: ", params().length()); ret += String.format("\n%30s %d", "Trainable Parameters: ", params().length() - frozenParams); ret += String.format("\n%30s %d", "Frozen Parameters: ", frozenParams); ret += "\n"; ret += StringUtils.repeat("=", 140); ret += "\n"; return ret; } /** * This method just makes sure there's no state preserved within layers */ protected void clearLayersStates() { for (int f = 0; f < layers.length; f++) { layers[f].setInput(null); } for (int f = 0; f < vertices.length; f++) { vertices[f].clearVertex(); } } /** * Indicates whether some other object is "equal to" this one. *

* The {@code equals} method implements an equivalence relation * on non-null object references: *

    *
  • It is reflexive: for any non-null reference value * {@code x}, {@code x.equals(x)} should return * {@code true}. *
  • It is symmetric: for any non-null reference values * {@code x} and {@code y}, {@code x.equals(y)} * should return {@code true} if and only if * {@code y.equals(x)} returns {@code true}. *
  • It is transitive: for any non-null reference values * {@code x}, {@code y}, and {@code z}, if * {@code x.equals(y)} returns {@code true} and * {@code y.equals(z)} returns {@code true}, then * {@code x.equals(z)} should return {@code true}. *
  • It is consistent: for any non-null reference values * {@code x} and {@code y}, multiple invocations of * {@code x.equals(y)} consistently return {@code true} * or consistently return {@code false}, provided no * information used in {@code equals} comparisons on the * objects is modified. *
  • For any non-null reference value {@code x}, * {@code x.equals(null)} should return {@code false}. *
*

* The {@code equals} method for class {@code Object} implements * the most discriminating possible equivalence relation on objects; * that is, for any non-null reference values {@code x} and * {@code y}, this method returns {@code true} if and only * if {@code x} and {@code y} refer to the same object * ({@code x == y} has the value {@code true}). *

* Note that it is generally necessary to override the {@code hashCode} * method whenever this method is overridden, so as to maintain the * general contract for the {@code hashCode} method, which states * that equal objects must have equal hash codes. * * @param obj the reference object with which to compare. * @return {@code true} if this object is the same as the obj * argument; {@code false} otherwise. * @see #hashCode() * @see HashMap */ @Override public boolean equals(Object obj) { if (obj == null) return false; if (obj instanceof ComputationGraph) { ComputationGraph network = (ComputationGraph) obj; boolean paramsEquals = network.params().equals(params()); boolean confEquals = getConfiguration().equals(network.getConfiguration()); boolean updaterEquals = getUpdater().equals(network.getUpdater()); return paramsEquals && confEquals && updaterEquals; } return false; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy