/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.autodiff.samediff;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import com.google.common.primitives.Ints;
import com.google.flatbuffers.FlatBufferBuilder;
import com.rits.cloning.Cloner;
import com.rits.cloning.IFastCloner;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.output.CloseShieldOutputStream;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
import org.nd4j.autodiff.samediff.internal.*;
import org.nd4j.autodiff.samediff.ops.*;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.autodiff.util.cloner.DataBufferFastCloner;
import org.nd4j.autodiff.util.cloner.INDArrayFastCloner;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.graph.*;
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.*;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.controlflow.If;
import org.nd4j.linalg.api.ops.impl.controlflow.While;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.collection.IntArrayKeyMap;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.DeviceLocalNDArray;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ConstantInitScheme;
import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import java.io.*;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;
/**
* SameDiff is the entrypoint for ND4J's automatic differentiation functionality.
*
* You define a graph symbolically.
*
* That graph accumulates operations.
*
* In order to execute the graph, you run one of the execution methods, such as {@link #exec(Map, String...)}
*/
@AllArgsConstructor
@Builder
@Slf4j
public class SameDiff extends SDBaseOps {
//Fields for graph structure and execution
@Getter //TODO use package private instead of public getters?
private final Map variables = new LinkedHashMap<>(); //Use linked hash map to guarantee iteration order based on order they were added. Used in inputs() and flatbuffers serde
@Getter
private final Map ops = new LinkedHashMap<>();
@Getter
private final Map sessions = new ConcurrentHashMap<>(); //Key: thread ID
private final Map constantArrays = new ConcurrentHashMap<>();
private final Map variablesArrays = new ConcurrentHashMap<>(); //TODO issues with DeviceLocal + mutable / changed during training?
private final Map> placeholdersPerThread = new ConcurrentHashMap<>(); //Placeholders for each thread - if the user sets them
private final List lossVariables = new ArrayList<>();
///////////////////////////////////////
//Fields related to training
@Getter
private TrainingConfig trainingConfig; //Configuration for training. Must be set for training/evaluation, but not for other operations
@Getter
private boolean initializedTraining; //True if training setup has been done
@Getter
private INDArray updaterState; //Updater state array (1d, length equal to number of trainable parameters)
@Getter
private Map updaterViews; //Views of updaterState array for each trainable parameter
@Getter
private Map updaterMap; //GradientUpdater instance for each trainable parameter
////////////////////////////////////////
//map a function's instance id to a base name, used for propagating variable names
//for output during import
private Map baseNameForFunctionInstanceId;
private DifferentialFunctionFactory functionFactory;
@Deprecated //TO BE REMOVED - to ShapeSession
private Map variableNameToShape; //Key: SDVariable name. Value: shape for that variable
@Deprecated //TO BE REMOVED - to Variable
private Map forwardVarForGrad;
// counter for auto-naming variables
private int variableId = 0;
////////////////////////////////////////
/** Op creator object for math operations */
public final SDMath math = new SDMath(this);
/** Op creator object for random number generation operations */
public final SDRandom random = new SDRandom(this);
/** Op creator object for general neural network operations */
public final SDNN nn = new SDNN(this);
/** Op creator object for convolutional neural network operations */
public final SDCNN cnn = new SDCNN(this);
/** Op creator object for recurrent neural network operations */
public final SDRNN rnn = new SDRNN(this);
/** Op creator object for loss function operations */
public final SDLoss loss = new SDLoss(this);
/** Op creator object for math operations */
public SDMath math(){
return math;
}
/** Op creator object for random number generation operations */
public SDRandom random(){
return random;
}
/** Op creator object for general neural network operations */
public SDNN nn(){
return nn;
}
/** Op creator object for convolutional neural network operations */
public SDCNN cnn(){
return cnn;
}
/** Op creator object for recurrent neural network operations */
public SDRNN rnn(){
return rnn;
}
/** Op creator object for loss function operations */
public SDLoss loss(){
return loss;
}
/**
* For import, many times we have variables
* that map to properties. Most common
* we will have an input to a function that is mapped to an ndarray.
* That ndarray is usually a scalar shape.
*
* That array with a scalar shape can be something like an axis.
*
* We often don't know that array's value till run time.
* This map stores variable names that we should resolve
* from samediff. We use the value of that array
* to update the properties.
*/
private Map> propertiesToResolve;
/**
* A map of own name to
* the properties of the function (things like execution axes etc)
* The valid values can be:
* int
* long
* INDArray
*/
private Map> propertiesForFunction;
@Deprecated //TO BE REMOVED - to Variable
private Map placeHolderOriginalShapes;
private Map sameDiffFunctionDefinitionMap;
private Map sameDiffFunctionInstances;
private Set placeHolderFunctions;
private static Cloner cloner = newCloner();
private static Map opMethods;
private Table fieldVariableResolutionMapping;
// flag, shows if graph was already registered with libnd4j
private transient AtomicBoolean wasRegistered = new AtomicBoolean(false);
//debug mode variables
@Getter
private boolean debugMode;
private Map opsForResult;
private boolean resolvedVariables = false;
@Getter
@Setter
boolean logExecution = true;
@Getter
private SameDiff parent;
@Getter
private SameDiff child;
public final static String TRAINING_CONFIG_JSON_ZIP_ENTRY_NAME = "trainingConfig.json";
public final static String SAMEDIFF_FILE_ENTRY_NAME = "samediff.fb";
static {
opMethods = new HashMap<>();
Method[] methods = SameDiff.class.getDeclaredMethods();
for (Method method : methods) {
if (method.getReturnType().equals(SDVariable.class)) {
opMethods.put(method.getName(), method);
}
}
}
/**
* @return New cloner object. NOTE: INTENDED FOR DEVELOPER USE ONLY
*/
public static Cloner newCloner() {
Cloner cloner = new Cloner();
//Implement custom cloning for INDArrays (default can have problems with off-heap and pointers)
//Sadly: the cloner library does NOT support interfaces here, hence we need to use the actual classes
//cloner.registerFastCloner(INDArray.class, new INDArrayFastCloner()); //Does not work due to interface
IFastCloner fc = new INDArrayFastCloner();
cloner.registerFastCloner(Nd4j.getBackend().getNDArrayClass(), fc);
//Same thing with DataBuffers: off heap -> cloner library chokes on them, but need to know the concrete
// buffer classes, not just the interface
IFastCloner fc2 = new DataBufferFastCloner();
DataBufferFactory d = Nd4j.getDataBufferFactory();
doReg(cloner, fc2, d.intBufferClass());
doReg(cloner, fc2, d.longBufferClass());
doReg(cloner, fc2, d.halfBufferClass());
doReg(cloner, fc2, d.floatBufferClass());
doReg(cloner, fc2, d.doubleBufferClass());
doReg(cloner, fc2, CompressedDataBuffer.class);
return cloner;
}
private static void doReg(Cloner cl, IFastCloner fc, Class c) {
if (c != null)
cl.registerFastCloner(c, fc);
}
/**
* Update the opName for the variable
* with the given vertex id
*
* @param varName the vertex id to update
* @param withName thew new opName
*/
public void updateVariableName(String varName, String withName) {
SDVariable oldVarNameRef = getVariable(varName);
Variable v = variables.remove(varName);
String oldVarName = varName;
oldVarNameRef.setVarName(withName);
v.setName(withName);
variables.put(withName, v);
for(SameDiffOp op : ops.values()){
List outputsOfOp = op.getOutputsOfOp();
if(outputsOfOp != null && !outputsOfOp.isEmpty()) {
for (int i = 0; i < outputsOfOp.size(); i++) {
if (outputsOfOp.get(i).equals(oldVarName)) {
outputsOfOp.set(i, withName);
}
}
}
List inputsToOp = op.getInputsToOp();
if(inputsToOp != null && !inputsToOp.isEmpty()) {
for (int i = 0; i < inputsToOp.size(); i++) {
if (inputsToOp.get(i).equals(oldVarName)) {
inputsToOp.set(i, withName);
}
}
}
}
// if (variableNameToArr.containsKey(oldVarName)) {
// val arr = variableNameToArr.remove(oldVarName);
// variableNameToArr.put(withName, arr);
// }
if (variableNameToShape.containsKey(oldVarName)) {
val shape = variableNameToShape.remove(oldVarName);
variableNameToShape.put(withName, shape);
}
if (forwardVarForGrad.containsKey(oldVarName)) {
val forwardGrad = forwardVarForGrad.remove(oldVarName);
forwardVarForGrad.put(withName, forwardGrad);
}
if (v.getInputsForOp() != null) {
List funcNames = v.getInputsForOp();
for (String s : funcNames) {
DifferentialFunction func = ops.get(s).getOp();
if (func instanceof BaseOp) {
BaseOp baseOp = (BaseOp) func;
if (baseOp.getXVertexId() != null && baseOp.getXVertexId().equals(oldVarName)) {
baseOp.setXVertexId(withName);
}
if (baseOp.getYVertexId() != null && baseOp.getYVertexId().equals(oldVarName)) {
baseOp.setYVertexId(withName);
}
if (baseOp.getZVertexId() != null && baseOp.getZVertexId().equals(oldVarName)) {
baseOp.setZVertexId(withName);
}
}
}
}
if (v.getOutputOfOp() != null) {
DifferentialFunction func = ops.get(v.getOutputOfOp()).getOp();
if (func instanceof BaseOp) {
BaseOp baseOp = (BaseOp) func;
if (baseOp.getXVertexId() != null && baseOp.getXVertexId().equals(oldVarName)) {
baseOp.setXVertexId(withName);
}
if (baseOp.getYVertexId() != null && baseOp.getYVertexId().equals(oldVarName)) {
baseOp.setYVertexId(withName);
}
if (baseOp.getZVertexId() != null && baseOp.getZVertexId().equals(oldVarName)) {
baseOp.setZVertexId(withName);
}
}
}
}
/**
* Clears debugging state and disables debug mode.
*/
public SameDiff disableDebugging() {
debugMode = false;
return this;
}
/**
* Enables tracing of graphs automatically.
*/
public SameDiff enableDebugMode() {
debugMode = true;
return this;
}
/**
* Returns this samediff instance's {@link DifferentialFunctionFactory}
*
* @return
*/
public DifferentialFunctionFactory f() {
return functionFactory;
}
/**
* @param sameDiff
* @return
*/
public SDVariable invokeGraphOn(SameDiff sameDiff) {
//map the new vertices on to the old ones
Map thisVertexIdToNew = new HashMap<>();
int idx = 1;
for (val var : variables()) {
SDVariable clone = cloner.deepCloneDontCloneInstances(var, var.getSameDiff());
SDVariable newVar = sameDiff.var(clone);
if (var.getArr() != null && var.getVariableType() != VariableType.ARRAY) { //ARRAY type = "activations" - are overwritten anyway
sameDiff.associateArrayWithVariable(var.getArr(), newVar);
}
thisVertexIdToNew.put(idx, idx);
clone.setSameDiff(sameDiff);
idx++;
}
val newFunctions = new LinkedHashMap();
for (SameDiffOp op : ops.values()) {
DifferentialFunction function = op.getOp();
if (function instanceof SDVariable) {
continue;
}
DifferentialFunction clone = cloner.deepCloneDontCloneInstances(
function,
function.getSameDiff());
clone.setSameDiff(sameDiff);
clone.setOwnName(function.getOwnName());
if (sameDiff.functionExists(function.getOwnName()))
sameDiff.putFunctionForId(function.getOwnName(), function);
newFunctions.put(function.getOwnName(), clone);
val argsForFunction = function.args();
val outputsForFunction = function.outputVariables();
//note that these have the same variable names
sameDiff.addArgsFor(argsForFunction, clone);
sameDiff.addOutgoingFor(outputsForFunction, function);
for (val arg : clone.args()) {
arg.setSameDiff(sameDiff);
}
for (val output : clone.outputVariables()) {
output.setSameDiff(sameDiff);
}
sameDiff.ops.put(function.getOwnName(), op);
}
return sameDiff.variables().get(sameDiff.variables().size() - 1);
}
/**
* Returns true if the given function id exists
*
* @param id the function id to test for
* @return true if the function id exists, false otherwise
*/
public boolean functionExists(String id) {
return ops.containsKey(id);
}
public DifferentialFunction functionOutputFor(String varName){
if(variables.get(varName).getOutputOfOp() == null)
return null;
String outName = variables.get(varName).getOutputOfOp();
if(outName == null)
return null;
return ops.get(outName).getOp();
}
/**
* Get the function by the {@link DifferentialFunction#getOwnName()}
*
* @param id the id of the function
* @return the function for the given id if it exists
*/
public DifferentialFunction getFunctionById(@NonNull String id) {
if (!ops.containsKey(id)) {
throw new ND4JIllegalStateException("No function with id " + id + " found!");
}
return ops.get(id).getOp();
}
/**
* Put the function for the given id
*
* @param id the id of the function
* @param function the function
*/
public void putFunctionForId(String id, DifferentialFunction function) {
if (ops.containsKey(id) && ops.get(id).getOp() == null) {
throw new ND4JIllegalStateException("Function by id already exists!");
} else if (function instanceof SDVariable) {
throw new ND4JIllegalStateException("Function must not be a variable!");
}
if(ops.containsKey(id)){
} else {
ops.put(id, SameDiffOp.builder().name(id).op(function).build());
}
}
/**
* Returns the name(s) of the inputs for the given function
*
* @param function the function to get the inputs for
* @return the input ids for a given function
*/
public String[] getInputsForFunction(DifferentialFunction function) {
if (!ops.containsKey(function.getOwnName()))
throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName());
List inputs = ops.get(function.getOwnName()).getInputsToOp();
return inputs == null ? null : inputs.toArray(new String[inputs.size()]);
}
/**
* Returns the name(s) of the outputs for the given function
*
* @param function the function to get the outputs for
* @return the outputs ids for a given function
*/
public String[] getOutputsForFunction(DifferentialFunction function) {
if (!ops.containsKey(function.getOwnName()))
throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName());
List outputs = ops.get(function.getOwnName()).getOutputsOfOp();
return outputs == null ? null : outputs.toArray(new String[outputs.size()]);
}
/**
* Get the output variable(s) for the specified differential function
*
* @param function the function reference to get the output variable(s) for
* @return the output variables for the given function
*/
public SDVariable[] getOutputVariablesForFunction(DifferentialFunction function) {
val inputs = getOutputsForFunction(function);
if (inputs == null) {
throw new ND4JIllegalStateException("No inputs found for function " + function);
}
val vars = new SDVariable[inputs.length];
for (int i = 0; i < inputs.length; i++) {
vars[i] = getVariable(inputs[i]);
}
return vars;
}
/**
* Get the input variable(s) for the specified differential function
*
* @param function the function reference to get the input variable(s) for
* @return the input variables for the given function
*/
public SDVariable[] getInputVariablesForFunction(DifferentialFunction function) {
val inputs = getInputsForFunction(function);
if (inputs == null) {
throw new ND4JIllegalStateException("No inputs found for function " + function);
}
val vars = new SDVariable[inputs.length];
for (int i = 0; i < inputs.length; i++) {
vars[i] = getVariable(inputs[i]);
if (vars[i] == null) {
throw new ND4JIllegalStateException("Found null variable at index " + i);
}
}
return vars;
}
public void setArrayForVariable(@NonNull String varName, @NonNull INDArray arr){
Preconditions.checkState(variables.containsKey(varName), "No variable with name \"%s\" exists", varName);
SDVariable v = getVariable(varName);
if(v.isConstant()) {
constantArrays.put(varName, new DeviceLocalNDArray(arr));
} else if(v.getVariableType() == VariableType.VARIABLE) {
variablesArrays.put(varName, new DeviceLocalNDArray(arr));
} else if(v.isPlaceHolder()){
long tid = Thread.currentThread().getId();
if(!placeholdersPerThread.containsKey(tid)){
placeholdersPerThread.put(tid, new HashMap());
}
placeholdersPerThread.get(tid).put(varName, arr);
} else {
throw new UnsupportedOperationException("Cannot set variable of type " + v.getVariableType() + " using this method");
}
}
/**
* Get the shape for the given vertex id.
* Note that if an array is defined, it will use the shape of the array instead.
*
* A shape *and* an array should not be defined at the same time.
* This wastes memory. The internal map used for tracking shapes for particular
* vertex ids should also delete redundant shapes stored to avoid redundant sources of information.
*
* @param varName the vertex id to get the shape for
* @return the shape for the given vertex if any.
*/
public long[] getShapeForVarName(String varName) {
if (arrayAlreadyExistsForVarName(varName)) {
return getVariable(varName).getArr().shape();
}
return variableNameToShape.get(varName);
}
public LongShapeDescriptor getShapeDescriptorForVarName(String varName) {
if (getVariable(varName).getArr() != null) {
return getVariable(varName).getArr().shapeDescriptor();
}
// FIXME: do we really want this Nd4j.dataType() here?
return LongShapeDescriptor.fromShape(variableNameToShape.get(varName), Nd4j.dataType());
}
/**
* Associate a vertex id with the given shape.
*
* @param varName the vertex id to associate
* @param shape the shape to associate with
* @see #putShapeForVarName(String, long[])
* @see #putOrUpdateShapeForVarName(String, long[], boolean)
*/
@Deprecated
public void putShapeForVarName(String varName, long[] shape) {
if (shape == null) {
throw new ND4JIllegalStateException("Shape must not be null!");
}
if (variableNameToShape.containsKey(varName)) {
throw new ND4JIllegalStateException("Shape for " + varName + " already exists!");
}
variableNameToShape.put(varName, shape);
}
public void putShapeForVarName(String varName, LongShapeDescriptor shape) {
val v = getVariable(varName);
putShapeForVarName(varName, shape.getShape());
v.setDataType(shape.dataType());
}
/**
* Put or update the shape for the given variable name. Optionally supports clearing the specified variable's
* INDArray if it's shape does not match the new shape
* @param varName Variable name
* @param shape Shape to put
* @param clearArrayOnShapeMismatch If false: no change to arrays. If true: if an INDArray is defined for the specified
* variable name, it will be removed from the graph (to be later re-generated) if
* its shape does not match the specified shape
*/
@Deprecated
public void putOrUpdateShapeForVarName(String varName, long[] shape, boolean clearArrayOnShapeMismatch){
Preconditions.checkNotNull(shape, "Cannot put null shape for variable: %s", varName);
if(variableNameToShape.containsKey(varName)){
// updateShapeForVarName(varName, shape, clearArrayOnShapeMismatch);
//TODO
} else {
putShapeForVarName(varName, shape);
}
}
/**
* Returns true if the given vertex id and shape already exist.
*
* @param varName the vertex id
* @return true if the ndarray and vertex id already exist
*/
public boolean shapeAlreadyExistsForVarName(String varName) {
return variableNameToShape.containsKey(varName) || arrayAlreadyExistsForVarName(varName);
}
/**
* Returns true if the given vertex id and {@link INDArray} already exist.
*
* @param varName the vertex id
* @return true if a vertex with the given INDArray exists, and it has an INDArray associated with it
*/
public boolean arrayAlreadyExistsForVarName(String varName) {
SDVariable var = getVariable(varName);
switch(var.getVariableType()){
case VARIABLE:
return variablesArrays.containsKey(varName);
case ARRAY:
long tid = Thread.currentThread().getId();
return sessions.containsKey(tid) && sessions.get(tid).contains(varName, InferenceSession.OUTER_FRAME, 0, null);
case CONSTANT:
return constantArrays.containsKey(varName);
case PLACEHOLDER:
return placeholdersPerThread.containsKey(Thread.currentThread().getId()) &&
placeholdersPerThread.get(Thread.currentThread().getId()).containsKey(varName);
default:
throw new RuntimeException("Unknown variable type: " + var.getVariableType());
}
}
/**
* Get an {@link INDArray} for a given vertex id, or null if none exists
*
* @param varName Variable name to get the array for
* @return Array, or null if none exists
*/
public INDArray getArrForVarName(@NonNull String varName) {
Preconditions.checkState(variables.containsKey(varName), "No variable found with name \"%s\"", varName);
SDVariable v = variables.get(varName).getVariable();
switch(v.getVariableType()){
case VARIABLE:
if(!variablesArrays.containsKey(varName)) {
//VARIBALE type arrays should have a parameter initializer...
// we should use this to azy init the array if none is present
v.storeAndAllocateNewArray();
}
return variablesArrays.get(varName).get();
case CONSTANT:
if(!constantArrays.containsKey(varName))
return null;
return constantArrays.get(varName).get();
case ARRAY:
//Only stored in inference session...
InferenceSession s = sessions.get(Thread.currentThread().getId());
if(s == null)
return null;
return s.get(varName, InferenceSession.OUTER_FRAME, 0, null, false);
case PLACEHOLDER:
long tid = Thread.currentThread().getId();
if(placeholdersPerThread.get(tid) == null || !placeholdersPerThread.get(tid).containsKey(varName))
return null;
return placeholdersPerThread.get(tid).get(varName);
default:
throw new RuntimeException("Unknown variable type: " + v.getVariableType());
}
}
/**
* Associate the array with the given variable.
*
* @param arr the array to get the variable for
* @param variable the name of the variable to associate the array with
*/
public void associateArrayWithVariable(INDArray arr, @NonNull String variable) {
Preconditions.checkState(variables.containsKey(variable), "Cannot associate array with variable \"%s\": " +
"variable \"%s\" does not exist in this SameDiff instance", variable, variable);
associateArrayWithVariable(arr, this.getVariable(variable));
}
/**
* Associate the array with the given variable.
*
* @param arr the array to get the variable for
* @param variable the variable to associate the array with
*/
public void associateArrayWithVariable(INDArray arr, SDVariable variable) {
if (variable == null) {
throw new ND4JIllegalArgumentException("Variable must not be null!");
}
if (arr == null) {
throw new ND4JIllegalArgumentException("Array must not be null");
}
if (variable.dataType() != arr.dataType())
arr = arr.castTo(variable.dataType());
Preconditions.checkState(variable.dataType() == arr.dataType(), "Variable \"%s\" has datatype %s: cannot associate array with type %s with this variable",
variable.getVarName(), variable.dataType(), arr.dataType());
// FIXME: remove this before release
if (sessions.get(Thread.currentThread().getId()) == null) {
sessions.put(Thread.currentThread().getId(), new InferenceSession(this));
}
boolean duped = false;
if(arr.isAttached()) {
arr = arr.detach();
duped = true;
}
if(arr.isView()) {
arr = arr.dup();
duped = true;
}
if(!duped && variable.getVariableType() == VariableType.VARIABLE) {
for (DeviceLocalNDArray otherArr : variablesArrays.values()) {
if (otherArr.get() == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour)
arr = arr.dup();
break;
}
}
}
switch(variable.getVariableType()){
case VARIABLE:
variablesArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr));
break;
case CONSTANT:
constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr));
break;
case ARRAY:
// FIXME: remove this before release
val session = sessions.get(Thread.currentThread().getId());
val varId = session.newVarId(variable.getVarName(), AbstractSession.OUTER_FRAME, 0, null);
session.getNodeOutputs().put(varId, arr);
//throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY");
case PLACEHOLDER:
long tid = Thread.currentThread().getId();
if(!placeholdersPerThread.containsKey(tid)){
placeholdersPerThread.put(tid, new HashMap());
}
placeholdersPerThread.get(tid).put(variable.getVarName(), arr);
break;
default:
throw new IllegalStateException("Unknown variable type: " + variable.getVariableType());
}
//putOrUpdateShapeForVarName(variable.getVarName(), arr.shape(), true);
//Also update nested SameDiff instances (such as gradient function)
if(sameDiffFunctionInstances != null && sameDiffFunctionInstances.size() > 0){
for(Map.Entry e : sameDiffFunctionInstances.entrySet()){
SameDiff sd = e.getValue();
SDVariable v = sd.getVariable(variable.getVarName());
if(v != null){
sd.associateArrayWithVariable(arr, v);
}
}
}
}
/**
* Associate a {@link SameDiff} namespace as a sub function.
*
* @param name the opName of the function
* @param nameSpace the namespace
*/
public void putSubFunction(String name, SameDiff nameSpace) {
if (sameDiffFunctionInstances.containsKey(name) && sameDiffFunctionInstances.get(name) != nameSpace) {
throw new ND4JIllegalStateException("Unable to replace samediff namespace. Please choose another opName");
}
sameDiffFunctionInstances.put(name, nameSpace);
}
/**
* Return a copy of the internal variable map
*
* @return Map of variables by name
*/
public Map variableMap() {
Map ret = new LinkedHashMap<>();
for(Variable v : variables.values()){
ret.put(v.getName(), v.getVariable());
}
return ret;
}
/**
* Invoke an op by opName
*
* @param op the op
* @param x the first input
* @param y the second input
* @return the result variable
*/
@Deprecated //TO BE REMOVED - should not be part of public API
public SDVariable invoke(Op op, SDVariable x, SDVariable y) {
if (!opMethods.containsKey(op.opName())) {
throw new ND4JIllegalStateException("Illegal method opName " + op.opName());
}
if (x != null && y != null) {
try {
return (SDVariable) opMethods.get(op.opName()).invoke(this, x, y);
} catch (Exception e) {
}
} else {
try {
return (SDVariable) opMethods.get(op.opName()).invoke(this, x);
} catch (Exception e) {
}
}
throw new ND4JIllegalStateException("Illegal method opName " + op.opName());
}
/**
* The set of defined SameDiff function names. SameDiff function instances should not be confused
* with DifferentialFunction ops; an example of a SameDiff function instance is the gradient "grad" function
*
* @return Set of defined SameDiff function instance names
*/
public Collection definedFunctionNames() {
return this.sameDiffFunctionInstances.keySet();
}
/**
* Invoke an op by opName
*
* @param op the op
* @param x the first input
* @return the result variable
*/
public SDVariable invoke(Op op, SDVariable x) {
return invoke(op, x, null);
}
private SameDiff() {
functionFactory = new DifferentialFunctionFactory(this);
sameDiffFunctionDefinitionMap = new LinkedHashMap<>();
sameDiffFunctionInstances = new LinkedHashMap<>();
forwardVarForGrad = new LinkedHashMap<>();
opsForResult = new IntArrayKeyMap<>();
variableNameToShape = new LinkedHashMap<>();
placeHolderOriginalShapes = new LinkedHashMap<>();
placeHolderFunctions = new LinkedHashSet<>();
baseNameForFunctionInstanceId = new LinkedHashMap<>();
propertiesToResolve = new LinkedHashMap<>();
propertiesForFunction = new LinkedHashMap<>();
fieldVariableResolutionMapping = HashBasedTable.create();
}
/**
* Adds a property that needs to be resolve for later.
* These variables are typically values that are arrays
* that are named but have an unknown value till execution time.
*
* This is very common for model import.
*
* @param forFunction the function to add the property to resolve for
* @param arrayName the array name
*/
public void addPropertyToResolve(DifferentialFunction forFunction, String arrayName) {
if (!propertiesToResolve.containsKey(forFunction.getOwnName())) {
List newVal = new ArrayList<>();
newVal.add(arrayName);
propertiesToResolve.put(forFunction.getOwnName(), newVal);
} else {
List newVal = propertiesToResolve.get(forFunction.getOwnName());
newVal.add(arrayName);
}
}
/**
* Return the properties to resolve for the given function.
* This is typically used right before execution in model import in
* {@link DifferentialFunction#resolvePropertiesFromSameDiffBeforeExecution()}
*
* @param function the function get the properties to resolve for
* @return the properties to resolve for the given function
*/
public List propertiesToResolveForFunction(DifferentialFunction function) {
if (!propertiesToResolve.containsKey(function.getOwnName()))
return Collections.emptyList();
return propertiesToResolve.get(function.getOwnName());
}
private void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, Object propertyValue) {
if (!propertiesForFunction.containsKey(functionFor.getOwnName())) {
Map fields = new LinkedHashMap<>();
fields.put(propertyName, propertyValue);
propertiesForFunction.put(functionFor.getOwnName(), fields);
} else {
val fieldMap = propertiesForFunction.get(functionFor.getOwnName());
if (fieldMap.containsKey(propertyName)) {
throw new ND4JIllegalStateException("Attempting to override property " + propertyName);
}
fieldMap.put(propertyName, propertyValue);
}
}
/**
* Adds a field name -> variable name mapping for a given function.
* This is used for model import where there is an unresolved variable at the time of calling any
* {@link org.nd4j.imports.graphmapper.GraphMapper#importGraph(File)}
* .
*
* This data structure is typically accessed during {@link DifferentialFunction#resolvePropertiesFromSameDiffBeforeExecution()}
*
* When a function attempts to resolve variables right before execution, there needs to be a way of knowing
* which variable in a samediff graph should map to a function's particular field name
*
* @param function the function to map
* @param fieldName the field name for the function to map
* @param varName the variable name of the array to get from samediff
*/
public void addVariableMappingForField(DifferentialFunction function, String fieldName, String varName) {
fieldVariableResolutionMapping.put(function.getOwnName(), fieldName, varName);
}
/**
* Get the variable name to use
* for resolving a given field
* for a given function during import time.
* This method is u sed during {@link DifferentialFunction#resolvePropertiesFromSameDiffBeforeExecution()}
*
* @param function the function to get the variable name for
* @param fieldName the field name to resolve for
* @return the resolve variable name if any
*/
public String getVarNameForFieldAndFunction(DifferentialFunction function, String fieldName) {
return fieldVariableResolutionMapping.get(function.getOwnName(), fieldName);
}
/**
* Sets a base name for the function id.
* This is used for when calling {@link #generateOutputVariableForOp(DifferentialFunction, String)}
* for ensuring original names for model import map to current samediff names
* when names are generated.
*
* @param baseName the base name to add
* @param function the function to declare a base name for.
*/
public void setBaseNameForFunctionInstanceId(String baseName, DifferentialFunction function) {
baseNameForFunctionInstanceId.put(function.getOwnName(), baseName);
}
/**
* Returns the base name for the given function
* if any (may return null)
*
* @param function the function to get the base name for
* @return the base name for the given function (if any) based
* on the function's instance id.
*/
public String getBaseNameForFunction(DifferentialFunction function) {
return baseNameForFunctionInstanceId.get(function.getOwnName());
}
/**
* Attempts to insert the {@link DifferentialFunction} reference in to this {@link SameDiff} instance.
* If the given array field with the given index already exists, it will do a reference check to ensure that the 2
* array fields are the same. If not, an exception is thrown.
* If the instances are the same (by semantics, not reference) then it will just return the original instance.
* This is to ensure that instances that are created are unique and reference checked.
*
* @param function the array field to attempt to create
* @return Original instance
*/
public X setupFunction(X function) {
Preconditions.checkNotNull(function, "Passed in function must not be null!");
if (function instanceof SDVariable) {
if (function.getSameDiff() != this) {
function.setSameDiff(this);
}
return function;
}
return function;
}
/**
* Adds outgoing arguments to the graph for the specified DifferentialFunction
* Also checks for input arguments and updates the graph adding an appropriate edge when the full graph is declared.
*
* @param variables Variables - arguments for the specified differential function
* @param function Differential function
*/
public void addOutgoingFor(SDVariable[] variables, DifferentialFunction function) {
String[] varNames = new String[variables.length];
for (int i = 0; i < varNames.length; i++) {
varNames[i] = variables[i].getVarName();
}
addOutgoingFor(varNames, function);
}
/**
* Adds outgoing arguments to the graph for the specified DifferentialFunction
* Also checks for input arguments and updates the graph adding an appropriate edge when the full graph is declared.
*
* @param varNames Name of the variables that are outputs of the specified differential function
* @param function Differential function
*/
public void addOutgoingFor(String[] varNames, DifferentialFunction function) {
if (function.getOwnName() == null)
throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
if (ops.get(function.getOwnName()).getOutputsOfOp() != null && !ops.get(function.getOwnName()).getOutputsOfOp().isEmpty()) {
throw new ND4JIllegalStateException("Outgoing arguments already declared for " + function);
}
if (varNames == null)
throw new ND4JIllegalStateException("Var names can not be null!");
for (int i = 0; i < varNames.length; i++) {
if (varNames[i] == null)
throw new ND4JIllegalStateException("Variable name elements can not be null!");
}
ops.get(function.getOwnName()).setOutputsOfOp(Arrays.asList(varNames));
for (String resultName : varNames) {
variables.get(resultName).setOutputOfOp(function.getOwnName());
}
}
/**
* Adds incoming arguments for the specified differential function to the graph
*
* @param variables Name of the variables that are arguments (inputs) to the specified function
* @param function Function
*/
public void addArgsFor(String[] variables, DifferentialFunction function) {
if (function.getOwnName() == null)
throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
//double check if function contains placeholder args
for (val varName : variables) {
if (isPlaceHolder(varName)) {
placeHolderFunctions.add(function.getOwnName());
}
}
//Add function if it doesn't exist
//TODO could "not existing" be a bug sometimes?
if(!ops.containsKey(function.getOwnName())){
ops.put(function.getOwnName(), SameDiffOp.builder().name(function.getOwnName()).op(function).build());
}
//Update variable 'inputs to op' accounting for repeated inputs (like y = x+x)
ops.get(function.getOwnName()).setInputsToOp(Arrays.asList(variables)); //Duplicate variables OK/required here
for (String variableName : variables) {
List funcs = this.variables.get(variableName).getInputsForOp();
if (funcs == null) {
funcs = new ArrayList<>();
this.variables.get(variableName).setInputsForOp(funcs);
}
if(!funcs.contains(function.getOwnName())) //Avoid duplicates for function names.
funcs.add(function.getOwnName());
}
}
/**
* Adds incoming arguments for the specified differential function to the graph
*
* @param variables variables that are arguments (inputs) to the specified function
* @param function Function
*/
public void addArgsFor(SDVariable[] variables, DifferentialFunction function) {
String[] varNames = new String[variables.length];
for (int i = 0; i < varNames.length; i++) {
if (variables[i] == null)
throw new ND4JIllegalStateException("Found null variable at index " + i);
varNames[i] = variables[i].getVarName();
}
addArgsFor(varNames, function);
}
/**
* Get the differential function (if any) that this variable is the output for
*
* @param variableName Name of the variable
* @return The differential function that this variable is an output of, or null if it is not the output of a function
*/
public DifferentialFunction getVariableOutputFunction(String variableName) {
Preconditions.checkState(variables.containsKey(variableName), "No variable with name \"%s\" found in graph", variableName);
if(variables.get(variableName).getOutputOfOp() == null)
return null;
return ops.get(variables.get(variableName).getOutputOfOp()).getOp();
}
/**
* Returns true if this function already has defined arguments
*
* @param function the function to check
* @return true if the function has args, false otherwise
*/
public boolean hasArgs(DifferentialFunction function) {
List vertexIdArgs = ops.get(function.getOwnName()).getInputsToOp();
return vertexIdArgs != null && vertexIdArgs.size() > 0;
}
/**
* Get an array of differential functions that have been defined for this SameDiff instance
* @return Array of differential functions
*/
public DifferentialFunction[] functions() {
List out = new ArrayList<>(ops.size());
for(SameDiffOp op : ops.values()){
out.add(op.getOp());
}
return out.toArray(new DifferentialFunction[out.size()]);
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + (variables != null ? variables.hashCode() : 0);
return result;
}
/**
* Create a new SameDiff instance from an existing instance.
* Note that state (variables and functions) is shared between the two SameDiff instance
*
* @param originalSameDiff Original SameDiff instance
* @return Copy
*/
public static SameDiff create(SameDiff originalSameDiff) {
SameDiff ret = SameDiff.builder()
.sameDiffFunctionInstances(originalSameDiff.sameDiffFunctionInstances)
.build();
ret.variables.putAll(originalSameDiff.variables);
//ensuring proper sameDiff reference
DifferentialFunctionFactory differentialFunctionFactory = new DifferentialFunctionFactory(ret);
ret.functionFactory = differentialFunctionFactory;
return ret;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SameDiff sameDiff = (SameDiff) o;
if (variables != null ? !variables.equals(sameDiff.variables) : sameDiff.variables != null)
return false;
if (sameDiffFunctionDefinitionMap != null ? !sameDiffFunctionDefinitionMap.equals(sameDiff.sameDiffFunctionDefinitionMap) : sameDiff.sameDiffFunctionDefinitionMap != null)
return false;
return sameDiffFunctionInstances != null ? sameDiffFunctionInstances.equals(sameDiff.sameDiffFunctionInstances) : sameDiff.sameDiffFunctionInstances == null;
}
/**
* Create a new (empty) SameDiff instance without any functions or variables
* @return New SameDiff instance
*/
public static SameDiff create() {
return new SameDiff();
}
/**
* Clone/duplicate the SameDiff instance, including arrays etc. The returned SameDiff instance should have no
* shared state with the original instance
* @return The cloned SameDiff instance
*/
public SameDiff dup() {
Cloner cloner = newCloner();
SameDiff clone = cloner.deepClone(this);
//TODO don't clone sessions in the first place!
clone.sessions.clear();
return clone;
}
/**
* Count the number of elements in all arrays, according to {@link SDVariable#getShape()}
* @return Number of array elements for all variables
*/
public long numElements() {
long ret = 0;
for (SDVariable variable : variables()) {
long[] shape = variable.getShape();
if(shape != null) {
ret += ArrayUtil.prod(shape);
}
}
return ret;
}
/**
* Returns the inputs (placeholders)
* for the samediff graph
* @return the inputs for this graph
*/
public List inputs() {
List out = new ArrayList<>();
for(String s : variables.keySet()){
if(isPlaceHolder(s))
out.add(s);
}
return out;
}
/**
* Outputs are those variables (not placeholders, constants, etc) that are the output of a function that aren't the
* input to any other ops.
* Usually these are the output of the last function(s) in the SameDiff instance.
* @return The (inferred) outputs of the SameDiff instance, in no particular order
*/
public List outputs(){
List out = new ArrayList<>();
for(Variable v : variables.values()){
if(v.getVariable().isConstant() || v.getVariable().isPlaceHolder() || //Exclude constants and placeholders
(v.getInputsForOp() != null && !v.getInputsForOp().isEmpty()) || //Exclude variables that are inputs to ops
(v.getControlDepsForOp() != null && !v.getControlDepsForOp().isEmpty()) || //Exclude variables that are control dependency inputs to ops
(v.getControlDepsForVar() != null && !v.getControlDepsForVar().isEmpty())) { //Exclude variables that are control dependency inputs to other variables (mainly for import of cond etc ops)
continue;
}
//Also exclude assert etc ops - doesn't make sense to return these "outputs" to user
if(v.getOutputOfOp() != null){
String opName = v.getOutputOfOp();
SameDiffOp o = ops.get(opName);
if(o.getOp() instanceof Assert){
continue;
}
//A bit of a hack for TF import: some TF graphs have Switch ops, where the output of one branch isn't consumed
// by any ops. Consequently, during execution this "output" might never be available. So we'll exclude the output of execution here
if(o.getOp() instanceof Switch){
continue;
}
}
out.add(v.getName());
}
return out;
}
/**
* The list of all variables in the graph
*
* @return All variables in the graph
*/
public List variables() {
return new ArrayList<>(variableMap().values());
}
/**
* Get the names of variables (if any) that have been marked as loss variables to be minimized.
* Variables can be marked as loss variables in a few different ways:
* (a) Losses are automatically added when creating loss functions via {@link #sd()}
* (b) Via {@link #setLossVariables(String...)}, @link #addLossVariable(String)} or {@link SDVariable#markAsLoss()}
* (c) Via {@link TrainingConfig#setLossVariables(List)}
*/
public List getLossVariables(){
return Collections.unmodifiableList(this.lossVariables);
}
/**
* Clear/remove any existing loss variables, and set the loss variables to the specified variable names.
* See {@link #addLossVariable(String)} for more details
* @param lossVariableNames Names of variables to be loss function variables
*/
public void setLossVariables(String... lossVariableNames){
this.lossVariables.clear();
for(String s : lossVariableNames){
addLossVariable(s);
}
//After changing loss function variables, we (probably) need to recreate gradient function - as gradient
// function is defined with respect to specific loss function variables
sameDiffFunctionInstances.remove("grad");
}
/**
* Mark the specified variable as a loss function variable. This means that this variable will be minimized via backprop during training.
* This will add the variable as a loss to any others - i.e., if multiple variables are marked as losses, their values will be summed
* to give the total network loss.
* Note that only floating point (Float16/32/64) variables may be marked as a loss.
* Note also that only ARRAY type SDVariables can be marked as losses to be minimized. That is, we cannot mark the value
* of a constant, variable or placeholder to be minimized as doing so would not make sense.
*/
public void addLossVariable(@NonNull String variableName){
Preconditions.checkState(hasVariable(variableName), "No variable with name \"%s\" exists", variableName);
SDVariable v = getVariable(variableName);
Preconditions.checkState(v.dataType().isFPType(), "Only floating point type variables can be marked as losses to be minimized." +
" SDVariable \"%s\" has datatype %s", variableName, v.dataType());
Preconditions.checkState(v.getVariableType() == VariableType.ARRAY, "Only ARRAY type SDVariables can be marked as losses to be minimized." +
" SDVariable \"%s\" has variable type %s", variableName, v.getVariableType());
if(!lossVariables.contains(variableName)){
lossVariables.add(variableName);
}
}
/**
* Set the training configuration ({@link TrainingConfig}) for the SameDiff instance.
* A TrainingConfig must be set before the SameDiff instance can be trained via the fit methods
* @param trainingConfig Training configuration
*/
public void setTrainingConfig(TrainingConfig trainingConfig){
this.trainingConfig = trainingConfig;
}
/**
* Fit the SameDiff instance based on a single DataSet (i.e., a single minibatch for one iteration).
* This method can only be used for singe input, single output SameDiff instances as DataSet only supports a
* single input and a single output.
* Note that a {@link TrainingConfig} must be set via {@link #setTrainingConfig(TrainingConfig)} before training can
* be performed.
*
* @param dataSet The DataSet (single minibatch) to peform training on
*/
public void fit(DataSet dataSet){
fit(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), 1, false);
}
/**
* Fit the SameDiff instance based on a single MultiDataSet (i.e., a single minibatch for one iteration).
* Note that a {@link TrainingConfig} must be set via {@link #setTrainingConfig(TrainingConfig)} before training can
* be performed.
*
* @param dataSet The DataSet (single minibatch) to peform training on
*/
public void fit(MultiDataSet dataSet){
fit(new SingletonMultiDataSetIterator(dataSet), 1, false);
}
/**
* Fit the SameDiff instance based on DataSetIterator for the specified number of epochs.
* This method can only be used for singe input, single output SameDiff instances as DataSet only supports a
* single input and a single output.
* Note that a {@link TrainingConfig} must be set via {@link #setTrainingConfig(TrainingConfig)} before training can
* be performed.
*
* @param iter The iterator to train the SameDiff instance with
* @param numEpochs The number of epochs for training. Must be > 0
*/
public void fit(DataSetIterator iter, int numEpochs) {
fit(new MultiDataSetIteratorAdapter(iter), numEpochs, true);
}
/**
* Fit the SameDiff instance based on MultiDataSetIterator for the specified number of epochs.
* This method can both singe input, single output and multi-input, multi-output SameDiff instances
* Note that a {@link TrainingConfig} must be set via {@link #setTrainingConfig(TrainingConfig)} before training can
* be performed.
*
* @param iter The iterator to train the SameDiff instance with
* @param numEpochs The number of epochs for training. Must be > 0
*/
public void fit(MultiDataSetIterator iter, int numEpochs){
fit(iter, numEpochs, true);
}
//Synchronized for thread safety
protected synchronized void fit(MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount){
Preconditions.checkNotNull(iter, "Iterator must not be null");
Preconditions.checkState(numEpochs > 0, "Number of training epochs must be a positive number. Got: %s", numEpochs);
Preconditions.checkState(trainingConfig != null, "No training configuration has been set. A training configuration must " +
"be set before training. Use setTrainingConfig(TrainingConfig)");
Preconditions.checkState(numEpochs == 1 || iter.resetSupported(), "Cannot train for multiple epochs on an iterator that" +
" does not support resetting");
if(!iter.hasNext() && iter.resetSupported())
iter.reset();
boolean performedValidation = false;
for(int i = 0; i < numEpochs; i++) {
while (iter.hasNext()) {
org.nd4j.linalg.dataset.api.MultiDataSet ds = iter.next();
if(!performedValidation){
Preconditions.checkState(trainingConfig.getDataSetFeatureMapping().size() == ds.numFeatureArrays(),
"The number of dataset feature mapping variables set in the training configuration (%s) must match" +
" the number of dataset feature arrays (%s)", trainingConfig.getDataSetFeatureMapping().size(), ds.numFeatureArrays());
List labelMapping = trainingConfig.getDataSetLabelMapping();
int lblSize = labelMapping == null ? 0 : labelMapping.size();
Preconditions.checkState(lblSize == ds.numLabelsArrays(),
"The number of dataset label mapping variables set in the training configuration (%s) must match" +
" the number of dataset label arrays (%s)", lblSize, ds.numLabelsArrays());
performedValidation = true;
}
//Create placeholder variable map
Map placeholders = toPlaceholderMap(ds);
Preconditions.checkState(placeholders.size() > 0, "No placeholder variables were set for training");
resolveVariablesWith(placeholders);
//Calculate gradients:
execBackwards(placeholders);
//Apply updater:
if (!initializedTraining)
initializeTraining();
int iteration = trainingConfig.getIterationCount();
int e = trainingConfig.getEpochCount();
for (String s : trainingConfig.getTrainableParams()) {
//TODO fix using inference session
INDArray param = variables.get(s).getVariable().getArr();
SDVariable gradVar = variables.get(s).getVariable().getGradient();
if(gradVar == null){
//Not all trainable parameters have gradients defined.
//Consider graph: in1->loss1; in2->loss2, where we optimize only loss1.
//No gradient will be present for in2, because in2 doesn't impact loss1 at all
continue;
}
INDArray grad = gradVar.getArr();
//Note: don't need to divide by minibatch - that should be handled in loss function and hence loss function gradients,
// which should flow through to here
//Pre-apply regularization (L1, L2)
List r = trainingConfig.getRegularization();
int iterCount = trainingConfig.getIterationCount();
int epochCount = trainingConfig.getEpochCount();
double lr = trainingConfig.getUpdater().hasLearningRate() ? trainingConfig.getUpdater().getLearningRate(iteration, epochCount) : 1.0;
if(r != null && r.size() > 0){
for(Regularization reg : r){
if(reg.applyStep() == Regularization.ApplyStep.BEFORE_UPDATER){
reg.apply(param, grad, lr, iterCount, epochCount);
}
}
}
//Apply updater. Note that we need to reshape to [1,length] for updater
INDArray reshapedView = Shape.newShapeNoCopy(grad, new long[]{1, grad.length()}, grad.ordering() == 'f'); //TODO make sure we always reshape in same order!
Preconditions.checkState(reshapedView != null, "Error reshaping array for parameter \"%s\": array is a view?", s);
GradientUpdater u = updaterMap.get(s);
try {
u.applyUpdater(reshapedView, iteration, e);
} catch (Throwable t) {
throw new RuntimeException("Error applying updater " + u.getClass().getSimpleName() + " to parameter \"" + s
+ "\": either parameter size is inconsistent between iterations, or \"" + s + "\" should not be a trainable parameter?", t);
}
//Post-apply regularization (weight decay)
if(r != null && r.size() > 0){
for(Regularization reg : r){
if(reg.applyStep() == Regularization.ApplyStep.POST_UPDATER){
reg.apply(param, grad, lr, iterCount, epochCount);
}
}
}
if (trainingConfig.isMinimize()) {
param.subi(grad);
} else {
param.addi(grad);
}
}
trainingConfig.incrementIterationCount();
}
if(i < numEpochs - 1) {
iter.reset();
}
if(incrementEpochCount)
trainingConfig.incrementEpochCount();
}
}
/**
* Calculate the regularization (L1, L2 and/or WeightDecay) component of the loss function for the current parameters..
* Note that the training configuration must be set (via {@link #setTrainingConfig(TrainingConfig)}) before this
* method can be called
*
* @return The regularization component of the score/loss function
*/
public double calcRegularizationScore() {
Preconditions.checkState(trainingConfig != null, "No training configuration has been set. A training configuration must " +
"be set before calculating the L2 loss. Use setTrainingConfig(TrainingConfig)");
if(trainingConfig.getRegularization() == null || trainingConfig.getRegularization().isEmpty()){
return 0.0;
}
if(trainingConfig.getTrainableParams() == null || trainingConfig.getTrainableParams().isEmpty())
initializeTraining();
List l = trainingConfig.getRegularization();
double loss = 0.0;
for (String s : trainingConfig.getTrainableParams()) {
for(Regularization r : l){
INDArray arr = getVariable(s).getArr();
loss += r.score(arr, trainingConfig.getIterationCount(), trainingConfig.getEpochCount());
}
}
return loss;
}
/**
* Perform setup for training. Does the following:
* 1. Infer the set of trainable parameters - unless specified manually by the user
* 2. Set up the updaters
*/
protected void initializeTraining(){
if(!initializedTraining) {
if(trainingConfig == null) {
throw new ND4JIllegalStateException("Please specify a training config with setTrainingConfig");
}
//First: infer the variables to be optimized if required
if(trainingConfig.getTrainableParams() == null || trainingConfig.getTrainableParams().size() == 0) {
//Variable is trainable if it's not the output of some function
//TODO also - should be floating point type
List trainVarList = new ArrayList<>();
for(Variable var : variables.values()){
SDVariable v = var.getVariable();
String n = v.getVarName();
if(variables.get(n).getOutputOfOp() == null && //Is a leaf (not the output of a function)
!isPlaceHolder(n) && //and not a placeholder
!variables.get(n).getVariable().isConstant() && //and not a constant
(trainingConfig.getDataSetFeatureMapping() == null || !trainingConfig.getDataSetFeatureMapping().contains(n)) && //and not an input (this really should be a placeholder, but we can't guarantee that...)
(trainingConfig.getDataSetLabelMapping() == null || !trainingConfig.getDataSetLabelMapping().contains(n)) && //and not a label (this really should be a placeholder, but we can't guarantee that...)
(trainingConfig.getDataSetFeatureMaskMapping() == null || !trainingConfig.getDataSetFeatureMaskMapping().contains(n)) && //and not a feature mask (this really should be a placeholder, but we can't guarantee that...)
(trainingConfig.getDataSetLabelMaskMapping() == null || !trainingConfig.getDataSetLabelMaskMapping().contains(n))){ //and not a label input (this really should be a placeholder, but we can't guarantee that...)
trainVarList.add(n);
}
}
trainingConfig.setTrainableParams(trainVarList);
log.info("Inferred trainable variables: {}", trainVarList);
}
//Allocate updater state
long numTrainableParams = 0;
DataType dt = null; //TODO support mixed precision variables - https://github.com/deeplearning4j/deeplearning4j/issues/6992
for(String s : trainingConfig.getTrainableParams()) {
SDVariable v = variables.get(s).getVariable();
Preconditions.checkState(v != null, "No variable found for trainable parameter name \"%s\"", s);
INDArray arr = v.getArr();
Preconditions.checkState(arr != null, "No array found for trainable parameter \"%s\"", s);
numTrainableParams += arr.length();
if(dt == null)
dt = arr.dataType();
}
long updaterStateSize = trainingConfig.getUpdater().stateSize(numTrainableParams);
if(updaterStateSize > 0) {
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
updaterState = Nd4j.createUninitialized(dt, 1, updaterStateSize);
}
}
long viewSoFar = 0;
updaterViews = new HashMap<>();
updaterMap = new HashMap<>();
for(String s : trainingConfig.getTrainableParams()) {
long thisSize = trainingConfig.getUpdater().stateSize(variables.get(s).getVariable().getArr().length());
INDArray view = (updaterStateSize == 0 || thisSize == 0 ? null :
updaterState.get(NDArrayIndex.interval(0, 1), NDArrayIndex.interval(viewSoFar, viewSoFar + thisSize)));
updaterViews.put(s, view);
updaterMap.put(s, trainingConfig.getUpdater().instantiate(view, true));
viewSoFar += thisSize;
}
initializedTraining = true;
}
}
/**
* Convert the MultiDataSet to a {@code Map} based on the TrainingConfig settings.
* The key is the placeholder/variable that the value INDArray should be associated with.
*
* @param ds MultiDataSet - source of the features/labels
* @return MultiDataSet converted to a Map, based on TrainingConfig
*/
private Map toPlaceholderMap(org.nd4j.linalg.dataset.api.MultiDataSet ds) {
Map placeholders = new HashMap<>();
int count = 0;
for(String s : trainingConfig.getDataSetFeatureMapping()){
placeholders.put(s, ds.getFeatures(count++));
}
count = 0;
if(trainingConfig.getDataSetLabelMapping() != null) {
//Labels may be null in some models (unsupervised etc)
for (String s : trainingConfig.getDataSetLabelMapping()) {
placeholders.put(s, ds.getLabels(count++));
}
}
if(trainingConfig.getDataSetFeatureMaskMapping() != null && trainingConfig.getDataSetFeatureMaskMapping().size() > 0){
count = 0;
for(String s : trainingConfig.getDataSetFeatureMaskMapping()){
if(s == null) {
count++;
continue;
}
placeholders.put(s, ds.getFeaturesMaskArray(count++));
}
}
if(trainingConfig.getDataSetLabelMaskMapping() != null && trainingConfig.getDataSetLabelMaskMapping().size() > 0){
count = 0;
for(String s : trainingConfig.getDataSetLabelMaskMapping()){
if(s == null) {
count++;
continue;
}
placeholders.put(s, ds.getLabelsMaskArray(count++));
}
}
return placeholders;
}
/**
* Evaluate the performance of a single variable's prediction.
* For example, if the variable to evaluatate was called "softmax" you would use:
*
* {@code Evaluation e = new Evaluation();
* sameDiff.evaluate(iterator, "softmax", e);}
*
*
* @param iterator Iterator as source of data to evaluate
* @param outputVariable The variable to evaluate
* @param evaluations The evaluations to perform
*/
public void evaluate(DataSetIterator iterator, String outputVariable, IEvaluation... evaluations) {
Preconditions.checkArgument(evaluations != null && evaluations.length > 0, "No evaluations were passed to the evaluate method");
evaluate(new MultiDataSetIteratorAdapter(iterator), Collections.singletonMap(outputVariable, Arrays.asList(evaluations)),
Collections.singletonMap(outputVariable, 0));
}
/**
* Evaluation for multiple-output networks.
* See {@link #evaluate(MultiDataSetIterator, Map, Map)}
*/
public void evaluate(DataSetIterator iterator, Map variableEvals){
Map map = new HashMap<>();
Map> variableEvalsList = new HashMap<>();
for(String s : variableEvals.keySet()){
map.put(s, 0); //Only 1 possible output here with DataSetIterator
variableEvalsList.put(s, Collections.singletonList(variableEvals.get(s)));
}
evaluate(new MultiDataSetIteratorAdapter(iterator), variableEvalsList, map);
}
/**
* Evaluation for multiple output networks - one ore more
* See {@link #evaluate(MultiDataSetIterator, Map, Map)}
*/
public void evaluateMultiple(DataSetIterator iterator, Map> variableEvals){
Map map = new HashMap<>();
for(String s : variableEvals.keySet()){
map.put(s, 0); //Only 1 possible output here with DataSetIterator
}
evaluate(new MultiDataSetIteratorAdapter(iterator), variableEvals, map);
}
/**
* Evaluate the performance of a single variable's prediction.
* For example, if the variable to evaluatate was called "softmax" you would use:
*
* {@code Evaluation e = new Evaluation();
* sameDiff.evaluate(iterator, "softmax", e);}
*
*
* @param iterator Iterator as source of data to evaluate
* @param outputVariable The variable to evaluate
* @param labelIndex The index of the target variable's labels in the iterator
* @param evaluations The evaluations to perform
*/
public void evaluate(MultiDataSetIterator iterator, String outputVariable, int labelIndex, IEvaluation... evaluations) {
Preconditions.checkArgument(evaluations != null && evaluations.length > 0, "No evaluations were passed to the evaluate method");
evaluate(iterator, Collections.singletonMap(outputVariable, Arrays.asList(evaluations)),
Collections.singletonMap(outputVariable, labelIndex));
}
/**
* Perform evaluation using classes such as {@link org.nd4j.evaluation.classification.Evaluation} for classifier outputs
* and {@link org.nd4j.evaluation.regression.RegressionEvaluation} for regression outputs.
*
* Example: classifier evaluation
* Predictions variable name: "softmaxOutput"
* Evaluations to perform: {@link org.nd4j.evaluation.classification.Evaluation}
* Data: single input, single output MultiDataSets
* Code:
*
* {@code
* MultiDataSetIterator data = ...
* Map> evals = Collections.singletonMap("softmaxOutput",Collections.singletonList(new Evaluation()));
* Map labelMapping = Collections.singletonMap("softmaxOutput",0); //Compare: "softmaxOutput" vs. MultiDataSet.getLabels(0)
* }
*
*
* @param iterator The iterator - the source of the data for evaluation
* @param variableEvals The evaluations to perform. Key: the name of the variable. Value: the evaluations to perform
* @param predictionLabelMapping The output/label mapping. Key: the name of the variable.
*/
public void evaluate(MultiDataSetIterator iterator, Map> variableEvals, Map predictionLabelMapping){
Preconditions.checkState(trainingConfig != null, "Training config has not been set");
Preconditions.checkState(variableEvals.keySet().equals(predictionLabelMapping.keySet()), "Keysets for variable evaluations" +
" and for the prediction label mapping must be equal. Keys for variables to evaluate: %s vs. keys for label mapping: %s", variableEvals.keySet(), predictionLabelMapping.keySet());
if(!iterator.hasNext() && iterator.resetSupported())
iterator.reset();
List reqVars = new ArrayList<>(variableEvals.keySet());
while(iterator.hasNext()){
MultiDataSet ds = iterator.next();
Map placeholderMap = toPlaceholderMap(ds);
Map m = exec(placeholderMap, reqVars);
for(Map.Entry> e : variableEvals.entrySet()){
INDArray prediction = m.get(e.getKey());
for(IEvaluation eval : e.getValue()){
//TODO masking, time series, etc
INDArray label = ds.getLabels(predictionLabelMapping.get(e.getKey()));
eval.eval(label, prediction);
}
}
}
}
/**
* Do inference on a network with a single input.
* For example, if the variable to infer was called "softmax" you would use:
*
* {@code
* sameDiff.output(iterator, "softmax");}
*
*
* @param dataSet The data to evaluate
* @param outputs The variables to evaluate
*/
public Map output(DataSet dataSet, String... outputs){
return output(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), outputs).get(0);
}
/**
* Do inference on a network with a single input.
* For example, if the variable to infer was called "softmax" you would use:
*
* {@code
* sameDiff.output(iterator, "softmax");}
*
*
* @param iterator Iterator as source of data to evaluate
* @param outputs The variables to evaluate
*/
public List