All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.samediff.SameDiff Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.autodiff.samediff;

import com.google.flatbuffers.FlatBufferBuilder;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.*;
import org.nd4j.autodiff.listeners.impl.HistoryListener;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.api.OutAndGrad;
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
import org.nd4j.autodiff.samediff.array.ThreadSafeArrayHolder;
import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
import org.nd4j.autodiff.samediff.config.EvaluationConfig;
import org.nd4j.autodiff.samediff.config.FitConfig;
import org.nd4j.autodiff.samediff.config.OutputConfig;
import org.nd4j.autodiff.samediff.internal.*;
import org.nd4j.autodiff.samediff.ops.*;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.common.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.graph.*;
import org.nd4j.imports.VariableUtils;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.*;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.common.primitives.AtomicBoolean;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.common.util.ND4JFileUtils;
import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Sets;
import org.nd4j.shade.guava.collect.Table;
import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.tensorflow.framework.GraphDef;

import java.io.*;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static org.nd4j.autodiff.util.SameDiffUtils.stackOutputs;
import static org.nd4j.imports.VariableUtils.stripVarSuffix;

@Slf4j
public class SameDiff extends SDBaseOps {
    protected static final String GRAD_FN_KEY = "grad";

    //Fields for graph structure and execution
    @Getter
    private final Map variables = new LinkedHashMap<>();         //Use linked hash map to guarantee iteration order based on order they were added. Used in inputs() and flatbuffers serde
    @Getter
    private final Map ops = new LinkedHashMap<>();
    @Getter
    private final Map sessions = new ConcurrentHashMap<>();      //Key: thread ID

    private ArrayHolder constantArrays = new ThreadSafeArrayHolder(true);
    private ArrayHolder variablesArrays = new ThreadSafeArrayHolder(true);
    private final Map> placeholdersPerThread = new ConcurrentHashMap<>(); //Placeholders for each thread - if the user sets them

    private final List lossVariables = new ArrayList<>();

    private final List listeners = new ArrayList<>();

    private final List nameScopes = new ArrayList<>();  //Used as a stack

    private List outputs;       //Names of the output variables, set by the user.

    ///////////////////////////////////////
    //Fields related to training
    @Getter
    private TrainingConfig trainingConfig;                          //Configuration for training. Must be set for training/evaluation, but not for other operations
    @Getter
    private boolean initializedTraining;                            //True if training setup has been done
    @Getter
    private Map updaterMap;                 //GradientUpdater instance for each trainable parameter

    ////////////////////////////////////////

//    private DifferentialFunctionFactory functionFactory;

    // counter for auto-naming variables
    private int variableId = 0;

    ////////////////////////////////////////

    /**
     * Op creator object for math operations
     */
    public final SDMath math = new SDMath(this);
    /**
     * Op creator object for random number generation operations
     */
    public final SDRandom random = new SDRandom(this);
    /**
     * Op creator object for general neural network operations
     */
    public final SDNN nn = new SDNN(this);
    /**
     * Op creator object for convolutional neural network operations
     */
    public final SDCNN cnn = new SDCNN(this);
    /**
     * Op creator object for recurrent neural network operations
     */
    public final SDRNN rnn = new SDRNN(this);
    /**
     * Op creator object for loss function operations
     */
    public final SDLoss loss = new SDLoss(this);
    /**
     * Op creator object for image operations
     */
    public final SDImage image = new SDImage(this);

    /**
     * Op creator object for bitwise operations
     */
    public final SDBitwise bitwise = new SDBitwise(this);

    /**
     * Op creator object for linalg operations
     */
    public final SDLinalg linalg = new SDLinalg(this);

    /**
     * Op creator object for math operations
     */
    public SDMath math() {
        return math;
    }

    /**
     * Op creator object for random number generation operations
     */
    public SDRandom random() {
        return random;
    }

    /**
     * Op creator object for general neural network operations
     */
    public SDNN nn() {
        return nn;
    }

    /**
     * Op creator object for convolutional neural network operations
     */
    public SDCNN cnn() {
        return cnn;
    }

    /**
     * Op creator object for recurrent neural network operations
     */
    public SDRNN rnn() {
        return rnn;
    }

    /**
     * Op creator object for loss function operations
     */
    public SDLoss loss() {
        return loss;
    }

    /**
     * Op creator object for image operations
     */
    public SDImage image() {
        return image;
    }

    /**
     * Op creator object for bitwise operations
     */
    public SDBitwise bitwise(){
        return bitwise;
    }

    /**
     * Op creator object for linalg operations
     */
    public SDLinalg linalg(){
        return linalg;
    }

    private Map sameDiffFunctionInstances;

    private Table fieldVariableResolutionMapping;

    // flag, shows if graph was already registered with libnd4j
    private transient AtomicBoolean wasRegistered = new AtomicBoolean(false);


    //debug mode variables
    @Getter
    private boolean debugMode;

    @Getter
    private Stack argumentInterceptors = new Stack<>();
    @Getter
    private Set pausedArgumentInterceptors = new HashSet<>();

    private Set blockNames = new HashSet<>();

    @Getter
    @Setter
    boolean logExecution = true;

    @Getter
    private SameDiff parent;

    @Getter
    private SameDiff child;


    /**
     * Clears debugging state and disables debug mode.
     */
    public SameDiff disableDebugging() {
        debugMode = false;
        return this;
    }

    /**
     * Enables tracing of graphs automatically.
     */
    public SameDiff enableDebugMode() {
        debugMode = true;
        return this;
    }

    /**
     * Set the current SameDiff-wide {@link Listener} instances.
     *
     * Note that this will overwrite the current listener list.
     * If you want to use additional listeners for a single operation,
     * use the listener arguments in those methods (e.g. {@link #fit()} and {@link FitConfig#listeners(Listener...)}).
     *
     * @param listeners Listeners
     */
    public void setListeners(Listener... listeners) {
        this.listeners.clear();
        addListeners(listeners);
    }

    /**
     * See {@link #setListeners(Listener...)}.
     */
    public void setListeners(Collection listeners) {
        this.listeners.clear();
        addListeners(listeners);
    }


    /**
     * Add SameDiff-wide {@link Listener} instances.
     *
     * If you want to use additional listeners for a single operation,
     * use the listener arguments in those methods (e.g. {@link #fit()} and {@link FitConfig#listeners(Listener...)}).
     *
     * @param listeners Listeners
     */
    public void addListeners(Listener... listeners) {
        addListeners(Arrays.asList(listeners));
    }

    /**
     * See {@link #addListeners(Listener...)}.
     */
    public void addListeners(Collection listeners) {
        this.listeners.addAll(listeners);
    }

    /**
     * Gets the current SameDiff-wide listeners.
     */
    public List getListeners() {
        return listeners;
    }

