/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.autodiff.samediff;
import com.google.flatbuffers.FlatBufferBuilder;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.*;
import org.nd4j.autodiff.listeners.impl.HistoryListener;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.api.OutAndGrad;
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
import org.nd4j.autodiff.samediff.array.ThreadSafeArrayHolder;
import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
import org.nd4j.autodiff.samediff.config.EvaluationConfig;
import org.nd4j.autodiff.samediff.config.FitConfig;
import org.nd4j.autodiff.samediff.config.OutputConfig;
import org.nd4j.autodiff.samediff.internal.*;
import org.nd4j.autodiff.samediff.ops.*;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.common.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.graph.*;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.*;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.common.primitives.AtomicBoolean;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.common.util.ND4JFileUtils;
import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Sets;
import org.nd4j.shade.guava.collect.Table;
import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.tensorflow.framework.GraphDef;
import java.io.*;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static org.nd4j.autodiff.util.SameDiffUtils.stackOutputs;
/**
* SameDiff is the entrypoint for ND4J's automatic differentiation functionality.
*
* You define a graph symbolically.
*
* That graph accumulates operations.
*
* In order to execute the graph, you run one of the execution methods, such as {@link #output(Map, String...)}
*/
@Slf4j
public class SameDiff extends SDBaseOps {
protected static final String GRAD_FN_KEY = "grad";
//Fields for graph structure and execution
@Getter
private final Map variables = new LinkedHashMap<>(); //Use linked hash map to guarantee iteration order based on order they were added. Used in inputs() and flatbuffers serde
@Getter
private final Map ops = new LinkedHashMap<>();
@Getter
private final Map sessions = new ConcurrentHashMap<>(); //Key: thread ID
private ArrayHolder constantArrays = new ThreadSafeArrayHolder(true);
private ArrayHolder variablesArrays = new ThreadSafeArrayHolder(true);
private final Map> placeholdersPerThread = new ConcurrentHashMap<>(); //Placeholders for each thread - if the user sets them
private final List lossVariables = new ArrayList<>();
private final List listeners = new ArrayList<>();
private final List nameScopes = new ArrayList<>(); //Used as a stack
private List outputs; //Names of the output variables, set by the user.
///////////////////////////////////////
//Fields related to training
@Getter
private TrainingConfig trainingConfig; //Configuration for training. Must be set for training/evaluation, but not for other operations
@Getter
private boolean initializedTraining; //True if training setup has been done
@Getter
private Map updaterMap; //GradientUpdater instance for each trainable parameter
////////////////////////////////////////
// private DifferentialFunctionFactory functionFactory;
// counter for auto-naming variables
private int variableId = 0;
////////////////////////////////////////
/**
* Op creator object for math operations
*/
public final SDMath math = new SDMath(this);
/**
* Op creator object for random number generation operations
*/
public final SDRandom random = new SDRandom(this);
/**
* Op creator object for general neural network operations
*/
public final SDNN nn = new SDNN(this);
/**
* Op creator object for convolutional neural network operations
*/
public final SDCNN cnn = new SDCNN(this);
/**
* Op creator object for recurrent neural network operations
*/
public final SDRNN rnn = new SDRNN(this);
/**
* Op creator object for loss function operations
*/
public final SDLoss loss = new SDLoss(this);
/**
* Op creator object for image operations
*/
public final SDImage image = new SDImage(this);
/**
* Op creator object for bitwise operations
*/
public final SDBitwise bitwise = new SDBitwise(this);
/**
* Op creator object for linalg operations
*/
public final SDLinalg linalg = new SDLinalg(this);
/**
* Op creator object for math operations
*/
public SDMath math() {
return math;
}
/**
* Op creator object for random number generation operations
*/
public SDRandom random() {
return random;
}
/**
* Op creator object for general neural network operations
*/
public SDNN nn() {
return nn;
}
/**
* Op creator object for convolutional neural network operations
*/
public SDCNN cnn() {
return cnn;
}
/**
* Op creator object for recurrent neural network operations
*/
public SDRNN rnn() {
return rnn;
}
/**
* Op creator object for loss function operations
*/
public SDLoss loss() {
return loss;
}
/**
* Op creator object for image operations
*/
public SDImage image() {
return image;
}
/**
* Op creator object for bitwise operations
*/
public SDBitwise bitwise(){
return bitwise;
}
/**
* Op creator object for linalg operations
*/
public SDLinalg linalg(){
return linalg;
}
private Map sameDiffFunctionInstances;
private Table fieldVariableResolutionMapping;
// flag, shows if graph was already registered with libnd4j
private transient AtomicBoolean wasRegistered = new AtomicBoolean(false);
//debug mode variables
@Getter
private boolean debugMode;
@Getter
private Stack argumentInterceptors = new Stack<>();
@Getter
private Set pausedArgumentInterceptors = new HashSet<>();
private Set blockNames = new HashSet<>();
@Getter
@Setter
boolean logExecution = true;
@Getter
private SameDiff parent;
@Getter
private SameDiff child;
/**
* Clears debugging state and disables debug mode.
*/
public SameDiff disableDebugging() {
debugMode = false;
return this;
}
/**
* Enables tracing of graphs automatically.
*/
public SameDiff enableDebugMode() {
debugMode = true;
return this;
}
/**
* Set the current SameDiff-wide {@link Listener} instances.
*
* Note that this will overwrite the current listener list.
* If you want to use additional listeners for a single operation,
* use the listener arguments in those methods (e.g. {@link #fit()} and {@link FitConfig#listeners(Listener...)}).
*
* @param listeners Listeners
*/
public void setListeners(Listener... listeners) {
this.listeners.clear();
addListeners(listeners);
}
/**
* See {@link #setListeners(Listener...)}.
*/
public void setListeners(Collection extends Listener> listeners) {
this.listeners.clear();
addListeners(listeners);
}
/**
* Add SameDiff-wide {@link Listener} instances.
*
* If you want to use additional listeners for a single operation,
* use the listener arguments in those methods (e.g. {@link #fit()} and {@link FitConfig#listeners(Listener...)}).
*
* @param listeners Listeners
*/
public void addListeners(Listener... listeners) {
addListeners(Arrays.asList(listeners));
}
/**
* See {@link #addListeners(Listener...)}.
*/
public void addListeners(Collection extends Listener> listeners) {
this.listeners.addAll(listeners);
}
/**
* Gets the current SameDiff-wide listeners.
*/
public List getListeners() {
return listeners;
}
/**
* Set the array holders for variable and constant arrays
* NOTE: this is usually reserved for developers and internal use, and should not be needed by almost all users
* See {@link ArrayHolder} for more details
*
* @param variableArrayHolder Array holder for variable arrays
* @param constantArrayHolder Array holder for constant arrays
* @param initialize If true: transfer any arrays from the current array holders to the new/specified ones
*/
public void setArrayHolders(@NonNull ArrayHolder variableArrayHolder, @NonNull ArrayHolder constantArrayHolder, boolean initialize){
if(initialize){
variableArrayHolder.initFrom(this.variablesArrays);
constantArrayHolder.initFrom(this.constantArrays);
}
this.variablesArrays = variableArrayHolder;
this.constantArrays = constantArrayHolder;
}
/**
* @return The current name scope, if any (null otherwise). See {@link #withNameScope(String)} for more details.
*/
public String currentNameScope() {
if (nameScopes.isEmpty())
return null;
//Would use String.join but that is Java 8+
StringBuilder sb = new StringBuilder();
boolean first = true;
for (NameScope ns : nameScopes) {
if (!first) {
sb.append("/");
}
sb.append(ns.getName());
first = false;
}
return sb.toString();
}
/**
* @return The name with the current name scope (if any) appended. See {@link #withNameScope(String)}
*/
protected String nameWithScope(String name) {
String scope = currentNameScope();
if (scope == null) {
return name;
}
if (!name.startsWith(scope + "/"))
return scope + "/" + name;
else
return name;
}
//Intentionally package private
void addNameScope(NameScope nameScope) {
nameScopes.add(nameScope);
}
//Intentionally package private
void closeNameScope(NameScope nameScope) {
//Check that the name scope is closed correctly/in order
Preconditions.checkState(!nameScopes.isEmpty(), "Cannot close name scope: no name scopes are currently defined");
Preconditions.checkState(nameScopes.get(nameScopes.size() - 1).equals(nameScope),
"Cannot close name scope %s: Name scopes must be closed in order. Current name scopes: \"%s\"", nameScope, currentNameScope());
nameScopes.remove(nameScopes.size() - 1);
}
/**
* Create a name scope. Name scopes append a prefix to the names of any variables and ops created while they are open.
*
* {@code
* SameDiff sd = SameDiff.create();
* SDVariable x = sd.var("x", DataType.FLOAT, 5);
* SDVariable y;
* try(NameScope ns = sd.withNameScope("myScope"){
* y = sd.var("y", DataType.FLOAT, 5);
* }
* SDVariable z = sd.var("z", DataType.FLOAT, 5);
*
* String xName = x.name(); //RESULT: "x"
* String yName = y.name(); //RESULT: "myScope/y"
* String zName = z.name(); //RESULT: "z"
* }
*
*
* Note that name scopes can also be nested:
*
* {@code
* SameDiff sd = SameDiff.create();
* SDVariable x;
* try(NameScope ns = sd.withNameScope("first"){
* try(NameScope ns2 = sd.withNameScope("second"){
* x = sd.var("x", DataType.FLOAT, 5);
* }
* }
* String xName = x.name(); //RESULT: "first/second/x"
* }
*
*
* @param nameScope Name of the name scope to open/create
* @return The NameScope object
*/
public NameScope withNameScope(String nameScope) {
NameScope ns = new NameScope(this, nameScope);
addNameScope(ns);
return ns;
}
/**
* Gets all operations in a given name scope.
*/
public List getOpsInScope(NameScope scope) {
ArrayList ops = new ArrayList<>();
for (SameDiffOp v : this.ops.values()) {
if (v.getName().startsWith(scope.getName()))
ops.add(v);
}
return ops;
}
/**
* See {@link #getOpsInScope(NameScope)}.
*/
public List getOpsInScope(String scope){
return getOpsInScope(new NameScope(this, scope));
}
/**
* Gets all variables in a given name scope.
*/
public List getVariablesInScope(NameScope scope) {
ArrayList vars = new ArrayList<>();
for (SDVariable v : variables()) {
if (v.name().startsWith(scope.getName()))
vars.add(v);
}
return vars;
}
/**
* See {@link #getVariablesInScope(NameScope)}.
*/
public List getVariablesInScope(String scope){
return getVariablesInScope(new NameScope(this, scope));
}
/**
* @param sameDiff
* @return
*/
public SDVariable invokeGraphOn(SameDiff sameDiff) {
//map the new vertices on to the old ones
Map thisVertexIdToNew = new HashMap<>();
int idx = 1;
for (val var : variables()) {
SDVariable clone = var.clone(this);
SDVariable newVar = sameDiff.var(clone);
if (var.getVariableType() != VariableType.ARRAY && var.getArr() != null ) { //ARRAY type = "activations" - are overwritten anyway
sameDiff.associateArrayWithVariable(var.getArr(), newVar);
}
thisVertexIdToNew.put(idx, idx);
clone.setSameDiff(sameDiff);
idx++;
}
Map reverseMap = new HashMap<>();
int count = 0;
for( Variable v : variables.values()){
reverseMap.put(v.getName(), count++);
}
val newFunctions = new LinkedHashMap();
for (SameDiffOp op : ops.values()) {
DifferentialFunction function = op.getOp();
//Clone the op
DifferentialFunction clone = FlatBuffersMapper.cloneViaSerialize(this, function, reverseMap);
clone.setSameDiff(sameDiff);
clone.setOwnName(function.getOwnName());
if (sameDiff.opExists(function.getOwnName()))
sameDiff.putOpForId(function.getOwnName(), function);
newFunctions.put(function.getOwnName(), clone);
val argsForFunction = function.args();
val outputsForFunction = function.outputVariables();
//note that these have the same variable names
sameDiff.addArgsFor(argsForFunction, clone);
sameDiff.addOutgoingFor(outputsForFunction, function);
for (val arg : clone.args()) {
arg.setSameDiff(sameDiff);
}
for (val output : clone.outputVariables()) {
output.setSameDiff(sameDiff);
}
sameDiff.ops.put(function.getOwnName(), op);
}
return sameDiff.variables().get(sameDiff.variables().size() - 1);
}
/**
* Returns true if the given function id exists
*
* @param id the function id to test for
* @return true if the function id exists, false otherwise
*/
public boolean opExists(String id) {
return ops.containsKey(id);
}
/**
* Get the differential function (if any) that this variable is the output for
*
* @param variableName Name of the variable
* @return The differential function that this variable is an output of, or null if it is not the output of a function
*/
public DifferentialFunction getVariableOutputOp(String variableName) {
Preconditions.checkState(variables.containsKey(variableName), "No variable with name \"%s\" found in graph", variableName);
if (variables.get(variableName).getOutputOfOp() == null)
return null;
return ops.get(variables.get(variableName).getOutputOfOp()).getOp();
}
/**
* Get the function by the {@link DifferentialFunction#getOwnName()}
*
* @param id the id of the function
* @return the function for the given id if it exists
*/
public DifferentialFunction getOpById(@NonNull String id) {
if (!ops.containsKey(id)) {
throw new ND4JIllegalStateException("No function with id " + id + " found!");
}
return ops.get(id).getOp();
}
/**
* Put the function for the given id
*
* @param id the id of the function
* @param function the function
*/
public void putOpForId(String id, DifferentialFunction function) {
if (ops.containsKey(id) && ops.get(id).getOp() == null) {
throw new ND4JIllegalStateException("Function by id already exists!");
}
if (!ops.containsKey(id)) {
ops.put(id, SameDiffOp.builder().name(id).op(function).build());
}
}
/**
* Returns the name(s) of the inputs for the given function
*
* @param function the function to get the inputs for
* @return the input ids for a given function
*/
public String[] getInputsForOp(@NonNull DifferentialFunction function) {
if (!ops.containsKey(function.getOwnName()))
throw new ND4JIllegalStateException("Unknown function instance id found: \"" + function.getOwnName() + "\"");
List inputs = ops.get(function.getOwnName()).getInputsToOp();
return inputs == null ? null : inputs.toArray(new String[inputs.size()]);
}
/**
* Returns the name(s) of the outputs for the given function
*
* @param function the function to get the outputs for
* @return the outputs ids for a given function
*/
public String[] getOutputsForOp(DifferentialFunction function) {
if (!ops.containsKey(function.getOwnName()))
throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName());
List outputs = ops.get(function.getOwnName()).getOutputsOfOp();
return outputs == null ? null : outputs.toArray(new String[outputs.size()]);
}
/**
* Get the output variable(s) for the specified differential function
*
* @param function the function reference to get the output variable(s) for
* @return the output variables for the given function
*/
public SDVariable[] getOutputVariablesForOp(DifferentialFunction function) {
val inputs = getOutputsForOp(function);
if (inputs == null) {
throw new ND4JIllegalStateException("No inputs found for function " + function);
}
val vars = new SDVariable[inputs.length];
for (int i = 0; i < inputs.length; i++) {
vars[i] = getVariable(inputs[i]);
}
return vars;
}
/**
* Get the input variable(s) for the specified differential function
*
* @param function the function reference to get the input variable(s) for
* @return the input variables for the given function
*/
public SDVariable[] getInputVariablesForOp(DifferentialFunction function) {
val inputs = getInputsForOp(function);
if (inputs == null) {
throw new ND4JIllegalStateException("No inputs found for function " + function);
}
val vars = new SDVariable[inputs.length];
for (int i = 0; i < inputs.length; i++) {
vars[i] = getVariable(inputs[i]);
if (vars[i] == null) {
throw new ND4JIllegalStateException("Found null variable at index " + i);
}
}
return vars;
}
/**
* Set the stored {@link INDArray} for a variable. Only works if the variable is of type
* {@link VariableType#CONSTANT}, {@link VariableType#PLACEHOLDER}, or {@link VariableType#VARIABLE}.
*/
public void setArrayForVariable(@NonNull String varName, @NonNull INDArray arr) {
Preconditions.checkState(variables.containsKey(varName), "No variable with name \"%s\" exists", varName);
SDVariable v = getVariable(varName);
if (v.isConstant()) {
constantArrays.setArray(varName, arr);
} else if (v.getVariableType() == VariableType.VARIABLE) {
variablesArrays.setArray(varName, arr);
} else if (v.isPlaceHolder()) {
long tid = Thread.currentThread().getId();
if (!placeholdersPerThread.containsKey(tid)) {
placeholdersPerThread.put(tid, new HashMap());
}
placeholdersPerThread.get(tid).put(varName, arr);
} else {
throw new UnsupportedOperationException("Cannot set variable of type " + v.getVariableType() + " using this method");
}
}
/**
* Returns true if the given vertex id and {@link INDArray} already exist.
*
* @param varName the vertex id
* @return true if a vertex with the given INDArray exists, and it has an INDArray associated with it
*/
public boolean arrayAlreadyExistsForVarName(String varName) {
SDVariable var = getVariable(varName);
switch (var.getVariableType()) {
case VARIABLE:
return variablesArrays.hasArray(varName);
case ARRAY:
long tid = Thread.currentThread().getId();
return sessions.containsKey(tid) && sessions.get(tid).contains(varName, InferenceSession.OUTER_FRAME, 0, null);
case CONSTANT:
return constantArrays.hasArray(varName);
case PLACEHOLDER:
return placeholdersPerThread.containsKey(Thread.currentThread().getId()) &&
placeholdersPerThread.get(Thread.currentThread().getId()).containsKey(varName);
default:
throw new RuntimeException("Unknown variable type: " + var.getVariableType());
}
}
/**
* Get an {@link INDArray} for a given vertex id, or null if none exists
*
* @param varName Variable name to get the array for
* @return Array, or null if none exists
*/
public INDArray getArrForVarName(@NonNull String varName) {
Preconditions.checkState(variables.containsKey(varName), "No variable found with name \"%s\"", varName);
SDVariable v = variables.get(varName).getVariable();
switch (v.getVariableType()) {
case VARIABLE:
return variablesArrays.getArray(varName);
case CONSTANT:
if (!constantArrays.hasArray(varName))
return null;
return constantArrays.getArray(varName);
case ARRAY:
//Only stored in inference session...
InferenceSession s = sessions.get(Thread.currentThread().getId());
if (s == null)
return null;
return s.get(varName, InferenceSession.OUTER_FRAME, 0, null, false);
case PLACEHOLDER:
long tid = Thread.currentThread().getId();
if (placeholdersPerThread.get(tid) == null || !placeholdersPerThread.get(tid).containsKey(varName))
return null;
return placeholdersPerThread.get(tid).get(varName);
default:
throw new RuntimeException("Unknown variable type: " + v.getVariableType());
}
}
/**
* Associate the array with the given variable.
*
* @param arr the array to get the variable for
* @param variable the name of the variable to associate the array with
*/
public void associateArrayWithVariable(INDArray arr, @NonNull String variable) {
Preconditions.checkState(variables.containsKey(variable), "Cannot associate array with variable \"%s\": " +
"variable \"%s\" does not exist in this SameDiff instance", variable, variable);
associateArrayWithVariable(arr, this.getVariable(variable));
}
/**
* Associate the array with the given variable.
*
* @param arr the array to get the variable for
* @param variable the variable to associate the array with
*/
public void associateArrayWithVariable(INDArray arr, SDVariable variable) {
if (variable == null) {
throw new ND4JIllegalArgumentException("Variable must not be null!");
}
if (arr == null) {
throw new ND4JIllegalArgumentException("Array must not be null");
}
if (variable.dataType() != arr.dataType())
arr = arr.castTo(variable.dataType());
Preconditions.checkState(variable.dataType() == arr.dataType(), "Variable \"%s\" has datatype %s: cannot associate array with type %s with this variable",
variable.name(), variable.dataType(), arr.dataType());
if (sessions.get(Thread.currentThread().getId()) == null) {
sessions.put(Thread.currentThread().getId(), new InferenceSession(this));
}
if (arr.isAttached()) {
arr = arr.detach();
}
switch (variable.getVariableType()) {
case VARIABLE:
variablesArrays.setArray(variable.name(), arr);
break;
case CONSTANT:
constantArrays.setArray(variable.name(), arr);
break;
case ARRAY:
throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY - arrays for" +
" this type of variable is calculated ");
case PLACEHOLDER:
//Validate placeholder shapes:
long[] phShape = variable.placeholderShape();
Preconditions.checkState(phShape == null || Shape.shapeMatchesPlaceholder(phShape, arr.shape()),
"Invalid array shape: cannot associate an array with shape %ndShape with a placeholder of shape %s:" +
"shape is wrong rank or does not match on one or more dimensions", arr, phShape);
long tid = Thread.currentThread().getId();
if (!placeholdersPerThread.containsKey(tid)) {
placeholdersPerThread.put(tid, new HashMap());
}
placeholdersPerThread.get(tid).put(variable.name(), arr);
break;
default:
throw new IllegalStateException("Unknown variable type: " + variable.getVariableType());
}
//putOrUpdateShapeForVarName(variable.name(), arr.shape(), true);
//Also update nested SameDiff instances (such as gradient function)
if (sameDiffFunctionInstances != null && sameDiffFunctionInstances.size() > 0) {
for (Map.Entry e : sameDiffFunctionInstances.entrySet()) {
SameDiff sd = e.getValue();
SDVariable v = sd.getVariable(variable.name());
if (v != null) {
sd.associateArrayWithVariable(arr, v);
}
}
}
}
/**
* Update the constant or variable type SDVariable with the values from the specified
* array. Note that unlike {@link #associateArrayWithVariable(INDArray, String)} this method will take the
* values from the argument array and assign it to the current array.
* The actual array (INDArray object) will not be stored or otherwise used within the SameDiff instance.
* @param arr Array values to set
* @param variable Variable to update the array of. Must be CONSTANT or VARIBLE type SDVariable
*/
public void assignArray(@NonNull INDArray arr, @NonNull SDVariable variable){
Preconditions.checkState(variable.getVariableType() == VariableType.VARIABLE || variable.getVariableType() == VariableType.CONSTANT,
"assignArray method can only be used with VARIBLE or CONSTANT type SDVariables, variable \"%s\" has type %s", variable.name(), variable.getVariableType());
//DeviceLocal doesn't work with views
if(arr.isView())
arr = arr.dup();
if(variable.getVariableType() == VariableType.VARIABLE ){
variablesArrays.setArray(variable.name(), arr);
} else {
constantArrays.setArray(variable.name(), arr);
}
}
/**
* Associate a {@link SameDiff} namespace as a sub function.
*
* @param name the opName of the function
* @param nameSpace the namespace
*/
public void putSubFunction(String name, SameDiff nameSpace) {
if (sameDiffFunctionInstances.containsKey(name) && sameDiffFunctionInstances.get(name) != nameSpace) {
throw new ND4JIllegalStateException("Unable to replace samediff namespace. Please choose another opName");
}
sameDiffFunctionInstances.put(name, nameSpace);
}
/**
* Return a copy of the internal variable map
*
* @return Map of variables by name
*/
public Map variableMap() {
Map ret = new LinkedHashMap<>();
for (Variable v : variables.values()) {
ret.put(v.getName(), v.getVariable());
}
return ret;
}
/**
* The set of defined SameDiff function names. SameDiff function instances should not be confused
* with DifferentialFunction ops; an example of a SameDiff function instance is the gradient "grad" function
*
* @return Set of defined SameDiff function instance names
*/
public Collection definedFunctionNames() {
return this.sameDiffFunctionInstances.keySet();
}
private SameDiff() {
super(null);
super.sd = this;
sameDiffFunctionInstances = new LinkedHashMap<>();
fieldVariableResolutionMapping = HashBasedTable.create();
}
/**
* Attempts to insert the {@link DifferentialFunction} reference in to this {@link SameDiff} instance.
* If the given array field with the given index already exists, it will do a reference check to ensure that the 2
* array fields are the same. If not, an exception is thrown.
* If the instances are the same (by semantics, not reference) then it will just return the original instance.
* This is to ensure that instances that are created are unique and reference checked.
*
* @param function the array field to attempt to create
* @return Original instance
*/
public X setupFunction(X function) {
Preconditions.checkNotNull(function, "Passed in function must not be null!");
if (function instanceof SDVariable) {
if (function.getSameDiff() != this) {
function.setSameDiff(this);
}
return function;
}
return function;
}
/**
* Adds outgoing arguments to the graph for the specified DifferentialFunction
* Also checks for input arguments and updates the graph adding an appropriate edge when the full graph is declared.
*
* @param variables Variables - arguments for the specified differential function
* @param function Differential function
*/
public void addOutgoingFor(SDVariable[] variables, DifferentialFunction function) {
String[] varNames = new String[variables.length];
for (int i = 0; i < varNames.length; i++) {
varNames[i] = variables[i].name();
}
addOutgoingFor(varNames, function);
}
/**
* Adds outgoing arguments to the graph for the specified DifferentialFunction
* Also checks for input arguments and updates the graph adding an appropriate edge when the full graph is declared.
*
* @param varNames Name of the variables that are outputs of the specified differential function
* @param function Differential function
*/
public void addOutgoingFor(String[] varNames, DifferentialFunction function) {
if (function.getOwnName() == null)
throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
if (ops.get(function.getOwnName()).getOutputsOfOp() != null && !ops.get(function.getOwnName()).getOutputsOfOp().isEmpty()) {
throw new ND4JIllegalStateException("Outgoing arguments already declared for " + function);
}
if (varNames == null)
throw new ND4JIllegalStateException("Var names can not be null!");
for (int i = 0; i < varNames.length; i++) {
if (varNames[i] == null)
throw new ND4JIllegalStateException("Variable name elements can not be null!");
}
ops.get(function.getOwnName()).setOutputsOfOp(Arrays.asList(varNames));
for (String resultName : varNames) {
variables.get(resultName).setOutputOfOp(function.getOwnName());
}
}
/**
* Add a new argument interceptor to the interceptor stack
*
* For internal use only.
*
* When a op is added with arguments, most recent argument interceptor is called on it.
* If ops are added in that interceptor, the next most recent will be called on their args, and so on.
*
* @param interceptor the argument interceptor to add
*/
public void addArgumentInterceptor(@NonNull ArgumentInterceptor interceptor) {
argumentInterceptors.push(interceptor);
}
private boolean isArgumentInterceptorPaused(@NonNull ArgumentInterceptor interceptor) {
return pausedArgumentInterceptors.contains(interceptor);
}
private ArgumentInterceptor getArgumentInterceptorToUse() {
if (argumentInterceptors.isEmpty())
return null;
ArgumentInterceptor use = argumentInterceptors.peek();
int i = 1;
while (isArgumentInterceptorPaused(use)) {
if (argumentInterceptors.size() - i < 0)
return null;
use = argumentInterceptors.elementAt(argumentInterceptors.size() - i);
i++;
}
return use;
}
/**
* Remote the top (most recently added) argument interceptor
*
* For internal use only.
*/
public void removeArgumentInterceptor() {
if (!argumentInterceptors.isEmpty())
argumentInterceptors.pop();
}
/**
* Pause the top (most recently added) argument interceptor
*
* For internal use only.
*/
public void pauseArgumentInterceptor() {
pausedArgumentInterceptors.add(argumentInterceptors.peek());
}
/**
* Pause the given argument interceptor
*
* For internal use only.
*
* @param interceptor the argument interceptor to pause
*/
public void pauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor) {
pausedArgumentInterceptors.add(interceptor);
}
/**
* Unpause the top (most recently added) argument interceptor
*
* For internal use only.
*/
public void unpauseArgumentInterceptor() {
pausedArgumentInterceptors.remove(argumentInterceptors.peek());
}
/**
* Unpause the top given argument interceptor
*
* For internal use only.
*
* @param interceptor the argument interceptor to unpause
*/
public void unpauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor) {
pausedArgumentInterceptors.remove(interceptor);
}
/**
* Adds incoming arguments for the specified differential function to the graph
*
* @param variables Name of the variables that are arguments (inputs) to the specified function
* @param function Function
*/
public void addArgsFor(String[] variables, DifferentialFunction function) {
ArgumentInterceptor interceptor = getArgumentInterceptorToUse();
if (interceptor != null) {
pauseArgumentInterceptor(interceptor);
for (int i = 0; i < variables.length; i++) {
variables[i] = interceptor.intercept(getVariable(variables[i])).name();
}
unpauseArgumentInterceptor(interceptor);
}
if (function.getOwnName() == null)
throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
//Add function if it doesn't exist
//TODO could "not existing" be a bug sometimes?
if (!ops.containsKey(function.getOwnName())) {
ops.put(function.getOwnName(), SameDiffOp.builder().name(function.getOwnName()).op(function).build());
}
//Update variable 'inputs to op' accounting for repeated inputs (like y = x+x)
ops.get(function.getOwnName()).setInputsToOp(Arrays.asList(variables)); //Duplicate variables OK/required here
for (String variableName : variables) {
List funcs = this.variables.get(variableName).getInputsForOp();
if (funcs == null) {
funcs = new ArrayList<>();
this.variables.get(variableName).setInputsForOp(funcs);
}
if (!funcs.contains(function.getOwnName())) //Avoid duplicates for function names.
funcs.add(function.getOwnName());
}
}
/**
* Adds incoming arguments for the specified differential function to the graph
*
* @param variables variables that are arguments (inputs) to the specified function
* @param function Function
*/
public void addArgsFor(SDVariable[] variables, DifferentialFunction function) {
String[] varNames = new String[variables.length];
for (int i = 0; i < varNames.length; i++) {
if (variables[i] == null)
throw new ND4JIllegalStateException("Found null variable at index " + i);
varNames[i] = variables[i].name();
}
addArgsFor(varNames, function);
}
/**
* Replaces the argument at i with newArg for function
* Does not use (or remove) ArgumentInterceptor stuff
*/
public void replaceArgFor(int i, @NonNull SDVariable newArg, @NonNull DifferentialFunction function) {
Preconditions.checkArgument(i < function.args().length, "Index out of range: function " +
function.getOwnName() + " only has " + function.args().length + " args but you are trying" +
"to replace the argument at " + i);
String oldName = function.arg(i).name();
String newName = newArg.name();
List oldArgs = ops.get(function.getOwnName()).getInputsToOp();
oldArgs = new ArrayList<>(oldArgs);
oldArgs.set(i, newName);
ops.get(function.getOwnName()).setInputsToOp(oldArgs);
List funcs = this.variables.get(newName).getInputsForOp();
if (funcs == null) {
funcs = new ArrayList<>();
this.variables.get(newName).setInputsForOp(funcs);
}
if (!funcs.contains(function.getOwnName())) //Avoid duplicates for function names.
funcs.add(function.getOwnName());
List oldFuncs = this.variables.get(oldName).getInputsForOp();
if (oldFuncs != null) {
if (!ArrayUtils.contains(function.argNames(), oldName))
oldFuncs.remove(function.getOwnName());
}
}
/**
* Returns true if this function already has defined arguments
*
* @param function the function to check
* @return true if the function has args, false otherwise
*/
public boolean hasArgs(DifferentialFunction function) {
List vertexIdArgs = ops.get(function.getOwnName()).getInputsToOp();
return vertexIdArgs != null && vertexIdArgs.size() > 0;
}
/**
* Clear the placeholder arrays from the SameDiff instance
*
* @param allThreads If true: clear the placeholders for all threads. False: clear only for current thread
*/
public void clearPlaceholders(boolean allThreads) {
if (allThreads) {
this.placeholdersPerThread.clear();
} else {
long tid = Thread.currentThread().getId();
this.placeholdersPerThread.remove(tid);
}
for (SameDiff sd : this.sameDiffFunctionInstances.values()) {
sd.clearPlaceholders(allThreads);
}
}
/**
* Clear the input arrays to each op.
* This is usually not required, under normal SameDiff use
*/
public void clearOpInputs() {
for (SameDiffOp op : ops.values()) {
if (op.getOp() instanceof Op) {
Op o = ((Op) op.getOp());
o.setX(null);
if (o.y() != null) {
o.setY(null);
}
} else if (op.getOp() instanceof DynamicCustomOp) {
DynamicCustomOp o = (DynamicCustomOp) op.getOp();
o.setInputArguments((INDArray[]) null);
}
}
for (SameDiff sd : this.sameDiffFunctionInstances.values()) {
sd.clearOpInputs();
}
}
/**
* Get an array of differential functions that have been defined for this SameDiff instance
*
* @return Array of differential functions
*/
public DifferentialFunction[] ops() {
List out = new ArrayList<>(ops.size());
for (SameDiffOp op : ops.values()) {
out.add(op.getOp());
}
return out.toArray(new DifferentialFunction[out.size()]);
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + (variables != null ? variables.hashCode() : 0);
return result;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass())
return false;
SameDiff sameDiff = (SameDiff) o;
boolean eqVars = variables.equals(sameDiff.variables);
boolean eqOps = ops.equals(sameDiff.ops);
return eqVars && eqOps;
}
/**
* Create a new (empty) SameDiff instance without any functions or variables
*
* @return New SameDiff instance
*/
public static SameDiff create() {
return new SameDiff();
}
/**
* Clone/duplicate the SameDiff instance, including arrays etc. The returned SameDiff instance should have no
* shared state with the original instance
*
* @return The cloned SameDiff instance
*/
public SameDiff dup() {
ByteBuffer bb = asFlatBuffers(true);
try {
return fromFlatBuffers(bb);
} catch (IOException e){
throw new RuntimeException(e);
}
}
/**
* Count the number of elements in all arrays, according to {@link SDVariable#getShape()}
*
* @return Number of array elements for all variables
*/
public long numElements() {
long ret = 0;
for (SDVariable variable : variables()) {
long[] shape = variable.getShape();
if (shape != null) {
ret += ArrayUtil.prod(shape);
}
}
return ret;
}
/**
* Returns the inputs (placeholders) for the SameDiff graph
*
* @return the inputs for this graph
*/
public List inputs() {
List out = new ArrayList<>();
for (String s : variables.keySet()) {
if (isPlaceHolder(s))
out.add(s);
}
return out;
}
/**
* Outputs are the names of the predictions of the network.
* Note that the outputs must be set using {@link #setOutputs(List)} first
*
* @return The outputs of the SameDiff instance, or null if no outputs have been set
*/
public List outputs() {
return this.outputs;
}
/**
* See {@link #setOutputs(List)}
*/
public void setOutputs(String... outputs){
setOutputs(outputs == null ? null : Arrays.asList(outputs));
}
/**
* Set the outputs of the SameDiff instance.
* Outputs are the names of the variables that are the predictions of the neural network.
* Note that this is merely a convenience, and does not impact execution at all. Outputs can be retrieved (after
* setting here) using {@link #outputs()}
* @param outputs Outputs to set. Must be valid variable names in this SameDiff instance
*/
public void setOutputs(List outputs){
if(outputs != null){
for(String s : outputs){
Preconditions.checkArgument(variables.containsKey(s), "Cannot set variable \"%s\" as an output: SameDiff instance does not contain a variable with this name");
}
}
this.outputs = outputs;
}
/**
* The list of all variables in the graph
*
* @return All variables in the graph
*/
public List variables() {
return new ArrayList<>(variableMap().values());
}
/**
* Get the names of variables (if any) that have been marked as loss variables to be minimized.
* Variables can be marked as loss variables in a few different ways:
* (a) Losses are automatically added when creating loss functions via {@link #sd()}
* (b) Via {@link #setLossVariables(String...)}, @link #addLossVariable(String)} or {@link SDVariable#markAsLoss()}
* (c) Via {@link TrainingConfig#setLossVariables(List)}
*/
public List getLossVariables() {
return Collections.unmodifiableList(this.lossVariables);
}
/**
* Clear/remove any existing loss variables, and set the loss variables to the specified variable names.
* See {@link #addLossVariable(String)} for more details
*
* @param lossVariableNames Names of variables to be loss function variables
*/
public void setLossVariables(@NonNull String... lossVariableNames) {
this.lossVariables.clear();
for (String s : lossVariableNames) {
addLossVariable(s);
}
//After changing loss function variables, we (probably) need to recreate gradient function - as gradient
// function is defined with respect to specific loss function variables
sameDiffFunctionInstances.remove("grad");
}
/**
* See {@link #setLossVariables(String...)}
*/
public void setLossVariables(@NonNull SDVariable... lossVariables) {
String[] varNames = new String[lossVariables.length];
for (int i = 0; i < lossVariables.length; i++)
varNames[i] = lossVariables[i].name();
setLossVariables(varNames);
}
/**
* Mark the specified variable as a loss function variable. This means that this variable will be minimized via backprop during training.
* This will add the variable as a loss to any others - i.e., if multiple variables are marked as losses, their values will be summed
* to give the total network loss.
* Note that only floating point (Float16/32/64) variables may be marked as a loss.
* Note also that only ARRAY type SDVariables can be marked as losses to be minimized. That is, we cannot mark the value
* of a constant, variable or placeholder to be minimized as doing so would not make sense.
*/
public void addLossVariable(@NonNull String variableName) {
Preconditions.checkState(hasVariable(variableName), "No variable with name \"%s\" exists", variableName);
SDVariable v = getVariable(variableName);
Preconditions.checkState(v.dataType().isFPType(), "Only floating point type variables can be marked as losses to be minimized." +
" SDVariable \"%s\" has datatype %s", variableName, v.dataType());
Preconditions.checkState(v.getVariableType() == VariableType.ARRAY, "Only ARRAY type SDVariables can be marked as losses to be minimized." +
" SDVariable \"%s\" has variable type %s", variableName, v.getVariableType());
if (!lossVariables.contains(variableName)) {
lossVariables.add(variableName);
}
}
/**
* See {@link #addLossVariable(String)}
*/
public void addLossVariable(@NonNull SDVariable variable) {
addLossVariable(variable.name());
}
/**
* Set the training configuration ({@link TrainingConfig}) for the SameDiff instance.
* A TrainingConfig must be set before the SameDiff instance can be trained via the fit methods
*
* @param trainingConfig Training configuration
*/
public void setTrainingConfig(TrainingConfig trainingConfig) {
this.trainingConfig = trainingConfig;
}
/**
* Fit the SameDiff instance based on a single DataSet (i.e., a single minibatch for one iteration).
* This method can only be used for singe input, single output SameDiff instances as DataSet only supports a
* single input and a single output.
* Note that a {@link TrainingConfig} must be set via {@link #setTrainingConfig(TrainingConfig)} before training can
* be performed.
*
* @param dataSet The DataSet (single minibatch) to peform training on
* @param listeners Additional listeners to use during this operation
* @return a {@link History} object containing the history information for this training operation
* (evaluations specified in the {@link TrainingConfig}, loss values, and timing information).
*/
public History fit(@NonNull DataSet dataSet, @NonNull Listener... listeners) {
return fit(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), 1, false,
null, 1, listeners);
}
/**
* Fit the SameDiff instance based on a single MultiDataSet (i.e., a single minibatch for one iteration).
* Note that a {@link TrainingConfig} must be set via {@link #setTrainingConfig(TrainingConfig)} before training can
* be performed.
*
* @param dataSet The MultiDataSet (single minibatch) to peform training on
* @param listeners Additional listeners to use during this operation
* @return a {@link History} object containing the history information for this training operation
* (evaluations specified in the {@link TrainingConfig}, loss values, and timing information).
*/
public History fit(@NonNull MultiDataSet dataSet, @NonNull Listener... listeners) {
return fit(new SingletonMultiDataSetIterator(dataSet), 1, false,
null, 1, listeners);
}
/**
* Fit the SameDiff instance based on DataSetIterator for the specified number of epochs.
* This method can only be used for singe input, single output SameDiff instances as DataSet only supports a
* single input and a single output.
* Note that a {@link TrainingConfig} must be set via {@link #setTrainingConfig(TrainingConfig)} before training can
* be performed.
*
* A special case of {@link #fit()}.
*
* @param iter The iterator to train the SameDiff instance with
* @param numEpochs The number of epochs for training. Must be > 0
* @param validationIter The DataSetIterator to use for validation (null to skip validation)
* @param validationFrequency The frequency with which to run validation. 1 is every epoch, 2 is every other, etc.
* @param listeners Additional listeners to use during this operation
* @return a {@link History} object containing the history information for this training operation
* (evaluations specified in the {@link TrainingConfig}, loss values, and timing information).
*/
public History fit(@NonNull DataSetIterator iter, int numEpochs, DataSetIterator validationIter, int validationFrequency, @NonNull Listener... listeners) {
return fit().train(iter, numEpochs).validate(validationIter, validationFrequency).listeners(listeners).exec();
}
/**
* See {@link #fit(DataSetIterator, int, DataSetIterator, int, Listener...)}, does not preform validation.
*
* A special case of {@link #fit()}.
*
* @param iter The iterator to train the SameDiff instance with
* @param numEpochs The number of epochs for training. Must be > 0
* @param listeners Additional listeners to use during this operation
* @return a {@link History} object containing the history information for this training operation
* (evaluations specified in the {@link TrainingConfig}, loss values, and timing information).
*/
public History fit(@NonNull DataSetIterator iter, int numEpochs, @NonNull Listener... listeners) {
return fit().train(iter, numEpochs).listeners(listeners).exec();
}
/**
* Fit the SameDiff instance based on MultiDataSetIterator for the specified number of epochs.
* This method can both singe input, single output and multi-input, multi-output SameDiff instances
* Note that a {@link TrainingConfig} must be set via {@link #setTrainingConfig(TrainingConfig)} before training can
* be performed.
*
* A special case of {@link #fit()}.
*
* @param iter The iterator to train the SameDiff instance with
* @param numEpochs The number of epochs for training. Must be > 0
* @param validationIter The MultiDataSetIterator to use for validation (null to skip validation)
* @param validationFrequency The frequency with which to run validation. 1 is every epoch, 2 is every other, etc.
* @param listeners Additional listeners to use during this operation
* @return a {@link History} object containing the history information for this training operation
* (evaluations specified in the {@link TrainingConfig}, loss values, and timing information).
*/
public History fit(@NonNull MultiDataSetIterator iter, int numEpochs, MultiDataSetIterator validationIter, int validationFrequency, @NonNull Listener... listeners) {
return fit(iter, numEpochs, true, validationIter, validationFrequency, listeners);
}
/**
* See {@link #fit(MultiDataSetIterator, int, MultiDataSetIterator, int, Listener...)}, does not preform validation.
*
* A special case of {@link #fit()}.
*
* @param iter The iterator to train the SameDiff instance with
* @param numEpochs The number of epochs for training. Must be > 0
* @param listeners Additional listeners to use during this operation
* @return a {@link History} object containing the history information for this training operation
* (evaluations specified in the {@link TrainingConfig}, loss values, and timing information).
*/
public History fit(@NonNull MultiDataSetIterator iter, int numEpochs, @NonNull Listener... listeners) {
return fit().train(iter, numEpochs).listeners(listeners).exec();
}
/**
* Set up for a fit operation using {@link FitConfig}.
*
* Supports the setting of training data ({@link MultiDataSetIterator} or {@link DataSetIterator}), number of epochs,
* validation data ({@link MultiDataSetIterator} or {@link DataSetIterator}), validation frequency, and additional listeners.
*
* Example: train on data for 5 epochs, validating on valData every 2nd epoch
*
* {@code
* SameDiff sd = ...;
* MultiDataSet data = ...;
* MultiDataSet valData = ...;
*
* History hist = sd.fit()
* .train(data, 5)
* .validate(valData, 2)
* .exec();
* }
*
*/
public FitConfig fit() {
return new FitConfig(this);
}
//Synchronized for thread safety
protected synchronized History fit(@NonNull MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount,
MultiDataSetIterator validationData, int validationFrequency, @NonNull Listener... listeners) {
boolean async = iter.asyncSupported();
boolean validationAsync = false;
if (validationData != null)
validationAsync = validationData.asyncSupported();
if (async) {
iter = new AsyncMultiDataSetIterator(iter, 3, true);
}
if (validationAsync) {
validationData = new AsyncMultiDataSetIterator(validationData, 3, true);
}
try {
return fitHelper(iter, numEpochs, incrementEpochCount, validationData, validationFrequency, Arrays.asList(listeners));
} finally {
if (async) {
((AsyncMultiDataSetIterator) iter).shutdown();
}
if (validationAsync) {
((AsyncMultiDataSetIterator) validationData).shutdown();
}
}
}
//fitHelper should only be called from fit method above
protected synchronized History fitHelper(@NonNull MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount,
MultiDataSetIterator validationData, int validationFrequency, @NonNull List listeners) {
Preconditions.checkNotNull(iter, "Iterator must not be null");
Preconditions.checkState(numEpochs > 0, "Number of training epochs must be a positive number. Got: %s", numEpochs);
Preconditions.checkState(trainingConfig != null, "No training configuration has been set. A training configuration must " +
"be set before training. Use setTrainingConfig(TrainingConfig)");
Preconditions.checkState(numEpochs == 1 || iter.resetSupported(), "Cannot train for multiple epochs on an iterator that" +
" does not support resetting");
HistoryListener history = new HistoryListener(trainingConfig);
List activeListeners = new ArrayList<>();
activeListeners.add(history);
for (Listener l : this.listeners)
if (l.isActive(Operation.TRAINING))
activeListeners.add(l);
for (Listener l : listeners)
if (l.isActive(Operation.TRAINING))
activeListeners.add(l);
validateListenerActivations(activeListeners, Operation.TRAINING);
validateListenerActivations(activeListeners, Operation.TRAINING_VALIDATION);
if (!iter.hasNext() && iter.resetSupported())
iter.reset();
boolean performedValidation = false;
int trainThreadNum = 0;
long jThreadId = Thread.currentThread().getId();
boolean hasListeners = !activeListeners.isEmpty();
At at = At.builder()
.epoch(trainingConfig.getEpochCount())
.iteration(trainingConfig.getIterationCount())
.trainingThreadNum(trainThreadNum)
.javaThreadNum(jThreadId)
.operation(Operation.TRAINING)
.build();
LossCurve lossCurve = null;
Set requiredVars = new HashSet<>();
for (Listener l : activeListeners) {
ListenerVariables lv = l.requiredVariables(this);
if(lv != null) {
Set s = lv.trainingVariables();
if(s != null) {
requiredVars.addAll(s);
}
}
}
List listenersWitHistory = new ArrayList<>(listeners);
for(Listener l : this.listeners){
if(!listenersWitHistory.contains(l))
listenersWitHistory.add(l);
}
listenersWitHistory.add(history);
SameDiff gradInstance = getFunction("grad");
if(gradInstance == null){
createGradFunction();
gradInstance = getFunction("grad");
}
TrainingSession ts = new TrainingSession(gradInstance);
gradInstance.setTrainingConfig(trainingConfig); //In case any listeners want to use it
for(Listener l : activeListeners){
l.operationStart(gradInstance, Operation.TRAINING);
}
Set paramsToTrain = new LinkedHashSet<>();
for(Variable v : variables.values()){
if(v.getVariable().getVariableType() == VariableType.VARIABLE){
//TODO not all variable type are needed - i.e., variable that doesn't impact loss should be skipped
paramsToTrain.add(v.getName());
}
}
Loss lastLoss = null;
for (int i = 0; i < numEpochs; i++) {
if (incrementEpochCount && hasListeners) {
at.setEpoch(trainingConfig.getEpochCount());
for (Listener l : activeListeners) {
l.epochStart(this, at);
}
}
long epochStartTime = System.currentTimeMillis();
double[] lossSums = null;
List lossNames = null;
int lossCount = 0;
while (iter.hasNext()) {
long dataStart = hasListeners ? System.currentTimeMillis() : 0;
org.nd4j.linalg.dataset.api.MultiDataSet ds = iter.next();
long dataEnd = hasListeners ? System.currentTimeMillis() : 0;
if (!performedValidation) {
Preconditions.checkState(trainingConfig.getDataSetFeatureMapping().size() == ds.numFeatureArrays(),
"The number of dataset feature mapping variables set in the training configuration (%s) must match" +
" the number of dataset feature arrays (%s)", trainingConfig.getDataSetFeatureMapping().size(), ds.numFeatureArrays());
List labelMapping = trainingConfig.getDataSetLabelMapping();
int lblSize = labelMapping == null ? 0 : labelMapping.size();
Preconditions.checkState(lblSize == ds.numLabelsArrays(),
"The number of dataset label mapping variables set in the training configuration (%s) must match" +
" the number of dataset label arrays (%s)", lblSize, ds.numLabelsArrays());
performedValidation = true;
}
if (hasListeners) {
at.setIteration(trainingConfig.getIterationCount());
for (Listener l : activeListeners) {
l.iterationStart(this, at, ds, (dataEnd - dataStart));
}
}
//Create placeholder variable map
Map placeholders = toPlaceholderMap(ds);
Preconditions.checkState(placeholders.size() > 0, "No placeholder variables were set for training");
//Call TrainingSession to perform training
if (!initializedTraining)
initializeTraining();
lastLoss = ts.trainingIteration(
trainingConfig,
placeholders,
paramsToTrain,
updaterMap,
ds,
getLossVariables(),
listenersWitHistory,
at);
if (lossSums == null) {
lossSums = lastLoss.getLosses().clone();
} else {
for (int j = 0; j < lossSums.length; j++) {
lossSums[j] += lastLoss.getLosses()[j];
}
}
lossCount++;
trainingConfig.incrementIterationCount();
}
long epochTime = System.currentTimeMillis() - epochStartTime;
if (incrementEpochCount) {
lossNames = lastLoss.getLossNames();
for (int j = 0; j < lossSums.length; j++)
lossSums[j] /= lossCount;
if (lossCurve != null)
lossCurve = lossCurve.addLossAndCopy(lossSums, lossNames);
else
lossCurve = new LossCurve(lossSums, lossNames);
}
if (incrementEpochCount) {
if (hasListeners) {
boolean doStop = false;
Listener stopped = null;
for (Listener l : activeListeners) {
ListenerResponse res = l.epochEnd(this, at, lossCurve, epochTime);
if (res == ListenerResponse.STOP && (i < numEpochs - 1)) {
doStop = true;
stopped = l;
}
}
if (doStop) {
log.info("Stopping training early. Listener " + stopped + " gave a STOP signal at epoch " + at.epoch() + " and iteration " + at.iteration());
for (Listener l1 : activeListeners)
l1.operationEnd(this, Operation.TRAINING);
if (i < numEpochs - 1) {
iter.reset();
}
if (incrementEpochCount)
trainingConfig.incrementEpochCount();
return history.getReport();
}
//validation evaluation
if (validationData != null && (validationFrequency <= 0 || i % validationFrequency == 0)) {
long validationStart = System.currentTimeMillis();
outputHelper(validationData, new At(at.epoch(), 0, 0, 0, null, Operation.TRAINING_VALIDATION),
listenersWitHistory);
long validationTime = System.currentTimeMillis() - validationStart;
boolean doStopV = false;
Listener stoppedV = null;
for (Listener l : activeListeners) {
ListenerResponse res = l.validationDone(this, at, validationTime);
if (res == ListenerResponse.STOP && (i < numEpochs - 1)) {
doStopV = true;
stoppedV = l;
}
}
if (doStopV) {
log.info("Stopping training early from validation. Listener " + stoppedV + " gave a STOP signal at epoch " + at.epoch() + " and iteration " + at.iteration());
for (Listener l1 : activeListeners)
l1.operationEnd(this, Operation.TRAINING);
if (i < numEpochs - 1) {
iter.reset();
}
if (incrementEpochCount)
trainingConfig.incrementEpochCount();
return history.getReport();
}
}
}
trainingConfig.incrementEpochCount();
}
if (i < numEpochs - 1) {
iter.reset();
}
}
for (Listener l1 : activeListeners)
l1.operationEnd(this, Operation.TRAINING);
return history.getReport();
}
/**
* Ensure the specified listeners do not request any activations that aren't present for the given operation
*/
private void validateListenerActivations(List listeners, Operation op) {
for (Listener l : listeners) {
ListenerVariables lv = l.requiredVariables(this);
if(lv != null) {
for (String s : lv.requiredVariables(op)) {
if (!variables.containsKey(s)) {
Preconditions.checkState(false, "Listener %s requested variable %s that is not defined in this SameDiff graph", l, s);
}
}
}
}
}
/**
* Calculate the regularization (L1, L2 and/or WeightDecay) component of the loss function for the current parameters..
* Note that the training configuration must be set (via {@link #setTrainingConfig(TrainingConfig)}) before this
* method can be called
*
* @return The regularization component of the score/loss function
*/
public double calcRegularizationScore() {
Preconditions.checkState(trainingConfig != null, "No training configuration has been set. A training configuration must " +
"be set before calculating the L2 loss. Use setTrainingConfig(TrainingConfig)");
if (trainingConfig.getRegularization() == null || trainingConfig.getRegularization().isEmpty()) {
return 0.0;
}
List l = trainingConfig.getRegularization();
double loss = 0.0;
for (Variable v : variables.values()) {
SDVariable sdv = v.getVariable();
if (sdv.getVariableType() != VariableType.VARIABLE || !sdv.dataType().isFPType()) {
//Only trainable parameters (FP and variable type vars) contribute to regularization score
continue;
}
for (Regularization r : l) {
INDArray arr = sdv.getArr();
loss += r.score(arr, trainingConfig.getIterationCount(), trainingConfig.getEpochCount());
}
}
return loss;
}
/**
* Perform setup for training. Does the following:
* 1. Infer the set of trainable parameters - unless specified manually by the user
* 2. Set up the updaters
*/
protected void initializeTraining() {
if (!initializedTraining) {
if (trainingConfig == null) {
throw new ND4JIllegalStateException("Please specify a training config with setTrainingConfig");
}
updaterMap = new HashMap<>();
for (Variable v : variables.values()) {
if (v.getVariable().getVariableType() != VariableType.VARIABLE || !v.getVariable().dataType().isFPType()) {
//Skip non-trainable parameters
continue;
}
INDArray arr = v.getVariable().getArr();
long stateSize = trainingConfig.getUpdater().stateSize(arr.length());
INDArray view = stateSize == 0 ? null : Nd4j.createUninitialized(arr.dataType(), 1, stateSize);
GradientUpdater gu = trainingConfig.getUpdater().instantiate(view, false);
gu.setStateViewArray(view, arr.shape(), arr.ordering(), true);
updaterMap.put(v.getName(), gu);
}
initializedTraining = true;
}
}
/**
* Convert the MultiDataSet to a {@code Map} based on the TrainingConfig settings.
* The key is the placeholder/variable that the value INDArray should be associated with.
*
* @param ds MultiDataSet - source of the features/labels
* @return MultiDataSet converted to a Map, based on TrainingConfig
*/
private Map toPlaceholderMap(org.nd4j.linalg.dataset.api.MultiDataSet ds) {
Map placeholders = new HashMap<>();
int count = 0;
for (String s : trainingConfig.getDataSetFeatureMapping()) {
placeholders.put(s, ds.getFeatures(count++));
}
count = 0;
if (trainingConfig.getDataSetLabelMapping() != null) {
//Labels may be null in some models (unsupervised etc)
for (String s : trainingConfig.getDataSetLabelMapping()) {
placeholders.put(s, ds.getLabels(count++));
}
}
if (trainingConfig.getDataSetFeatureMaskMapping() != null && trainingConfig.getDataSetFeatureMaskMapping().size() > 0) {
count = 0;
for (String s : trainingConfig.getDataSetFeatureMaskMapping()) {
if (s == null) {
count++;
continue;
}
placeholders.put(s, ds.getFeaturesMaskArray(count++));
}
}
if (trainingConfig.getDataSetLabelMaskMapping() != null && trainingConfig.getDataSetLabelMaskMapping().size() > 0) {
count = 0;
for (String s : trainingConfig.getDataSetLabelMaskMapping()) {
if (s == null) {
count++;
continue;
}
placeholders.put(s, ds.getLabelsMaskArray(count++));
}
}
return placeholders;
}
/**
* Evaluate the performance of a single variable's prediction.
* For example, if the variable to evaluatate was called "softmax" you would use:
*
* {@code Evaluation e = new Evaluation();
* sameDiff.evaluate(iterator, "softmax", e);}
*
*
* A special case of {@link #evaluate()}.
*
* @param iterator Iterator as source of data to evaluate
* @param outputVariable The variable to evaluate
* @param listeners Additional listeners to use during this operation.
* @param evaluations The evaluations to perform
*/
public void evaluate(@NonNull DataSetIterator iterator, @NonNull String outputVariable, @NonNull List listeners, @NonNull IEvaluation... evaluations) {
Preconditions.checkArgument(evaluations != null && evaluations.length > 0, "No evaluations were passed to the evaluate method");
evaluate().data(iterator).evaluate(outputVariable, evaluations).listeners(listeners.toArray(new Listener[0])).exec();
}
/**
* See {@link #evaluate(DataSetIterator, String, List, IEvaluation[])}.
*
* A special case of {@link #evaluate()}.
*/
public void evaluate(@NonNull DataSetIterator iterator, @NonNull String outputVariable, @NonNull IEvaluation... evaluations) {
evaluate().data(iterator).evaluate(outputVariable, evaluations).exec();
}
/**
* Evaluation for multiple-output networks.
* See {@link #evaluate(MultiDataSetIterator, Map, Map, Listener[])}.
*
* A special case of {@link #evaluate()}.
*/
public void evaluate(@NonNull DataSetIterator iterator, @NonNull Map variableEvals, @NonNull Listener... listeners) {
Map map = new HashMap<>();
Map> variableEvalsList = new HashMap<>();
for (String s : variableEvals.keySet()) {
map.put(s, 0); //Only 1 possible output here with DataSetIterator
variableEvalsList.put(s, Collections.singletonList(variableEvals.get(s)));
}
evaluate(new MultiDataSetIteratorAdapter(iterator), variableEvalsList, map, listeners);
}
/**
* Evaluation for multiple output networks - one or more.
* See {@link #evaluate(MultiDataSetIterator, Map, Map, Listener[])}.
*
* A special case of {@link #evaluate()}.
*/
public void evaluateMultiple(DataSetIterator iterator, Map> variableEvals, @NonNull Listener... listeners) {
Map map = new HashMap<>();
for (String s : variableEvals.keySet()) {
map.put(s, 0); //Only 1 possible output here with DataSetIterator
}
evaluate(new MultiDataSetIteratorAdapter(iterator), variableEvals, map, listeners);
}
/**
* Evaluate the performance of a single variable's prediction.
* For example, if the variable to evaluatate was called "softmax" you would use:
*
* {@code Evaluation e = new Evaluation();
* sameDiff.evaluate(iterator, "softmax", e);}
*
*
* A special case of {@link #evaluate()}.
*
* @param iterator Iterator as source of data to evaluate
* @param outputVariable The variable to evaluate
* @param labelIndex The index of the target variable's labels in the iterator
* @param listeners Additional listeners to use during this operation.
* @param evaluations The evaluations to perform
*/
public void evaluate(@NonNull MultiDataSetIterator iterator, @NonNull String outputVariable, int labelIndex,
@NonNull List listeners, @NonNull IEvaluation... evaluations) {
Preconditions.checkArgument(evaluations != null && evaluations.length > 0, "No evaluations were passed to the evaluate method");
evaluate().data(iterator).evaluate(outputVariable, labelIndex, evaluations).listeners(listeners.toArray(new Listener[0])).exec();
}
/**
* See {@link #evaluate(MultiDataSetIterator, String, int, List, IEvaluation[])}.
*
* A special case of {@link #evaluate()}.
*/
public void evaluate(@NonNull MultiDataSetIterator iterator, @NonNull String outputVariable, int labelIndex, @NonNull IEvaluation... evaluations) {
evaluate().data(iterator).evaluate(outputVariable, labelIndex, evaluations).exec();
}
/**
* Perform evaluation using classes such as {@link Evaluation} for classifier outputs
* and {@link org.nd4j.evaluation.regression.RegressionEvaluation} for regression outputs.
*
* Example: classifier evaluation
* Predictions variable name: "softmaxOutput"
* Evaluations to perform: {@link Evaluation}
* Data: single input, single output MultiDataSets
* Code:
*
* {@code
* MultiDataSetIterator data = ...
* Map> evals = Collections.singletonMap("softmaxOutput",Collections.singletonList(new Evaluation()));
* Map labelMapping = Collections.singletonMap("softmaxOutput",0); //Compare: "softmaxOutput" vs. MultiDataSet.getLabels(0)
* }
*
*
* A special case of {@link #evaluate()}.
*
* @param iterator The iterator - the source of the data for evaluation
* @param variableEvals The evaluations to perform. Key: the name of the variable. Value: the evaluations to perform
* @param predictionLabelMapping The output/label mapping. Key: the name of the variable.
* @param listeners Additional listeners to use during this operation.
*/
public void evaluate(MultiDataSetIterator iterator, Map> variableEvals, Map predictionLabelMapping, Listener... listeners) {
evaluateHelper(iterator, variableEvals, predictionLabelMapping, At.defaultAt(Operation.EVALUATION), listeners);
}
/**
* Set up for a evaluation operation using EvaluationConfig.
*
* Supports the setting of the data ({@link MultiDataSetIterator} or {@link DataSetIterator}),
* adding evaluations for variables (with optional label index setting), setting label indices,
* and setting additional listeners.
* Does not require setting label indices when using a {@link DataSetIterator}.
*
* Also supports using {@link SDVariable} instances instead of variable names.
*
*
* Example: evaluate "pred" with {@link Evaluation} and {@link ROC}, using label 0.
*
* {@code
* SameDiff sd = ...;
* MultiDataSetIterator data = ...;
*
* EvaluationRecord results = sd.evaluate()
* .data(data)
* .evaluate("pred", 0, new Evaluation(), new ROC()),
* .exec();
* }
*
* Example: evaluate "pred" with {@link Evaluation}, using the only label from a DataSetIterator.
*
* {@code
* SameDiff sd = ...;
* DataSetIterator singleData = ...;
*
* EvaluationRecord results = sd.evaluate()
* .data(singleData)
* .evaluate("pred", new Evaluation()),
* .exec();
* }
*
*/
public EvaluationConfig evaluate() {
return new EvaluationConfig(this);
}
/**
* Helper method for evaluations. Should only be called from the above evaluate method
*/
private void evaluateHelper(MultiDataSetIterator iterator,
Map> variableEvals, Map predictionLabelMapping, At at, @NonNull Listener... listeners) {
Preconditions.checkState(trainingConfig != null, "Training config has not been set");
Preconditions.checkState(variableEvals.keySet().equals(predictionLabelMapping.keySet()), "Keysets for variable evaluations" +
" and for the prediction label mapping must be equal. Keys for variables to evaluate: %s vs. keys for label mapping: %s", variableEvals.keySet(), predictionLabelMapping.keySet());
List activeListeners = new ArrayList<>();
for (Listener l : listeners)
if (l.isActive(at.operation()))
activeListeners.add(l);
for (Listener l : this.listeners)
if (l.isActive(at.operation()))
activeListeners.add(l);
validateListenerActivations(activeListeners, at.operation());
for (Listener l : activeListeners)
l.operationStart(this, at.operation());
boolean hasListeners = !activeListeners.isEmpty();
if (!iterator.hasNext() && iterator.resetSupported())
iterator.reset();
Set requiredVars = new HashSet<>(variableEvals.keySet());
if (hasListeners) {
for (Listener l : activeListeners) {
ListenerVariables v = l.requiredVariables(this);
if(v != null) {
requiredVars.addAll(v.evaluationVariables());
}
}
}
String[] requiredVarsArr = requiredVars.toArray(new String[0]);
while (iterator.hasNext()) {
MultiDataSet ds = iterator.next();
Map placeholderMap = toPlaceholderMap(ds);
Map m = directExecHelper(placeholderMap, at, ds, Collections.emptyList(), activeListeners, requiredVarsArr);
for (Map.Entry> e : variableEvals.entrySet()) {
INDArray prediction = m.get(e.getKey());
for (IEvaluation eval : e.getValue()) {
//TODO time series, etc
INDArray label = ds.getLabels(predictionLabelMapping.get(e.getKey()));
INDArray mask = ds.getLabelsMaskArray(predictionLabelMapping.get(e.getKey()));
eval.eval(label, prediction, mask);
}
}
at.setIteration(at.iteration() + 1);
}
for (Listener l : activeListeners)
l.operationEnd(this, at.operation());
}
/**
* Do a single batch inference on a network with a single input.
* For example, if the variable to infer was called "softmax" you would use:
*
* {@code
* sameDiff.output(iterator, "softmax");}
*
*
* @param dataSet The data to evaluate
* @param outputs The variables to evaluate
*/
public Map output(@NonNull DataSet dataSet, @NonNull String... outputs) {
return outputBatches(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), outputs).get(0);
}
/**
* Do a single batch inference on a network.
* For example, if the variable to infer was called "softmax" you would use:
*
* {@code
* sameDiff.output(iterator, "softmax");}
*
*
* @param dataSet The data to evaluate
* @param outputs The variables to evaluate
*/
public Map output(@NonNull MultiDataSet dataSet, @NonNull String... outputs) {
return outputBatches(new SingletonMultiDataSetIterator(dataSet), outputs).get(0);
}
/**
* Do inference on a network with a single input.
* For example, if the variable to infer was called "softmax" you would use:
*
* {@code
* sameDiff.output(iterator, "softmax");}
*
*
* Uses concatenation on the outputs of {@link #outputBatches(DataSetIterator, String...)} which may cause issues with some inputs.
* RNNs with variable time series length and CNNs with variable image sizes will most likely have issues.
*
* Special case of {@link #output()}.
*
* @param iterator Iterator as source of data to evaluate
* @param listeners Additional listeners to use during this operation.
* @param outputs The variables to evaluate
*/
public Map output(@NonNull DataSetIterator iterator, @NonNull List listeners, @NonNull String... outputs) {
return output().data(iterator).output(outputs).listeners(listeners.toArray(new Listener[0])).exec();
}
/**
* See {@link #output(DataSetIterator, List, String...)}. No additional listeners.
*
* Special case of {@link #output()}.
*/
public Map output(@NonNull DataSetIterator dataSet, @NonNull String... outputs) {
return output().data(dataSet).output(outputs).exec();
}
/**
* See {@link #output(DataSetIterator, List, String...)}, but without the concatenation of batches.
*
* Special case of {@link #output()}.
*/
public List