    /**
     * Set the array holders for variable and constant arrays
* NOTE: this is usually reserved for developers and internal use, and should not be needed by almost all users
* See {@link ArrayHolder} for more details * * @param variableArrayHolder Array holder for variable arrays * @param constantArrayHolder Array holder for constant arrays * @param initialize If true: transfer any arrays from the current array holders to the new/specified ones */ public void setArrayHolders(@NonNull ArrayHolder variableArrayHolder, @NonNull ArrayHolder constantArrayHolder, boolean initialize){ if(initialize){ variableArrayHolder.initFrom(this.variablesArrays); constantArrayHolder.initFrom(this.constantArrays); } this.variablesArrays = variableArrayHolder; this.constantArrays = constantArrayHolder; } /** * @return The current name scope, if any (null otherwise). See {@link #withNameScope(String)} for more details. */ public String currentNameScope() { if (nameScopes.isEmpty()) return null; //Would use String.join but that is Java 8+ StringBuilder sb = new StringBuilder(); boolean first = true; for (NameScope ns : nameScopes) { if (!first) { sb.append("/"); } sb.append(ns.getName()); first = false; } return sb.toString(); } /** * @return The name with the current name scope (if any) appended. See {@link #withNameScope(String)} */ protected String nameWithScope(String name) { String scope = currentNameScope(); if (scope == null) { return name; } if (!name.startsWith(scope + "/")) return scope + "/" + name; else return name; } //Intentionally package private void addNameScope(NameScope nameScope) { nameScopes.add(nameScope); } //Intentionally package private void closeNameScope(NameScope nameScope) { //Check that the name scope is closed correctly/in order Preconditions.checkState(!nameScopes.isEmpty(), "Cannot close name scope: no name scopes are currently defined"); Preconditions.checkState(nameScopes.get(nameScopes.size() - 1).equals(nameScope), "Cannot close name scope %s: Name scopes must be closed in order. Current name scopes: \"%s\"", nameScope, currentNameScope()); nameScopes.remove(nameScopes.size() - 1); } /** * Create a name scope. Name scopes append a prefix to the names of any variables and ops created while they are open. *
     *  {@code
     *  SameDiff sd = SameDiff.create();
     *  SDVariable x = sd.var("x", DataType.FLOAT, 5);
     *  SDVariable y;
     *  try(NameScope ns = sd.withNameScope("myScope"){
     *      y = sd.var("y", DataType.FLOAT, 5);
     *  }
     *  SDVariable z = sd.var("z", DataType.FLOAT, 5);
     *
     *  String xName = x.name();      //RESULT: "x"
     *  String yName = y.name();      //RESULT: "myScope/y"
     *  String zName = z.name();      //RESULT: "z"
     *  }
     * 
*

* Note that name scopes can also be nested: *

     *  {@code
     *  SameDiff sd = SameDiff.create();
     *  SDVariable x;
     *  try(NameScope ns = sd.withNameScope("first"){
     *      try(NameScope ns2 = sd.withNameScope("second"){
     *          x = sd.var("x", DataType.FLOAT, 5);
     *      }
     *  }
     *  String xName = x.name();      //RESULT: "first/second/x"
     *  }
     * 
* * @param nameScope Name of the name scope to open/create * @return The NameScope object */ public NameScope withNameScope(String nameScope) { NameScope ns = new NameScope(this, nameScope); addNameScope(ns); return ns; } /** * Gets all operations in a given name scope. */ public List getOpsInScope(NameScope scope) { ArrayList ops = new ArrayList<>(); for (SameDiffOp v : this.ops.values()) { if (v.getName().startsWith(scope.getName())) ops.add(v); } return ops; } /** * See {@link #getOpsInScope(NameScope)}. */ public List getOpsInScope(String scope){ return getOpsInScope(new NameScope(this, scope)); } /** * Gets all variables in a given name scope. */ public List getVariablesInScope(NameScope scope) { ArrayList vars = new ArrayList<>(); for (SDVariable v : variables()) { if (v.name().startsWith(scope.getName())) vars.add(v); } return vars; } /** * See {@link #getVariablesInScope(NameScope)}. */ public List getVariablesInScope(String scope){ return getVariablesInScope(new NameScope(this, scope)); } /** * @param sameDiff * @return */ public SDVariable invokeGraphOn(SameDiff sameDiff) { //map the new vertices on to the old ones Map thisVertexIdToNew = new HashMap<>(); int idx = 1; for (val var : variables()) { SDVariable clone = var.clone(this); SDVariable newVar = sameDiff.var(clone); if (var.getVariableType() != VariableType.ARRAY && var.getArr() != null ) { //ARRAY type = "activations" - are overwritten anyway sameDiff.associateArrayWithVariable(var.getArr(), newVar); } thisVertexIdToNew.put(idx, idx); clone.setSameDiff(sameDiff); idx++; } Map reverseMap = new HashMap<>(); int count = 0; for( Variable v : variables.values()){ reverseMap.put(v.getName(), count++); } val newFunctions = new LinkedHashMap(); for (SameDiffOp op : ops.values()) { DifferentialFunction function = op.getOp(); //Clone the op DifferentialFunction clone = FlatBuffersMapper.cloneViaSerialize(this, function, reverseMap); clone.setSameDiff(sameDiff); clone.setOwnName(function.getOwnName()); if (sameDiff.opExists(function.getOwnName())) sameDiff.putOpForId(function.getOwnName(), function); newFunctions.put(function.getOwnName(), clone); val argsForFunction = function.args(); val outputsForFunction = function.outputVariables(); //note that these have the same variable names sameDiff.addArgsFor(argsForFunction, clone); sameDiff.addOutgoingFor(outputsForFunction, function); for (val arg : clone.args()) { arg.setSameDiff(sameDiff); } for (val output : clone.outputVariables()) { output.setSameDiff(sameDiff); } sameDiff.ops.put(function.getOwnName(), op); } return sameDiff.variables().get(sameDiff.variables().size() - 1); } /** * Returns true if the given function id exists * * @param id the function id to test for * @return true if the function id exists, false otherwise */ public boolean opExists(String id) { return ops.containsKey(id); } /** * Get the differential function (if any) that this variable is the output for * * @param variableName Name of the variable * @return The differential function that this variable is an output of, or null if it is not the output of a function */ public DifferentialFunction getVariableOutputOp(String variableName) { Preconditions.checkState(variables.containsKey(variableName), "No variable with name \"%s\" found in graph", variableName); if (variables.get(variableName).getOutputOfOp() == null || ops.get(stripVarSuffix(variables.get(variableName).getOutputOfOp())) == null) return null; return ops.get(stripVarSuffix(variables.get(variableName).getOutputOfOp())).getOp(); } /** * Get the function by the {@link DifferentialFunction#getOwnName()} * * @param id the id of the function * @return the function for the given id if it exists */ public DifferentialFunction getOpById(@NonNull String id) { if (!ops.containsKey(id)) { throw new ND4JIllegalStateException("No function with id " + id + " found!"); } return ops.get(id).getOp(); } /** * Put the function for the given id * * @param id the id of the function * @param function the function */ public void putOpForId(String id, DifferentialFunction function) { if (ops.containsKey(id) && ops.get(id).getOp() == null) { throw new ND4JIllegalStateException("Function by id already exists!"); } if (!ops.containsKey(id)) { ops.put(id, SameDiffOp.builder().name(id).op(function).build()); } } /** * Returns the name(s) of the inputs for the given function * * @param function the function to get the inputs for * @return the input ids for a given function */ public String[] getInputsForOp(@NonNull DifferentialFunction function) { if (!ops.containsKey(function.getOwnName())) throw new ND4JIllegalStateException("Unknown function instance id found: \"" + function.getOwnName() + "\""); List inputs = ops.get(function.getOwnName()).getInputsToOp(); return inputs == null ? null : inputs.toArray(new String[inputs.size()]); } /** * Returns the name(s) of the outputs for the given function * * @param function the function to get the outputs for * @return the outputs ids for a given function */ public String[] getOutputsForOp(DifferentialFunction function) { if (!ops.containsKey(function.getOwnName())) throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName()); List outputs = ops.get(function.getOwnName()).getOutputsOfOp(); return outputs == null ? null : outputs.toArray(new String[outputs.size()]); } /** * Get the output variable(s) for the specified differential function * * @param function the function reference to get the output variable(s) for * @return the output variables for the given function */ public SDVariable[] getOutputVariablesForOp(DifferentialFunction function) { val inputs = getOutputsForOp(function); if (inputs == null) { throw new ND4JIllegalStateException("No inputs found for function " + function); } val vars = new SDVariable[inputs.length]; for (int i = 0; i < inputs.length; i++) { vars[i] = getVariable(inputs[i]); } return vars; } /** * Get the input variable(s) for the specified differential function * * @param function the function reference to get the input variable(s) for * @return the input variables for the given function */ public SDVariable[] getInputVariablesForOp(DifferentialFunction function) { val inputs = getInputsForOp(function); if (inputs == null) { throw new ND4JIllegalStateException("No inputs found for function " + function); } val vars = new SDVariable[inputs.length]; for (int i = 0; i < inputs.length; i++) { vars[i] = getVariable(inputs[i]); if (vars[i] == null) { throw new ND4JIllegalStateException("Found null variable at index " + i); } } return vars; } /** * Set the stored {@link INDArray} for a variable. Only works if the variable is of type * {@link VariableType#CONSTANT}, {@link VariableType#PLACEHOLDER}, or {@link VariableType#VARIABLE}. */ public void setArrayForVariable(@NonNull String varName, @NonNull INDArray arr) { Preconditions.checkState(variables.containsKey(varName), "No variable with name \"%s\" exists", varName); SDVariable v = getVariable(varName); if (v.isConstant()) { constantArrays.setArray(varName, arr); } else if (v.getVariableType() == VariableType.VARIABLE) { variablesArrays.setArray(varName, arr); } else if (v.isPlaceHolder()) { long tid = Thread.currentThread().getId(); if (!placeholdersPerThread.containsKey(tid)) { placeholdersPerThread.put(tid, new HashMap()); } placeholdersPerThread.get(tid).put(varName, arr); } else { throw new UnsupportedOperationException("Cannot set variable of type " + v.getVariableType() + " using this method"); } } /** * Returns true if the given vertex id and {@link INDArray} already exist. * * @param varName the vertex id * @return true if a vertex with the given INDArray exists, and it has an INDArray associated with it */ public boolean arrayAlreadyExistsForVarName(String varName) { SDVariable var = getVariable(varName); switch (var.getVariableType()) { case VARIABLE: return variablesArrays.hasArray(varName); case ARRAY: long tid = Thread.currentThread().getId(); return sessions.containsKey(tid) && sessions.get(tid).contains(varName, InferenceSession.OUTER_FRAME, 0, null); case CONSTANT: return constantArrays.hasArray(varName); case PLACEHOLDER: return placeholdersPerThread.containsKey(Thread.currentThread().getId()) && placeholdersPerThread.get(Thread.currentThread().getId()).containsKey(varName); default: throw new RuntimeException("Unknown variable type: " + var.getVariableType()); } } /** * Get an {@link INDArray} for a given vertex id, or null if none exists * * @param varName Variable name to get the array for * @return Array, or null if none exists */ public INDArray getArrForVarName(@NonNull String varName) { Preconditions.checkState(variables.containsKey(varName), "No variable found with name \"%s\"", varName); SDVariable v = variables.get(varName).getVariable(); switch (v.getVariableType()) { case VARIABLE: return variablesArrays.getArray(varName); case CONSTANT: if (!constantArrays.hasArray(varName)) return null; return constantArrays.getArray(varName); case ARRAY: //Only stored in inference session... InferenceSession s = sessions.get(Thread.currentThread().getId()); if (s == null) return null; return s.get(varName, InferenceSession.OUTER_FRAME, 0, null, false); case PLACEHOLDER: long tid = Thread.currentThread().getId(); if (placeholdersPerThread.get(tid) == null || !placeholdersPerThread.get(tid).containsKey(varName)) return null; return placeholdersPerThread.get(tid).get(varName); default: throw new RuntimeException("Unknown variable type: " + v.getVariableType()); } } /** * Associate the array with the given variable. * * @param arr the array to get the variable for * @param variable the name of the variable to associate the array with */ public void associateArrayWithVariable(INDArray arr, @NonNull String variable) { Preconditions.checkState(variables.containsKey(variable), "Cannot associate array with variable \"%s\": " + "variable \"%s\" does not exist in this SameDiff instance", variable, variable); associateArrayWithVariable(arr, this.getVariable(variable)); } /** * Associate the array with the given variable. * * @param arr the array to get the variable for * @param variable the variable to associate the array with */ public void associateArrayWithVariable(INDArray arr, SDVariable variable) { if (variable == null) { throw new ND4JIllegalArgumentException("Variable must not be null!"); } if (arr == null) { throw new ND4JIllegalArgumentException("Array must not be null"); } if (variable.dataType() != arr.dataType()) arr = arr.castTo(variable.dataType()); Preconditions.checkState(variable.dataType() == arr.dataType(), "Variable \"%s\" has datatype %s: cannot associate array with type %s with this variable", variable.name(), variable.dataType(), arr.dataType()); if (sessions.get(Thread.currentThread().getId()) == null) { sessions.put(Thread.currentThread().getId(), new InferenceSession(this)); } if (arr.isAttached()) { arr = arr.detach(); } switch (variable.getVariableType()) { case VARIABLE: variablesArrays.setArray(variable.name(), arr); break; case CONSTANT: constantArrays.setArray(variable.name(), arr); break; case ARRAY: throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY - arrays for" + " this type of variable is calculated "); case PLACEHOLDER: //Validate placeholder shapes: long[] phShape = variable.placeholderShape(); Preconditions.checkState(phShape == null || Shape.shapeMatchesPlaceholder(phShape, arr.shape()), "Invalid array shape: cannot associate an array with shape %ndShape with a placeholder of shape %s:" + "shape is wrong rank or does not match on one or more dimensions", arr, phShape); long tid = Thread.currentThread().getId(); if (!placeholdersPerThread.containsKey(tid)) { placeholdersPerThread.put(tid, new HashMap()); } placeholdersPerThread.get(tid).put(variable.name(), arr); break; default: throw new IllegalStateException("Unknown variable type: " + variable.getVariableType()); } //putOrUpdateShapeForVarName(variable.name(), arr.shape(), true); //Also update nested SameDiff instances (such as gradient function) if (sameDiffFunctionInstances != null && sameDiffFunctionInstances.size() > 0) { for (Map.Entry e : sameDiffFunctionInstances.entrySet()) { SameDiff sd = e.getValue(); SDVariable v = sd.getVariable(variable.name()); if (v != null) { sd.associateArrayWithVariable(arr, v); } } } } /** * Update the constant or variable type SDVariable with the values from the specified * array. Note that unlike {@link #associateArrayWithVariable(INDArray, String)} this method will take the * values from the argument array and assign it to the current array. * The actual array (INDArray object) will not be stored or otherwise used within the SameDiff instance. * @param arr Array values to set * @param variable Variable to update the array of. Must be CONSTANT or VARIBLE type SDVariable */ public void assignArray(@NonNull INDArray arr, @NonNull SDVariable variable){ Preconditions.checkState(variable.getVariableType() == VariableType.VARIABLE || variable.getVariableType() == VariableType.CONSTANT, "assignArray method can only be used with VARIBLE or CONSTANT type SDVariables, variable \"%s\" has type %s", variable.name(), variable.getVariableType()); //DeviceLocal doesn't work with views if(arr.isView()) arr = arr.dup(); if(variable.getVariableType() == VariableType.VARIABLE ){ variablesArrays.setArray(variable.name(), arr); } else { constantArrays.setArray(variable.name(), arr); } } /** * Associate a {@link SameDiff} namespace as a sub function. * * @param name the opName of the function * @param nameSpace the namespace */ public void putSubFunction(String name, SameDiff nameSpace) { if (sameDiffFunctionInstances.containsKey(name) && sameDiffFunctionInstances.get(name) != nameSpace) { throw new ND4JIllegalStateException("Unable to replace samediff namespace. Please choose another opName"); } sameDiffFunctionInstances.put(name, nameSpace); } /** * Return a copy of the internal variable map * * @return Map of variables by name */ public Map variableMap() { Map ret = new LinkedHashMap<>(); for (Variable v : variables.values()) { ret.put(v.getName(), v.getVariable()); } return ret; } /** * The set of defined SameDiff function names. SameDiff function instances should not be confused * with DifferentialFunction ops; an example of a SameDiff function instance is the gradient "grad" function * * @return Set of defined SameDiff function instance names */ public Collection definedFunctionNames() { return this.sameDiffFunctionInstances.keySet(); } private SameDiff() { super(null); super.sd = this; sameDiffFunctionInstances = new LinkedHashMap<>(); fieldVariableResolutionMapping = HashBasedTable.create(); } /** * Attempts to insert the {@link DifferentialFunction} reference in to this {@link SameDiff} instance. * If the given array field with the given index already exists, it will do a reference check to ensure that the 2 * array fields are the same. If not, an exception is thrown.
* If the instances are the same (by semantics, not reference) then it will just return the original instance. * This is to ensure that instances that are created are unique and reference checked. * * @param function the array field to attempt to create * @return Original instance */ public X setupFunction(X function) { Preconditions.checkNotNull(function, "Passed in function must not be null!"); if (function instanceof SDVariable) { if (function.getSameDiff() != this) { function.setSameDiff(this); } return function; } return function; } /** * Adds outgoing arguments to the graph for the specified DifferentialFunction * Also checks for input arguments and updates the graph adding an appropriate edge when the full graph is declared. * * @param variables Variables - arguments for the specified differential function * @param function Differential function */ public void addOutgoingFor(SDVariable[] variables, DifferentialFunction function) { String[] varNames = new String[variables.length]; for (int i = 0; i < varNames.length; i++) { varNames[i] = variables[i].name(); } addOutgoingFor(varNames, function); } /** * Adds outgoing arguments to the graph for the specified DifferentialFunction * Also checks for input arguments and updates the graph adding an appropriate edge when the full graph is declared. * * @param varNames Name of the variables that are outputs of the specified differential function * @param function Differential function */ public void addOutgoingFor(String[] varNames, DifferentialFunction function) { if (function.getOwnName() == null) throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly"); if (ops.get(function.getOwnName()).getOutputsOfOp() != null && !ops.get(function.getOwnName()).getOutputsOfOp().isEmpty()) { throw new ND4JIllegalStateException("Outgoing arguments already declared for " + function); } if (varNames == null) throw new ND4JIllegalStateException("Var names can not be null!"); for (int i = 0; i < varNames.length; i++) { if (varNames[i] == null) throw new ND4JIllegalStateException("Variable name elements can not be null!"); } ops.get(function.getOwnName()).setOutputsOfOp(Arrays.asList(varNames)); for (String resultName : varNames) { variables.get(resultName).setOutputOfOp(function.getOwnName()); } } /** * Add a new argument interceptor to the interceptor stack *

* For internal use only. *

* When a op is added with arguments, most recent argument interceptor is called on it. * If ops are added in that interceptor, the next most recent will be called on their args, and so on. * * @param interceptor the argument interceptor to add */ public void addArgumentInterceptor(@NonNull ArgumentInterceptor interceptor) { argumentInterceptors.push(interceptor); } private boolean isArgumentInterceptorPaused(@NonNull ArgumentInterceptor interceptor) { return pausedArgumentInterceptors.contains(interceptor); } private ArgumentInterceptor getArgumentInterceptorToUse() { if (argumentInterceptors.isEmpty()) return null; ArgumentInterceptor use = argumentInterceptors.peek(); int i = 1; while (isArgumentInterceptorPaused(use)) { if (argumentInterceptors.size() - i < 0) return null; use = argumentInterceptors.elementAt(argumentInterceptors.size() - i); i++; } return use; } /** * Remote the top (most recently added) argument interceptor *

* For internal use only. */ public void removeArgumentInterceptor() { if (!argumentInterceptors.isEmpty()) argumentInterceptors.pop(); } /** * Pause the top (most recently added) argument interceptor *

* For internal use only. */ public void pauseArgumentInterceptor() { pausedArgumentInterceptors.add(argumentInterceptors.peek()); } /** * Pause the given argument interceptor *

* For internal use only. * * @param interceptor the argument interceptor to pause */ public void pauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor) { pausedArgumentInterceptors.add(interceptor); } /** * Unpause the top (most recently added) argument interceptor *

* For internal use only. */ public void unpauseArgumentInterceptor() { pausedArgumentInterceptors.remove(argumentInterceptors.peek()); } /** * Unpause the top given argument interceptor *

* For internal use only. * * @param interceptor the argument interceptor to unpause */ public void unpauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor) { pausedArgumentInterceptors.remove(interceptor); } /** * Adds incoming arguments for the specified differential function to the graph * * @param variables Name of the variables that are arguments (inputs) to the specified function * @param function Function */ public void addArgsFor(String[] variables, DifferentialFunction function) { ArgumentInterceptor interceptor = getArgumentInterceptorToUse(); if (interceptor != null) { pauseArgumentInterceptor(interceptor); for (int i = 0; i < variables.length; i++) { variables[i] = interceptor.intercept(getVariable(variables[i])).name(); } unpauseArgumentInterceptor(interceptor); } if (function.getOwnName() == null) throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly"); //Add function if it doesn't exist //TODO could "not existing" be a bug sometimes? if (!ops.containsKey(function.getOwnName())) { ops.put(function.getOwnName(), SameDiffOp.builder().name(function.getOwnName()).op(function).build()); } //Update variable 'inputs to op' accounting for repeated inputs (like y = x+x) ops.get(function.getOwnName()).setInputsToOp(Arrays.asList(variables)); //Duplicate variables OK/required here for (String variableName : variables) { if(this.variables.containsKey(variableName)) { List funcs = this.variables.get(variableName).getInputsForOp(); if (funcs == null) { funcs = new ArrayList<>(); this.variables.get(variableName).setInputsForOp(funcs); } if (!funcs.contains(function.getOwnName())) //Avoid duplicates for function names. funcs.add(function.getOwnName()); } } } /** * Adds incoming arguments for the specified differential function to the graph * * @param variables variables that are arguments (inputs) to the specified function * @param function Function */ public void addArgsFor(SDVariable[] variables, DifferentialFunction function) { String[] varNames = new String[variables.length]; for (int i = 0; i < varNames.length; i++) { if (variables[i] == null) throw new ND4JIllegalStateException("Found null variable at index " + i); varNames[i] = variables[i].name(); } addArgsFor(varNames, function); } /** * Replaces the argument at i with newArg for function * Does not use (or remove) ArgumentInterceptor stuff */ public void replaceArgFor(int i, @NonNull SDVariable newArg, @NonNull DifferentialFunction function) { Preconditions.checkArgument(i < function.args().length, "Index out of range: function " + function.getOwnName() + " only has " + function.args().length + " args but you are trying" + "to replace the argument at " + i); String oldName = function.arg(i).name(); String newName = newArg.name(); List oldArgs = ops.get(function.getOwnName()).getInputsToOp(); oldArgs = new ArrayList<>(oldArgs); oldArgs.set(i, newName); ops.get(function.getOwnName()).setInputsToOp(oldArgs); List funcs = this.variables.get(newName).getInputsForOp(); if (funcs == null) { funcs = new ArrayList<>(); this.variables.get(newName).setInputsForOp(funcs); } if (!funcs.contains(function.getOwnName())) //Avoid duplicates for function names. funcs.add(function.getOwnName()); List oldFuncs = this.variables.get(oldName).getInputsForOp(); if (oldFuncs != null) { if (!ArrayUtils.contains(function.argNames(), oldName)) oldFuncs.remove(function.getOwnName()); } } /** * Returns true if this function already has defined arguments * * @param function the function to check * @return true if the function has args, false otherwise */ public boolean hasArgs(DifferentialFunction function) { List vertexIdArgs = ops.get(function.getOwnName()).getInputsToOp(); return vertexIdArgs != null && vertexIdArgs.size() > 0; } /** * Clear the placeholder arrays from the SameDiff instance * * @param allThreads If true: clear the placeholders for all threads. False: clear only for current thread */ public void clearPlaceholders(boolean allThreads) { if (allThreads) { this.placeholdersPerThread.clear(); } else { long tid = Thread.currentThread().getId(); this.placeholdersPerThread.remove(tid); } for (SameDiff sd : this.sameDiffFunctionInstances.values()) { sd.clearPlaceholders(allThreads); } } /** * Clear the input arrays to each op. * This is usually not required, under normal SameDiff use */ public void clearOpInputs() { for (SameDiffOp op : ops.values()) { if (op.getOp() instanceof Op) { Op o = ((Op) op.getOp()); o.setX(null); if (o.y() != null) { o.setY(null); } } else if (op.getOp() instanceof DynamicCustomOp) { DynamicCustomOp o = (DynamicCustomOp) op.getOp(); o.setInputArguments((INDArray[]) null); } } for (SameDiff sd : this.sameDiffFunctionInstances.values()) { sd.clearOpInputs(); } } /** * Get an array of differential functions that have been defined for this SameDiff instance * * @return Array of differential functions */ public DifferentialFunction[] ops() { List out = new ArrayList<>(ops.size()); for (SameDiffOp op : ops.values()) { out.add(op.getOp()); } return out.toArray(new DifferentialFunction[out.size()]); } @Override public int hashCode() { int result = super.hashCode(); result = 31 * result + (variables != null ? variables.hashCode() : 0); return result; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; SameDiff sameDiff = (SameDiff) o; boolean eqVars = variables.equals(sameDiff.variables); boolean eqOps = ops.equals(sameDiff.ops); return eqVars && eqOps; } /** * Create a new (empty) SameDiff instance without any functions or variables * * @return New SameDiff instance */ public static SameDiff create() { return new SameDiff(); } /** * Clone/duplicate the SameDiff instance, including arrays etc. The returned SameDiff instance should have no * shared state with the original instance * * @return The cloned SameDiff instance */ public SameDiff dup() { ByteBuffer bb = asFlatBuffers(true); try { return fromFlatBuffers(bb); } catch (IOException e){ throw new RuntimeException(e); } } /** * Count the number of elements in all arrays, according to {@link SDVariable#getShape()} * * @return Number of array elements for all variables */ public long numElements() { long ret = 0; for (SDVariable variable : variables()) { long[] shape = variable.getShape(); if (shape != null) { ret += ArrayUtil.prod(shape); } } return ret; } /** * Returns the inputs (placeholders) for the SameDiff graph * * @return the inputs for this graph */ public List inputs() { List out = new ArrayList<>(); for (String s : variables.keySet()) { if (isPlaceHolder(s)) out.add(s); } return out; } /** * Outputs are the names of the predictions of the network. * Note that the outputs must be set using {@link #setOutputs(List)} first * * @return The outputs of the SameDiff instance, or null if no outputs have been set */ public List outputs() { return this.outputs; } /** * See {@link #setOutputs(List)} */ public void setOutputs(String... outputs){ setOutputs(outputs == null ? null : Arrays.asList(outputs)); } /** * Set the outputs of the SameDiff instance. * Outputs are the names of the variables that are the predictions of the neural network. * Note that this is merely a convenience, and does not impact execution at all. Outputs can be retrieved (after * setting here) using {@link #outputs()} * @param outputs Outputs to set. Must be valid variable names in this SameDiff instance */ public void setOutputs(List outputs){ if(outputs != null){ for(String s : outputs){ Preconditions.checkArgument(variables.containsKey(s), "Cannot set variable \"%s\" as an output: SameDiff instance does not contain a variable with this name"); } } this.outputs = outputs; } /** * The list of all variables in the graph * * @return All variables in the graph */ public List variables() { return new ArrayList<>(variableMap().values()); } /** * Get the names of variables (if any) that have been marked as loss variables to be minimized.
* Variables can be marked as loss variables in a few different ways:
* (a) Losses are automatically added when creating loss functions via {@link #sd()}
* (b) Via {@link #setLossVariables(String...)}, @link #addLossVariable(String)} or {@link SDVariable#markAsLoss()}
* (c) Via {@link TrainingConfig#setLossVariables(List)}
*/ public List getLossVariables() { return Collections.unmodifiableList(this.lossVariables); } /** * Clear/remove any existing loss variables, and set the loss variables to the specified variable names.
* See {@link #addLossVariable(String)} for more details * * @param lossVariableNames Names of variables to be loss function variables */ public void setLossVariables(@NonNull String... lossVariableNames) { this.lossVariables.clear(); for (String s : lossVariableNames) { addLossVariable(s); } //After changing loss function variables, we (probably) need to recreate gradient function - as gradient // function is defined with respect to specific loss function variables sameDiffFunctionInstances.remove("grad"); } /** * See {@link #setLossVariables(String...)} */ public void setLossVariables(@NonNull SDVariable... lossVariables) { String[] varNames = new String[lossVariables.length]; for (int i = 0; i < lossVariables.length; i++) varNames[i] = lossVariables[i].name(); setLossVariables(varNames); } /** * Mark the specified variable as a loss function variable. This means that this variable will be minimized via backprop during training.
* This will add the variable as a loss to any others - i.e., if multiple variables are marked as losses, their values will be summed * to give the total network loss.
* Note that only floating point (Float16/32/64) variables may be marked as a loss.
* Note also that only ARRAY type SDVariables can be marked as losses to be minimized. That is, we cannot mark the value * of a constant, variable or placeholder to be minimized as doing so would not make sense.
*/ public void addLossVariable(@NonNull String variableName) { Preconditions.checkState(hasVariable(variableName), "No variable with name \"%s\" exists", variableName); SDVariable v = getVariable(variableName); Preconditions.checkState(v.dataType().isFPType(), "Only floating point type variables can be marked as losses to be minimized." + " SDVariable \"%s\" has datatype %s", variableName, v.dataType()); Preconditions.checkState(v.getVariableType() == VariableType.ARRAY, "Only ARRAY type SDVariables can be marked as losses to be minimized." + " SDVariable \"%s\" has variable type %s", variableName, v.getVariableType()); if (!lossVariables.contains(variableName)) { lossVariables.add(variableName); } } /** * See {@link #addLossVariable(String)} */ public void addLossVariable(@NonNull SDVariable variable) { addLossVariable(variable.name()); } /** * Set the training configuration ({@link TrainingConfig}) for the SameDiff instance. * A TrainingConfig must be set before the SameDiff instance can be trained via the fit methods * * @param trainingConfig Training configuration */ public void setTrainingConfig(TrainingConfig trainingConfig) { this.trainingConfig = trainingConfig; } /** * Fit the SameDiff instance based on a single DataSet (i.e., a single minibatch for one iteration).
* This method can only be used for singe input, single output SameDiff instances as DataSet only supports a * single input and a single output.
* Note that a {@link TrainingConfig} must be set via {@link #setTrainingConfig(TrainingConfig)} before training can * be performed. * * @param dataSet The DataSet (single minibatch) to peform training on * @param listeners Additional listeners to use during this operation * @return a {@link History} object containing the history information for this training operation * (evaluations specified in the {@link TrainingConfig}, loss values, and timing information). */ public History fit(@NonNull DataSet dataSet, @NonNull Listener... listeners) { return fit(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), 1, false, null, 1, listeners); } /** * Fit the SameDiff instance based on a single MultiDataSet (i.e., a single minibatch for one iteration).
* Note that a {@link TrainingConfig} must be set via {@link #setTrainingConfig(TrainingConfig)} before training can * be performed. * * @param dataSet The MultiDataSet (single minibatch) to peform training on * @param listeners Additional listeners to use during this operation * @return a {@link History} object containing the history information for this training operation * (evaluations specified in the {@link TrainingConfig}, loss values, and timing information). */ public History fit(@NonNull MultiDataSet dataSet, @NonNull Listener... listeners) { return fit(new SingletonMultiDataSetIterator(dataSet), 1, false, null, 1, listeners); } /** * Fit the SameDiff instance based on DataSetIterator for the specified number of epochs.
* This method can only be used for singe input, single output SameDiff instances as DataSet only supports a * single input and a single output.
* Note that a {@link TrainingConfig} must be set via {@link #setTrainingConfig(TrainingConfig)} before training can * be performed. *

* A special case of {@link #fit()}. * * @param iter The iterator to train the SameDiff instance with * @param numEpochs The number of epochs for training. Must be > 0 * @param validationIter The DataSetIterator to use for validation (null to skip validation) * @param validationFrequency The frequency with which to run validation. 1 is every epoch, 2 is every other, etc. * @param listeners Additional listeners to use during this operation * @return a {@link History} object containing the history information for this training operation * (evaluations specified in the {@link TrainingConfig}, loss values, and timing information). */ public History fit(@NonNull DataSetIterator iter, int numEpochs, DataSetIterator validationIter, int validationFrequency, @NonNull Listener... listeners) { return fit().train(iter, numEpochs).validate(validationIter, validationFrequency).listeners(listeners).exec(); } /** * See {@link #fit(DataSetIterator, int, DataSetIterator, int, Listener...)}, does not preform validation. *

* A special case of {@link #fit()}. * * @param iter The iterator to train the SameDiff instance with * @param numEpochs The number of epochs for training. Must be > 0 * @param listeners Additional listeners to use during this operation * @return a {@link History} object containing the history information for this training operation * (evaluations specified in the {@link TrainingConfig}, loss values, and timing information). */ public History fit(@NonNull DataSetIterator iter, int numEpochs, @NonNull Listener... listeners) { return fit().train(iter, numEpochs).listeners(listeners).exec(); } /** * Fit the SameDiff instance based on MultiDataSetIterator for the specified number of epochs.
* This method can both singe input, single output and multi-input, multi-output SameDiff instances
* Note that a {@link TrainingConfig} must be set via {@link #setTrainingConfig(TrainingConfig)} before training can * be performed. *

* A special case of {@link #fit()}. * * @param iter The iterator to train the SameDiff instance with * @param numEpochs The number of epochs for training. Must be > 0 * @param validationIter The MultiDataSetIterator to use for validation (null to skip validation) * @param validationFrequency The frequency with which to run validation. 1 is every epoch, 2 is every other, etc. * @param listeners Additional listeners to use during this operation * @return a {@link History} object containing the history information for this training operation * (evaluations specified in the {@link TrainingConfig}, loss values, and timing information). */ public History fit(@NonNull MultiDataSetIterator iter, int numEpochs, MultiDataSetIterator validationIter, int validationFrequency, @NonNull Listener... listeners) { return fit(iter, numEpochs, true, validationIter, validationFrequency, listeners); } /** * See {@link #fit(MultiDataSetIterator, int, MultiDataSetIterator, int, Listener...)}, does not preform validation. *

* A special case of {@link #fit()}. * * @param iter The iterator to train the SameDiff instance with * @param numEpochs The number of epochs for training. Must be > 0 * @param listeners Additional listeners to use during this operation * @return a {@link History} object containing the history information for this training operation * (evaluations specified in the {@link TrainingConfig}, loss values, and timing information). */ public History fit(@NonNull MultiDataSetIterator iter, int numEpochs, @NonNull Listener... listeners) { return fit().train(iter, numEpochs).listeners(listeners).exec(); } /** * Set up for a fit operation using {@link FitConfig}. *

* Supports the setting of training data ({@link MultiDataSetIterator} or {@link DataSetIterator}), number of epochs, * validation data ({@link MultiDataSetIterator} or {@link DataSetIterator}), validation frequency, and additional listeners. *

* Example: train on data for 5 epochs, validating on valData every 2nd epoch *

     *     {@code
     *     SameDiff sd = ...;
     *     MultiDataSet data = ...;
     *     MultiDataSet valData = ...;
     *
     *     History hist = sd.fit()
     *         .train(data, 5)
     *         .validate(valData, 2)
     *         .exec();
     *     }
     * 
*/ public FitConfig fit() { return new FitConfig(this); } //Synchronized for thread safety protected synchronized History fit(@NonNull MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount, MultiDataSetIterator validationData, int validationFrequency, @NonNull Listener... listeners) { boolean async = iter.asyncSupported(); boolean validationAsync = false; if (validationData != null) validationAsync = validationData.asyncSupported(); if (async) { iter = new AsyncMultiDataSetIterator(iter, 3, true); } if (validationAsync) { validationData = new AsyncMultiDataSetIterator(validationData, 3, true); } try { return fitHelper(iter, numEpochs, incrementEpochCount, validationData, validationFrequency, Arrays.asList(listeners)); } finally { if (async) { ((AsyncMultiDataSetIterator) iter).shutdown(); } if (validationAsync) { ((AsyncMultiDataSetIterator) validationData).shutdown(); } } } //fitHelper should only be called from fit method above protected synchronized History fitHelper(@NonNull MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount, MultiDataSetIterator validationData, int validationFrequency, @NonNull List listeners) { Preconditions.checkNotNull(iter, "Iterator must not be null"); Preconditions.checkState(numEpochs > 0, "Number of training epochs must be a positive number. Got: %s", numEpochs); Preconditions.checkState(trainingConfig != null, "No training configuration has been set. A training configuration must " + "be set before training. Use setTrainingConfig(TrainingConfig)"); Preconditions.checkState(numEpochs == 1 || iter.resetSupported(), "Cannot train for multiple epochs on an iterator that" + " does not support resetting"); HistoryListener history = new HistoryListener(trainingConfig); List activeListeners = new ArrayList<>(); activeListeners.add(history); for (Listener l : this.listeners) if (l.isActive(Operation.TRAINING)) activeListeners.add(l); for (Listener l : listeners) if (l.isActive(Operation.TRAINING)) activeListeners.add(l); validateListenerActivations(activeListeners, Operation.TRAINING); validateListenerActivations(activeListeners, Operation.TRAINING_VALIDATION); if (!iter.hasNext() && iter.resetSupported()) iter.reset(); boolean performedValidation = false; int trainThreadNum = 0; long jThreadId = Thread.currentThread().getId(); boolean hasListeners = !activeListeners.isEmpty(); At at = At.builder() .epoch(trainingConfig.getEpochCount()) .iteration(trainingConfig.getIterationCount()) .trainingThreadNum(trainThreadNum) .javaThreadNum(jThreadId) .operation(Operation.TRAINING) .build(); LossCurve lossCurve = null; Set requiredVars = new HashSet<>(); for (Listener l : activeListeners) { ListenerVariables lv = l.requiredVariables(this); if(lv != null) { Set s = lv.trainingVariables(); if(s != null) { requiredVars.addAll(s); } } } List listenersWitHistory = new ArrayList<>(listeners); for(Listener l : this.listeners){ if(!listenersWitHistory.contains(l)) listenersWitHistory.add(l); } listenersWitHistory.add(history); SameDiff gradInstance = getFunction("grad"); if(gradInstance == null){ createGradFunction(); gradInstance = getFunction("grad"); } TrainingSession ts = new TrainingSession(gradInstance); gradInstance.setTrainingConfig(trainingConfig); //In case any listeners want to use it for(Listener l : activeListeners){ l.operationStart(gradInstance, Operation.TRAINING); } Set paramsToTrain = new LinkedHashSet<>(); for(Variable v : variables.values()){ if(v.getVariable().getVariableType() == VariableType.VARIABLE){ //TODO not all variable type are needed - i.e., variable that doesn't impact loss should be skipped paramsToTrain.add(v.getName()); } } Loss lastLoss = null; for (int i = 0; i < numEpochs; i++) { if (incrementEpochCount && hasListeners) { at.setEpoch(trainingConfig.getEpochCount()); for (Listener l : activeListeners) { l.epochStart(this, at); } } long epochStartTime = System.currentTimeMillis(); double[] lossSums = null; List lossNames = null; int lossCount = 0; while (iter.hasNext()) { long dataStart = hasListeners ? System.currentTimeMillis() : 0; org.nd4j.linalg.dataset.api.MultiDataSet ds = iter.next(); long dataEnd = hasListeners ? System.currentTimeMillis() : 0; if (!performedValidation) { Preconditions.checkState(trainingConfig.getDataSetFeatureMapping().size() == ds.numFeatureArrays(), "The number of dataset feature mapping variables set in the training configuration (%s) must match" + " the number of dataset feature arrays (%s)", trainingConfig.getDataSetFeatureMapping().size(), ds.numFeatureArrays()); List labelMapping = trainingConfig.getDataSetLabelMapping(); int lblSize = labelMapping == null ? 0 : labelMapping.size(); Preconditions.checkState(lblSize == ds.numLabelsArrays(), "The number of dataset label mapping variables set in the training configuration (%s) must match" + " the number of dataset label arrays (%s)", lblSize, ds.numLabelsArrays()); performedValidation = true; } if (hasListeners) { at.setIteration(trainingConfig.getIterationCount()); for (Listener l : activeListeners) { l.iterationStart(this, at, ds, (dataEnd - dataStart)); } } //Create placeholder variable map Map placeholders = toPlaceholderMap(ds); Preconditions.checkState(placeholders.size() > 0, "No placeholder variables were set for training"); //Call TrainingSession to perform training if (!initializedTraining) initializeTraining(); lastLoss = ts.trainingIteration( trainingConfig, placeholders, paramsToTrain, updaterMap, ds, getLossVariables(), listenersWitHistory, at); if (lossSums == null) { lossSums = lastLoss.getLosses().clone(); } else { for (int j = 0; j < lossSums.length; j++) { lossSums[j] += lastLoss.getLosses()[j]; } } lossCount++; trainingConfig.incrementIterationCount(); } long epochTime = System.currentTimeMillis() - epochStartTime; if (incrementEpochCount) { lossNames = lastLoss.getLossNames(); for (int j = 0; j < lossSums.length; j++) lossSums[j] /= lossCount; if (lossCurve != null) lossCurve = lossCurve.addLossAndCopy(lossSums, lossNames); else lossCurve = new LossCurve(lossSums, lossNames); } if (incrementEpochCount) { if (hasListeners) { boolean doStop = false; Listener stopped = null; for (Listener l : activeListeners) { ListenerResponse res = l.epochEnd(this, at, lossCurve, epochTime); if (res == ListenerResponse.STOP && (i < numEpochs - 1)) { doStop = true; stopped = l; } } if (doStop) { log.info("Stopping training early. Listener " + stopped + " gave a STOP signal at epoch " + at.epoch() + " and iteration " + at.iteration()); for (Listener l1 : activeListeners) l1.operationEnd(this, Operation.TRAINING); if (i < numEpochs - 1) { iter.reset(); } if (incrementEpochCount) trainingConfig.incrementEpochCount(); return history.getReport(); } //validation evaluation if (validationData != null && (validationFrequency <= 0 || i % validationFrequency == 0)) { long validationStart = System.currentTimeMillis(); outputHelper(validationData, new At(at.epoch(), 0, 0, 0, null, Operation.TRAINING_VALIDATION), listenersWitHistory); long validationTime = System.currentTimeMillis() - validationStart; boolean doStopV = false; Listener stoppedV = null; for (Listener l : activeListeners) { ListenerResponse res = l.validationDone(this, at, validationTime); if (res == ListenerResponse.STOP && (i < numEpochs - 1)) { doStopV = true; stoppedV = l; } } if (doStopV) { log.info("Stopping training early from validation. Listener " + stoppedV + " gave a STOP signal at epoch " + at.epoch() + " and iteration " + at.iteration()); for (Listener l1 : activeListeners) l1.operationEnd(this, Operation.TRAINING); if (i < numEpochs - 1) { iter.reset(); } if (incrementEpochCount) trainingConfig.incrementEpochCount(); return history.getReport(); } } } trainingConfig.incrementEpochCount(); } if (i < numEpochs - 1) { iter.reset(); } } for (Listener l1 : activeListeners) l1.operationEnd(this, Operation.TRAINING); return history.getReport(); } /** * Ensure the specified listeners do not request any activations that aren't present for the given operation */ private void validateListenerActivations(List listeners, Operation op) { for (Listener l : listeners) { ListenerVariables lv = l.requiredVariables(this); if(lv != null) { for (String s : lv.requiredVariables(op)) { if (!variables.containsKey(s)) { Preconditions.checkState(false, "Listener %s requested variable %s that is not defined in this SameDiff graph", l, s); } } } } } /** * Calculate the regularization (L1, L2 and/or WeightDecay) component of the loss function for the current parameters.. * Note that the training configuration must be set (via {@link #setTrainingConfig(TrainingConfig)}) before this * method can be called * * @return The regularization component of the score/loss function */ public double calcRegularizationScore() { Preconditions.checkState(trainingConfig != null, "No training configuration has been set. A training configuration must " + "be set before calculating the L2 loss. Use setTrainingConfig(TrainingConfig)"); if (trainingConfig.getRegularization() == null || trainingConfig.getRegularization().isEmpty()) { return 0.0; } List l = trainingConfig.getRegularization(); double loss = 0.0; for (Variable v : variables.values()) { SDVariable sdv = v.getVariable(); if (sdv.getVariableType() != VariableType.VARIABLE || !sdv.dataType().isFPType()) { //Only trainable parameters (FP and variable type vars) contribute to regularization score continue; } for (Regularization r : l) { INDArray arr = sdv.getArr(); loss += r.score(arr, trainingConfig.getIterationCount(), trainingConfig.getEpochCount()); } } return loss; } /** * Perform setup for training. Does the following: * 1. Infer the set of trainable parameters - unless specified manually by the user * 2. Set up the updaters */ protected void initializeTraining() { if (!initializedTraining) { if (trainingConfig == null) { throw new ND4JIllegalStateException("Please specify a training config with setTrainingConfig"); } updaterMap = new HashMap<>(); for (Variable v : variables.values()) { if (v.getVariable().getVariableType() != VariableType.VARIABLE || !v.getVariable().dataType().isFPType()) { //Skip non-trainable parameters continue; } INDArray arr = v.getVariable().getArr(); long stateSize = trainingConfig.getUpdater().stateSize(arr.length()); INDArray view = stateSize == 0 ? null : Nd4j.createUninitialized(arr.dataType(), 1, stateSize); GradientUpdater gu = trainingConfig.getUpdater().instantiate(view, false); gu.setStateViewArray(view, arr.shape(), arr.ordering(), true); updaterMap.put(v.getName(), gu); } initializedTraining = true; } } /** * Convert the MultiDataSet to a {@code Map} based on the TrainingConfig settings. * The key is the placeholder/variable that the value INDArray should be associated with. * * @param ds MultiDataSet - source of the features/labels * @return MultiDataSet converted to a Map, based on TrainingConfig */ private Map toPlaceholderMap(org.nd4j.linalg.dataset.api.MultiDataSet ds) { Map placeholders = new HashMap<>(); int count = 0; for (String s : trainingConfig.getDataSetFeatureMapping()) { placeholders.put(s, ds.getFeatures(count++)); } count = 0; if (trainingConfig.getDataSetLabelMapping() != null) { //Labels may be null in some models (unsupervised etc) for (String s : trainingConfig.getDataSetLabelMapping()) { placeholders.put(s, ds.getLabels(count++)); } } if (trainingConfig.getDataSetFeatureMaskMapping() != null && trainingConfig.getDataSetFeatureMaskMapping().size() > 0) { count = 0; for (String s : trainingConfig.getDataSetFeatureMaskMapping()) { if (s == null) { count++; continue; } placeholders.put(s, ds.getFeaturesMaskArray(count++)); } } if (trainingConfig.getDataSetLabelMaskMapping() != null && trainingConfig.getDataSetLabelMaskMapping().size() > 0) { count = 0; for (String s : trainingConfig.getDataSetLabelMaskMapping()) { if (s == null) { count++; continue; } placeholders.put(s, ds.getLabelsMaskArray(count++)); } } return placeholders; } /** * Evaluate the performance of a single variable's prediction.
* For example, if the variable to evaluatate was called "softmax" you would use: *
     * {@code Evaluation e = new Evaluation();
     * sameDiff.evaluate(iterator, "softmax", e);}
     * 
*

* A special case of {@link #evaluate()}. * * @param iterator Iterator as source of data to evaluate * @param outputVariable The variable to evaluate * @param listeners Additional listeners to use during this operation. * @param evaluations The evaluations to perform */ public void evaluate(@NonNull DataSetIterator iterator, @NonNull String outputVariable, @NonNull List listeners, @NonNull IEvaluation... evaluations) { Preconditions.checkArgument(evaluations != null && evaluations.length > 0, "No evaluations were passed to the evaluate method"); evaluate().data(iterator).evaluate(outputVariable, evaluations).listeners(listeners.toArray(new Listener[0])).exec(); } /** * See {@link #evaluate(DataSetIterator, String, List, IEvaluation[])}. *

* A special case of {@link #evaluate()}. */ public void evaluate(@NonNull DataSetIterator iterator, @NonNull String outputVariable, @NonNull IEvaluation... evaluations) { evaluate().data(iterator).evaluate(outputVariable, evaluations).exec(); } /** * Evaluation for multiple-output networks.
* See {@link #evaluate(MultiDataSetIterator, Map, Map, Listener[])}. *

* A special case of {@link #evaluate()}. */ public void evaluate(@NonNull DataSetIterator iterator, @NonNull Map variableEvals, @NonNull Listener... listeners) { Map map = new HashMap<>(); Map> variableEvalsList = new HashMap<>(); for (String s : variableEvals.keySet()) { map.put(s, 0); //Only 1 possible output here with DataSetIterator variableEvalsList.put(s, Collections.singletonList(variableEvals.get(s))); } evaluate(new MultiDataSetIteratorAdapter(iterator), variableEvalsList, map, listeners); } /** * Evaluation for multiple output networks - one or more. * See {@link #evaluate(MultiDataSetIterator, Map, Map, Listener[])}. *

* A special case of {@link #evaluate()}. */ public void evaluateMultiple(DataSetIterator iterator, Map> variableEvals, @NonNull Listener... listeners) { Map map = new HashMap<>(); for (String s : variableEvals.keySet()) { map.put(s, 0); //Only 1 possible output here with DataSetIterator } evaluate(new MultiDataSetIteratorAdapter(iterator), variableEvals, map, listeners); } /** * Evaluate the performance of a single variable's prediction.
* For example, if the variable to evaluatate was called "softmax" you would use: *

     * {@code Evaluation e = new Evaluation();
     * sameDiff.evaluate(iterator, "softmax", e);}
     * 
*

* A special case of {@link #evaluate()}. * * @param iterator Iterator as source of data to evaluate * @param outputVariable The variable to evaluate * @param labelIndex The index of the target variable's labels in the iterator * @param listeners Additional listeners to use during this operation. * @param evaluations The evaluations to perform */ public void evaluate(@NonNull MultiDataSetIterator iterator, @NonNull String outputVariable, int labelIndex, @NonNull List listeners, @NonNull IEvaluation... evaluations) { Preconditions.checkArgument(evaluations != null && evaluations.length > 0, "No evaluations were passed to the evaluate method"); evaluate().data(iterator).evaluate(outputVariable, labelIndex, evaluations).listeners(listeners.toArray(new Listener[0])).exec(); } /** * See {@link #evaluate(MultiDataSetIterator, String, int, List, IEvaluation[])}. *

* A special case of {@link #evaluate()}. */ public void evaluate(@NonNull MultiDataSetIterator iterator, @NonNull String outputVariable, int labelIndex, @NonNull IEvaluation... evaluations) { evaluate().data(iterator).evaluate(outputVariable, labelIndex, evaluations).exec(); } /** * Perform evaluation using classes such as {@link Evaluation} for classifier outputs * and {@link org.nd4j.evaluation.regression.RegressionEvaluation} for regression outputs.
*
* Example: classifier evaluation
* Predictions variable name: "softmaxOutput"
* Evaluations to perform: {@link Evaluation}
* Data: single input, single output MultiDataSets
* Code:
*

     * {@code
     * MultiDataSetIterator data = ...
     * Map> evals = Collections.singletonMap("softmaxOutput",Collections.singletonList(new Evaluation()));
     * Map labelMapping = Collections.singletonMap("softmaxOutput",0);  //Compare: "softmaxOutput" vs. MultiDataSet.getLabels(0)
     * }
     * 
*

* A special case of {@link #evaluate()}. * * @param iterator The iterator - the source of the data for evaluation * @param variableEvals The evaluations to perform. Key: the name of the variable. Value: the evaluations to perform * @param predictionLabelMapping The output/label mapping. Key: the name of the variable. * @param listeners Additional listeners to use during this operation. */ public void evaluate(MultiDataSetIterator iterator, Map> variableEvals, Map predictionLabelMapping, Listener... listeners) { evaluateHelper(iterator, variableEvals, predictionLabelMapping, At.defaultAt(Operation.EVALUATION), listeners); } /** * Set up for a evaluation operation using EvaluationConfig. *

* Supports the setting of the data ({@link MultiDataSetIterator} or {@link DataSetIterator}), * adding evaluations for variables (with optional label index setting), setting label indices, * and setting additional listeners. * Does not require setting label indices when using a {@link DataSetIterator}. *

* Also supports using {@link SDVariable} instances instead of variable names. * *

* Example: evaluate "pred" with {@link Evaluation} and {@link ROC}, using label 0. *

     *      {@code
     *     SameDiff sd = ...;
     *     MultiDataSetIterator data = ...;
     *
     *     EvaluationRecord results = sd.evaluate()
     *         .data(data)
     *         .evaluate("pred", 0, new Evaluation(), new ROC()),
     *         .exec();
     *      }
     *  
* Example: evaluate "pred" with {@link Evaluation}, using the only label from a DataSetIterator. *
     *      {@code
     *     SameDiff sd = ...;
     *     DataSetIterator singleData = ...;
     *
     *     EvaluationRecord results = sd.evaluate()
     *         .data(singleData)
     *         .evaluate("pred", new Evaluation()),
     *         .exec();
     *      }
     *  
*/ public EvaluationConfig evaluate() { return new EvaluationConfig(this); } /** * Helper method for evaluations. Should only be called from the above evaluate method */ private void evaluateHelper(MultiDataSetIterator iterator, Map> variableEvals, Map predictionLabelMapping, At at, @NonNull Listener... listeners) { Preconditions.checkState(trainingConfig != null, "Training config has not been set"); Preconditions.checkState(variableEvals.keySet().equals(predictionLabelMapping.keySet()), "Keysets for variable evaluations" + " and for the prediction label mapping must be equal. Keys for variables to evaluate: %s vs. keys for label mapping: %s", variableEvals.keySet(), predictionLabelMapping.keySet()); List activeListeners = new ArrayList<>(); for (Listener l : listeners) if (l.isActive(at.operation())) activeListeners.add(l); for (Listener l : this.listeners) if (l.isActive(at.operation())) activeListeners.add(l); validateListenerActivations(activeListeners, at.operation()); for (Listener l : activeListeners) l.operationStart(this, at.operation()); boolean hasListeners = !activeListeners.isEmpty(); if (!iterator.hasNext() && iterator.resetSupported()) iterator.reset(); Set requiredVars = new HashSet<>(variableEvals.keySet()); if (hasListeners) { for (Listener l : activeListeners) { ListenerVariables v = l.requiredVariables(this); if(v != null) { requiredVars.addAll(v.evaluationVariables()); } } } String[] requiredVarsArr = requiredVars.toArray(new String[0]); while (iterator.hasNext()) { MultiDataSet ds = iterator.next(); Map placeholderMap = toPlaceholderMap(ds); Map m = directExecHelper(placeholderMap, at, ds, Collections.emptyList(), activeListeners, requiredVarsArr); for (Map.Entry> e : variableEvals.entrySet()) { INDArray prediction = m.get(e.getKey()); for (IEvaluation eval : e.getValue()) { //TODO time series, etc INDArray label = ds.getLabels(predictionLabelMapping.get(e.getKey())); INDArray mask = ds.getLabelsMaskArray(predictionLabelMapping.get(e.getKey())); eval.eval(label, prediction, mask); } } at.setIteration(at.iteration() + 1); } for (Listener l : activeListeners) l.operationEnd(this, at.operation()); } /** * Do a single batch inference on a network with a single input.
* For example, if the variable to infer was called "softmax" you would use: *
     * {@code
     * sameDiff.output(iterator, "softmax");}
     * 
* * @param dataSet The data to evaluate * @param outputs The variables to evaluate */ public Map output(@NonNull DataSet dataSet, @NonNull String... outputs) { return outputBatches(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), outputs).get(0); } /** * Do a single batch inference on a network.
* For example, if the variable to infer was called "softmax" you would use: *
     * {@code
     * sameDiff.output(iterator, "softmax");}
     * 
* * @param dataSet The data to evaluate * @param outputs The variables to evaluate */ public Map output(@NonNull MultiDataSet dataSet, @NonNull String... outputs) { return outputBatches(new SingletonMultiDataSetIterator(dataSet), outputs).get(0); } /** * Do inference on a network with a single input.
* For example, if the variable to infer was called "softmax" you would use: *
     * {@code
     * sameDiff.output(iterator, "softmax");}
     * 
*

* Uses concatenation on the outputs of {@link #outputBatches(DataSetIterator, String...)} which may cause issues with some inputs. * RNNs with variable time series length and CNNs with variable image sizes will most likely have issues. *

* Special case of {@link #output()}. * * @param iterator Iterator as source of data to evaluate * @param listeners Additional listeners to use during this operation. * @param outputs The variables to evaluate */ public Map output(@NonNull DataSetIterator iterator, @NonNull List listeners, @NonNull String... outputs) { return output().data(iterator).output(outputs).listeners(listeners.toArray(new Listener[0])).exec(); } /** * See {@link #output(DataSetIterator, List, String...)}. No additional listeners. *

* Special case of {@link #output()}. */ public Map output(@NonNull DataSetIterator dataSet, @NonNull String... outputs) { return output().data(dataSet).output(outputs).exec(); } /** * See {@link #output(DataSetIterator, List, String...)}, but without the concatenation of batches. *

* Special case of {@link #output()}. */ public List> outputBatches(DataSetIterator iterator, List listeners, String... outputs) { return output().data(iterator).output(outputs).listeners(listeners.toArray(new Listener[0])).execBatches(); } /** * See {@link #output(DataSetIterator, String...)}, but without the concatenation of batches. *

* Special case of {@link #output()}. */ public List> outputBatches(DataSetIterator iterator, String... outputs) { return output().data(iterator).output(outputs).execBatches(); } /** * Perform inference.
*
* Example: classifier inference
* Predictions variable name: "softmaxOutput"
* Evaluations to perform: {@link Evaluation}
* Data: single output MultiDataSets
* Code:
*

     * {@code
     * MultiDataSetIterator data = ...
     * sameDiff.output(iterator, "softmaxOutput);
     * }
     * 
*

* Special case of {@link #output()}. * * @param iterator The iterator - the source of the data for inference * @param listeners Additional listeners to use during this operation. * @param outputs The set of outputs to report. If null, defaults to all outputs of this SameDiff. */ public Map output(@NonNull MultiDataSetIterator iterator, @NonNull List listeners, @NonNull String... outputs) { return stackOutputs(outputHelper(iterator, At.defaultAt(Operation.INFERENCE), listeners, outputs)); } /** * See {@link #output(MultiDataSetIterator, List, String...)}. No additional listeners. *

* Special case of {@link #output()}. */ public Map output(@NonNull MultiDataSetIterator dataSet, @NonNull String... outputs) { return output().data(dataSet).output(outputs).exec(); } /** * Perform inference.
*
* Example: classifier inference
* Predictions variable name: "softmaxOutput"
* Evaluations to perform: {@link Evaluation}
* Data: single output MultiDataSets
* Code:
*

     * {@code
     * MultiDataSetIterator data = ...
     * sameDiff.output(iterator, "softmaxOutput);
     * }
     * 
*

* Uses concatenation on the outputs of {@link #outputBatches(MultiDataSetIterator, List, String...)} which may cause issues with some inputs. * RNNs with variable time series length and CNNs with variable image sizes will most likely have issues. *

* Special case of {@link #output()}. * * @param iterator The iterator - the source of the data for inference * @param listeners Additional listeners to use during this operation. * @param outputs The set of outputs to report. If null, defaults to all outputs of this SameDiff. */ public List> outputBatches(MultiDataSetIterator iterator, List listeners, String... outputs) { return outputHelper(iterator, At.defaultAt(Operation.INFERENCE), listeners, outputs); } /** * See {@link #outputBatches(MultiDataSetIterator, List, String...)}. No additional listeners. *

* Special case of {@link #output()}. */ public List> outputBatches(MultiDataSetIterator iterator, String... outputs) { return output().data(iterator).output(outputs).execBatches(); } /** * Set up for an inference operation using OutputConfig. * Supports the setting of variables to output, the input data ({@link MultiDataSetIterator} or {@link DataSetIterator}), * and additional listeners. * Has exec methods to get results in batches or concatenated, or to get results when there is only * a single output (again in batches or concatenated). *

* Also supports using {@link SDVariable} instances instead of variable names. * *

* Example: get the output of pred, with batches concatenated together *

     *     {@code
     *     SameDiff sd = ...;
     *     MultiDataSet data = ...;
     *
     *     INDArray out = sd.output()
     *         .data(data)
     *         .output("pred")
     *         .outputSingle();
     *     }
     * 
*/ public OutputConfig output() { return new OutputConfig(this); } /** * Helper method to run inference. Also used for validation */ private List> outputHelper(MultiDataSetIterator iterator, At at, @NonNull List listeners, @NonNull String... outputs) { Preconditions.checkState(trainingConfig != null, "Training config has not been set"); List activeListeners = new ArrayList<>(); for (Listener l : listeners) if (l.isActive(at.operation())) activeListeners.add(l); for (Listener l : this.listeners) if (l.isActive(at.operation())) activeListeners.add(l); validateListenerActivations(activeListeners, at.operation()); for (Listener l : activeListeners) l.operationStart(this, at.operation()); boolean hasListeners = !activeListeners.isEmpty(); List neededOutputs; if (outputs != null && outputs.length != 0) { neededOutputs = Arrays.asList(outputs); } else { neededOutputs = getLossVariables(); } String[] neededOutputsArr = neededOutputs.toArray(new String[0]); List> predictions = new ArrayList<>(); if (!iterator.hasNext() && iterator.resetSupported()) iterator.reset(); Set requiredVars = new HashSet<>(); for (Listener l : activeListeners) { if (at.operation() == Operation.TRAINING_VALIDATION) requiredVars.addAll(l.requiredVariables(this).validationVariables()); else requiredVars.addAll(l.requiredVariables(this).inferenceVariables()); } while (iterator.hasNext()) { long dataStart = hasListeners ? System.currentTimeMillis() : 0; MultiDataSet ds = iterator.next(); long dataEnd = hasListeners ? System.currentTimeMillis() : 0; Map placeholderMap = toPlaceholderMap(ds); if (hasListeners) { for (Listener l : activeListeners) { l.iterationStart(this, at, ds, (dataEnd - dataStart)); } Map outs = directExecHelper(placeholderMap, at, ds, requiredVars, activeListeners, neededOutputsArr); for (Listener l : activeListeners) { l.iterationDone(this, at, ds, null); } predictions.add(outs); } else { predictions.add(directExecHelper(placeholderMap, at, ds, requiredVars, activeListeners, neededOutputsArr)); } at.setIteration(at.iteration() + 1); } for (Listener l : activeListeners) l.operationEnd(this, at.operation()); return predictions; } /** * Set up for a single batch inference operation using OutputConfig. * Supports the setting of placeholder inputs, outputs, and additional listeners. * Has exec methods to get the single output if only one is requested, or all requested outputs. *

* Also supports using {@link SDVariable} instances instead of variable names. *

* Example: get the value of "out" with placeholders x and y *

     *     {@code
     *     SameDiff sd = ...;
     *     INDArray xValue = ...;
     *     INDArray yValue = ...;
     *     SDVariable y = ...;
     *
     *     INDArray outValue = sd.batchOutput()
     *         .output("out")
     *         .input("x", xValue)
     *         .input(y, yValue)
     *         .outputSingle();
     *     }
     * 
*/ public BatchOutputConfig batchOutput() { return new BatchOutputConfig(this); } /** * Do inference for all variables for a single batch. *

* See {@link #output(Map, List, String...)}. *

* Special case of {@link #batchOutput()}. */ public Map outputAll(Map placeholders) { return batchOutput().outputAll().inputs(placeholders).output(); } /** * Do inference for a single variable for a single batch. *

* See {@link #output(Map, List, String...)}. *

* Special case of {@link #batchOutput()}. */ public INDArray outputSingle(Map placeholders, String output) { return batchOutput().output(output).inputs(placeholders).outputSingle(); } /** * Do inference for the given variables for a single batch. *

* See {@link #output(Map, List, String...)}. *

* Special case of {@link #batchOutput()}. */ public Map output(Map placeholders, @NonNull List outputs) { return batchOutput().output(outputs.toArray(new String[0])).inputs(placeholders).output(); } /** * Do inference for the given variables for a single batch. *

* See {@link #output(Map, List, String...)}. *

* Special case of {@link #batchOutput()}. */ public Map output(Map placeholders, String... outputs) { return batchOutput().output(outputs).inputs(placeholders).output(); } /** * Do inference for the given variables for a single batch. *

* Special case of {@link #batchOutput()}. * * @param placeholders The values to use for placeholders. * @param listeners Additional listeners to use during this operation. * @param outputs The variables to output and return. */ public Map output(Map placeholders, List listeners, String... outputs) { return batchOutputHelper(placeholders, listeners, Operation.INFERENCE, outputs); } protected Map batchOutputHelper(Map placeholders, List listeners, Operation operation, String... outputs) { List activeListeners = new ArrayList<>(); if(operation == null) operation = Operation.INFERENCE; for (Listener l : this.listeners) if (l.isActive(operation)) activeListeners.add(l); if(listeners != null) { for (Listener l : listeners) if (l.isActive(operation)) activeListeners.add(l); } for (Listener l : activeListeners) { l.operationStart(this, operation); } validateListenerActivations(activeListeners, operation); Map ret = directExecHelper(placeholders, At.defaultAt(operation), null, Collections.emptyList(), activeListeners, outputs); for (Listener l : activeListeners) { l.operationEnd(this, operation); } return ret; } /** * Do inference for the given variables for a single batch, with training information */ protected Map directExecHelper(Map placeholders, At at, MultiDataSet batch, Collection requiredActivations, List activeListeners, String... outputs) { if (at == null) at = At.defaultAt(); Preconditions.checkState(outputs != null && outputs.length > 0, "No outputs were specified"); long threadId = Thread.currentThread().getId(); if (!sessions.containsKey(threadId)) { log.info("Creating new InferenceSession for thread {}", threadId); sessions.put(threadId, new InferenceSession(this)); } List phNames = inputs(); if (placeholders == null && phNames != null) { //Maybe user set placeholders before calling exec method? placeholders = placeholdersPerThread.get(Thread.currentThread().getId()); } //Placeholder validation is performed in InferenceSession InferenceSession is = sessions.get(threadId); return is.output(outputs == null ? Collections.emptyList() : Arrays.asList(outputs), placeholders, batch, requiredActivations, activeListeners, at); } /** * See {@link #one(String, DataType, int...)}. * Creates a constant - i.e., CONSTANT type SDVariable. * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). */ public SDVariable one(String name, int... shape) { return one(name, Nd4j.defaultFloatingPointType(), shape); } /** * See {@link #one(String, DataType, long...)}. * Creates a constant - i.e., CONSTANT type SDVariable. * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). */ public SDVariable one(String name, long... shape) { return one(name, Nd4j.defaultFloatingPointType(), shape); } /** * Create a new variable with the specified shape, with all values initialized to 1.0. * Creates a constant - i.e., CONSTANT type SDVariable. * * @param name the name of the variable to create * @param shape the shape of the array to be created * @return the created variable */ public SDVariable one(String name, org.nd4j.linalg.api.buffer.DataType dataType, int... shape) { return one(name, dataType, ArrayUtil.toLongArray(shape)); } /** * Create a new variable with the specified shape, with all values initialized to 1.0. * Creates a constant - i.e., CONSTANT type SDVariable. * * @param name the name of the variable to create * @param shape the shape of the array to be created * @return the created variable */ public SDVariable one(String name, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { return constant(name, Nd4j.ones(dataType, shape)); } /** * See {@link #zero(String, DataType, long...)}. * Creates a constant - i.e., CONSTANT type SDVariable. * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). */ public SDVariable zero(String name, long... shape) { return zero(name, Nd4j.defaultFloatingPointType(), shape); } /** * See {@link #zero(String, DataType, int...)}. * Creates a constant - i.e., CONSTANT type SDVariable. * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). */ public SDVariable zero(String name, int... shape) { return zero(name, Nd4j.defaultFloatingPointType(), shape); } /** * Create a new variable with the specified shape, with all values initialized to 0. * Creates a constant - i.e., CONSTANT type SDVariable. * * @param name the name of the variable to create * @param shape the shape of the array to be created * @return the created variable */ public SDVariable zero(String name, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { return constant(name, Nd4j.zeros(dataType, shape)); } /** * Create a new variable with the specified shape, with all values initialized to 0. * Creates a constant - i.e., CONSTANT type SDVariable. * * @param name the name of the variable to create * @param shape the shape of the array to be created * @return the created variable */ public SDVariable zero(String name, org.nd4j.linalg.api.buffer.DataType dataType, int... shape) { return zero(name, dataType, ArrayUtil.toLongArray(shape)); } /** * Create an SDVariable with a fixed/constant value, with a generated name
* Constants are not modified by training/backprop. See {@link VariableType} for more details. * * @param constant Value for the constant SDVariable * @return The created variable */ public SDVariable constant(@NonNull INDArray constant) { return constant(getNewVarName(), constant); } /** * Create an SDVariable with a fixed/constant value
* Constants are not modified by training/backprop. See {@link VariableType} for more details. * * @param name Name of the constant SDVariable * @param constant Value for the constant SDVariable * @return The created variable */ public SDVariable constant(String name, @NonNull INDArray constant) { Preconditions.checkState(!variables.containsKey(name), "Variable with name \"%s\" already exists", name); if (name == null || name.length() < 1) name = getNewVarName(); if(constant.isView()) { try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()){ constant = constant.dup(); } } SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType()); name = v.name(); variables.put(name, Variable.builder().name(name).variable(v).build()); constantArrays.setArray(name, constant); return v; } /** * Create a a placeholder variable. Placeholders are variables that expect an array to be provided during training * and inference.
* For example, the SDVariables for your input/features and labels should be placeholders.
* See also: {@link VariableType} * * @param name the name of the variable * @param dataType Data type of the new placeholder * @param shape the shape of the variable if any * @return SDVariable placeholder */ public SDVariable placeHolder(@NonNull String name, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { Preconditions.checkState(!variables.containsKey(name), "Variable already exists with name %s", name); SDVariable ret = new SDVariable(name, VariableType.PLACEHOLDER, this, shape, dataType); variables.put(name, Variable.builder().name(name).variable(ret).build()); return ret; } /** * Variable initialization with a specified {@link WeightInitScheme} * This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details. * * @param name the name of the variable * @param shape the shape of the array to be created * @param weightInitScheme the weight initialization scheme * @return the created variable */ public SDVariable var(@NonNull String name, @NonNull WeightInitScheme weightInitScheme, @NonNull org.nd4j.linalg.api.buffer.DataType dataType, @NonNull long... shape) { return var(name, VariableType.VARIABLE, weightInitScheme, dataType, shape); } /** * Variable initialization with a specified {@link WeightInitScheme} * This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details. * * @param name the name of the variable * @param variableType the SameDiff variable type of the variable (e.g. CONSTANT, PLACEHOLDER, etc.) * @param weightInitScheme the weight initialization scheme * @param dataType the data type of the variable (float, int, etc) * @param shape the shape of the array to be created * @return the created variable */ public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { if(shape != null) { for (long l : shape) { Preconditions.checkArgument(l != 0, "Cannot create variable with a shape that contains zeros (empty array shape) - got shape %s", shape); } } if (name == null || name.length() < 1) name = getNewVarName(); else name = generateNewVarName(name, 0); if (variables.containsKey(name)) { if (nameScopes.isEmpty()) { throw new IllegalArgumentException("Another variable with the name " + name + " already exists (current name scope: \"" + currentNameScope() + "\""); } else { throw new IllegalArgumentException("Another variable with the name " + name + " already exists."); } } Preconditions.checkState(variableType != VariableType.VARIABLE || weightInitScheme != null, "A weight initalization scheme must be provided" + " when creating a VARIABLE type SDVariables - variable name: \"%s\"", name); SDVariable ret = new SDVariable(name, variableType, this, shape, dataType); addVariable(ret); if(variableType == VariableType.VARIABLE){ try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { INDArray vArr = weightInitScheme.create(dataType, shape); variablesArrays.setArray(name, vArr); } } return ret; } /** * Creates a {@link SDVariable} with the given shape and name
* The underlying array will be initialized using the specified weight initilization scheme
* This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details. * * @param name the name of the variable * @param shape the shape of the variable * @param weightInitScheme Weight initialization scheme to use to initialize the underlying array * @return the created variable */ public SDVariable var(@NonNull String name, @NonNull LongShapeDescriptor shape, WeightInitScheme weightInitScheme) { return var(name, weightInitScheme, shape.dataType(), shape.getShape()); } /** * Creates a {@link SDVariable} with the given shape and name
* Any array will be generated with all zeros for the values
* This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details. * * @param name the name of the variable * @param shape the shape of the variable * @return the created variable */ public SDVariable var(String name, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { Preconditions.checkNotNull(shape != null, "Invalid shape: shape may not be null"); if (Shape.isPlaceholderShape(shape)) { return placeHolder(name, dataType, shape); } return var(name, new ZeroInitScheme(), dataType, shape); } /** * Creates a {@link SDVariable} with the given shape and name
* Any array will be generated with all zeros for the values
* This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details. * * @param name the name of the variable * @param shapeDesc the shape of the variable * @return the created variable */ public SDVariable var(String name, LongShapeDescriptor shapeDesc) { Preconditions.checkNotNull(shapeDesc != null, "Invalid shape: shape may not be null"); return var(name, shapeDesc, new ZeroInitScheme()); } /** * Creates a {@link SDVariable} with the given shape and name
* Any array will be generated with all zeros for the values. Data type will be given by {@link Nd4j#defaultFloatingPointType()}
* This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details. * * @param name the name of the variable * @param shape the shape of the variable * @return the created variable */ public SDVariable var(String name, int... shape) { return var(name, Nd4j.defaultFloatingPointType(), shape); } /** * Creates a {@link SDVariable} with the given shape and name
* Any array will be generated with all zeros for the values. Data type will be given by {@link Nd4j#defaultFloatingPointType()}
* This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details. * * @param name the name of the variable * @param shape the shape of the variable * @return the created variable */ public SDVariable var(String name, long... shape) { return var(name, Nd4j.defaultFloatingPointType(), shape); } /** * Variable initialization with a specified {@link WeightInitScheme}. Data type will be given by {@link Nd4j#defaultFloatingPointType()}
* This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details. * * @param name the name of the variable * @param shape the shape of the array to be created * @param weightInitScheme the weight initialization scheme * @return the created variable */ public SDVariable var(@NonNull String name, @NonNull WeightInitScheme weightInitScheme, @NonNull long... shape) { return var(name, weightInitScheme, Nd4j.defaultFloatingPointType(), shape); } /** * Creates a {@link SDVariable} with the given shape and name
* Any array will be generated with all zeros for the values
* * @param name the name of the variable * @param shape the shape of the variable * @return the created variable */ public SDVariable var(String name, org.nd4j.linalg.api.buffer.DataType dataType, int... shape) { Preconditions.checkNotNull(shape, "Invalid shape: shape may not be null"); if (Shape.isPlaceholderShape(shape)) { return placeHolder(name, dataType, ArrayUtil.toLongArray(shape)); } return var(name, new ZeroInitScheme(), dataType, ArrayUtil.toLongArray(shape)); } /** * Initialize a {@link SDVariable} reference tying this variable to this samediff instance. *

* {@link NDArraySupplierInitScheme} is used to ensure that if the array is allocated anywhere * and {@link SameDiff} instance to exist as a copy of the variable. * * @param v Variable * @return */ public SDVariable var(@NonNull final SDVariable v) { if (variables.containsKey(v.name()) && variables.get(v.name()).getVariable().getArr() != null) return variables.get(v.name()).getVariable(); if (v.name() == null || v.name().length() < 1) throw new IllegalArgumentException("Name for variable must be defined"); VariableType vt = v.getVariableType(); NDArraySupplierInitScheme s = null; switch (vt) { case VARIABLE: SDVariable r = new SDVariable(v.name(), v.getVariableType(), this, v.getShape(), v.dataType()); addVariable(r); try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ variablesArrays.setArray(v.name(), v.getArr().dup()); } return r; case ARRAY: SDVariable ret = new SDVariable(v.name(), v.getVariableType(), this, v.getShape(), v.dataType()); return addVariable(ret); case CONSTANT: return constant(v.name(), v.getArr()); case PLACEHOLDER: return placeHolder(v.name(), v.dataType(), v.placeholderShape()); default: throw new RuntimeException("Unknown/not supported variable type: " + vt); } } private String getNewVarName() { return generateNewVarName("sd_var", 0, false); } /** * Creates a {@link SDVariable} with the specified shape and a generated name
* Any array will be generated with all zeros for the values
* This method creates a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details. * * @param shape the shape of the variable * @return the created variable */ public SDVariable var(org.nd4j.linalg.api.buffer.DataType dataType, int... shape) { return var(getNewVarName(), dataType, shape); } /** * Creates a {@link SDVariable} with the specified shape and a generated name
* Any array will be generated with all zeros for the values
* This method creates a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details. * * @param shape the shape of the variable * @return the created variable */ public SDVariable var(org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { return var(getNewVarName(), dataType, shape); } /** * Creates a {@link SDVariable} with the specified shape and a generated name. The associated array will * then be generated using the specified weight initialization scheme * * @param weightInitScheme The weight initialization scheme to use when generating an INDArray * @param shape the shape of the variable * @return the created variable */ public SDVariable var(WeightInitScheme weightInitScheme, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { return var(getNewVarName(), weightInitScheme, dataType, shape); } /** * Create an {@link SDVariable} with a generated name, and assocate the specified array with it.
* This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details. * * @param arr Array to associate with the new variable * @return New SDVariable * @see #var(String, INDArray) */ public SDVariable var(INDArray arr) { return var(getNewVarName(), arr); } /** * Create an {@link SDVariable} with the specified name, and associate the specified array with it
* This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details. * * @param arr Array to associate with the new variable * @return New SDVariable with the specified name and array */ public SDVariable var(String name, @NonNull INDArray arr) { if (variables.containsKey(name) && variables.get(name).getVariable().getArr() != null) throw new IllegalArgumentException("Another variable with the name " + name + " already exists."); Preconditions.checkState(arr.dataType().isFPType(), "Cannot create variable with non-floating point type:" + " provided array has datatype %s. Variables must be floating point type to be trainable by backpropagation.\n" + "For non floating point types, these should be created as placeholders or constants instead.", arr.dataType()); Preconditions.checkArgument(!arr.isEmpty(), "Empty arrays cannot be used when creating variables. Array shape: %ndShape", arr); if (name == null || name.length() < 1) name = getNewVarName(); boolean duped = false; if (arr.isAttached()) { arr = arr.detach(); duped = true; } if (!duped) { for (String s : variablesArrays.arrayNames()) { if (variablesArrays.getArray(s) == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour) arr = arr.dup(); break; } } } SDVariable ret = new SDVariable(name, VariableType.VARIABLE, this, arr.shape(), arr.dataType()); associateArrayWithVariable(arr, ret); addVariable(ret); return ret; } /** * Convert the specified variable to a constant. This is equivalent to "freezing" a variable so that it's value * won't be changed by further training.
* This can only be done for variables and placeholders, not ARRAY type variables (which are usually network activations). * As a constant, this variable will no longer be modified by any subsequent training.
* See also: {@link VariableType} * * @param variable Variable to convert to a constant * @return The (now constant) SDVariable */ public SDVariable convertToConstant(@NonNull SDVariable variable) { convertToConstants(Collections.singletonList(variable)); return variable; } /** * Convert all of the specified variables to constants. This is equivalent to "freezing" the variables so that their values * won't be changed by further training.
* This can only be done for variables and placeholders, not ARRAY type variables (which are usually network activations). * As constants, these variables will no longer be modified by any subsequent training.
* See also: {@link VariableType} * * @param variables Variables to convert to constants * @return The (now constant) SDVariables */ public void convertToConstants(List variables) { if (variables.size() == 0) return; boolean allConst = true; for (SDVariable variable : variables) { if (variable.getVariableType() != VariableType.CONSTANT) { allConst = false; Preconditions.checkState(variable.getVariableType() != VariableType.ARRAY, "Cannot convert variable of type ARRAY to a constant: %s", variable); } } if (allConst) { return; //No op } //Remove all sessions in case they have any cached arrays/state sessions.clear(); //If gradient function has been defined, remove it (so it will be recreated later) sameDiffFunctionInstances.remove(GRAD_FN_KEY); for (SDVariable variable : variables) { String n = variable.name(); INDArray arr = variable.getArr(); Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable); constantArrays.setArray(n, arr); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads variablesArrays.removeArray(n); if (!placeholdersPerThread.isEmpty()) { for (Map m : placeholdersPerThread.values()) { m.remove(n); } } variable.setVariableType(VariableType.CONSTANT); } if (trainingConfig != null && initializedTraining) { //Remove updater state for now constant variables for (SDVariable v : variables) { GradientUpdater gu = updaterMap.remove(v.name()); Map m = gu == null ? null : gu.getState(); if (m != null) { for (INDArray arr : m.values()) { if (arr.closeable()) arr.close(); } } //Also check dataset feature/label mapping - remove any placeholders here... if (trainingConfig.getDataSetFeatureMapping() != null && trainingConfig.getDataSetFeatureMapping().contains(v.name())) { List newFM = new ArrayList<>(trainingConfig.getDataSetFeatureMapping()); //New list in case of immutable list newFM.remove(v.name()); trainingConfig.setDataSetFeatureMapping(newFM); } if (trainingConfig.getDataSetLabelMapping() != null && trainingConfig.getDataSetLabelMapping().contains(v.name())) { List newLM = new ArrayList<>(trainingConfig.getDataSetLabelMapping()); newLM.remove(v.name()); trainingConfig.setDataSetLabelMapping(newLM); } if (trainingConfig.getDataSetFeatureMaskMapping() != null && trainingConfig.getDataSetFeatureMaskMapping().contains(v.name())) { List newFMM = new ArrayList<>(trainingConfig.getDataSetFeatureMaskMapping()); newFMM.remove(v.name()); trainingConfig.setDataSetFeatureMaskMapping(newFMM); } if (trainingConfig.getDataSetLabelMaskMapping() != null && trainingConfig.getDataSetLabelMaskMapping().contains(v.name())) { List newLMM = new ArrayList<>(trainingConfig.getDataSetLabelMaskMapping()); newLMM.remove(v.name()); trainingConfig.setDataSetLabelMaskMapping(newLMM); } } } } /** * Convert the specified variable to a VARIABLE type SDVariable.
* This can only be done for constants and placeholders, not ARRAY type variables (which are usually network activations). * As a variable, this variable will modified during any subsequent training.
* See also: {@link VariableType} * * @return This variable (now a variable type SDVariable) */ public SDVariable convertToVariable(@NonNull SDVariable constant) { Preconditions.checkState(constant.dataType().isFPType(), "Only floating point SDVariables can be converted to variables," + " datatype of %s is %s", constant.name(), constant.dataType()); convertToVariables(Collections.singletonList(constant)); return constant; } /** * Convert the specified variables to VARIABLE type SDVariables.
* This can only be done for constants and placeholders, not ARRAY type variables (which are usually network activations). * As variables, this variable will modified during any subsequent training.
* See also: {@link VariableType} */ public void convertToVariables(@NonNull List constants) { if (constants.size() == 0) return; boolean allConst = true; for (SDVariable variable : constants) { if (variable.getVariableType() != VariableType.VARIABLE) { allConst = false; } Preconditions.checkState(variable.getVariableType() != VariableType.ARRAY, "Cannot convert variable of type ARRAY to a variable: %s", variable); } if (allConst) { return; //No op } //Remove all sessions in case they have any cached arrays/state sessions.clear(); //If gradient function has been defined, remove it (so it will be recreated later) sameDiffFunctionInstances.remove(GRAD_FN_KEY); for (SDVariable variable : constants) { String n = variable.name(); INDArray arr = variable.getArr(); Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable); variablesArrays.setArray(n, arr); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads constantArrays.removeArray(n); if (!placeholdersPerThread.isEmpty()) { for (Map m : placeholdersPerThread.values()) { m.remove(n); } } variable.setVariableType(VariableType.VARIABLE); } //For training: need to add new updater state if (trainingConfig != null && initializedTraining) { //Add updater state for this variable: updaterState, updaterViews, updaterMap for (SDVariable v : constants) { if (!updaterMap.containsKey(v.name())) { //Create new updater state INDArray arr = v.getArr(); long thisSize = trainingConfig.getUpdater().stateSize(arr.length()); if (thisSize > 0) { INDArray stateArr = Nd4j.create(arr.dataType(), 1, thisSize); GradientUpdater u = trainingConfig.getUpdater().instantiate(stateArr, false); u.setStateViewArray(stateArr, arr.shape(), arr.ordering(), true); //TODO eventually this should be 1 call... updaterMap.put(v.name(), u); } else { GradientUpdater u = trainingConfig.getUpdater().instantiate((INDArray) null, true); updaterMap.put(v.name(), u); } } } } } /** * Convert the datatypes of the specified constants, placeholders and variables.
* After conversion, the downstream datatypes are changed. * For example, {@code z(float) = x(float)+y(float)}, changing both x and y to double results in {@code z(double) = x(double)+y(double)} * without doing anything to change z's datatype directly (z datatype is inferred from x + y + add op).
* ARRAY type SDVariables cannot be converted directly, as their datatypes are determined by the function + * input datatypes. * Note that this method should be used with caution: incorrect datatype modifications may leave your network * in an incorrect state. For example, {@code op(x(float),y(float)) -> op(x(double),y(float))} may not be * supported by all ops. * * @param dataTypeMap Map of SDVariables to change the datatype for. Key = SDVariable name, Value = new datatype */ public void convertDataTypes(@NonNull Map dataTypeMap) { if (dataTypeMap.isEmpty()) return; //First: check these are all either constants, variables or placeholders. for (Map.Entry e : dataTypeMap.entrySet()) { String s = e.getKey(); Preconditions.checkState(variables.containsKey(s), "Cannot change datatype of variable \"%s\": No variable with this name exists", s); SDVariable v = variables.get(s).getVariable(); Preconditions.checkState(v.getVariableType() != VariableType.ARRAY, "Cannot change datatype of ARRAY type variable \"%s\": " + "datatype of ARRAY type variables is determined by the datatypes of their inputs plus corresponding "); if (v.getVariableType() != VariableType.PLACEHOLDER) { //Can't convert constant or variable between numerical and non-numerical type (not possible to cast) Preconditions.checkState(v.dataType().isNumerical() == e.getValue().isNumerical(), "Cannot convert variables between numerical " + "and non-numerical types: attempting to convert variable \"%s\" from %s to %s", e.getKey(), v.dataType(), e.getValue()); } } boolean anyChanged = false; for (Map.Entry e : dataTypeMap.entrySet()) { String s = e.getKey(); DataType d = e.getValue(); SDVariable v = variables.get(s).getVariable(); if (v.dataType() == d) continue; //No-op v.setDataType(d); switch (v.getVariableType()) { case VARIABLE: INDArray arr = variablesArrays.removeArray(e.getKey()); INDArray newArr = arr.castTo(d); variablesArrays.setArray(e.getKey(), newArr); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads break; case CONSTANT: INDArray arr2 = constantArrays.removeArray(e.getKey()); INDArray newArr2 = arr2.castTo(d); constantArrays.setArray(e.getKey(), newArr2); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads break; case PLACEHOLDER: Map m = placeholdersPerThread.get(Thread.currentThread().getId()); if (m != null && m.containsKey(e.getKey())) { m.put(e.getKey(), m.get(e.getKey()).castTo(d)); } break; case ARRAY: default: throw new IllegalStateException("Cannot convert array type variable"); //Should never happen } anyChanged = true; } if (anyChanged) { sessions.clear(); //Recalculate datatypes of outputs, and dynamically update them Set allSeenOps = new HashSet<>(); Queue queueOps = new LinkedList<>(); for(String s : dataTypeMap.keySet()){ Variable v = variables.get(s); v.getVariable().setDataType(dataTypeMap.get(s)); List inToOp = v.getInputsForOp(); if(inToOp != null){ for(String op : inToOp) { if (!allSeenOps.contains(op)) { allSeenOps.add(op); queueOps.add(op); } } } } while(!queueOps.isEmpty()){ String op = queueOps.remove(); SameDiffOp o = ops.get(op); List inVars = o.getInputsToOp(); List inDTypes = new ArrayList<>(); if(inVars != null) { for (String s : inVars) { SDVariable v = variables.get(s).getVariable(); inDTypes.add(v.dataType()); } } List outDtypes = o.getOp().calculateOutputDataTypes(inDTypes); List outVars = o.getOutputsOfOp(); for( int i=0; i newInputs = new ArrayList<>(op.getInputsToOp()); while (newInputs.contains(from)) { newInputs.set(newInputs.indexOf(from), to); } op.setInputsToOp(newInputs); } } if (v.getControlDepsForOp() != null) { for (String opName : v.getControlDepsForOp()) { SameDiffOp op = ops.get(opName); List newCDs = new ArrayList<>(op.getControlDeps()); while (newCDs.contains(from)) { newCDs.set(newCDs.indexOf(from), to); } op.setControlDeps(newCDs); } } if (v.getControlDepsForVar() != null) { for (String varName : v.getControlDepsForVar()) { Variable var = variables.get(varName); List newCDs = new ArrayList<>(var.getControlDeps()); while (newCDs.contains(from)) { newCDs.set(newCDs.indexOf(from), to); } var.setControlDeps(newCDs); } } if (v.getControlDeps() != null) { for (String varName : v.getControlDeps()) { Variable var = variables.get(varName); List newCDsFor = new ArrayList<>(var.getControlDepsForVar()); while (newCDsFor.contains(from)) { newCDsFor.set(newCDsFor.indexOf(from), to); } var.setControlDepsForVar(newCDsFor); } } if (v.getOutputOfOp() != null) { SameDiffOp op = ops.get(v.getOutputOfOp()); List newOuts = new ArrayList<>(op.getOutputsOfOp()); while (newOuts.contains(from)) { newOuts.set(newOuts.indexOf(from), to); } op.setOutputsOfOp(newOuts); } variables.remove(from); variables.put(to, v); if(v.getVariable().getVariableType() == VariableType.CONSTANT && constantArrays.hasArray(from)) { constantArrays.rename(from, to); } if(v.getVariable().getVariableType() == VariableType.VARIABLE && variablesArrays.hasArray(from)) { variablesArrays.rename(from, to); } if(v.getVariable().getVariableType() == VariableType.PLACEHOLDER) { for(Map e : placeholdersPerThread.values()){ //Not really thread safe - but renaming variables during execution in other threads can never be thread safe :) if(e != null && e.containsKey(from)){ INDArray arr = e.remove(from); e.put(to, arr); } } } if (trainingConfig != null) { if (trainingConfig.getDataSetFeatureMapping() != null && trainingConfig.getDataSetFeatureMapping().contains(from)) { List l = new ArrayList<>(trainingConfig.getDataSetFeatureMapping()); while (l.contains(from)) { l.set(l.indexOf(from), to); } trainingConfig.setDataSetFeatureMapping(l); } if (trainingConfig.getDataSetLabelMapping() != null && trainingConfig.getDataSetLabelMapping().contains(from)) { List l = new ArrayList<>(trainingConfig.getDataSetLabelMapping()); while (l.contains(from)) { l.set(l.indexOf(from), to); } trainingConfig.setDataSetLabelMapping(l); } if (trainingConfig.getDataSetFeatureMaskMapping() != null && trainingConfig.getDataSetFeatureMaskMapping().contains(from)) { List l = new ArrayList<>(trainingConfig.getDataSetFeatureMaskMapping()); while (l.contains(from)) { l.set(l.indexOf(from), to); } trainingConfig.setDataSetFeatureMaskMapping(l); } if (trainingConfig.getDataSetLabelMaskMapping() != null && trainingConfig.getDataSetLabelMaskMapping().contains(from)) { List l = new ArrayList<>(trainingConfig.getDataSetLabelMaskMapping()); while (l.contains(from)) { l.set(l.indexOf(from), to); } trainingConfig.setDataSetLabelMaskMapping(l); } if (trainingConfig.getLossVariables() != null && trainingConfig.getLossVariables().contains(from)) { List l = new ArrayList<>(trainingConfig.getLossVariables()); while (l.contains(from)) { l.set(l.indexOf(from), to); } trainingConfig.setLossVariables(l); } } for (SameDiff sd : sameDiffFunctionInstances.values()) { if (sd.hasVariable(from)) { sd.renameVariable(from, to); } } //Check losses: if(lossVariables.contains(from)) { int idx = lossVariables.indexOf(from); lossVariables.set(idx, to); } } /** * Rename the specified variable to the new name. * * @param from The variable to rename - this variable must exist * @param to The new name for the variable - no variable with this name must already exist */ public void renameVariable(String from, String to) { SameDiffOp op = ops.get(stripVarSuffix(from)); renameVariable(op,from,to); } /** * Remove an argument for a function. Note that if this function does not contain the argument, it will just be a no op. * * @param varName the variable name to remove * @param function the function to remove the argument from */ public void removeArgFromOp(String varName, DifferentialFunction function) { val args = function.args(); for (int i = 0; i < args.length; i++) { if (args[i].name().equals(varName)) { /** * Since we are removing the variable reference * from the arguments we need to update both * the reverse and forward arguments. */ List reverseArgs = ops.get(function.getOwnName()).getInputsToOp(); val newArgs = new ArrayList(args.length - 1); for (int arg = 0; arg < args.length; arg++) { if (!reverseArgs.get(arg).equals(varName)) { newArgs.add(reverseArgs.get(arg)); } } ops.get(function.getOwnName()).setInputsToOp(newArgs); break; } } variables.get(varName).getInputsForOp().remove(function.getOwnName()); } /** * Get the variable based on the opName * * @param name the opName of the variable * @return the variabel instance if there is one */ public SDVariable getVariable(String name) { Variable v = variables.get(name); return v == null ? null : v.getVariable(); } public boolean hasVariable(String name) { return variables.containsKey(name); } /** * Get the gradient for the variable with the specified name.
* The gradient variable is the variable that represents the derivative of the loss function with respect * to the output of this variable. I.e., if this variable is X and loss function is L, then gradient() returns the * variable representing dL/dX
* Note that only floating point variables can have gradients.
* Note also that a gradient may not yet be defined, and/or if no loss function variables have been set.
* You can set the loss function variables using {@link SameDiff#setLossVariables(String...)} and then create the * gradient functions using {@link SameDiff#createGradFunction()}. Alternatively, the gradient function will be * created automatically when training is performed. * * @param varName the vertex id * @return the gradient for this variable or null */ public SDVariable getGradForVariable(String varName) { Preconditions.checkState(variables.containsKey(varName), "No variable with name \"%s\" exists", varName); SDVariable v = getVariable(varName); Preconditions.checkState(v.dataType().isFPType(), "Cannot get gradient of %s variable \"%s\": only floating" + " point variables have gradients", varName, v.dataType()); //Gradients are being placed in the inner "grad" function SameDiff instance, but not the outer one if (variables.containsKey(varName) && variables.get(varName).getGradient() != null) { return variables.get(varName).getGradient(); } else if (sameDiffFunctionInstances.containsKey(GRAD_FN_KEY) && sameDiffFunctionInstances.get(GRAD_FN_KEY).variables.containsKey(varName)) { return sameDiffFunctionInstances.get(GRAD_FN_KEY).variables.get(varName).getGradient(); } return null; } /** * Determine if the specified variable has a gradient with respect to the current loss. Note that: * (a) Non-floating-point variables (integer, string, etc) will never have gradients
* (b) This method will return false if no gradient function has been created yet. See {@link SameDiff#createGradFunction()} * and {@link SameDiff#setLossVariables(String...)}
* (c) Floating point variables may not have any gradient if the specified loss variables does not depend on the * specified variable at all. In this case, "no gradient" for floating point is equivalent to "always 0"
* * @param varName Name of the variable to check the existence of a gradient variable for * @return True if a gradient variable exists for the specified variable, for the current loss */ public boolean variableHasGradient(String varName) { Preconditions.checkState(variables.containsKey(varName), "No variable with name \"%s\" exists", varName); SDVariable v = getVariable(varName); if (!v.dataType().isFPType() || v.isConstant()) return false; return getGradForVariable(varName) != null; } /** * Assign a SDVariable to represent the gradient of the SDVariable with the specified name * * @param variableName the variable name to assign the gradient variable for * @param variable the gradient variable */ public void setGradientForVariableName(String variableName, SDVariable variable) { Preconditions.checkState(variables.containsKey(variableName), "No variable exists with name \"%s\"", variableName); if (variable == null) { throw new ND4JIllegalStateException("Unable to set null gradient for variable name " + variableName); } variables.get(variableName).setGradient(variable); } /** * Get the gradient for the variable with the specified variable name. * Note that in order to run this function, {@link #execBackwards(Map, Operation, MultiDataSet, Collection, List)} must be executed first. * All gradient functions are obtained from the results of the execBackwards call. * * @param varName the variable name to get the gradient variable for. * @return The gradient variable for the specified variable */ public SDVariable grad(String varName) { if (!sameDiffFunctionInstances.containsKey(GRAD_FN_KEY)) { createGradFunction(); } SameDiff grad = getFunction(GRAD_FN_KEY); SDVariable var = grad.getVariable(varName); return getFunction(GRAD_FN_KEY).getGradForVariable(var.name()); } /** * Create a new double scalar (rank 0) SDVariable with the specified value * * @param name Name of the SDVariable * @param value Value to initialize the variable with * @return SDVariable */ public SDVariable scalar(String name, double value) { try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { return var(name, Nd4j.scalar(value)); } } /** * Create a new float scalar (rank 0) SDVariable with the specified value * * @param name Name of the SDVariable * @param value Value to initialize the variable with * @return SDVariable */ public SDVariable scalar(String name, float value) { try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { return var(name, Nd4j.scalar(value)); } } /** * Create a new integer scalar (rank 0) SDVariable with the specified value * * @param name Name of the SDVariable * @param value Value to initialize the variable with * @return SDVariable */ public SDVariable scalar(String name, int value) { try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { return var(name, Nd4j.scalar(value)); } } /** * Create a new long scalar (rank 0) SDVariable with the specified value * * @param name Name of the SDVariable * @param value Value to initialize the variable with * @return SDVariable */ public SDVariable scalar(String name, long value) { try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { return var(name, Nd4j.scalar(value)); } } /** * Create a new scalar (rank 0) SDVariable with the specified value and datatype * * @param name Name of the SDVariable * @param dataType Data type of the scalar * @param value Value to initialize the variable with * @return SDVariable */ public SDVariable scalar(String name, DataType dataType, Number value) { try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { return var(name, Nd4j.scalar(dataType, value)); } } /** * Create a new double scalar constant (rank 0) with the specified value.
* Constants are not modified by training/backprop. See {@link VariableType} for more details. * * @param value Value to initialize the constant with * @return SDVariable */ public SDVariable constant(double value) { return constant(null, value); } /** * Create a new double scalar constant (rank 0) with the specified value * * @param name Name of the SDVariable * @param value Value to initialize the constant with * @return SDVariable */ public SDVariable constant(String name, double value) { try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { return constant(name, Nd4j.scalar(value)); } } /** * Create a new float scalar constant (rank 0) with the specified value
* Constants are not modified by training/backprop. See {@link VariableType} for more details. * * @param value Value to initialize the constant with * @return SDVariable */ public SDVariable constant(float value) { return constant(null, value); } /** * Create a new float scalar constant (rank 0) with the specified value * * @param name Name of the SDVariable * @param value Value to initialize the constant with * @return SDVariable */ public SDVariable constant(String name, float value) { try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { return constant(name, Nd4j.scalar(value)); } } /** * Create a new integer scalar constant (rank 0) with the specified value * * @param value Value to initialize the constant with */ public SDVariable constant(int value) { return constant(null, value); } /** * Create a new integer scalar constant (rank 0) with the specified value * * @param name Name of the SDVariable * @param value Value to initialize the constant with * @return SDVariable */ public SDVariable constant(String name, int value) { try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { return constant(name, Nd4j.scalar(value)); } } /** * Create a new long scalar constant (rank 0) with the specified value * * @param value Value to initialize the constant with */ public SDVariable constant(long value) { return constant(null, value); } /** * Create a new long scalar constant (rank 0) with the specified value * * @param name Name of the SDVariable * @param value Value to initialize the constant with */ public SDVariable constant(String name, long value) { try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { return constant(name, Nd4j.scalar(value)); } } /** * Create a new scalar constant (rank 0) with the specified value and datatype * * @param name Name of the SDVariable * @param dataType Data type of the scalar constant * @param value Value to initialize the constant with */ public SDVariable constant(String name, DataType dataType, Number value) { try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { return constant(name, Nd4j.scalar(dataType, value)); } } /** * Add the specified variable to this SameDiff instance * * @param variable Variable to add */ public SDVariable addVariable(SDVariable variable) { Preconditions.checkState(variable.getSameDiff() == this, "Samediff instance must be the same."); if (variables.containsKey(variable.name()) && !variables.get(variable.name()).getVariable().equals(variable)) { throw new IllegalArgumentException("Variable with name \"" + variable.name() + "\" already exists"); } Preconditions.checkState(variable.getSameDiff() == this, "Same diff instance for variable must be the same!"); variables.put(variable.name(), Variable.builder().name(variable.name()).variable(variable).build()); return variable; } /** * Generate the variables based on the given input op and return the output variable names. * * @param function the function to generate the output * variable names for * @return the set of names generated for each output of the function. */ public SDVariable[] generateOutputVariableForOp(DifferentialFunction function, String baseName, boolean isImport) { if (baseName == null) baseName = function.getOwnName(); if (baseName == null) baseName = function.opName(); //First: calculate output data types. We can always calculate output data types, even if the input arrays //are not available - *except for sometimes during import, until all ops/variables have been added* List outputDataTypes = null; if (!isImport) { List inputDataTypes = new ArrayList<>(); List fnInputs = ops.get(function.getOwnName()).getInputsToOp(); if (fnInputs != null) { for (String var : fnInputs) { inputDataTypes.add(variables.get(var).getVariable().dataType()); } } outputDataTypes = function.calculateOutputDataTypes(inputDataTypes); } //Determine number of output variables if (function instanceof CustomOp) { CustomOp customOp = (CustomOp) function; int num_outputs = function.getNumOutputs(); //Use this in preference - if set. Descriptor might specify 2, but it can sometimes be 2+ if (num_outputs <= 0) { val descriptor = customOp.getDescriptor(); if (descriptor != null) { num_outputs = descriptor.getNumOutputs(); } if (num_outputs <= 0) { throw new ND4UnresolvedOutputVariables("Could not determine number of output variables for op " + function.getOwnName() + " - " + function.getClass().getSimpleName() + ". Ops can override" + " getNumOutputs() to specify number of outputs if required"); } } SDVariable[] ret = new SDVariable[num_outputs]; //Infer the output types: we can always determine datatype but not always shapes if(isImport || (outputDataTypes != null && outputDataTypes.size() == num_outputs)) log.trace( "Incorrect number of output datatypes: got %s but expected datatypes for %s outputs - %s (op: %s), could be due to variable input types.", (outputDataTypes == null ? null : outputDataTypes.size()), num_outputs, outputDataTypes, function.getClass().getSimpleName()); //dynamic shapes //When importing from TF: convention is "unstack", "unstack:1", "unstack:2", ... for (int i = 0; i < ret.length; i++) { SDVariable var = (i == 0 ? getVariable(baseName) : getVariable(baseName + ":" + i)); if (var == null) { //Generate new variable name if one with the specified name doesn't exist //Note: output of an op is ARRAY type - activations, not a trainable parameter. Thus has no weight init scheme org.nd4j.linalg.api.buffer.DataType dataType = isImport ? null : outputDataTypes.get(i); var = var(generateNewVarName(baseName, i), VariableType.ARRAY, null, dataType, (long[]) null); } var.setCreator(function); ret[i] = var; } //Update the internal state: outgoing variables for function if (getOutputsForOp(function) == null) addOutgoingFor(ret, function); return ret; } //this is for unresolved shapes, we know xyz is always 1 output else if (function instanceof BaseOp) { SDVariable[] ret = new SDVariable[1]; SDVariable checkGet = getVariable(baseName); SDVariable[] args = function.args(); if (checkGet == null) { //Note: output of an op is ARRAY type - activations, not a trainable parameter. Thus has no weight init scheme org.nd4j.linalg.api.buffer.DataType dataType = outputDataTypes.get(0); checkGet = var(baseName, VariableType.ARRAY, null, dataType, (long[]) null); } if (checkGet == null) { //Note: output of an op is ARRAY type - activations, not a trainable parameter. Thus has no weight init scheme org.nd4j.linalg.api.buffer.DataType dataType = outputDataTypes.get(0); checkGet = var(baseName, VariableType.ARRAY, null, dataType, (long[]) null); } checkGet.setCreator(function); ret[0] = checkGet; //Update the internal state: outgoing variables for function if (getOutputsForOp(function) == null) addOutgoingFor(ret, function); return ret; } else { throw new RuntimeException("Unknown op type: " + function.getClass()); } } /** * Generate the variables based on the given input op * and return the output variable names. * * @param function the function to generate the output * variable names for * @return the set of names generated for each output of the function. */ public SDVariable[] generateOutputVariableForOp(DifferentialFunction function) { return generateOutputVariableForOp(function, function.getOwnName() != null ? function.getOwnName() : function.opName(), false); } /** * Get a SameDiff function instance given the name of the function * * @param functionName the name of the function * @return the same diff function instance defined for the given name */ public SameDiff getFunction(String functionName) { return sameDiffFunctionInstances.get(functionName); } /** * Create a new TensorArray. */ public TensorArray tensorArray(DataType dataType) { TensorArray ta = new TensorArray(this, dataType); SDVariable[] outVars = ta.outputVariables(); return ta; } /** * @param functionName * @param with */ public SDVariable invokeFunctionOn(String functionName, SameDiff with) { SameDiff instance = sameDiffFunctionInstances.get(functionName); SDVariable ret = instance.invokeGraphOn(with); return ret; } /** * @param function */ public SameDiff defineFunction(String function, SameDiffFunctionDefinition functionDefinition, SDVariable[] variables) { if (!sameDiffFunctionInstances.containsKey(function)) { SameDiff sub = SameDiff.create(); this.child = sub; sub.parent = this; //setup subgraph //re execute to populate subgraph SDVariable[] ret = new SDVariable[variables.length]; for (int i = 0; i < ret.length; i++) { ret[i] = sub.var(variables[i]); } functionDefinition.define(sub, null, ret); sameDiffFunctionInstances.put(function, sub); } this.child = null; return sameDiffFunctionInstances.get(function); } /** * @param function */ public void defineFunction(String function, SameDiffFunctionDefinition functionDefinition) { defineFunction(function, functionDefinition, new LinkedHashMap()); } /** * @param function * @param functionDefinition * @param inputs */ public void defineFunction(String function, SameDiffFunctionDefinition functionDefinition, Map inputs) { if (!sameDiffFunctionInstances.containsKey(function)) { SameDiff sub = SameDiff.create(); //setup subgraph //re execute to populate subgraph functionDefinition.define(sub, inputs, null); sameDiffFunctionInstances.put(function, sub); } } /** * See {@link #calculateGradients(Map, Collection)} */ public Map calculateGradients(Map placeholderVals, @NonNull String... variables) { Preconditions.checkArgument(variables.length > 0, "No variables were specified"); return calculateGradients(placeholderVals, Arrays.asList(variables)); } /** * Calculate and return the gradients for the specified variables * * @param placeholderVals Placeholders. May be null * @param variables Names of the variables that you want the gradient arrays for * @return Gradients as a map, keyed by the variable name */ public Map calculateGradients(Map placeholderVals, @NonNull Collection variables) { Preconditions.checkArgument(!variables.isEmpty(), "No variables were specified"); OutAndGrad oag = calculateGradientsAndOutputs(placeholderVals, null, variables); return oag.getGradients(); } /** * Calculate the activations and the gradients for the specified variables, in one execution call. * This is equivalent to calling {@link #output(Map, List)} and {@link #calculateGradients(Map, Collection)}, but * is more efficient than calling both separately. * * @param placeholderVals Placeholders. May be null * @param outputVars Names of the variables that you want the activations/outputs for. May be null * @param gradientVars Names of the variables that you want the gradient arrays for. May be null * @return Activations and gradients, keyed by variable name */ public OutAndGrad calculateGradientsAndOutputs(Map placeholderVals, Collection outputVars, Collection gradientVars){ Preconditions.checkArgument((outputVars != null && !outputVars.isEmpty()) || (gradientVars != null && !gradientVars.isEmpty()), "No variables were specified for either output or gradients"); if (getFunction(GRAD_FN_KEY) == null) { createGradFunction(); } List varNames = new ArrayList<>(); if(outputVars != null){ varNames.addAll(outputVars); } if(gradientVars != null) { for (String s : gradientVars) { Preconditions.checkState(this.variables.containsKey(s), "No variable with name \"%s\" exists in the SameDiff instance", s); SDVariable v = getVariable(s).getGradient(); if (v != null) { //In a few cases (like loss not depending on trainable parameters) we won't have gradient array for parameter variable varNames.add(v.name()); } } } //Key is gradient variable name SameDiff gradFn = getFunction(GRAD_FN_KEY); gradFn.setListeners(listeners); Map grads = gradFn.batchOutputHelper(placeholderVals, null, Operation.TRAINING, varNames.toArray(new String[0])); Map outOutputs = outputVars == null ? null : new HashMap(); Map outGrads = gradientVars == null ? null : new HashMap(); if(outputVars != null){ for(String s : outputVars){ outOutputs.put(s, grads.get(s)); } } if(gradientVars != null) { for (String s : gradientVars) { if (getVariable(s).getGradient() != null) { String gradVar = getVariable(s).getGradient().name(); outGrads.put(s, grads.get(gradVar)); } } } return new OutAndGrad(outOutputs, outGrads); } /** * Returns true if the gradient function has been created - i.e., {@link #createGradFunction()} or {@link #createGradFunction(String...)} * has been called at all * * @return True if gradient (backprop) function exists */ public boolean hasGradientFunction() { return sameDiffFunctionInstances.containsKey(GRAD_FN_KEY); } /** * Create the gradient function (for calculating gradients via {@link #calculateGradients(Map, Collection)}) if it is not already defined. * Users do not usually need to call this function manually, as it is called as required in the aforementioned method. *

* If the gradient function already exists, this method is a no-op.
* After this method returns, the SameDiff function instance for the gradient can be accessed using {@link #getFunction(String)} * with name "grad" as the argument.
* Note that the gradient array (after execBackwards has been called) can be accessed via {@code SDVariable.gradient().getArr()} */ public void createGradFunction() { createGradFunction((String[]) null); } /** * As per {@link #createGradFunction()}, but this method allows a set of variables requiring gradients to be specified. * By default, only parameter gradients will be calculated; placeholder gradients may not be defined (unless they happen * to be calculated in the same op as calculating a parameter gradient. * This method allows you to override this behaviour by passing the name of the placeholder you want the gradients for. * The specified gradient variables still need to be floating point variables. * * @param variablesRequiringGradients May be null. If non-null: the gradients for the variables with these names will * be calculated and available after backprop has been done */ public void createGradFunction(final String... variablesRequiringGradients) { if (lossVariables.isEmpty()) { if (trainingConfig != null && trainingConfig.getLossVariables() != null && !trainingConfig.getLossVariables().isEmpty()) { lossVariables.addAll(trainingConfig.getLossVariables()); } else { List lossInferred = bestGuessLossVariables(); if (lossInferred.size() == 1) { String outName = lossInferred.get(0); String opName = variables.get(outName).getOutputOfOp(); if (opName == null || !(ops.get(opName).getOp() instanceof ExternalErrorsFunction)) { log.info("Inferring output \"{}\" as loss variable as none were previously set." + "Use SameDiff.setLossVariables() or SDVariable.markAsLoss() to override", lossInferred.get(0)); } lossVariables.add(lossInferred.get(0)); } else if(lossInferred.isEmpty()){ //Check for external errors function for(SameDiffOp o : ops.values()){ if(o.getOp() instanceof ExternalErrorsFunction){ List l = o.getOutputsOfOp(); lossVariables.add(l.get(0)); } } } } } Preconditions.checkState(!lossVariables.isEmpty(), "Cannot create gradient function: " + "No loss variables (variables to minimize) have been specified. Loss variables are the variables that" + " represent the loss/cost/score to be minimized during training, and that all gradients are calculated with respect to.\n" + " Losses can be specified either in TrainingConfiguration (Builder.minimize(...)) or via SameDiff.setLossVariables()/addLossVariable()"); if (log.isTraceEnabled()) { log.trace("Defining function \"grad\""); } if (variablesRequiringGradients != null && variablesRequiringGradients.length > 0) { //Check that they are FP variables... for (String s : variablesRequiringGradients) { Preconditions.checkArgument(variables.containsKey(s), "Cannot ensure gradient exists for variable: no variable with name \"%s\" exists", s); DataType dt = variables.get(s).getVariable().dataType(); Preconditions.checkState(dt.isFPType(), "Cannot ensure gradient exists for variable \"%s\": variable is not a floating point SDVariable." + " Only floating point SDVariables have gradients defined - variable has type %s", s, dt); } } /* Defining gradient function: Starting point: (a) Set of loss function variables - i.e., one or more variables representing loss to be minimized (b) Set of floating point variables we want to train (Variable type SDVariables only - not constants, arrays, placeholders) Observation: A trainable parameter only has a gradient defined if there is a floating point path between the variable and the loss. for example: X(fp) -> cast(int) -> cast(fp) -> loss - X has no gradient defined Algorithm for backprop: Step 1: Determine if variable requires a gradient (is trainable) How? Walk backward on op graph starting at loss variable(s), along FP variables only. Collect FP variables in set as we go. This gives us a subgraph "connected to loss by FP path" - gradient for FP variable is defined only if it's in that set/subgraph. Step 2: Determine minimal set of variables (including array type SDVariables - i.e., activations) we need gradients for Consider following graph: X(fp) -> cast(int) -> cast(fp) -> lots of FP ops -> loss unless we need them for other variables, there's zero point calculating the activation gradients for the "cast(fp) -> lots of FP ops" part of the graph, as the gradient from that branch won't go anywhere. How to determine minimal subset? Start with FP graph from step 1... then keep pruning leaves until the only remaining leaves are those FP variables that we need gradients for. Note that the user can also specify variables that they need gradients for (like placeholders) that normally wouldn't get gradients. Step 3: Differentiate ops in minimal subgraph The only major issue here is with multiple output ops, where only one of the outputs lead to the loss. For example, X -> slice -> (A,B); B -> loss, with A being unused (in that it doesn't contribute to the loss function) But to do split op backprop, we need gradient variables/arrays for both outputs (A and B). We know the shape and type of dL/dA must be exactly the same as A; we also know that the loss function doesn't depend on A. Hence, dL/dA == zerosLike(A) */ final SameDiff outer = this; defineFunction(GRAD_FN_KEY, new SameDiffFunctionDefinition() { @Override public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false); //Training isn't thread safe, no need to use DeviceLocal, even with lazy init //Propagate graph to this samediff instance which will also contain the backward if (SameDiff.this.debugMode) { sameDiff.enableDebugMode(); } outer.invokeGraphOn(sameDiff); if (debugMode) { //Expect incoming args and outgoing args to be the same Preconditions.checkState(sameDiff.ops.keySet().equals(ops.keySet()), "ops keysets not equal"); } List allFunctions = new ArrayList<>(sameDiff.ops.values()); if (allFunctions.isEmpty()) { throw new ND4JIllegalStateException("No ops found!"); } for (SameDiffOp op : allFunctions) { DifferentialFunction func = op.getOp(); val args = func.args(); for (val arg : args) arg.setSameDiff(sameDiff); val outputs = func.outputVariables(); for (val output : outputs) output.setSameDiff(sameDiff); func.setSameDiff(sameDiff); } List finalOutputs = new ArrayList<>(lossVariables.size()); SDVariable initialGrad = sameDiff.var("one-var", Nd4j.scalar(1.0f)); for (String s : lossVariables) { Preconditions.checkNotNull(s, "Encountered null value in loss variables. Null loss variables are not allowed." + " Use SameDiff.setLossVariables with non-null array names to fix"); Preconditions.checkState(variables.containsKey(s), "Specified loss function variable \"%s\" does not exist", s); SDVariable v = variables.get(s).getVariable(); Preconditions.checkState(v.dataType().isFPType(), "Specified loss function variable \"%s\" is not a floating" + "point variable (datatype: %s). Only floating point variables may be used as loss function variable", s, v.dataType()); v = v.sum(); //If output is not a scalar: we'll use loss = v.sum(), same as adding loss for multiple outputs. We don't always know for sure if output is scalar at this point if (v.dataType() == initialGrad.dataType()) { sameDiff.setGradientForVariableName(v.name(), initialGrad); } else { sameDiff.setGradientForVariableName(v.name(), initialGrad.castTo(v.dataType())); } if (finalOutputs.contains(v)) { log.warn("Loss function variable \"{}\" appears multiple times in list of loss variables - using only first instance", s); } else { finalOutputs.add(v); } } if (log.isTraceEnabled()) { String[] initialOutputsStr = allFunctions.get(allFunctions.size() - 1).getOp().outputVariablesNames(); String s = initialOutputsStr == null ? "null" : Arrays.toString(initialOutputsStr); log.trace("Defining backward function: initial outputs {}", s); } //----- Step 1: Determine FP variables connected to loss ----- // Find all FP variables that are connected to loss by an floating point (FP16/32/64) path Set allFpVarsConnectedToLoss = new HashSet<>(); Queue toProcess = new LinkedList<>(); for (String s : lossVariables) { if (!toProcess.contains(s)) { toProcess.add(s); } } while (!toProcess.isEmpty()) { String next = toProcess.remove(); if (!allFpVarsConnectedToLoss.contains(next)) { Variable v = variables.get(next); if (v.getVariable().dataType().isFPType()) { allFpVarsConnectedToLoss.add(v.getName()); //Work out what op (if any) this is an output of... and add the inputs to that op to be processed if (v.getOutputOfOp() != null) { String opName = v.getOutputOfOp(); SameDiffOp op = ops.get(opName); List opInputs = op.getInputsToOp(); if (opInputs != null) { for (String s : opInputs) { Variable inputVar = variables.get(s); if (inputVar.getVariable().dataType().isFPType()) { //Add this connected floating point type to the list to be processed toProcess.add(s); } } } } } } } //----- Step 2: Determine minimal set of FP variables actually required ----- // Keep removing leaf nodes until only Variable type SDVariables remain Set minimalSubgraphVars = new HashSet<>(allFpVarsConnectedToLoss); Queue leafFPVars = new LinkedList<>(); for (String s : allFpVarsConnectedToLoss) { //First: determine if is a FP leaf (Array type SDVariable) Variable v = variables.get(s); if (v.getVariable().getVariableType() == VariableType.ARRAY) { String opName = v.getOutputOfOp(); //Always defined for array type SameDiffOp op = ops.get(opName); List inputsToOp = op.getInputsToOp(); boolean anyInputsInSubgraph = false; if (inputsToOp != null) { for (String s2 : inputsToOp) { if (allFpVarsConnectedToLoss.contains(s2)) { //Connection s2 -> s exists... therefore s is not a leaf (yet) anyInputsInSubgraph = true; break; } } } if (!anyInputsInSubgraph) { //Mark s as a leaf to be removed leafFPVars.add(s); } } VariableType vt = v.getVariable().getVariableType(); boolean isUserRequested = variablesRequiringGradients != null && ArrayUtils.contains(variablesRequiringGradients, s); if ((vt == VariableType.CONSTANT || vt == VariableType.PLACEHOLDER) && !isUserRequested) { leafFPVars.add(s); } } while (!leafFPVars.isEmpty()) { String nextLeaf = leafFPVars.remove(); Variable v = variables.get(nextLeaf); minimalSubgraphVars.remove(nextLeaf); //Now, after removing: check what this variable is input to... //If nextLeaf is input to some op X, then if none of inputs y->X are present in subgraph, then // output variables X->z must now be leafs //Note that any time we remove a variable, the only possible new leafs are those that this one // is connected to. List inputsTo = v.getInputsForOp(); if (inputsTo != null && !inputsTo.isEmpty()) { for (String opName : inputsTo) { SameDiffOp op = ops.get(opName); List inputsToOp = op.getInputsToOp(); boolean anyPresent = false; for (String s : inputsToOp) { if (minimalSubgraphVars.contains(s) || (variablesRequiringGradients != null && ArrayUtils.contains(variablesRequiringGradients, s))) { //Note second condition: means user explicitly specified that they want gradients for that input variable... hence we need to diff this op anyPresent = true; break; } } if (!anyPresent) { //All inputs to op X are not in subgraph. Therefore outputs of op must be new leaves List outVars = op.getOutputsOfOp(); if (outVars != null) { for (String s : outVars) { if (!leafFPVars.contains(s)) { //Mark this variable to be processed next leafFPVars.add(s); } } } } } } } Preconditions.checkState(!minimalSubgraphVars.isEmpty(), "Cannot differentiate graph relative to the specified loss function variables %s:" + " graph does not contain any trainable SDVariables (floating point VARIABLE type SDVariables) that the loss function depend on.", lossVariables); //At this point: we know the set of variables that are connected to the loss - these all (and only) need gradients Queue availableForDiff = new LinkedList<>(); for (SDVariable lossVar : finalOutputs) { Variable v = sameDiff.variables.get(lossVar.name()); if (v.getOutputOfOp() != null) { String opName = v.getOutputOfOp(); availableForDiff.add(opName); } } // Collect all the the ops that have to be traversed before we can conclude that the gradient for // a variable is fully available //For example, if we have X -> op -> Y, and Y -> (A,B) we need gradient contribution from BOTH // Y->A and Y->B connections before we can do differentiation of op "op" final HashMap> prerequisites = new HashMap<>(); //Key: variable name. Value: list of op names for (String var : minimalSubgraphVars) { Variable variable = variables.get(var); // Copy the collection, as the original one will be modified during backprop final List inputsForOp = variable.getInputsForOp(); if (inputsForOp != null) { List req = new ArrayList<>(); for (String opName : inputsForOp) { //Need to filter ops here //For example, if we have: var -> Op1, and var -> Op2 //we might not need to differentiate Op2 if output of Op2 doesn't impact loss function SameDiffOp o = ops.get(opName); List opOutputs = o.getOutputsOfOp(); boolean anyOpOutputsRequired = false; if (opOutputs != null) { for (String s : opOutputs) { if (minimalSubgraphVars.contains(s)) { anyOpOutputsRequired = true; break; } } } if (anyOpOutputsRequired) { req.add(opName); } } prerequisites.put(variable.getName(), req); } } Set differentiatedOps = new HashSet<>(); while (!availableForDiff.isEmpty()) { String dfName = availableForDiff.remove(); DifferentialFunction df = sameDiff.ops.get(dfName).getOp(); //Get the inputs and outputs of the op List inputsToOp; List outputsOfOp; if (df instanceof GradientBackwardsMarker) { SameDiffOp op = sameDiff.ops.get(df.getOwnName()); inputsToOp = op.getInputsToOp(); outputsOfOp = Collections.emptyList(); } else { inputsToOp = sameDiff.ops.get(df.getOwnName()).getInputsToOp(); outputsOfOp = sameDiff.ops.get(df.getOwnName()).getOutputsOfOp(); } //Get gradients for all output variables: List grads = new ArrayList<>(); for (String s : outputsOfOp) { SDVariable v = sameDiff.getVariable(s); SDVariable g = v.hasGradient() ? v.gradient() : null; if (g == null) { //If no gradient exists at this point, 3 possibilities: // (a) we have a bug // (b) output of this op isn't used in calculating the loss // (c) output isn't a FP type //In the FP case, we should create a zero variable to backprop, because we can't perform backprop // for this op otherwise... if (!v.dataType().isFPType()) { grads.add(null); } else { //See "Step 3: Differentiate ops in minimal subgraph" above for explanation on why this should be zerosLike here... SDVariable gTemp = sameDiff.zerosLike(v); grads.add(gTemp); } } else { grads.add(g); } } //Differentiate: List currFnGrads = df.diff(grads); differentiatedOps.add(df.getOwnName()); //Check the inputs to this op, see if we can differentiate those ops now (and if so: add to queue) for (String s : inputsToOp) { Variable v = sameDiff.variables.get(s); String opName = v.getOutputOfOp(); if (opName == null || differentiatedOps.contains(opName)) { //Skip placeholder/constant etc; also skip if we've previously differentiated this op continue; } //Next: we've just differentiated OpX //For s -> OpX: we now have gradient for s after df.diff(grads) call earlier //Now, do we also need to differentiate OpY, where OpY -> s? //If any input variables x (x -> OpY) exist, if they are in the minimal subgraph, then we // need to differentiate OpY too //Note that just because we *need to* doesn't mean we *can* yet boolean isRequiredOp = false; SameDiffOp op = ops.get(opName); if (op.getInputsToOp() != null) { List opInputs = op.getInputsToOp(); boolean anyInputsRequired = false; for (String s2 : opInputs) { if (minimalSubgraphVars.contains(s2)) { anyInputsRequired = true; break; } } if (anyInputsRequired) { if (!differentiatedOps.contains(op.getName())) { isRequiredOp = true; } } } if (!isRequiredOp) { continue; } //Now that we know we need this op - check if we can actually differentiate it... //We can differentiate it if, for all variables that are outputs of this op: //(a) we have gradient already, OR //(b) it's not a FP variable, OR //(c) it's a FP variable but not one that the loss depends on //Note that for "output array is used multiple times" case (i.e., X->opY->Y, X->opZ->Z) we need all gradient // contributions - i.e., we need to have differentiated both opY and opZ boolean allAvailable = true; SameDiffOp o = sameDiff.ops.get(opName); for (String opOutput : o.getOutputsOfOp()) { Variable outVar = variables.get(opOutput); if (outVar.getVariable().dataType().isFPType()) { if (minimalSubgraphVars.contains(outVar.getName())) { //Need gradient for this variable to be available before we can differentiate if (outVar.getVariable().gradient() == null) { allAvailable = false; break; } //However, when a variable is used multiple times, we need ALL gradient contributions available: List prereqs = prerequisites.get(outVar.getName()); if (prereqs != null) { allAvailable &= differentiatedOps.containsAll(prereqs); if (!allAvailable) break; } } //If in't not in the minimal subgraph, loss doesn't depend on it, so we don't care about it } } if (allAvailable && !availableForDiff.contains(o.getOp().getOwnName())) { availableForDiff.add(o.getOp().getOwnName()); } } } //Let's validate we actually differentiated everything correctly: for (String s : minimalSubgraphVars) { if (lossVariables.contains(s)) continue; SDVariable v = variables.get(s).getVariable(); SDVariable g = v.gradient(); if (g == null) { throw new IllegalStateException("Error encountered during differentiation: no gradient for required variable \"" + s + "\" was calculated"); } } return new SDVariable[]{sameDiff.var(GRAD_FN_KEY, org.nd4j.linalg.api.buffer.DataType.FLOAT, 1)}; } }); associateSameDiffWithOpsAndVariables(); } /** * Try to infer the loss variable/s (usually loss variables). Note that this is not reliable in general. */ protected List bestGuessLossVariables() { List out = new ArrayList<>(); for (Variable v : variables.values()) { if (v.getVariable().isConstant() || v.getVariable().isPlaceHolder() || //Exclude constants and placeholders (v.getInputsForOp() != null && !v.getInputsForOp().isEmpty()) || //Exclude variables that are inputs to ops (v.getControlDepsForOp() != null && !v.getControlDepsForOp().isEmpty()) || //Exclude variables that are control dependency inputs to ops (v.getControlDepsForVar() != null && !v.getControlDepsForVar().isEmpty())) { //Exclude variables that are control dependency inputs to other variables (mainly for import of cond etc ops) continue; } //Also exclude assert etc ops - doesn't make sense to return these "outputs" to user if (v.getOutputOfOp() != null && v.getVariable().dataType().isFPType()) { String opName = v.getOutputOfOp(); SameDiffOp o = ops.get(opName); if (o.getOp() instanceof Assert) { continue; } //A bit of a hack for TF import: some TF graphs have Switch ops, where the output of one branch isn't consumed // by any ops. Consequently, during execution this "output" might never be available. So we'll exclude the output of execution here // This applies to SameDiff while loops as well if (o.getOp() instanceof Switch) { continue; } } out.add(v.getName()); } return out; } /** * Returns true if this vertex id is a place holder variable or not
* A place holder variable is one where the array shape(s) are currently known and can't yet be calculated * * @param varName the vertex id to test * @return True if the variable is a placeholder, false otherwise */ public boolean isPlaceHolder(String varName) { Preconditions.checkState(variables.containsKey(varName), "No variable present in SameDiff instance with name \"%s\"", varName); return variables.get(varName).getVariable().isPlaceHolder(); } /** * Updates the variable name property on the passed in variable, the reference in samediff, and returns the variable. *

* Note that if null for the new variable is passed in, it will just return the original input variable. * @param opToRename note we pass in the op here for times when an op may have multiple outputs * when this is the case, we need to pass in the op to rename otherwise context gets lost * and subsequent rename attempts will not operate on the op. * @param varToUpdate the variable to update * @param newVarName the new variable name * @return the passed in variable */ public SDVariable updateVariableNameAndReference(SameDiffOp opToRename,SDVariable varToUpdate, String newVarName) { if (varToUpdate == null) { throw new NullPointerException("Null input: No variable found for updating!"); } if (newVarName != null) { String nameScope = currentNameScope(); if (nameScope != null) { if (!newVarName.startsWith(nameScope + "/")) { newVarName = nameScope + "/" + newVarName; } } } if (newVarName != null && variables.containsKey(newVarName) && varToUpdate != variables.get(newVarName).getVariable()) { throw new IllegalStateException("Variable name \"" + newVarName + "\" already exists for a different SDVariable"); } if (newVarName == null && variables.containsKey(varToUpdate.name()) && variables.get(varToUpdate.name()).getVariable() != varToUpdate) { //Edge case: suppose we do m1=sd.mean(in), m2=sd.mean(m1) -> both initially have the name // "mean" and consequently a new variable name needs to be generated newVarName = generateNewVarName(varToUpdate.name(), 0); } if (newVarName == null || varToUpdate.name().equals(newVarName)) { return varToUpdate; } val oldVarName = varToUpdate.name(); varToUpdate.setVarName(newVarName); renameVariable(opToRename,oldVarName, newVarName); return varToUpdate; } /** * Updates the variable name property on the passed in variable, the reference in samediff, and returns the variable. *

* Note that if null for the new variable is passed in, it will just return the original input variable. * * @param varToUpdate the variable to update * @param newVarName the new variable name * @return the passed in variable */ public SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) { SameDiffOp op = ops.get(varToUpdate.name()); return updateVariableNameAndReference(op,varToUpdate,newVarName); } /** * Updates the variable name property on the passed in variables, its reference in samediff, and returns the variable. * * @param variablesToUpdate the variable to update * @param newVariableNames the new variable name * @return the updated, passed in variables */ public SDVariable[] updateVariableNamesAndReferences(SDVariable[] variablesToUpdate, String[] newVariableNames) { int numVariables = variablesToUpdate.length; SDVariable[] updatedVariables = new SDVariable[numVariables]; for (int i = 0; i < numVariables; i++) { SDVariable varToUpdate = variablesToUpdate[i]; String name = newVariableNames == null ? null : newVariableNames[i]; updatedVariables[i] = updateVariableNameAndReference(varToUpdate, name); } return updatedVariables; } /** * Associate the current SameDiff instance with all ops and variables. * This is necessary to ensure that when dealing with shared state (usually with a SameDiff function such * as "grad" - the backward function) we have the correct SameDiff instance set for all ops/SDVariables.
* If this is not done, arrays and shapes could be fetched from the incorrect SameDiff instance for some methods */ protected void associateSameDiffWithOpsAndVariables() { for (SDVariable var : variableMap().values()) { var.setSameDiff(this); } // for(DifferentialFunction df : functionInstancesById.values()){ for (SameDiffOp op : ops.values()) { DifferentialFunction df = op.getOp(); df.setSameDiff(this); //TODO: This is ugly but seemingly necessary //Finally, also set the SDVariable for each op //Otherwise: could have an op pointing to this SameDiff instance, but op's SDVariable's sameDiff field pointing // to another SameDiff instance. At which point, they could fetch shapes and arrays from some other instance // (i.e., not from this one that is currently executing) SDVariable[] args = df.args(); if (args != null) { for (SDVariable arg : args) { arg.setSameDiff(this); } } SDVariable[] outputs = df.outputVariables(); if (outputs != null) { for (SDVariable out : outputs) { out.setSameDiff(this); } } } } protected int asFlatNode(String name, @NonNull SameDiff scope, @NonNull FlatBufferBuilder bufferBuilder) { int scopeName = bufferBuilder.createString(name); int flatNode = FlatNode.createFlatNode(bufferBuilder, scopeName, scopeName, OpType.LOGIC, 10, // hardcoded value 0, 0, 0, (byte) 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); return flatNode; } /** * Note: INTENDED FOR DEVELOPER USE
* This method extract base variable name and output index (if exists) from raw variable name. * I.e: * - if variable name is "Unstack_2", result will be Pair("Unstack_2", 0) * - if variable name is "Unstack_2:12", result will be Pair("Unstack_2", 12) * * @param varName * @return */ public static Pair parseVariable(@NonNull String varName) { if (!varName.contains(":")) { return Pair.pairOf(varName, 0); } else { val split = varName.split(":"); val index = Integer.valueOf(split[split.length - 1]); if (split.length == 2) return Pair.pairOf(split[0], index); else { val builder = new StringBuilder(); for (int e = 0; e < split.length - 1; e++) { builder.append(split[e]); if (e < split.length - 2) builder.append(":"); } return Pair.pairOf(builder.toString(), index); } } } /** * This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and * all arrays as a ByteBuffer containing the FlatBuffers format data * * @param configuration - ExecutorConfiguration to be embedded into serialized graph * @param includeUpdaterState If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc) * @return a ByteBuffer holding the exported FlatBuffers representation of the graph */ public ByteBuffer asFlatBuffers(@NonNull ExecutorConfiguration configuration, boolean includeUpdaterState) { return asFlatBuffers(0, configuration, includeUpdaterState); } /** * This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and * all arrays as a ByteBuffer containing the FlatBuffers format data * * @param configuration - ExecutorConfiguration to be embedded into serialized graph * @param includeUpdaterState If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc) * @return a ByteBuffer holding the exported FlatBuffers representation of the graph */ public ByteBuffer asFlatBuffers(long graphId, @NonNull ExecutorConfiguration configuration, boolean includeUpdaterState) { Nd4j.getExecutioner().commit(); val bufferBuilder = new FlatBufferBuilder(1024); val idCounter = new AtomicInteger(0); val flatVariables = new ArrayList(); val flatOffsets = new ArrayList(); val flatNodes = new ArrayList(); // first of all we build VariableSpace dump val variableList = new ArrayList<>(variables()); val reverseMap = new LinkedHashMap(); val forwardMap = new LinkedHashMap(); val framesMap = new LinkedHashMap(); int idx = 0; val idxForOps = new IdentityHashMap(); List allVars = variables(); for (SDVariable variable : allVars) { INDArray arr = variable.getVariableType() == VariableType.ARRAY ? null : variable.getArr(); log.trace("Exporting variable: [{}]", variable.name()); //If variable is the output of some op - let's use the ONE index for exporting, and properly track the output // numbers. For example, unstack(x) -> y0, y1, y2 -> the y's should be say (3,0), (3,1), (3,2) NOT (4,0), (5,0), (6,0) String varName = variable.name(); int varIdx; int outputNum; if (variables.get(varName).getOutputOfOp() != null) { //This variable is the output of a node DifferentialFunction df = ops.get(variables.get(varName).getOutputOfOp()).getOp(); if (!idxForOps.containsKey(df)) { varIdx = idCounter.incrementAndGet(); idxForOps.put(df, varIdx); } else { varIdx = idxForOps.get(df); } String[] outNames = df.outputVariablesNames(); outputNum = ArrayUtils.indexOf(outNames, varName); Preconditions.checkState(outputNum >= 0, "Variable name \"%s\" not found in list of outputs: %s", varName, outNames); } else { varIdx = idCounter.incrementAndGet(); outputNum = 0; } reverseMap.put(variable.name(), varIdx); log.trace("Adding [{}] as [{}]", variable.name(), varIdx); int shape = 0; int name = bufferBuilder.createString(variable.name()); int array = 0; int id = IntPair.createIntPair(bufferBuilder, varIdx, outputNum); byte varType = (byte) variable.getVariableType().ordinal(); if (variable.isConstant() || variable.isPlaceHolder() || variable.getVariableType() == VariableType.VARIABLE) { //Don't export array type (i.e., activations), these are always replaced/re-calculated on each step array = arr == null ? 0 : arr.toFlatArray(bufferBuilder); } if (variable.getVariableType() == VariableType.PLACEHOLDER) { val shp = variable.getShape(); if(shp != null) { //Some models may have no shape defined, not ever a placeholder type shape shape = FlatVariable.createShapeVector(bufferBuilder, shp); } } int controlDeps = 0; int controlDepsForOp = 0; int controlDepsForVar = 0; Variable v = variables.get(varName); int[] cds = FlatBuffersMapper.mapOrNull(v.getControlDeps(), bufferBuilder); if(cds != null) controlDeps = FlatVariable.createControlDepsVector(bufferBuilder, cds); int[] cdsForOp = FlatBuffersMapper.mapOrNull(v.getControlDepsForOp(), bufferBuilder); if(cdsForOp != null) controlDepsForOp = FlatVariable.createControlDepForOpVector(bufferBuilder, cdsForOp); int[] cdsForVar = FlatBuffersMapper.mapOrNull(v.getControlDepsForVar(), bufferBuilder); if(cdsForVar != null) controlDepsForVar = FlatVariable.createControlDepsForVarVector(bufferBuilder, cdsForVar); int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(variable.dataType()), shape, array, -1, varType, controlDeps, controlDepsForOp, controlDepsForVar); flatVariables.add(flatVariable); } //add functions for (SameDiffOp op : ops.values()) { DifferentialFunction func = op.getOp(); Integer fnId = idxForOps.get(func); flatNodes.add(FlatBuffersMapper.asFlatNode(this, func, bufferBuilder, variableList, reverseMap, forwardMap, framesMap, idCounter, fnId)); } int outputsOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatOffsets)); int variablesOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatVariables)); int nodesOffset = FlatGraph.createNodesVector(bufferBuilder, Ints.toArray(flatNodes)); int numPlaceholders = 0; for (SDVariable v : variables()) { if (v.isPlaceHolder()) { numPlaceholders++; } } int[] placeholderOffsets = new int[numPlaceholders]; if (numPlaceholders > 0) { int i = 0; for (SDVariable v : variables()) { if (!v.isPlaceHolder()) continue; placeholderOffsets[i++] = bufferBuilder.createString(v.name()); } } int placeholdersOffset = FlatGraph.createPlaceholdersVector(bufferBuilder, placeholderOffsets); List lossVars = getLossVariables(); int[] lossVarOffsets = new int[lossVars == null ? 0 : lossVars.size()]; for (int i = 0; i < lossVarOffsets.length; i++) { lossVarOffsets[i] = bufferBuilder.createString(lossVars.get(i)); } int lossVarOffset = FlatGraph.createLossVariablesVector(bufferBuilder, lossVarOffsets); int trainingConfigOffset = 0; int updaterStateOffset = 0; if (trainingConfig != null) { String json = trainingConfig.toJson(); trainingConfigOffset = bufferBuilder.createString(json); } if (includeUpdaterState && updaterMap != null && !updaterMap.isEmpty()) { int[] updaterOffsets = new int[updaterMap.size()]; int updaterNum = 0; for (Map.Entry g : updaterMap.entrySet()) { int paramNameOffset = bufferBuilder.createString(g.getKey()); int stateKeyOffset = 0; int stateValuesOffset = 0; Map state = g.getValue().getState(); if (state != null && !state.isEmpty()) { int[] keysOffsets = new int[state.size()]; int[] valuesOffsets = new int[state.size()]; int i = 0; for (Map.Entry e : state.entrySet()) { keysOffsets[i] = bufferBuilder.createString(e.getKey()); valuesOffsets[i] = e.getValue().toFlatArray(bufferBuilder); i++; } stateKeyOffset = UpdaterState.createUpdaterStateKeysVector(bufferBuilder, keysOffsets); stateValuesOffset = UpdaterState.createUpdaterStateValuesVector(bufferBuilder, valuesOffsets); } updaterOffsets[updaterNum++] = UpdaterState.createUpdaterState(bufferBuilder, paramNameOffset, stateKeyOffset, stateValuesOffset); } updaterStateOffset = FlatGraph.createUpdaterStateVector(bufferBuilder, updaterOffsets); } int fg = FlatGraph.createFlatGraph(bufferBuilder, graphId, variablesOffset, nodesOffset, outputsOffset, configuration.getFlatConfiguration(bufferBuilder), placeholdersOffset, lossVarOffset, trainingConfigOffset, updaterStateOffset); bufferBuilder.finish(fg); synchronized (this) { for (Map.Entry e : reverseMap.entrySet()) { this.variables.get(e.getKey()).setVariableIndex(e.getValue()); } } return bufferBuilder.dataBuffer(); } /** * See {@link #asFlatGraph(long, ExecutorConfiguration, boolean)}. * * Uses the default {@link ExecutorConfiguration} with output mode as * {@link OutputMode#VARIABLE_SPACE}, execution mode as {@link ExecutionMode#SEQUENTIAL}, * with profiling disabled and gather timings enabled. */ public FlatGraph asFlatGraph(boolean includeUpdaterState) { return FlatGraph.getRootAsFlatGraph(this.asFlatBuffers(includeUpdaterState)); } /** * This method returns FlatGraph structure * * @param configuration * @param includeUpdaterState If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc) * @return */ public FlatGraph asFlatGraph(long graphId, ExecutorConfiguration configuration, boolean includeUpdaterState) { return FlatGraph.getRootAsFlatGraph(asFlatBuffers(graphId, configuration, includeUpdaterState)); } /** * This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and * all arrays as a ByteBuffer containing the FlatBuffers format data * * Uses the default {@link ExecutorConfiguration} with output mode as * {@link OutputMode#VARIABLE_SPACE}, execution mode as {@link ExecutionMode#SEQUENTIAL}, * with profiling disabled and gather timings enabled. * * @param includeUpdaterState If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc) * @return a ByteBuffer holding the exported FlatBuffers representation of the graph */ public ByteBuffer asFlatBuffers(boolean includeUpdaterState) { val configuration = ExecutorConfiguration.builder() .outputMode(OutputMode.VARIABLE_SPACE) .executionMode(org.nd4j.autodiff.execution.conf.ExecutionMode.SEQUENTIAL) .profilingMode(OpExecutioner.ProfilingMode.DISABLED) .gatherTimings(true) .build(); return asFlatBuffers(configuration, includeUpdaterState); } /** * Save the SameDiff instance to a file. Files can be loaded using {@link #load(File, boolean)} * * @param file File to save to * @param saveUpdaterState If true: save the updater state (arrays etc for Adam, Nesterov, RmsProp etc). If false: don't save * the updater state. If you want to continue training after loading your model, this should be true, * however may increase the file size significantly. * If the network is to be used for inference only, set this to false to save space */ public void save(@NonNull File file, boolean saveUpdaterState) { try { asFlatFile(file, saveUpdaterState); } catch (IOException e) { throw new RuntimeException("Error saving SameDiff instance to file", e); } } /** * As per {@link #save(File, boolean)} but the serialized SameDiff instance is written to the output stream instead. * Note that this temporarily saves to disk (using {@link ND4JFileUtils#createTempFile(String, String)} then copies all * file bytes to the stream * * @param outputStream Stream to write the serialized SameDiff instance to * @param saveUpdater If true: save the updater state (arrays etc for Adam, Nesterov, RmsProp etc). If false: don't save * the updater state. If you want to continue training after loading your model, this should be true, * however may increase the file size significantly. * If the network is to be used for inference only, set this to false to save space. */ public void save(@NonNull OutputStream outputStream, boolean saveUpdater) { File tempFile = ND4JFileUtils.createTempFile("SameDiffFile", "temp"); try { save(tempFile, saveUpdater); if (!(outputStream instanceof BufferedOutputStream)) { outputStream = new BufferedOutputStream(outputStream); } try (OutputStream os = outputStream; InputStream is = new BufferedInputStream(new FileInputStream(tempFile))) { IOUtils.copy(is, os); } catch (IOException e) { throw new RuntimeException("Error writing to output stream (or reading from temp file)", e); } } finally { tempFile.delete(); } } /** * Load the SameDiff instance previously saved with {@link #save(File, boolean)} * * @param file The file to load the network from * @param loadUpdaterState If true - load the updater state (history etc for updaters such as Adam, Nesterov momentum, RMSProp etc). * For inference only, this should be false, as the updater state will take more memory, but * is not required for training. * If the network is to be trained further, this should be true. * The updater state can only be loaded if it was saved with the network. * @return The loaded SameDiff network */ public static SameDiff load(@NonNull File file, boolean loadUpdaterState) { try { return fromFlatFile(file, loadUpdaterState); } catch (IOException e) { throw new RuntimeException("Error loading SameDiff instance from file", e); } } /** * As per {@link #load(File, boolean)} but the SameDiff instance * * @param is Input stream to load the saved network from * @param loadUpdaterState If true - load the updater state (history etc for updaters such as Adam, Nesterov momentum, RMSProp etc). * For inference only, this should be false, as the updater state will take more memory, but * is not required for training. * If the network is to be trained further, this should be true. * The updater state can only be loaded if it was saved with the network. * @return The loaded SameDiff network */ public static SameDiff load(@NonNull InputStream is, boolean loadUpdaterState) { File tempFile = ND4JFileUtils.createTempFile("SameDiffFile", "temp"); try { try (OutputStream os = new BufferedOutputStream(new FileOutputStream(tempFile))) { IOUtils.copy(is, os); } return fromFlatFile(tempFile, loadUpdaterState); } catch (IOException e) { throw new RuntimeException("Error loading SameDiff instance from file", e); } finally { tempFile.delete(); } } /** * This method converts SameDiff instance to FlatBuffers and saves it to file which can be restored later
* This includes the updater state, if applicable. * * Uses the default {@link ExecutorConfiguration} with output mode as * {@link OutputMode#VARIABLE_SPACE}, execution mode as {@link ExecutionMode#SEQUENTIAL}, * with profiling disabled and gather timings enabled. * * @param file File to save the FlatBuffers serialized graph (including arrays) to */ public void asFlatFile(@NonNull File file) throws IOException { asFlatFile(file, true); } /** * See {@link #asFlatFile(File, ExecutorConfiguration, boolean)}. * * Uses the default {@link ExecutorConfiguration} with output mode as * {@link OutputMode#VARIABLE_SPACE}, execution mode as {@link ExecutionMode#SEQUENTIAL}, * with profiling disabled and gather timings enabled. */ public void asFlatFile(@NonNull File file, boolean withUpdaterState) throws IOException { val fb = asFlatBuffers(withUpdaterState); val offset = fb.position(); val array = fb.array(); try (val fos = new FileOutputStream(file); val bos = new BufferedOutputStream(fos); val dos = new DataOutputStream(bos)) { dos.write(array, offset, array.length - offset); } } /** * This method converts SameDiff instance to FlatBuffers and saves it to file which can be restored later * * @param file File to save the FlatBuffers serialized graph (including arrays) to * @param includeUpdaterState If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc) */ public void asFlatFile(@NonNull File file, @NonNull ExecutorConfiguration configuration, boolean includeUpdaterState) throws IOException { val fb = asFlatBuffers(configuration, includeUpdaterState); val offset = fb.position(); val array = fb.array(); try (val fos = new FileOutputStream(file); val bos = new BufferedOutputStream(fos); val dos = new DataOutputStream(bos)) { dos.write(array, offset, array.length - offset); } } /** * Create a {@link SameDiff} instance from a file, including the updater state * The method to save the file is {@link #save(File, boolean)} * * @param file the file to load from * @return the loaded same diff instance * @throws IOException */ public static SameDiff fromFlatFile(@NonNull File file) throws IOException { return fromFlatFile(file, true); } /** * Create a {@link SameDiff} instance from a file, optionally also loading the updater state * The method to save the file is {@link #save(File, boolean)} * * @param file the file to load from * @param loadUpdaterState If true, load the updater state (Adam etc state). For training, use true. For inference, use false * @return the loaded same diff instance * @throws IOException */ public static SameDiff fromFlatFile(@NonNull File file, boolean loadUpdaterState) throws IOException { byte[] bytes; try (InputStream is = new BufferedInputStream(new FileInputStream(file))) { bytes = IOUtils.toByteArray(is); } ByteBuffer bbIn = ByteBuffer.wrap(bytes); return fromFlatBuffers(bbIn, loadUpdaterState); } /** * Create a {@link SameDiff} * instance from a byte buffers * instance. * * See {@link #fromFlatBuffers(ByteBuffer, boolean)}. Loads updater state (loadUpdaterState is true). * * @param bbIn the input byte buffer * @return the created samediff instance * @throws IOException */ public static SameDiff fromFlatBuffers(ByteBuffer bbIn) throws IOException { return fromFlatBuffers(bbIn, true); } /** * Create a {@link SameDiff} * instance from a byte buffers * instance. * * @param bbIn the input byte buffer * @param loadUpdaterState If true, load the updater state (Adam etc state). For training, use true. For inference, use false * @return the created samediff instance * @throws IOException */ public static SameDiff fromFlatBuffers(ByteBuffer bbIn, boolean loadUpdaterState) throws IOException { FlatGraph fg = FlatGraph.getRootAsFlatGraph(bbIn); int numOps = fg.nodesLength(); int numVars = fg.variablesLength(); List ops = new ArrayList<>(numOps); for (int i = 0; i < numOps; i++) { ops.add(fg.nodes(i)); } List vars = new ArrayList<>(numVars); for (int i = 0; i < numVars; i++) { vars.add(fg.variables(i)); } // FlatConfiguration conf = fg.configuration(); /* Reconstruct the graph We'll do the reconstruction manually here, rather than using sd.var(...), so that we have more control over the final result. */ SameDiff sd = SameDiff.create(); //Reconstruct placeholders int numPlaceholders = fg.placeholdersLength(); Set ph = new LinkedHashSet<>(); for (int i = 0; i < numPlaceholders; i++) { ph.add(fg.placeholders(i)); } //Reconstruct variables: Map varNodeIds = new HashMap<>(); Map, SDVariable> variablesByNodeAndOutNum = new HashMap<>(); Map> variablesByName = new HashMap<>(); for (FlatVariable v : vars) { int shapeLength = v.shapeLength(); long[] shape = new long[shapeLength]; for (int i = 0; i < shapeLength; i++) { shape[i] = v.shape(i); } String n = v.name(); byte dtypeByte = v.dtype(); org.nd4j.linalg.api.buffer.DataType dtype = FlatBuffersMapper.getDataTypeFromByte(dtypeByte); //TODO Infer this properly! Could be constant, etc. VariableType vt = VariableType.values()[v.variabletype()]; SDVariable var = new SDVariable(n, vt, sd, shape, dtype); sd.variables.put(n, Variable.builder().name(n).variable(var).build()); Variable v2 = sd.variables.get(n); //Reconstruct control dependencies if(v.controlDepsLength() > 0) { int num = v.controlDepsLength(); List l = new ArrayList<>(num); for(int i = 0; i < num; i++) { l.add(v.controlDeps(i)); } v2.setControlDeps(l); } if(v.controlDepForOpLength() > 0) { int num = v.controlDepForOpLength(); List l = new ArrayList<>(num); for( int i = 0; i < num; i++) { l.add(v.controlDepForOp(i)); } v2.setControlDepsForOp(l); } if(v.controlDepsForVarLength() > 0) { int num = v.controlDepsForVarLength(); List l = new ArrayList<>(num); for(int i = 0; i < num; i++) { l.add(v.controlDepsForVar(i)); } v2.setControlDepsForVar(l); } FlatArray fa = v.ndarray(); if (fa != null && vt != VariableType.ARRAY) { INDArray arr; try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { arr = Nd4j.createFromFlatArray(fa); } sd.setArrayForVariable(n, arr); } IntPair id = v.id(); //First value: node (op) id. Second: output number variablesByNodeAndOutNum.put(new Pair<>(id.first(), id.second()), var); if (!variablesByName.containsKey(n)) { variablesByName.put(n, new ArrayList()); } List list = variablesByName.get(n); list.add(var); } //Reconstruct ops: for (FlatNode fn : ops) { DifferentialFunction df = FlatBuffersMapper.fromFlatNode(fn); String name = fn.name(); df.setSameDiff(sd); df.setOwnName(name); if (sd.ops.containsKey(name)) { sd.ops.get(name).setOp(df); } else { sd.ops.put(name, SameDiffOp.builder().name(name).op(df).build()); } int outLength = fn.outputLength(); int[] outs = new int[outLength]; for (int i = 0; i < outLength; i++) { outs[i] = fn.output(i); } int opId = fn.id(); //Work out inputs and outputs: int[] output = new int[fn.outputLength()]; for (int i = 0; i < output.length; i++) { output[i] = fn.output(i); } int[] input = new int[fn.inputLength()]; for (int i = 0; i < input.length; i++) { input[i] = fn.input(i); } IntPair[] inputPaired = new IntPair[fn.inputPairedLength()]; List> intPairList = new ArrayList<>(); for (int i = 0; i < inputPaired.length; i++) { inputPaired[i] = fn.inputPaired(i); intPairList.add(new Pair<>(inputPaired[i].first(), inputPaired[i].second())); } String[] inputNames = new String[inputPaired.length]; for (int i = 0; i < inputPaired.length; i++) { int nodeId = inputPaired[i].first(); int nodeOutNum = inputPaired[i].second(); SDVariable varIn = variablesByNodeAndOutNum.get(new Pair<>(nodeId, nodeOutNum)); if (varIn == null) { //The variable corresponding to this op was not } inputNames[i] = varIn.name(); } SameDiffOp op = sd.ops.get(df.getOwnName()); op.setInputsToOp(Arrays.asList(inputNames)); //Reconstruct control dependencies if (fn.controlDepsLength() > 0) { int l = fn.controlDepsLength(); List list = new ArrayList<>(l); for( int i = 0; i < l; i++ ){ list.add(fn.controlDeps(i)); } op.setControlDeps(list); } if (fn.varControlDepsLength() > 0) { int l = fn.varControlDepsLength(); List list = new ArrayList<>(l); for( int i = 0; i 0) { int l = fn.controlDepForLength(); List list = new ArrayList<>(l); for( int i=0; i()); } if (!v.getInputsForOp().contains(df.getOwnName())) { v.getInputsForOp().add(df.getOwnName()); } } List varsForOp = variablesByName.get(name); //Can't assume that variables for the op have all been defined. For example, if we export before execution in SameDiff //In theory, we can reconstruct the output variables (minus names) if we know the number of op outputs //And we can calculate the op outputs - in most cases - after the op has been created and parameters set int numOutputs = df.getNumOutputs(); if (numOutputs <= 0) { numOutputs = fn.outputLength(); } String[] varNames = null; if (varsForOp != null && varsForOp.size() == numOutputs) { varNames = new String[varsForOp.size()]; for (int i = 0; i < varNames.length; i++) { varNames[i] = varsForOp.get(i).name(); sd.getVariables().get(varNames[i]).setOutputOfOp(df.getOwnName()); } sd.ops.get(df.getOwnName()).setOutputsOfOp(Arrays.asList(varNames)); } else { //We're missing some variables... int outputNamesLength = fn.outputNamesLength(); varNames = new String[outputNamesLength]; for (int i = 0; i < outputNamesLength; i++) { String n = fn.outputNames(i); varNames[i] = n; if (!sd.variables.containsKey(n)) { //Need to create the variable - perhaps it wasn't exported. Note output of node -> can only be VARIABLE type SDVariable var = new SDVariable(n, VariableType.VARIABLE, sd, null, null); sd.variables.put(n, Variable.builder().name(n).variable(var).build()); variablesByNodeAndOutNum.put(new Pair<>(opId, i), var); } sd.getVariables().get(varNames[i]).setOutputOfOp(df.getOwnName()); } sd.ops.get(df.getOwnName()).setOutputsOfOp(Arrays.asList(varNames)); } //Check the op mapping int he variablesByNodeAndOutputNum //For multi-output ops, variables will have their own index, not related to the op index for (int i = 0; i < varNames.length; i++) { Pair p = new Pair<>(opId, i); if (!variablesByNodeAndOutNum.containsKey(p)) { variablesByNodeAndOutNum.put(p, sd.getVariable(varNames[i])); } } } //Reconstruct loss variables if (fg.lossVariablesLength() > 0) { for (int i = 0; i < fg.lossVariablesLength(); i++) { sd.addLossVariable(fg.lossVariables(i)); } } //Reconstruct training config String tc = fg.trainingConfig(); if (tc != null) { sd.trainingConfig = TrainingConfig.fromJson(tc); } if (loadUpdaterState) { //Reconstruct updater state if (fg.updaterStateLength() > 0) { sd.updaterMap = new HashMap<>(); int n = fg.updaterStateLength(); for (int i = 0; i < n; i++) { UpdaterState us = fg.updaterState(i); String name = us.paramName(); int nKeys = us.updaterStateKeysLength(); Map m = new HashMap<>(); for (int j = 0; j < nKeys; j++) { String key = us.updaterStateKeys(j); FlatArray fa = us.updaterStateValues(j); INDArray stateArr = Nd4j.createFromFlatArray(fa); m.put(key, stateArr); } //Initialize the updater GradientUpdater gu = sd.trainingConfig.getUpdater().instantiate(m, false); sd.updaterMap.put(name, gu); } sd.initializedTraining = true; } } return sd; } /** * This method returns a text representation of the "flattened" graph. * * @return String representation of the graph * @see #summary() */ public String asFlatPrint() { val sb = new StringBuilder(); val fb = asFlatBuffers(false); val graph = FlatGraph.getRootAsFlatGraph(fb); sb.append("\nExternal variables:\n\n"); for (int e = 0; e < graph.variablesLength(); e++) { FlatVariable var = graph.variables(e); INDArray ndarray = null; try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { FlatArray fa = var.ndarray(); if (fa != null) { ndarray = Nd4j.createFromFlatArray(fa); } } sb.append(var.id().first()) .append(":<").append(var.name()).append("> "); if (ndarray == null) { sb.append("").append("; Values: ").append("").append(";\n"); } else { sb.append(Arrays.toString(ndarray.shapeInfoDataBuffer().asInt())).append("; Values: "); if (ndarray.data() == null) { //Empty array sb.append(""); } else if (ndarray.dataType() == DataType.UTF8) { sb.append(""); } else { if (ndarray.length() < 50) { sb.append(Arrays.toString(ndarray.data().asFloat()).replaceAll(" ", "")); } else { //Array is too long - only tak. last few values... sb.append("["); for (int i = 0; i < 50; i++) { if (i > 0) sb.append(","); sb.append(ndarray.data().getFloat(i)); } sb.append("]"); } } sb.append(";\n"); } } val map = Nd4j.getExecutioner().getCustomOperations(); sb.append("\nOps sequence:\n\n"); for (int e = 0; e < graph.nodesLength(); e++) { FlatNode node = graph.nodes(e); log.info("{}:<{}>", node.id(), node.name()); sb.append(node.id()) .append(":<").append(node.name()).append("> ").append(FlatBuffersMapper.getTypeFromByte(node.opType())); if (FlatBuffersMapper.getTypeFromByte(node.opType()) != Op.Type.CUSTOM) sb.append(": ").append(node.opNum()); else { val keys = map.keySet(); String opName = null; for (val k : keys) { val d = map.get(k); if (d.getHash() == node.opNum()) opName = k; } if (opName == null) opName = "unknown"; sb.append(": ").append(opName); } sb.append("; Inputs: {"); for (int i = 0; i < node.inputPairedLength(); i++) { IntPair pair = node.inputPaired(i); sb.append("[").append(pair.first()).append(":").append(pair.second()).append("]"); if (i < node.inputPairedLength() - 1) sb.append(", "); } sb.append("};"); sb.append(" OpNum: {").append(node.opNum()).append("};"); sb.append("\n"); } return sb.toString(); } /** * Generate and return a String representation of the current SameDiff instance
* Reports variables, ops, SameDiff function instances, and (where possible) array shapes.
* For ops, the input and output variables are reported.
* For variables, the ops that they are inputs to - or outputs of - are also reported * * @return A String representation of the SameDiff instance */ public String summary() { Map varMap = variableMap(); DifferentialFunction[] functions = ops(); int countVarsWithArrays = 0; for (String s : varMap.keySet()) { if (getArrForVarName(s) != null) { countVarsWithArrays++; } } StringBuilder sb = new StringBuilder(); String format = "%-25s%-20s"; sb.append("--- Summary ---\n"); sb.append(String.format(format, "Variables:", varMap.size())).append(" (").append(countVarsWithArrays).append(" with arrays)").append("\n") .append(String.format(format, "Functions:", functions.length)).append("\n") .append(String.format(format, "SameDiff Function Defs:", sameDiffFunctionInstances.size())).append("\n") .append("Loss function variables: ").append(getLossVariables()) .append("\n\n"); sb.append("--- Variables ---\n"); //Work out which function - if any - this arg is an output of... Map outputOfFn = new HashMap<>(); int maxLengthOutputOf = 22; //Length of "- Output Of Function -" int maxLengthOfName = 8; //Length of "- Name -" for (String s : varMap.keySet()) { String outputOf = null; for (SameDiffOp op : ops.values()) { List outputsOfOp = op.getOutputsOfOp(); if (outputsOfOp != null && outputsOfOp.contains(s)) { outputOf = op.getName(); break; } } if (outputOf == null) { outputOf = ""; } else { DifferentialFunction d = getOpById(outputOf); outputOf = d.getOwnName() + "(" + d.opName() + ")"; } outputOfFn.put(s, outputOf); maxLengthOutputOf = Math.max(maxLengthOutputOf, outputOf.length()); maxLengthOfName = Math.max(maxLengthOfName, s.length()); } maxLengthOutputOf += 2; maxLengthOfName += 2; //Create the output for values: format = "%-" + maxLengthOfName + "s%-20s%-20s%-20s%-" + maxLengthOutputOf + "s%-20s"; sb.append(String.format(format, "- Name -", "- Array Shape -", "- Variable Type -", "- Data Type-", "- Output Of Function -", "- Inputs To Functions -")).append("\n"); for (String s : varMap.keySet()) { INDArray arr = getArrForVarName(s); String arrayShape = "-"; if (arr != null) { arrayShape = Arrays.toString(arr.shape()); } else if (varMap.get(s).isPlaceHolder()) { SDVariable v = varMap.get(s); long[] phShape = v.placeholderShape(); if (phShape != null) { arrayShape = Arrays.toString(phShape); } } String varType = getVariable(s).getVariableType().toString(); String dtype = getVariable(s).dataType().toString(); List argNames = variables.get(s).getInputsForOp(); String dfArrStr = ""; if (argNames != null) { dfArrStr = argNames.toString(); } String outputOfStr = outputOfFn.get(s); sb.append(String.format(format, s, arrayShape, varType, dtype, outputOfStr, dfArrStr)).append("\n"); } sb.append("\n\n--- Functions ---\n"); //First: work out the amount of space we need for inputs and outputs... List dfInputStr = new ArrayList<>(); List dfOutputStr = new ArrayList<>(); int maxInLength = 10; //Length of "- Inputs -" int maxOutLength = 11; //Length of "- Outputs -" int maxOpNameLength = 17; //Default to min of 17 - length of "- Function Name -" int maxDfClassNameLength = 10; //Default to min of 10 for (DifferentialFunction df : functions) { String[] argNames = df.argNames(); String[] outNames = df.outputVariablesNames(); String argStr = Arrays.toString(argNames); String outStr = Arrays.toString(outNames); maxInLength = Math.max(maxInLength, argStr.length()); maxOutLength = Math.max(maxOutLength, outStr.length()); dfInputStr.add(argStr); dfOutputStr.add(outStr); String name = df.getOwnName() == null ? df.opName() : df.getOwnName(); maxOpNameLength = Math.max(maxOpNameLength, name.length()); maxDfClassNameLength = Math.max(maxDfClassNameLength, df.getClass().getSimpleName().length()); } //Extra padding space maxInLength += 2; maxOutLength += 2; maxOpNameLength += 2; maxDfClassNameLength += 2; format = "%-5s%-" + maxOpNameLength + "s%-" + maxDfClassNameLength + "s%-" + maxInLength + "s%-" + maxOutLength + "s"; sb.append(String.format(format, "", "- Function Name -", "- Op -", "- Inputs -", "- Outputs -")).append("\n"); for (int i = 0; i < functions.length; i++) { DifferentialFunction df = functions[i]; String fnName = df.getOwnName() == null ? df.opName() : df.getOwnName(); sb.append(String.format(format, i, fnName, df.getClass().getSimpleName(), dfInputStr.get(i), dfOutputStr.get(i))).append("\n"); } if (sameDiffFunctionInstances.size() > 0) { sb.append("\n\n--- SameDiff Defined Functions ---\n"); format = "%-20s%-15s%-15s%-15s"; sb.append(String.format(format, "- Name -", "- Variables -", "- Functions -", "- Fn Defs -")).append("\n"); for (Map.Entry e : sameDiffFunctionInstances.entrySet()) { SameDiff sd = e.getValue(); int vars = sd.variableMap().size(); int fns = (sd.ops() == null ? 0 : sd.ops().length); int defFns = sd.definedFunctionNames().size(); sb.append(String.format(format, e.getKey(), vars, fns, defFns)).append("\n"); } } return sb.toString(); } /** * For internal use only. * Creates a new distinct block name from baseName. * Block names are used by If and While */ public String newBlockName(String baseName) { if (baseName == null) return null; if (!blockNames.contains(baseName)) { blockNames.add(baseName); return baseName; } else { int i = 1; while (blockNames.contains(baseName + "_" + i)) { i++; } blockNames.add(baseName + "_" + i); return baseName + "_" + i; } } /** * Import a frozen Tensorflow graph to a new SameDiff graph. * * @param graphFile The text or binary file containing the graph * @return The imported graph */ public static SameDiff importFrozenTF(File graphFile) { return TFGraphMapper.importGraph(graphFile); } /** * See {@link #importFrozenTF(File)} */ public static SameDiff importFrozenTF(GraphDef graphDef) { return TFGraphMapper.importGraph(graphDef); } /** * See {@link #importFrozenTF(File)} *

* Again, the input can be text or binary. */ public static SameDiff importFrozenTF(InputStream graph) { return TFGraphMapper.importGraph(graph); } /** * Generate a new, distinct op name of the form <base>_#. *

* Applies name scope if active. * * @param base The base name to use * @param force Whether to force the result name to be the same as base. */ public String getOpName(String base, boolean force) { base = nameWithScope(base); if (force && ops.containsKey(base)) throw new IllegalArgumentException("Op with name \"" + base + "\" already exists"); else if (force) return base; int start = 1; // if we already have a name like "op_2", start from trying "op_3" if (base.contains("_") && base.matches(".*_\\d+")) { // extract number used to generate base Matcher num = Pattern.compile("(.*)_(\\d+)").matcher(base); // extract argIndex used to generate base if (num.find()) { start = Integer.parseInt(num.group(2)); base = num.group(1); } } String name = base; for (int i = start; true; i++) { // ensure that there are no variables that look like they are outputs of this op boolean varWithName = false; for (String varName : variables.keySet()) if (varName.startsWith(name + ":") || varName.equals(name)) varWithName = true; if (!ops.containsKey(name) && !varWithName) break; name = base + "_" + i; } return name; } /** * See {@link #getOpName(String, boolean)} * force is false */ public String getOpName(String base) { return getOpName(base, false); } /** * Generate a new, distinct variable name of the form <base>_#[:#]. *

* Applies name scopes if active. * * @param base The base of the name. * @param argIndex The argument index, used in the ":#". A value of 0 (or negative) does not include the ":#" part. * @param existingOp Whether to generate an distinct operation name from base (if false), or just use base (if true). */ public String generateNewVarName(String base, int argIndex, boolean existingOp) { base = nameWithScope(base); if (argIndex > 0 && base.contains(":")) { Matcher num = Pattern.compile("(.*):(\\d+)").matcher(base); // extract argIndex used to generate base if (num.find()) { argIndex = Integer.parseInt(num.group(2)) + 1; base = num.group(1); } } if (!existingOp) base = getOpName(base); if (argIndex > 0) base += ":" + argIndex; if (variables.containsKey(base)) throw new IllegalArgumentException("Variable with name \"" + base + "\" already exists"); return base; } /** * See {@link #generateNewVarName(String, int, boolean)} * existingOp is true. */ public String generateNewVarName(String base, int argIndex) { return generateNewVarName(base, argIndex, true); } /** * Returns an unused variable name of the format <base>_#. * * Intended to be used for custom variables (like weights), arguments and op outputs should use {@link #generateNewVarName(String, int)}. */ public String generateDistinctCustomVariableName(String base){ if(!variables.containsKey(base)) return base; int inc = 1; while(variables.containsKey(base + "_" + inc)){ inc++; } return base + "_" + inc; } @Override public String toString(){ return "SameDiff(nVars=" + variables.size() + ",nOps=" + ops.size() + ")"; } /** * See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)} */ public SDVariable ifCond(@NonNull SameDiffNoArgSingleLambda cond, @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){ return ifCond(null, null, cond, trueBody, falseBody); } /** * See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)} */ public SDVariable ifCond(String ifName, @NonNull SameDiffNoArgSingleLambda cond, @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){ return ifCond(null, ifName, cond, trueBody, falseBody); } /** * Constructs a If statement using the tensorflow style control flow operations (Switch and Merge) * * If the result of cond is true, returns the result of trueBody, otherwise returns the result of falseBody * * Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used to evaluate. * * See Tensorflow Control Flow Implementation * * @param outputName Name to give the output variable. If null, doesn't rename * @param ifName The name of the if block. If null, uses "if" * @param cond A lambda evaluating to the if condition * @param trueBody A lambda to be executed if cond is true (the if block) * @param falseBody A lambda to be executed if cond is false (the else block) * @return The value of trueBody if cond is true, or falseBody if it isn't */ public SDVariable ifCond(String outputName, String ifName, @NonNull SameDiffNoArgSingleLambda cond, @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){ ifName = newBlockName(ifName == null ? "if" : ifName); NameScope ifScope = sd.withNameScope(ifName); NameScope condScope = withNameScope("cond"); final SDVariable pred = cond.define(this); condScope.close(); if (pred.dataType() != DataType.BOOL) { //cleanup partially added block for(SDVariable v : getVariablesInScope(ifScope)) this.getVariables().remove(v.name()); for(SameDiffOp op : this.getOpsInScope(ifScope)) { for(String in : op.getInputsToOp()){ this.removeArgFromOp(in, op.getOp()); } this.getOps().remove(op.getName()); } throw new IllegalStateException("Can not use " + pred.name() + " as the condition of an If statement, the condition must be a boolean."); } final Map switches = new HashMap<>(); final Set declared = Sets.newHashSet(this.variableMap().keySet()); this.addArgumentInterceptor(new ArgumentInterceptor() { @Override public SDVariable intercept(SDVariable argument) { // if its declared in the if, we don't care acout it if(!declared.contains(argument.name())) return argument; // if we've already added a switch, move on if(switches.containsKey(argument.name())) return switches.get(argument.name())[1]; SDVariable[] s = switchOp(argument, pred); switches.put(argument.name(), s); return s[1]; } }); NameScope trueScope = this.withNameScope("trueBody"); SDVariable trueOut = trueBody.define(this); this.removeArgumentInterceptor(); if(declared.contains(trueOut.name())) { SDVariable[] s = switchOp(trueOut, pred); switches.put(trueOut.name(), s); trueOut = s[1]; } trueScope.close(); final Set declared2 = Sets.newHashSet(variableMap().keySet()); sd.addArgumentInterceptor(new ArgumentInterceptor() { @Override public SDVariable intercept(SDVariable argument) { // if its declared in the if, we don't care acout it if(!declared2.contains(argument.name())) return argument; // if we've already added a switch, move on if(switches.containsKey(argument.name())) return switches.get(argument.name())[0]; SDVariable[] s = switchOp(argument, pred); switches.put(argument.name(), s); return s[0]; } }); NameScope falseScope = this.withNameScope("falseBody"); SDVariable falseOut = falseBody.define(this); this.removeArgumentInterceptor(); if(declared2.contains(falseOut.name())) { SDVariable[] s = switchOp(falseOut, pred); switches.put(falseOut.name(), s); falseOut = s[0]; } falseScope.close(); SDVariable output = merge(trueOut, falseOut); ifScope.close(); return updateVariableNameAndReference(output, outputName); } /** * See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)} */ public SDVariable[] whileLoop(@NonNull SDVariable[] loopVars, @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){ return whileLoop(null, null, loopVars, cond, body); } /** * See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)} */ public SDVariable[] whileLoop(String loopName, @NonNull SDVariable[] loopVars, @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){ return whileLoop(null, loopName, loopVars, cond, body); } /** * Constructs a While loop using the tensorflow style control flow operations (Switch, Merge, Enter, Exit, and NextIteration) * * Repeatedly executes body on the loop variables and updates them with the results, until cond evaluates to false * * Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used for further iterations. * * See Tensorflow Control Flow Implementation * * @param outputNames Names to give the output variables. If null, doesn't rename * @param loopName The name of the loop block and frame (must be unique). If null, uses "if" * @param loopVars Loop variables' inputs * @param cond A lambda evaluating to the loop condition * @param body A lambda doing the loop operation and returning the new loop variable values * @return The values of the loop variables once condition is false */ public SDVariable[] whileLoop(String[] outputNames, final String loopName, @NonNull SDVariable[] loopVars, @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){ final String frameName = this.newBlockName(loopName == null ? "while" : loopName); NameScope loopScope = this.withNameScope(frameName); //SDVariable counter = SD.scalar(SD.generateNewVarName("counter", 0), 0); SDVariable[] entered = new SDVariable[loopVars.length]; for(int i = 0 ; i < loopVars.length ; i++){ entered[i] = new Enter(this, frameName, loopVars[i]).outputVariable(); } SDVariable[] merged = new SDVariable[loopVars.length]; Merge[] mergeOps = new Merge[loopVars.length]; for(int i = 0 ; i < loopVars.length ; i++){ // the second arg will later be replaced with the output of NextIteration // but that isn't available yet (and can't be, as it depends on this) mergeOps[i] = new Merge(this, entered[i], entered[i]); merged[i] = mergeOps[i].outputVariable(); } //Merge counterMerge = new Merge(SD, counter, counter); //counter = counterMerge.outputVariable(); NameScope condScope = this.withNameScope("cond"); SDVariable cond_result = cond.define(this, merged); condScope.close(); if (cond_result.dataType() != DataType.BOOL) throw new IllegalStateException("Can not use " + cond_result.name() + " as the condition of an While loop, the condition must be a boolean."); final Set alreadyEntered = Sets.newHashSet(); SDVariable[] trueSwitches = new SDVariable[loopVars.length]; SDVariable[] exits = new SDVariable[loopVars.length]; for(int i = 0 ; i < loopVars.length ; i++){ SDVariable[] s = switchOp(merged[i], cond_result); trueSwitches[i] = s[1]; alreadyEntered.add(s[1].name()); exits[i] = new Exit(this, s[0]).outputVariable(); } final Set declared = Sets.newHashSet(this.variableMap().keySet()); final Map done = new HashMap<>(); final SameDiff sd = this; this.addArgumentInterceptor(new ArgumentInterceptor() { @Override public SDVariable intercept(SDVariable argument) { if(!declared.contains(argument.name())) return argument; if(alreadyEntered.contains(argument.name())) return argument; if(done.containsKey(argument.name())) return done.get(argument.name()); SDVariable e = new Enter(sd, frameName, argument, true).outputVariable(); done.put(argument.name(), e); return e; } }); NameScope bodyScope = this.withNameScope("body"); SDVariable[] outs = body.define(this, trueSwitches); bodyScope.close(); this.removeArgumentInterceptor(); //counter.add(1); for(int i = 0 ; i < loopVars.length ; i++){ SDVariable n = new NextIteration(this, outs[i]).outputVariable(); mergeOps[i].replaceArg(1,n); } //counterMerge.replaceArg(1, counter); loopScope.close(); return updateVariableNamesAndReferences(exits, outputNames); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy