
org.nd4j.autodiff.samediff.SameDiff Maven / Gradle / Ivy
package org.nd4j.autodiff.samediff;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import com.google.common.primitives.Ints;
import com.google.flatbuffers.FlatBufferBuilder;
import com.rits.cloning.Cloner;
import com.rits.cloning.IFastCloner;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.BytePointer;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
import org.nd4j.autodiff.functions.FunctionProperties;
import org.nd4j.autodiff.samediff.flow.FlowPath;
import org.nd4j.autodiff.util.cloner.DataBufferFastCloner;
import org.nd4j.autodiff.util.cloner.INDArrayFastCloner;
import org.nd4j.base.Preconditions;
import org.nd4j.graph.*;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.*;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.distances.CosineSimilarity;
import org.nd4j.linalg.api.ops.impl.accum.distances.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.accum.distances.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.controlflow.If;
import org.nd4j.linalg.api.ops.impl.controlflow.While;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.*;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUConfiguration;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.collection.IntArrayKeyMap;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.impl.*;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ConstantInitScheme;
import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import java.io.*;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
/**
* SameDiff is the
* entrypoint for
* nd4j's autodiff.
*
* You define a graph symbolically.
*
* That graph accumulates operations.
*
* In order to execute the graph, you run
* {@link #exec()} to get all the operations
* {@link #exec(List)} for an already created set of ops
* {@link #execAndEndResult()} for the end result only
* {@link #execAndEndResult(List)} for a cached set of ops
*/
@AllArgsConstructor
@Builder
@Slf4j
public class SameDiff {
private Map incomingArgs;
private Map outgoingArgs;
private Map incomingArgsReverse;
private Map outgoingArgsReverse;
private Map permuteOrder;
private boolean shouldBootStrap = true;
private Set importedVarName;
//map a function's instance id to a base name, used for propagating variable names
//for output during import
private Map baseNameForFunctionInstanceId;
private DifferentialFunctionFactory functionFactory;
private Map variableMap;
private Map variableNameToShape;
//gradient information
private Map gradients;
private Map forwardVarForGrad;
private Map variableNameToArr;
//individual index for variable names
private Map> functionsArgsFor;
private Map> functionOutputFor;
// this entity holds runtime information for Switch/Merge/NextIteration etc stuff
private ThreadLocal localFlowPath = new ThreadLocal();
/**
* For import, many times we have variables
* that map to properties. Most common
* we will have an input to a function that is mapped to an ndarray.
* That ndarray is usually a scalar shape.
*
* That array with a scalar shape can be something like an axis.
*
* We often don't know that array's value till run time.
* This map stores variable names that we should resolve
* from samediff. We use the value of that array
* to update the properties.
*/
private Map> propertiesToResolve;
/**
* A map of own name to
* the properties of the function (things like execution axes etc)
* The valid values can be:
* int
* long
* INDArray
*/
private Map> propertiesForFunction;
private Map> placeHolderMap;
private Map placeHolderOriginalShapes;
private Set placeHolderVarNames;
private IdentityHashMap reverseArrayLookup;
private MemoryWorkspace workspace;
private Map sameDiffFunctionDefinitionMap;
private Map sameDiffFunctionInstances;
private Set placeHolderFunctions;
private static Cloner cloner = newCloner();
private static Map opMethods;
private Map functionInstancesById;
private Table fieldVariableResolutionMapping;
// flag, shows if graph was already registered with libnd4j
private transient AtomicBoolean wasRegistered = new AtomicBoolean(false);
//debug mode variables
@Getter
private boolean debugMode;
private Map opsForResult;
private boolean resolvedVariables = false;
@Getter
@Setter
boolean logExecution = true;
static {
opMethods = new HashMap<>();
Method[] methods = SameDiff.class.getDeclaredMethods();
for (Method method : methods) {
if (method.getReturnType().equals(SDVariable.class)) {
opMethods.put(method.getName(), method);
}
}
}
public static Cloner newCloner() {
Cloner cloner = new Cloner();
//Implement custom cloning for INDArrays (default can have problems with off-heap and pointers)
//Sadly: the cloner library does NOT support interfaces here, hence we need to use the actual classes
//cloner.registerFastCloner(INDArray.class, new INDArrayFastCloner()); //Does not work due to interface
IFastCloner fc = new INDArrayFastCloner();
cloner.registerFastCloner(Nd4j.getBackend().getNDArrayClass(), fc);
cloner.registerFastCloner(Nd4j.getBackend().getComplexNDArrayClass(), fc);
//Same thing with DataBuffers: off heap -> cloner library chokes on them, but need to know the concrete
// buffer classes, not just the interface
IFastCloner fc2 = new DataBufferFastCloner();
DataBufferFactory d = Nd4j.getDataBufferFactory();
doReg(cloner, fc2, d.intBufferClass());
doReg(cloner, fc2, d.longBufferClass());
doReg(cloner, fc2, d.halfBufferClass());
doReg(cloner, fc2, d.floatBufferClass());
doReg(cloner, fc2, d.doubleBufferClass());
doReg(cloner, fc2, CompressedDataBuffer.class);
return cloner;
}
private static void doReg(Cloner cl, IFastCloner fc, Class> c) {
if (c != null)
cl.registerFastCloner(c, fc);
}
/**
* Update the opName for the variable
* with the given vertex id
*
* @param varName the vertex id to update
* @param withName thew new opName
*/
public void updateVariableName(String varName, String withName) {
SDVariable oldVarNameRef = getVariable(varName);
variableMap.remove(oldVarNameRef.getVarName());
val oldVarName = varName;
oldVarNameRef.setVarName(withName);
variableMap.put(withName, oldVarNameRef);
for (val reverseValues : outgoingArgsReverse.entrySet()) {
for (int i = 0; i < reverseValues.getValue().length; i++) {
if (reverseValues.getValue()[i].equals(oldVarName)) {
reverseValues.getValue()[i] = withName;
}
}
}
for (val reverseValues : incomingArgsReverse.entrySet()) {
for (int i = 0; i < reverseValues.getValue().length; i++) {
if (reverseValues.getValue()[i].equals(oldVarName)) {
reverseValues.getValue()[i] = withName;
}
}
}
if (variableNameToArr.containsKey(oldVarName)) {
val arr = variableNameToArr.remove(oldVarName);
variableNameToArr.put(withName, arr);
}
if (variableNameToShape.containsKey(oldVarName)) {
val shape = variableNameToShape.remove(oldVarName);
variableNameToShape.put(withName, shape);
}
if (gradients.containsKey(oldVarName)) {
val grad = gradients.remove(oldVarName);
gradients.put(withName, grad);
}
if (forwardVarForGrad.containsKey(oldVarName)) {
val forwardGrad = forwardVarForGrad.remove(oldVarName);
forwardVarForGrad.put(withName, forwardGrad);
}
if (placeHolderMap.containsKey(oldVarName)) {
val placeholders = placeHolderMap.remove(oldVarName);
placeHolderMap.put(withName, placeholders);
}
if (functionsArgsFor.containsKey(oldVarName)) {
val funcs = functionsArgsFor.remove(oldVarName);
for (val func : funcs) {
if (func instanceof BaseOp) {
BaseOp baseOp = (BaseOp) func;
if (baseOp.getXVertexId() != null && baseOp.getXVertexId().equals(oldVarName)) {
baseOp.setXVertexId(withName);
}
if (baseOp.getYVertexId() != null && baseOp.getYVertexId().equals(oldVarName)) {
baseOp.setYVertexId(withName);
}
if (baseOp.getZVertexId() != null && baseOp.getZVertexId().equals(oldVarName)) {
baseOp.setZVertexId(withName);
}
}
}
functionsArgsFor.put(withName, funcs);
}
if (functionOutputFor.containsKey(oldVarName)) {
val funcs = functionOutputFor.remove(oldVarName);
for (val func : funcs) {
if (func instanceof BaseOp) {
BaseOp baseOp = (BaseOp) func;
if (baseOp.getXVertexId() != null && baseOp.getXVertexId().equals(oldVarName)) {
baseOp.setXVertexId(withName);
}
if (baseOp.getYVertexId() != null && baseOp.getYVertexId().equals(oldVarName)) {
baseOp.setYVertexId(withName);
}
if (baseOp.getZVertexId() != null && baseOp.getZVertexId().equals(oldVarName)) {
baseOp.setZVertexId(withName);
}
}
}
functionOutputFor.put(withName, funcs);
}
variableMap.remove(oldVarName);
}
/**
* Clears debugging state
* and disables debug mode.
*/
public SameDiff disableDebugging() {
debugMode = false;
return this;
}
/**
* Enables tracing of graphs automatically.
*/
public SameDiff enableDebugMode() {
debugMode = true;
return this;
}
/**
* Returns this samediff instance's
* {@link DifferentialFunctionFactory}
*
* @return
*/
public DifferentialFunctionFactory f() {
return functionFactory;
}
/**
* @param sameDiff
* @return
*/
public SDVariable invokeGraphOn(SameDiff sameDiff) {
//map the new vertices on to the old ones
Map thisVertexIdToNew = new HashMap<>();
int idx = 1;
for (val var : variables()) {
val clone = cloner.deepCloneDontCloneInstances(var, var.getSameDiff());
val newVar = sameDiff.var(clone);
if (var.getArr() != null) {
sameDiff.associateArrayWithVariable(var.getArr(), newVar);
}
thisVertexIdToNew.put(idx, idx);
clone.setSameDiff(sameDiff);
idx++;
}
val newFunctions = new LinkedHashMap();
for (DifferentialFunction function : functionInstancesById.values()) {
if (function instanceof SDVariable) {
continue;
}
DifferentialFunction clone = cloner.deepCloneDontCloneInstances(
function,
function.getSameDiff());
clone.setSameDiff(sameDiff);
clone.setOwnName(function.getOwnName());
if (sameDiff.functionExists(function.getOwnName()))
sameDiff.putFunctionForId(function.getOwnName(), function);
newFunctions.put(function.getOwnName(), clone);
val argsForFunction = function.args();
val outputsForFunction = function.outputVariables();
//note that these have the same variable names
sameDiff.addArgsFor(argsForFunction, clone);
sameDiff.addOutgoingFor(outputsForFunction, function);
for (val arg : clone.args()) {
arg.setSameDiff(sameDiff);
}
for (val output : clone.outputVariables()) {
output.setSameDiff(sameDiff);
}
sameDiff.functionInstancesById.put(function.getOwnName(), function);
}
for (val reverseArrayEntry : reverseArrayLookup.entrySet()) {
sameDiff.reverseArrayLookup.put(reverseArrayEntry.getKey(), sameDiff.getVariable(reverseArrayEntry.getValue().getVarName()));
}
return sameDiff.variables().get(sameDiff.variables().size() - 1);
}
/**
* Returns true if the given function id exists
*
* @param id the function id to test for
* @return true if the function id exists, false otherwise
*/
public boolean functionExists(String id) {
return functionInstancesById.containsKey(id);
}
/**
* Get the function by the {@link DifferentialFunction#getOwnName()}
*
* @param id the id of the function
* @return the function for the given id if it exists
*/
public DifferentialFunction getFunctionById(String id) {
if (!functionInstancesById.containsKey(id)) {
throw new ND4JIllegalStateException("No function with id " + id + " found!");
}
return functionInstancesById.get(id);
}
/**
* Put the function for id
*
* @param id the id
* @param function the function
*/
public void putFunctionForId(String id, DifferentialFunction function) {
if (functionInstancesById.containsKey(id)) {
throw new ND4JIllegalStateException("Function by id already exists!");
} else if (function instanceof SDVariable) {
throw new ND4JIllegalStateException("Function must not be a variable!");
}
functionInstancesById.put(id, function);
}
/**
* Returns the inputs for the given function
*
* @param function the function to get the
* inputs for
* @return the input ids for a given function
*/
public String[] getInputsForFunction(DifferentialFunction function) {
if (!incomingArgsReverse.containsKey(function.getOwnName()))
throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName());
return incomingArgsReverse.get(function.getOwnName());
}
/**
* Returns the outputs for the given function
*
* @param function the function to get the
* inputs for
* @return the outputs ids for a given function
*/
public String[] getOutputsForFunction(DifferentialFunction function) {
return outgoingArgsReverse.get(function.getOwnName());
}
/**
* Get the output variables given a set of ids
* from {@link #getOutputsForFunction(DifferentialFunction)}
*
* @param function the function reference to get the id for
* @return the output variables for the given function
*/
public SDVariable[] getOutputVariablesForFunction(DifferentialFunction function) {
val inputs = getOutputsForFunction(function);
if (inputs == null) {
throw new ND4JIllegalStateException("No inputs found for function " + function);
}
val vars = new SDVariable[inputs.length];
for (int i = 0; i < inputs.length; i++) {
vars[i] = getVariable(inputs[i]);
}
return vars;
}
/**
* Get the input variables given a set of ids
* from {@link #getInputVariablesForFunction(DifferentialFunction)}
*
* @param function the function reference to get the id for
* @return the output variables for the given function
*/
public SDVariable[] getInputVariablesForFunction(DifferentialFunction function) {
val inputs = getInputsForFunction(function);
if (inputs == null) {
throw new ND4JIllegalStateException("No inputs found for function " + function);
}
val vars = new SDVariable[inputs.length];
for (int i = 0; i < inputs.length; i++) {
vars[i] = getVariable(inputs[i]);
if (vars[i] == null) {
throw new ND4JIllegalStateException("Found null variable at index " + i);
}
}
return vars;
}
/**
* Update the ndarray for the given vertex id.
*
* @param varName
* @param arr
* @throws {@link ND4JIllegalStateException} when the array does not exist.
*/
public void updateArrayForVarName(String varName, INDArray arr) {
if (!variableNameToArr.containsKey(varName)) {
throw new ND4JIllegalStateException("Array for " + varName + " does not exist. Please use putArrayForVertexId instead.");
}
variableNameToArr.put(varName, arr);
reverseArrayLookup.put(arr, getVariable(varName));
}
/**
* Adds an ndarray for a given vertex id.
* Use {@link #updateArrayForVarName(String, INDArray)}
* if the array already exists.
*
* @param varName the vertex id to add
* @param arr the array to add
* @throws {@link ND4JIllegalStateException} when the array already exists.
*/
public void putArrayForVarName(String varName, INDArray arr) {
if (varName == null)
throw new ND4JIllegalStateException("No null names allowed!");
if (variableNameToArr.containsKey(varName)) {
throw new ND4JIllegalStateException("Array for " + varName + " already exists!");
}
variableNameToArr.put(varName, arr);
}
/**
* Get the shape for the given vertex id.
* Note that if an array is defined, it will use that shape instead.
*
* A shape *and* an array should not be defined at the same time.
* This wastes memory. The internal map used for tracking shapes for particular
* vertex ids should also delete redundant shapes stored to avoid redundant sources of information.
*
* @param varName the vertex id to get the shape for
* @return the shape for the given vertex if if any.
*/
public int[] getShapeForVarName(String varName) {
if (variableNameToArr.containsKey(varName)) {
return variableNameToArr.get(varName).shape();
}
return variableNameToShape.get(varName);
}
/**
* Update a vertex id with the given shape.
* Note that you should use {@link #putShapeForVarName(String, int[])}
* if you want to add a new shape.
* Update is meant to be an in place replacement
* of the shape for the vertex id *only*.
*
* @param varName the vertex id to associate
* @param shape the shape to associate with
*/
public void updateShapeForVarName(String varName, int[] shape) {
if (shape == null) {
throw new ND4JIllegalStateException("Null shapes not allowed!");
}
if (variableNameToArr.containsKey(varName) && !Arrays.equals(variableNameToArr.get(varName).shape(), shape)) {
throw new ND4JIllegalStateException("Already found an existing array!");
}
for (int i = 0; i < shape.length; i++) {
if (shape[i] < 1) {
addAsPlaceHolder(varName);
placeHolderOriginalShapes.put(varName, shape);
return;
}
}
variableNameToShape.put(varName, shape);
}
/**
* Associate a vertex id with the given shape.
*
* @param varName the vertex id to associate
* @param shape the shape to associate with
*/
public void putShapeForVarName(String varName, int[] shape) {
if (shape == null) {
throw new ND4JIllegalStateException("Shape must not be null!");
}
if (variableNameToShape.containsKey(varName)) {
throw new ND4JIllegalStateException("Shape for " + varName + " already exists!");
}
for (int i = 0; i < shape.length; i++) {
if (shape[i] < 1) {
addAsPlaceHolder(varName);
placeHolderOriginalShapes.put(varName, shape);
return;
}
}
variableNameToShape.put(varName, shape);
}
/**
* Returns true if the given vertex id
* and shape already exist.
*
* @param varName the vertex id
* @return true if the ndarray and vertex id already exist
*/
public boolean shapeAlreadyExistsForVarName(String varName) {
return variableNameToShape.containsKey(varName) || arrayAlreadyExistsForVarName(varName);
}
/**
* Returns true if the given vertex id
* and {@link INDArray} already exist.
*
* @param varName the vertex id
* @return true if the ndarray and vertex id already exist
*/
public boolean arrayAlreadyExistsForVarName(String varName) {
return variableNameToArr.containsKey(varName);
}
/**
* Get an {@link INDArray}
* for a given vertex id
*
* @param varName
* @return
*/
public INDArray getArrForVarName(String varName) {
return variableNameToArr.get(varName);
}
/**
* Associate the array with the given variable.
*
* @param arr the array to get the variable for
* @param variable the variable to associate
*/
public void associateArrayWithVariable(INDArray arr, SDVariable variable) {
if (variable == null) {
throw new ND4JIllegalArgumentException("Variable must not be null!");
}
if (arr == null) {
throw new ND4JIllegalArgumentException("Array must not be null");
}
reverseArrayLookup.put(arr, variable);
variableNameToArr.put(variable.getVarName(), arr);
if (!shapeAlreadyExistsForVarName(variable.getVarName()))
putShapeForVarName(variable.getVarName(), arr.shape());
else {
updateShapeForVarName(variable.getVarName(), arr.shape());
}
}
/**
* Associate a {@link SameDiff}
* namespace as a sub function.
*
* @param name the opName of the function
* @param nameSpace the namespace
*/
public void putSubFunction(String name, SameDiff nameSpace) {
if (sameDiffFunctionInstances.containsKey(name) && sameDiffFunctionInstances.get(name) != nameSpace) {
throw new ND4JIllegalStateException("Unable to replace samediff namespace. Please choose another opName");
}
sameDiffFunctionInstances.put(name, nameSpace);
}
/**
* Return the internal variable map
*
* @return
*/
public Map variableMap() {
return variableMap;
}
/**
* Invoke an op by opName
*
* @param op the op
* @param x the first input
* @param y the second input
* @return the result variable
*/
public SDVariable invoke(Op op, SDVariable x, SDVariable y) {
if (!opMethods.containsKey(op.opName())) {
throw new ND4JIllegalStateException("Illegal method opName " + op.opName());
}
if (x != null && y != null) {
try {
return (SDVariable) opMethods.get(op.opName()).invoke(this, x, y);
} catch (Exception e) {
}
} else {
try {
return (SDVariable) opMethods.get(op.opName()).invoke(this, x);
} catch (Exception e) {
}
}
throw new ND4JIllegalStateException("Illegal method opName " + op.opName());
}
/**
* Get an {@link SDVariable}
* for an array reference.
* Internally samediff associates array references
* with variables. This will typically be a shortcut
* for the array associated with {@link SDVariable#getArr()}
*
* @param arr the array reference
* @return the variable if one exists
*/
public SDVariable getVariableForArray(INDArray arr) {
return reverseArrayLookup.get(arr);
}
/**
* The set of defined function names
*
* @return
*/
public Collection definedFunctionNames() {
return this.sameDiffFunctionInstances.keySet();
}
/**
* Returns the number of bytes
* for the graph
*
* @return
*/
public long memoryForGraph() {
return numElements() * DataTypeUtil.lengthForDtype(Nd4j.dataType());
}
/**
* Invoke an op by opName
*
* @param op the op
* @param x the first input
* @return the result variable
*/
public SDVariable invoke(Op op, SDVariable x) {
return invoke(op, x, null);
}
private SameDiff() {
functionFactory = new DifferentialFunctionFactory(this);
variableMap = new LinkedHashMap<>();
sameDiffFunctionDefinitionMap = new LinkedHashMap<>();
sameDiffFunctionInstances = new LinkedHashMap<>();
gradients = new LinkedHashMap<>();
forwardVarForGrad = new LinkedHashMap<>();
opsForResult = new IntArrayKeyMap<>();
reverseArrayLookup = new IdentityHashMap<>();
variableNameToArr = new LinkedHashMap<>();
variableNameToShape = new LinkedHashMap<>();
placeHolderMap = new LinkedHashMap<>();
placeHolderVarNames = new LinkedHashSet<>();
placeHolderOriginalShapes = new LinkedHashMap<>();
incomingArgs = new LinkedHashMap<>();
outgoingArgs = new LinkedHashMap<>();
incomingArgsReverse = new LinkedHashMap<>();
outgoingArgsReverse = new LinkedHashMap<>();
functionInstancesById = new LinkedHashMap<>();
placeHolderFunctions = new LinkedHashSet<>();
functionsArgsFor = new LinkedHashMap<>();
functionOutputFor = new LinkedHashMap<>();
baseNameForFunctionInstanceId = new LinkedHashMap<>();
importedVarName = new LinkedHashSet<>();
permuteOrder = new LinkedHashMap<>();
propertiesToResolve = new LinkedHashMap<>();
propertiesForFunction = new LinkedHashMap<>();
fieldVariableResolutionMapping = HashBasedTable.create();
}
/**
* Adds a property that needs to be resolve for later.
* These variables are typically values that are arrays
* that are named but have an unknown value till execution time.
*
* This is very common for model import.
*
* @param forFunction the function to add the property to resolve for
* @param arrayName the array name
*/
public void addPropertyToResolve(DifferentialFunction forFunction, String arrayName) {
if (!propertiesToResolve.containsKey(forFunction.getOwnName())) {
List newVal = new ArrayList<>();
newVal.add(arrayName);
propertiesToResolve.put(forFunction.getOwnName(), newVal);
} else {
List newVal = propertiesToResolve.get(forFunction.getOwnName());
newVal.add(arrayName);
}
}
/**
* Return the properties to resolve for the given function.
* This is typically used right before execution in model import in
* {@link DifferentialFunction#resolvePropertiesFromSameDiffBeforeExecution()}
*
* @param function the function get the properties to resolve for
* @return the properties to resolve for the given function
*/
public List propertiesToResolveForFunction(DifferentialFunction function) {
if (!propertiesToResolve.containsKey(function.getOwnName()))
return Collections.emptyList();
return propertiesToResolve.get(function.getOwnName());
}
/**
* Returns true if the given function
* has ndarray properties to resolve.
*
* @param function the function to check
* @return true if the function has yet to be resolved properties
*/
public boolean hasPropertiesToResolve(DifferentialFunction function) {
return propertiesToResolve.containsKey(function.getOwnName());
}
/**
* Get the property for a given function
*
* @param functionInstance the function to get the
* property for
* @param propertyName the name of the property to get
* @param the inferred return type
* @return the property for the given function
*/
public T getPropertyForFunction(DifferentialFunction functionInstance, String propertyName) {
if (!propertiesForFunction.containsKey(functionInstance.getOwnName())) {
return null;
} else {
val map = propertiesForFunction.get(functionInstance.getOwnName());
return (T) map.get(propertyName);
}
}
/**
* Add a property for the given function
*
* @param functionFor the function add a property for
* @param propertyName the property name
* @param property the property value
*/
public void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, INDArray property) {
addPropertyForFunction(functionFor, propertyName, (Object) property);
}
/**
* Add a property for the given function
*
* @param functionFor the function to add the property for
* @param propertyName the name of the property to add the value for
* @param property the property value to add
*/
public void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, long property) {
addPropertyForFunction(functionFor, propertyName, (Object) property);
}
private void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, Object propertyValue) {
if (!propertiesForFunction.containsKey(functionFor.getOwnName())) {
Map fields = new LinkedHashMap<>();
fields.put(propertyName, propertyValue);
propertiesForFunction.put(functionFor.getOwnName(), fields);
} else {
val fieldMap = propertiesForFunction.get(functionFor.getOwnName());
if (fieldMap.containsKey(propertyName)) {
throw new ND4JIllegalStateException("Attempting to override property " + propertyName);
}
fieldMap.put(propertyName, propertyValue);
}
}
/**
* Adds a field name -> variable name
* mapping for a given function.
* This is used for model import
* where there is an unresolved variable
* at the time of calling any
* {@link org.nd4j.imports.graphmapper.GraphMapper#importGraph(File)}
* .
*
* This data structure is typically accessed during {@link DifferentialFunction#resolvePropertiesFromSameDiffBeforeExecution()}
*
* When a function attempts to resolve variables right before execution, there needs to be a way of knowing
* which variable in a samediff graph should map to a function's particular field name
*
* @param function the function to map
* @param fieldName the field name for the function to map
* @param varName the variable name of the array to get from samediff
*/
public void addVariableMappingForField(DifferentialFunction function, String fieldName, String varName) {
fieldVariableResolutionMapping.put(function.getOwnName(), fieldName, varName);
}
/**
* Get the variable name to use
* for resolving a given field
* for a given function during import time.
* This method is u sed during {@link DifferentialFunction#resolvePropertiesFromSameDiffBeforeExecution()}
*
* @param function the function to get the variable name for
* @param fieldName the field name to resolve for
* @return the resolve variable name if any
*/
public String getVarNameForFieldAndFunction(DifferentialFunction function, String fieldName) {
return fieldVariableResolutionMapping.get(function.getOwnName(), fieldName);
}
/**
* Returns true if the variable name is imported
*
* @param variableName the imported variable name
* @return true if the name is imported, false otherwise
*/
public boolean isImportVariable(String variableName) {
return importedVarName.contains(variableName);
}
/**
* Marks a variable name as imported.
* This is used in conjunction with model
* import to ensure immutability
* when referencing graph variables
* mapped from an external source.
*
* @param varName the var name to add.
*/
public void addVarNameForImport(String varName) {
importedVarName.add(varName);
}
/**
* Sets a base name for the function id.
* This is used for when calling {@link #generateOutputVariableForOp(DifferentialFunction, String)}
* for ensuring original names for model import map to current samediff names
* when names are generated.
*
* @param baseName the base name to add
* @param function the function to declare a base name for.
*/
public void setBaseNameForFunctionInstanceId(String baseName, DifferentialFunction function) {
baseNameForFunctionInstanceId.put(function.getOwnName(), baseName);
}
/**
* Returns the base name for the given function
* if any (may return null)
*
* @param function the function to get the base name for
* @return the base name for the given function (if any) based
* on the function's instance id.
*/
public String getBaseNameForFunction(DifferentialFunction function) {
return baseNameForFunctionInstanceId.get(function.getOwnName());
}
/**
* Attempts to insert the {@link DifferentialFunction}
* reference in to this {@link SameDiff}
* instance.
* If the given array field with the given
* index already exists, it will do a reference
* check to ensure that the 2 array fields are the same.
*
* If not, an exception is thrown.
* If the instances are the same (by semantics, not reference)
* then it will just return the original instance.
* This is to ensure that instances that are created are unique
* and reference checked.
*
* @param function the array field to attempt to create
* @return
*/
public X setupFunction(X function) {
Preconditions.checkNotNull(function, "Passed in function must not be null!");
if (function instanceof SDVariable) {
if (function.getSameDiff() != this) {
function.setSameDiff(this);
}
return function;
}
return function;
}
/**
* Adds outgoing args to the graph
*
* @param variables
* @param function
*/
public void addOutgoingFor(SDVariable[] variables, DifferentialFunction function) {
String[] varNames = new String[variables.length];
for (int i = 0; i < varNames.length; i++) {
varNames[i] = variables[i].getVarName();
}
addOutgoingFor(varNames, function);
}
/**
* Adds outgoing arguments to the graph.
* Also checks for input arguments
* and updates the graph adding an appropriate edge
* when the full graph is declared.
*
* @param varNames
* @param function
*/
public void addOutgoingFor(String[] varNames, DifferentialFunction function) {
if (function.getOwnName() == null)
throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
if (outgoingArgsReverse.containsKey(function.getOwnName())) {
throw new ND4JIllegalStateException("Outgoing arguments already declared for " + function);
}
if (varNames == null)
throw new ND4JIllegalStateException("Var names can not be null!");
for (int i = 0; i < varNames.length; i++) {
if (varNames[i] == null)
throw new ND4JIllegalStateException("Variable name elements can not be null!");
}
outgoingArgsReverse.put(function.getOwnName(), varNames);
outgoingArgs.put(varNames, function);
for (val resultName : varNames) {
List funcs = functionOutputFor.get(resultName);
if (funcs == null) {
funcs = new ArrayList<>();
functionOutputFor.put(resultName, funcs);
}
funcs.add(function);
}
}
/**
* Adds incoming args to the graph
*
* @param variables
* @param function
*/
public void addArgsFor(String[] variables, DifferentialFunction function) {
if (function.getOwnName() == null)
throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
//double check if function contains placeholder args
for (val varName : variables) {
if (isPlaceHolder(varName)) {
placeHolderFunctions.add(function.getOwnName());
}
}
incomingArgs.put(variables, function);
incomingArgsReverse.put(function.getOwnName(), variables);
for (val variableName : variables) {
List funcs = functionsArgsFor.get(variableName);
if (funcs == null) {
funcs = new ArrayList<>();
functionsArgsFor.put(variableName, funcs);
}
funcs.add(function);
}
}
/**
* Adds incoming args to the graph
*
* @param variables
* @param function
*/
public void addArgsFor(SDVariable[] variables, DifferentialFunction function) {
String[] varNames = new String[variables.length];
for (int i = 0; i < varNames.length; i++) {
if (variables[i] == null)
throw new ND4JIllegalStateException("Found null variable at index " + i);
varNames[i] = variables[i].getVarName();
}
addArgsFor(varNames, function);
}
/**
* Returns true if this function already
* has defined arguments
*
* @param function the function to check
* @return true if the function has args false otherwise
*/
public boolean hasArgs(int[] function) {
return incomingArgs.containsKey(function);
}
/**
* Returns true if this function already
* has defined arguments
*
* @param function the function to check
* @return true if the function has args false otherwise
*/
public boolean hasArgs(DifferentialFunction function) {
val vertexIdArgs = incomingArgsReverse.get(function.getOwnName());
if (vertexIdArgs != null) {
val args = incomingArgs.get(vertexIdArgs);
if (args != null)
return true;
}
return false;
}
public DifferentialFunction[] functions() {
val ret = functionInstancesById.values();
return ret.toArray(new DifferentialFunction[ret.size()]);
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + (variableMap != null ? variableMap.hashCode() : 0);
return result;
}
/**
* @param originalSameDiff
* @return
*/
public static SameDiff create(SameDiff originalSameDiff) {
SameDiff ret = SameDiff.builder()
.variableMap(originalSameDiff.variableMap)
.sameDiffFunctionInstances(originalSameDiff.sameDiffFunctionInstances)
.build();
//ensuring proper sameDiff reference
DifferentialFunctionFactory differentialFunctionFactory =
new
DifferentialFunctionFactory(ret);
ret.functionFactory = differentialFunctionFactory;
return ret;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SameDiff sameDiff = (SameDiff) o;
if (variableMap != null ? !variableMap.equals(sameDiff.variableMap) : sameDiff.variableMap != null)
return false;
if (sameDiffFunctionDefinitionMap != null ? !sameDiffFunctionDefinitionMap.equals(sameDiff.sameDiffFunctionDefinitionMap) : sameDiff.sameDiffFunctionDefinitionMap != null)
return false;
return sameDiffFunctionInstances != null ? sameDiffFunctionInstances.equals(sameDiff.sameDiffFunctionInstances) : sameDiff.sameDiffFunctionInstances == null;
}
/**
* @return
*/
public static SameDiff create() {
return new SameDiff();
}
/**
* Evaluate the given inputs
* based on the current graph
*
* @param inputs the inputs to evaluate
* @return
*/
public INDArray[] eval(Map inputs) {
SameDiff execPipeline = dup();
List opExecAction = execPipeline.exec().getRight();
if (opExecAction.isEmpty())
throw new IllegalStateException("No ops found to execute.");
INDArray[] ret = new INDArray[opExecAction.size()];
for (int i = 0; i < ret.length; i++) {
val varName = opExecAction.get(i).outputVariables()[0].getVarName();
ret[i] = execPipeline.getArrForVarName(varName);
}
return ret;
}
/**
* @return
*/
public SameDiff dup() {
Cloner cloner = newCloner();
return cloner.deepClone(this);
}
/**
* @return
*/
public long numElements() {
long ret = 0;
for (SDVariable variable : variables()) {
ret += ArrayUtil.prod(variable.getShape());
}
return ret;
}
private void initWorkspace() {
workspace = Nd4j.getWorkspaceManager().createNewWorkspace(
WorkspaceConfiguration.builder()
.initialSize(memoryForGraph())
.policyAllocation(AllocationPolicy.OVERALLOCATE)
.policyLearning(LearningPolicy.FIRST_LOOP)
.build());
Nd4j.getWorkspaceManager().setWorkspaceForCurrentThread(workspace);
}
/**
* The list of available
* variables in the graph
*
* @return
*/
public List variables() {
return new ArrayList<>(variableMap.values());
}
/**
* Variable initialization
* with 1.0
*
* @param name the opName of the variable
* @param shape the shape of the array to be created
* @return the created variable
*/
public SDVariable one(String name, int[] shape) {
return var(name, shape, new ConstantInitScheme('f', 1.0));
}
/**
* Return a variable of all 1s, with the same shape as the input
*
* @param input
* @return
*/
public SDVariable onesLike(SDVariable input) {
return onesLike(null, input);
}
/**
* Return a variable of all 1s, with the same shape as the input
*
* @param input
* @return
*/
public SDVariable onesLike(String name, SDVariable input) {
return f().onesLike(name, input);
}
/**
* Variable initialization
* with 0.0
*
* @param name the opName of the variable
* @param shape the shape of the array to be created
* @return the created variable
*/
public SDVariable zero(String name, int[] shape) {
return var(name, shape, new ZeroInitScheme());
}
/**
* Return a variable of all 0s with the same shape as the input
*
* @param input
* @return
*/
public SDVariable zerosLike(SDVariable input) {
return zerosLike(null, input);
}
/**
* Return a variable of all 0s, with the same shape as the input
*
* @param input
* @return
*/
public SDVariable zerosLike(String name, SDVariable input) {
return f().zerosLike(name, input);
}
/**
* Variable initialization
* with a specified {@link WeightInitScheme}
*
* @param name the opName of the variable
* @param shape the shape of the array to be created
* @param weightInitScheme the weight init scheme
* @return the created variable
*/
public SDVariable var(String name, int[] shape, WeightInitScheme weightInitScheme) {
if (variableMap.containsKey(name) && variableMap.get(name).getArr() != null)
return variableMap.get(name);
if (name == null || name.length() < 1)
throw new IllegalArgumentException("Name for variable must be defined");
if (workspace == null)
initWorkspace();
SDVariable ret = SDVariable.builder()
.sameDiff(this)
.shape(shape).weightInitScheme(weightInitScheme)
.varName(name)
.build();
addVariable(ret);
variableMap.put(name, ret);
return ret;
}
/**
* Creates a {@link SDVariable}
* with the given shape
* and a depth of 0.
*
* @param name the opName of the variable
* @param shape the shape of the variable
* @return the created variable
*/
public SDVariable var(String name, int[] shape) {
return var(name, shape, new ZeroInitScheme());
}
/**
* Initialize a {@link SDVariable}
* reference tying this variable to this
* samediff instance.
*
* {@link NDArraySupplierInitScheme} is used
* to ensure that if the array is allocated anywhere
* and {@link SameDiff} instance to exist as a copy of the variable.
*
* @param arr
* @return
*/
public SDVariable var(final SDVariable arr) {
if (variableMap.containsKey(arr.getVarName()) && variableMap.get(arr.getVarName()).getArr() != null)
return variableMap.get(arr.getVarName());
if (arr.getVarName() == null || arr.getVarName().length() < 1)
throw new IllegalArgumentException("Name for variable must be defined");
if (arr == null)
throw new IllegalArgumentException("Array for " + arr.getVarName() + " must not be null");
if (workspace == null)
initWorkspace();
final SDVariable ret = SDVariable.builder()
.sameDiff(this)
.shape(arr.getShape())
.varName(arr.getVarName())
.weightInitScheme(new NDArraySupplierInitScheme(new NDArraySupplierInitScheme.NDArraySupplier() {
@Override
public INDArray getArr() {
/**
* Pre allocate the array if it doesn't already exist.
* The reason we do this is to avoid race conditions with
* {@link #allocate()}
*/
if (arr.getArr() == null) {
INDArray retArr = arr.getWeightInitScheme().create(arr.getShape());
associateArrayWithVariable(retArr, arr);
}
return arr.getArr();
}
}))
.build();
variableMap.put(arr.getVarName(), ret);
return ret;
}
/**
* Remove an argument for a function. Note that if this function
* does not contain the argument, it will just be a no op.
*
* @param varName the variable name to remove
* @param function the function to remove the argument from
*/
public void removeArgFromFunction(String varName, DifferentialFunction function) {
val args = function.args();
for (int i = 0; i < args.length; i++) {
if (args[i].getVarName().equals(varName)) {
/**
* Since we are removing the variable reference
* from the arguments we need to update both
* the reverse and forward arguments.
*/
val reverseArgs = incomingArgsReverse.get(function.getOwnName());
incomingArgs.remove(reverseArgs);
incomingArgsReverse.remove(function.getOwnName());
val newArgs = new ArrayList(args.length - 1);
for (int arg = 0; arg < args.length; arg++) {
if (!reverseArgs[arg].equals(varName)) {
newArgs.add(reverseArgs[arg]);
}
}
val newArgsArr = newArgs.toArray(new String[newArgs.size()]);
incomingArgs.put(newArgsArr, function);
incomingArgsReverse.put(function.getOwnName(), newArgsArr);
//no further need to scan
break;
}
}
}
/**
* @param name
* @param arr
* @return
*/
public SDVariable var(String name, INDArray arr) {
if (variableMap.containsKey(name) && variableMap.get(name).getArr() != null)
return variableMap.get(name);
if (name == null || name.length() < 1)
throw new IllegalArgumentException("Name for variable must be defined");
if (arr == null)
throw new IllegalArgumentException("Array for " + name + " must not be null");
if (workspace == null)
initWorkspace();
val arrRef = arr.migrate();
SDVariable ret = SDVariable.builder()
.sameDiff(this)
.shape(arr.shape())
.varName(name)
.weightInitScheme(new NDArraySupplierInitScheme(new NDArraySupplierInitScheme.NDArraySupplier() {
@Override
public INDArray getArr() {
return arrRef;
}
}))
.build();
associateArrayWithVariable(arr, ret);
if (ArrayUtil.prod(arr.shape()) == 1)
ret.setScalarValue(arr.getDouble(0));
addVariable(ret);
if (getShapeForVarName(name) == null)
putShapeForVarName(name, arr.shape());
//ensure there is a reference to the array in the integer index
//this is used later for op creation
reverseArrayLookup.put(arr, ret);
variableMap.put(name, ret);
return ret;
}
/**
* Get the variable based on the opName
*
* @param name the opName of the variable
* @return the variabel instance if there is one
*/
public SDVariable getVariable(String name) {
return variableMap.get(name);
}
/**
* Get the gradient for the given vertex id
*
* @param varName the vertex id
* @return the gradient for this variable or null
*/
public SDVariable getGradForVariable(String varName) {
return gradients.get(varName);
}
/**
* Assign a vertex id
* to a gradient
*
* @param variableName the vertex id
* to assign
* @param variable the variable
*/
public void setGradientForVariableName(String variableName, SDVariable variable) {
if (variable == null) {
throw new ND4JIllegalStateException("Unable to set null gradient for variable name " + variableName);
}
gradients.put(variableName, variable);
}
/**
* Get the forward variable for gradient
* based on the gradient's vertex id
*
* @param vertexId the vertex id
* @return the gradient for the variable or null
*/
public SDVariable getForwardVariableForVertexId(int vertexId) {
return forwardVarForGrad.get(vertexId);
}
/**
* @param varName
* @param forwardVariable
*/
public void setForwardVariableForVarName(String varName, SDVariable forwardVariable) {
forwardVarForGrad.put(varName, forwardVariable);
}
/**
* Gradient with respect
* to the given variable opName.
* Note that in order to run this function,
* {@link #execBackwards()} must be executed first.
* All gradient functions are obtained within that time.
*
* @param varName the variable opName to get the gradient for.
* @return
*/
public SDVariable grad(String varName) {
if (!sameDiffFunctionInstances.containsKey("grad")) {
throw new IllegalStateException("Unable to obtain gradient. Please run execBackwards() first.");
}
SameDiff grad = getFunction("grad");
SDVariable var = grad.getVariable(varName);
return getFunction("grad").getGradForVariable(var.getVarName());
}
/**
* Average pooling 2d operation.
*
* @param inputs the inputs to average pooling 2d
* @param pooling2DConfig the configuration
* @return
*/
public SDVariable avgPooling2d(SDVariable[] inputs, Pooling2DConfig pooling2DConfig) {
return avgPooling2d(null, inputs, pooling2DConfig);
}
/**
* Average pooling 2d operation.
*
* @param name name of the operation in SameDiff
* @param inputs the inputs to average pooling 2d
* @param pooling2DConfig the configuration
* @return
*/
public SDVariable avgPooling2d(String name, SDVariable[] inputs, Pooling2DConfig pooling2DConfig) {
SDVariable ret = f().avgPooling2d(inputs, pooling2DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* Max pooling 2d operation.
*
* @param inputs the inputs to max pooling 2d
* @param pooling2DConfig the configuration
* @return
*/
public SDVariable maxPooling2d(SDVariable[] inputs, Pooling2DConfig pooling2DConfig) {
return maxPooling2d(null, inputs, pooling2DConfig);
}
/**
* Max pooling 2d operation.
*
* @param name name of the operation in SameDiff
* @param inputs the inputs to max pooling 2d
* @param pooling2DConfig the configuration
* @return
*/
public SDVariable maxPooling2d(String name, SDVariable[] inputs, Pooling2DConfig pooling2DConfig) {
SDVariable ret = f().maxPooling2d(inputs, pooling2DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* Average pooling 3d operation.
*
* @param inputs the inputs to average pooling 3d
* @param pooling3DConfig the configuration
* @return
*/
public SDVariable avgPooling3d(SDVariable[] inputs, Pooling3DConfig pooling3DConfig) {
return avgPooling3d(null, inputs, pooling3DConfig);
}
/**
* Average pooling 3d operation.
*
* @param name name of the operation in SameDiff
* @param inputs the inputs to average pooling 3d
* @param pooling3DConfig the configuration
* @return
*/
public SDVariable avgPooling3d(String name, SDVariable[] inputs, Pooling3DConfig pooling3DConfig) {
SDVariable ret = f().avgPooling3d(inputs, pooling3DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* Max pooling 3d operation.
*
* @param inputs the inputs to max pooling 3d
* @param pooling3DConfig the configuration
* @return
*/
public SDVariable maxPooling3d(SDVariable[] inputs, Pooling3DConfig pooling3DConfig) {
return maxPooling3d(null, inputs, pooling3DConfig);
}
/**
* Max pooling 3d operation.
*
* @param name name of the operation in SameDiff
* @param inputs the inputs to max pooling 3d
* @param pooling3DConfig the configuration
* @return
*/
public SDVariable maxPooling3d(String name, SDVariable[] inputs, Pooling3DConfig pooling3DConfig) {
SDVariable ret = f().maxPooling3d(inputs, pooling3DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* Conv1d operation.
*
* @param inputs the inputs to conv1d
* @param conv1DConfig the configuration
* @return
*/
public SDVariable conv1d(SDVariable[] inputs, Conv1DConfig conv1DConfig) {
return conv1d(null, inputs, conv1DConfig);
}
/**
* Conv1d operation.
*
* @param name name of the operation in SameDiff
* @param inputs the inputs to conv1d
* @param conv1DConfig the configuration
* @return
*/
public SDVariable conv1d(String name, SDVariable[] inputs, Conv1DConfig conv1DConfig) {
SDVariable ret = f().conv1d(inputs, conv1DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* Local response normalization operation.
*
* @param inputs the inputs to lrn
* @param lrnConfig the configuration
* @return
*/
public SDVariable localResponseNormalization(SDVariable inputs, LocalResponseNormalizationConfig lrnConfig) {
return localResponseNormalization(null, inputs, lrnConfig);
}
/**
* Local response normalization operation.
*
* @param name name of the operation in SameDiff
* @param inputs the inputs to lrn
* @param lrnConfig the configuration
* @return
*/
public SDVariable localResponseNormalization(String name, SDVariable inputs,
LocalResponseNormalizationConfig lrnConfig) {
SDVariable ret = f().localResponseNormalization(inputs, lrnConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* Conv2d operation.
*
* @param inputs the inputs to conv2d
* @param conv2DConfig the configuration
* @return
*/
public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
return conv2d(null, inputs, conv2DConfig);
}
/**
* Conv2d operation.
*
* @param name name of the operation in SameDiff
* @param inputs the inputs to conv2d
* @param conv2DConfig the configuration
* @return
*/
public SDVariable conv2d(String name, SDVariable[] inputs, Conv2DConfig conv2DConfig) {
SDVariable ret = f().conv2d(inputs, conv2DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* Depth-wise Conv2d operation.
*
* @param inputs the inputs to conv2d
* @param depthConv2DConfig the configuration
* @return
*/
public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
return depthWiseConv2d(null, inputs, depthConv2DConfig);
}
/**
* Depth-wise Conv2d operation.
*
* @param name name of the operation in SameDiff
* @param inputs the inputs to sconv2d
* @param depthConv2DConfig the configuration
* @return
*/
public SDVariable depthWiseConv2d(String name, SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
SDVariable ret = f().depthWiseConv2d(inputs, depthConv2DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* Separable Conv2d operation.
*
* @param inputs the inputs to conv2d
* @param conv2DConfig the configuration
* @return
*/
public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
return sconv2d(null, inputs, conv2DConfig);
}
/**
* Separable Conv2d operation.
*
* @param name name of the operation in SameDiff
* @param inputs the inputs to sconv2d
* @param conv2DConfig the configuration
* @return
*/
public SDVariable sconv2d(String name, SDVariable[] inputs, Conv2DConfig conv2DConfig) {
SDVariable ret = f().sconv2d(inputs, conv2DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* Deconv2d operation.
*
* @param inputs the inputs to sconv2d
* @param deconv2DConfig the configuration
* @return
*/
public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
return deconv2d(null, inputs, deconv2DConfig);
}
/**
* Deconv2d operation.
*
* @param name name of the operation in SameDiff
* @param inputs the inputs to sconv2d
* @param deconv2DConfig the configuration
* @return
*/
public SDVariable deconv2d(String name, SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
SDVariable ret = f().deconv2d(inputs, deconv2DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* Conv3d operation.
*
* @param inputs the inputs to conv3d
* @param conv3DConfig the configuration
* @return
*/
public SDVariable conv3d(SDVariable[] inputs, Conv3DConfig conv3DConfig) {
return conv3d(null, inputs, conv3DConfig);
}
/**
* Conv3d operation.
*
* @param name name of the operation in SameDiff
* @param inputs the inputs to conv3d
* @param conv3DConfig the configuration
* @return
*/
public SDVariable conv3d(String name, SDVariable[] inputs, Conv3DConfig conv3DConfig) {
SDVariable ret = f().conv3d(inputs, conv3DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* Batch norm operation.
*
*/
public SDVariable batchNorm(SDVariable input, SDVariable mean,
SDVariable variance, SDVariable gamma,
SDVariable beta,
boolean applyGamma, boolean applyBeta, double epsilon) {
return batchNorm(null, input, mean, variance, gamma, beta, applyGamma, applyBeta, epsilon);
}
/**
* Batch norm operation.
*
*/
public SDVariable batchNorm(String name, SDVariable input, SDVariable mean,
SDVariable variance, SDVariable gamma,
SDVariable beta,
boolean applyGamma, boolean applyBeta, double epsilon) {
SDVariable res = f().batchNorm(input, mean, variance, gamma, beta, applyGamma, applyBeta, epsilon);
return updateVariableNameAndReference(res, name);
}
/**
* @param name
* @param value
* @return
*/
public SDVariable scalar(String name, double value) {
return var(name, Nd4j.scalar(value));
}
/**
* @param iX
* @return
*/
public SDVariable gte(SDVariable iX, double iy) {
return gte(null, iX, iy);
}
/**
* @param iX
* @return
*/
public SDVariable lte(SDVariable iX, double iy) {
return lte(null, iX, iy);
}
/**
* @param iX
* @return
*/
public SDVariable gt(SDVariable iX, double iy) {
return gt(null, iX, iy);
}
/**
* @param iX
* @return
*/
public SDVariable lt(SDVariable iX, double iy) {
return lt(null, iX, iy);
}
/**
* @param iX
* @return
*/
public SDVariable neq(SDVariable iX, double iy) {
return neq(null, iX, iy);
}
/**
* @param iX
* @return
*/
public SDVariable eq(SDVariable iX, double iy) {
return eq(null, iX, iy);
}
/**
* @param iX
* @return
*/
public SDVariable gte(SDVariable iX, SDVariable iy) {
return gte(null, iX, iy);
}
/**
* @param iX
* @return
*/
public SDVariable lte(SDVariable iX, SDVariable iy) {
return lte(null, iX, iy);
}
/**
* @param iX
* @return
*/
public SDVariable gt(SDVariable iX, SDVariable iy) {
return gt(null, iX, iy);
}
/**
* @param iX
* @return
*/
public SDVariable lt(SDVariable iX, SDVariable iy) {
return lt(null, iX, iy);
}
/**
* @param iX
* @return
*/
public SDVariable neq(SDVariable iX, SDVariable iy) {
return neq(null, iX, iy);
}
/**
* @param iX
* @return
*/
public SDVariable eq(SDVariable iX, SDVariable iy) {
return eq(null, iX, iy);
}
/**
* @param iX
* @return
*/
public SDVariable or(SDVariable iX, SDVariable iy) {
return or(null, iX, iy);
}
public SDVariable and(SDVariable iX, SDVariable iY) {
return and(null, iX, iY);
}
public SDVariable and(String name, SDVariable ix, SDVariable iy) {
SDVariable result = f().and(ix, iy);
return updateVariableNameAndReference(result, name);
}
public SDVariable xor(SDVariable ix, SDVariable iy) {
return xor(null, ix, iy);
}
public SDVariable xor(String name, SDVariable ix, SDVariable iy) {
SDVariable result = f().xor(ix, iy);
return updateVariableNameAndReference(result, name);
}
public SDVariable abs(SDVariable ix) {
return abs(null, ix);
}
public SDVariable abs(String name, SDVariable ix) {
SDVariable result = f().abs(ix);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable neg(SDVariable iX) {
return neg(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable cos(SDVariable iX) {
return cos(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable sin(SDVariable iX) {
return sin(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable tan(SDVariable iX) {
return tan(null, iX);
}
public SDVariable invertPermutation(SDVariable input) {
return invertPermutation(null, input);
}
public SDVariable invertPermutation(String name, SDVariable input) {
SDVariable ret = f().invertPermutation(input, false);
return updateVariableNameAndReference(ret, name);
}
/**
* @param iX
* @return
*/
public SDVariable acos(SDVariable iX) {
return acos(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable asin(SDVariable iX) {
return asin(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable atan(SDVariable iX) {
return atan(null, iX);
}
public SDVariable atan2(SDVariable y, SDVariable x) {
return atan2(null, y, x);
}
public SDVariable atan2(String name, SDVariable y, SDVariable x) {
SDVariable ret = f().atan2(y, x);
return updateVariableNameAndReference(ret, name);
}
/**
* @param iX
* @return
*/
public SDVariable cosh(SDVariable iX) {
return cosh(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable sinh(SDVariable iX) {
return sinh(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable tanh(SDVariable iX) {
return tanh(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable acosh(SDVariable iX) {
return acosh(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable asinh(SDVariable iX) {
return asinh(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable atanh(SDVariable iX) {
return atanh(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable exp(SDVariable iX) {
return exp(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable rsqrt(SDVariable iX) {
return rsqrt(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable expm1(SDVariable iX) {
return expm1(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable log1p(SDVariable iX) {
return log1p(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable isInfinite(SDVariable iX) {
return isInfinite(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable isNaN(SDVariable iX) {
return isNaN(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable round(SDVariable iX) {
return round(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable isFinite(SDVariable iX) {
return isFinite(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable log(SDVariable iX) {
return log(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable cube(SDVariable iX) {
return cube(null, iX);
}
/**
* @param iX
* @param value
* @return
*/
public SDVariable pow(SDVariable iX, double value) {
return pow(null, iX, value);
}
/**
* @param iX
* @return
*/
public SDVariable sqrt(SDVariable iX) {
return sqrt(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable square(SDVariable iX) {
return square(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable floor(SDVariable iX) {
return floor(null, iX);
}
public SDVariable ceil(SDVariable x) {
return ceil(null, x);
}
public SDVariable ceil(String name, SDVariable x) {
SDVariable ret = f().ceil(x);
return updateVariableNameAndReference(ret, name);
}
public SDVariable clipByValue(SDVariable x, double clipValueMin, double clipValueMax) {
return clipByValue(null, x, clipValueMin, clipValueMax);
}
public SDVariable clipByValue(String name, SDVariable x, double clipValueMin, double clipValueMax) {
SDVariable ret = f().clipByValue(x, clipValueMin, clipValueMax);
return updateVariableNameAndReference(ret, name);
}
public SDVariable clipByNorm(SDVariable x, double clipValue) {
return clipByNorm(null, x, clipValue);
}
public SDVariable clipByNorm(String name, SDVariable x, double clipValue) {
SDVariable ret = f().clipByNorm(x, clipValue);
return updateVariableNameAndReference(ret, name);
}
/**
* @param iX
* @return
*/
public SDVariable relu(SDVariable iX, double cutoff) {
return relu(null, iX, cutoff);
}
/**
* @param iX
* @return
*/
public SDVariable relu6(SDVariable iX, double cutoff) {
return relu6(null, iX, cutoff);
}
/**
* @param iX
* @return
*/
public SDVariable softmax(SDVariable iX) {
return softmax(null, iX);
}
public SDVariable logSoftmax(SDVariable iX) {
return logSoftmax(null, iX);
}
public SDVariable logSoftmax(String name, SDVariable iX) {
SDVariable ret = f().logSoftmax(iX);
return updateVariableNameAndReference(ret, name);
}
public SDVariable selu(SDVariable iX) {
return selu(null, iX);
}
public SDVariable selu(String name, SDVariable iX) {
SDVariable ret = f().selu(iX);
return updateVariableNameAndReference(ret, name);
}
public SDVariable mergeAdd(SDVariable... iX) {
return mergeAdd(null, iX);
}
public SDVariable mergeAdd(String name, SDVariable[] iX) {
SDVariable ret = f().mergeadd(iX);
return updateVariableNameAndReference(ret, name);
}
public SDVariable batchToSpace(SDVariable iX, int[] blocks, int[][] crops) {
return batchToSpace(null, iX, blocks, crops);
}
public SDVariable batchToSpace(String name, SDVariable iX, int[] blocks, int[][] crops) {
SDVariable ret = f().batchToSpace(iX, blocks, crops);
return updateVariableNameAndReference(ret, name);
}
public SDVariable depthToSpace(SDVariable iX, int blockSize, String dataFormat) {
return depthToSpace(null, iX, blockSize, dataFormat);
}
public SDVariable depthToSpace(String name, SDVariable iX, int blockSize, String dataFormat) {
SDVariable ret = f().depthToSpace(iX, blockSize, dataFormat);
return updateVariableNameAndReference(ret, name);
}
public SDVariable spaceToBatch(SDVariable iX, int[] blocks, int[][] padding) {
return spaceToBatch(null, iX, blocks, padding);
}
public SDVariable spaceToBatch(String name, SDVariable iX, int[] blocks, int[][] padding) {
SDVariable ret = f().spaceToBatch(iX, blocks, padding);
return updateVariableNameAndReference(ret, name);
}
public SDVariable spaceToDepth(SDVariable iX, int blockSize, String dataFormat) {
return spaceToDepth(null, iX, blockSize, dataFormat);
}
public SDVariable spaceToDepth(String name, SDVariable iX, int blockSize, String dataFormat) {
SDVariable ret = f().spaceToDepth(iX, blockSize, dataFormat);
return updateVariableNameAndReference(ret, name);
}
public SDVariable[] dynamicPartition(SDVariable iX, SDVariable partitions, int numPartitions) {
return dynamicPartition(null, iX, partitions, numPartitions);
}
public SDVariable[] dynamicPartition(String[] name, SDVariable iX, SDVariable partitions, int numPartitions) {
SDVariable[] ret = f().dynamicPartition(iX, partitions, numPartitions);
return updateVariableNamesAndReferences(ret, name);
}
public SDVariable dynamicStitch(SDVariable[] indices, SDVariable[] iX) {
return dynamicStitch(null, indices, iX);
}
public SDVariable dynamicStitch(String name, SDVariable[] indices, SDVariable[] iX) {
SDVariable ret = f().dynamicStitch(indices, iX);
return updateVariableNameAndReference(ret, name);
}
public SDVariable dilation2D(SDVariable df, SDVariable weights, int[] strides,
int[] rates, boolean isSameMode) {
return dilation2D(null, df, weights, strides, rates, isSameMode);
}
public SDVariable dilation2D(String name, SDVariable df, SDVariable weights, int[] strides,
int[] rates, boolean isSameMode) {
SDVariable ret = f().dilation2D(df, weights, strides, rates, isSameMode);
return updateVariableNameAndReference(ret, name);
}
public SDVariable shape(SDVariable df) {
return shape(null, df);
}
public SDVariable shape(String name, SDVariable df) {
SDVariable ret = f().shape(df);
return updateVariableNameAndReference(ret, name);
}
public SDVariable cross(SDVariable a, SDVariable b) {
return cross(null, a, b);
}
public SDVariable cross(String name, SDVariable a, SDVariable b) {
SDVariable ret = f().cross(a, b);
return updateVariableNameAndReference(ret, name);
}
public SDVariable gather(SDVariable df, int axis, int[] broadcast) {
return gather(null, df, axis, broadcast);
}
public SDVariable gather(String name, SDVariable df, int axis, int[] broadcast) {
SDVariable ret = f().gather(df, axis, broadcast);
return updateVariableNameAndReference(ret, name);
}
public SDVariable gatherNd(SDVariable df, SDVariable indices) {
return gatherNd(null, df, indices);
}
public SDVariable gatherNd(String name, SDVariable df, SDVariable indices) {
SDVariable ret = f().gatherNd(df, indices);
return updateVariableNameAndReference(ret, name);
}
public SDVariable repeat(SDVariable df, int axis) {
return repeat(null, df, axis);
}
public SDVariable repeat(String name, SDVariable df, int axis) {
SDVariable ret = f().repeat(df, axis);
return updateVariableNameAndReference(ret, name);
}
public SDVariable stack(SDVariable[] values, int axis) {
return stack(null, values, axis);
}
public SDVariable stack(String name, SDVariable[] values, int axis) {
SDVariable ret = f().stack(values, axis);
return updateVariableNameAndReference(ret, name);
}
public SDVariable parallel_stack(SDVariable[] values) {
return parallel_stack(null, values);
}
public SDVariable parallel_stack(String name, SDVariable[] values) {
SDVariable ret = f().parallel_stack(values);
return updateVariableNameAndReference(ret, name);
}
public SDVariable[] unstack(SDVariable value, int axis) {
return unstack(null, value, axis);
}
public SDVariable[] unstack(String[] names, SDVariable value, int axis) {
SDVariable[] ret = f().unstack(value, axis);
return updateVariableNamesAndReferences(ret, names);
}
public SDVariable erf(SDVariable iX) {
return erf(null, iX);
}
public SDVariable erf(String name, SDVariable iX) {
SDVariable ret = f().erf(iX);
return updateVariableNameAndReference(ret, name);
}
public SDVariable erfc(SDVariable iX) {
return erfc(null, iX);
}
public SDVariable erfc(String name, SDVariable iX) {
SDVariable ret = f().erfc(iX);
return updateVariableNameAndReference(ret, name);
}
public SDVariable diag(SDVariable iX) {
return diag(null, iX);
}
public SDVariable diag(String name, SDVariable iX) {
SDVariable ret = f().diag(iX);
return updateVariableNameAndReference(ret, name);
}
public SDVariable diagPart(SDVariable iX) {
return diagPart(null, iX);
}
public SDVariable diagPart(String name, SDVariable iX) {
SDVariable ret = f().diagPart(iX);
return updateVariableNameAndReference(ret, name);
}
public SDVariable oneHot(SDVariable indices, int depth) {
return oneHot(null, indices, depth, -1, 1.00, 0.00);
}
public SDVariable oneHot(SDVariable indices, int depth, int axis, double on, double off) {
return oneHot(null, indices, depth, axis, on, off);
}
public SDVariable oneHot(String name, SDVariable indices, int depth) {
return oneHot(name, indices, depth, -1, 1.00, 0.00);
}
public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on, double off) {
SDVariable ret = f().onehot(indices, depth, axis, on, off);
return updateVariableNameAndReference(ret, name);
}
public SDVariable reciprocal(SDVariable a) {
return reciprocal(null,a);
}
public SDVariable reciprocal(String name, SDVariable a) {
SDVariable ret = f().reciprocal(a);
return updateVariableNameAndReference(ret,name);
}
/**
* @param iX
* @return
*/
public SDVariable gradientBackwardsMarker(SDVariable iX) {
return gradientBackwardsMarker(generateNewVarName(new GradientBackwardsMarker().opName(), 0), iX);
}
/**
* @param iX
* @return
*/
public SDVariable hardTanh(SDVariable iX) {
return hardTanh(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable hardTanhDerivative(SDVariable iX) {
return hardTanhDerivative(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable sigmoid(SDVariable iX) {
return sigmoid(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable sigmoidDerivative(SDVariable iX, SDVariable wrt) {
return sigmoidDerivative(null, iX, wrt);
}
public SDVariable logSigmoid(SDVariable iX) {
return logSigmoid(null, iX);
}
public SDVariable logSigmoid(String name, SDVariable iX) {
SDVariable ret = f().logSigmoid(iX);
return updateVariableNameAndReference(ret, name);
}
/**
* @param iX
* @return
*/
public SDVariable sign(SDVariable iX) {
return sign(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable softsign(SDVariable iX) {
return softsign(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable softsignDerivative(SDVariable iX) {
return softsignDerivative(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable softplus(SDVariable iX) {
return softplus(null, iX);
}
public SDVariable swish(SDVariable iX) {
return swish(null, iX);
}
public SDVariable swish(String name, SDVariable iX) {
SDVariable ret = f().swish(iX);
return updateVariableNameAndReference(ret, name);
}
/**
* @param iX
* @return
*/
public SDVariable elu(SDVariable iX) {
return elu(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable eluDerivative(SDVariable iX) {
return eluDerivative(null, iX);
}
/**
* @param iX
* @param cutoff
* @return
*/
public SDVariable leakyRelu(SDVariable iX, double cutoff) {
return leakyRelu(null, iX, cutoff);
}
/**
* @param iX
* @return
*/
public SDVariable mean(SDVariable iX) {
return mean(null, iX);
}
/**
* @param iX
* @param dimension
* @return
*/
public SDVariable mean(SDVariable iX, int... dimension) {
return mean(null, iX, dimension);
}
/**
* @param iX
* @param biasCorrected
* @param dimensions
* @return
*/
public SDVariable standardDeviation(SDVariable iX,
boolean biasCorrected,
int... dimensions) {
return standardDeviation(null, iX, biasCorrected, dimensions);
}
/**
* @param iX
* @param biasCorrected
* @param dimensions
* @return
*/
public SDVariable variance(SDVariable iX,
boolean biasCorrected,
int... dimensions) {
return variance(null, iX, biasCorrected, dimensions);
}
/**
* @param iX
* @param dimensions
* @return
*/
public SDVariable sum(SDVariable iX,
int... dimensions) {
return sum(null, iX, dimensions);
}
/**
* @param iX
* @param dimensions
* @return
*/
public SDVariable prod(SDVariable iX,
int... dimensions) {
return prod(null, iX, dimensions);
}
/**
* @param iX
* @param dimensions
* @return
*/
public SDVariable max(SDVariable iX, int... dimensions) {
return max(null, iX, dimensions);
}
public SDVariable max(SDVariable first, SDVariable second) {
return max(null, first, second);
}
public SDVariable max(String name, SDVariable first, SDVariable second) {
SDVariable result = f().max(first, second);
return updateVariableNameAndReference(result, name);
}
public SDVariable countZero(SDVariable input) {
return countZero(null, input);
}
public SDVariable countZero(String name, SDVariable input) {
SDVariable res = f().countZero(input);
return updateVariableNameAndReference(res, name);
}
public SDVariable zeroFraction(SDVariable input) {
return zeroFraction(null, input);
}
public SDVariable zeroFraction(String name, SDVariable input) {
SDVariable res = f().zeroFraction(input);
return updateVariableNameAndReference(res, name);
}
public SDVariable countNonZero(SDVariable input) {
return countNonZero(null, input);
}
public SDVariable countNonZero(String name, SDVariable input) {
SDVariable res = f().countNonZero(input);
return updateVariableNameAndReference(res, name);
}
/**
* @param iX
* @param dimensions
* @return
*/
public SDVariable min(SDVariable iX, int... dimensions) {
return min(null, iX, dimensions);
}
public SDVariable min(SDVariable first, SDVariable second) {
return min(null, first, second);
}
public SDVariable min(String name, SDVariable first, SDVariable second) {
SDVariable result = f().min(first, second);
return updateVariableNameAndReference(result, name);
}
public SDVariable argmax(SDVariable in, int... dimensions) {
return argmax(null, in, dimensions);
}
public SDVariable argmax(String name, SDVariable in, int... dimensions) {
SDVariable ret = f().argmax(in, dimensions);
return updateVariableNameAndReference(ret, name);
}
public SDVariable argmin(SDVariable in, int... dimensions) {
return argmin(null, in, dimensions);
}
public SDVariable argmin(String name, SDVariable in, int... dimensions) {
SDVariable ret = f().argmin(in, dimensions);
return updateVariableNameAndReference(ret, name);
}
public SDVariable cumsum(SDVariable in, boolean exclusive, boolean reverse, int... dimensions) {
return cumsum(null, in, exclusive, reverse, dimensions);
}
public SDVariable cumsum(String name, SDVariable in, boolean exclusive, boolean reverse, int... dimensions) {
SDVariable ret = f().cumsum(in, exclusive, reverse, dimensions);
return updateVariableNameAndReference(ret, name);
}
public SDVariable cumprod(SDVariable in, boolean exclusive, boolean reverse, int... dimensions) {
return cumprod(null, in, exclusive, reverse, dimensions);
}
public SDVariable cumprod(String name, SDVariable in, boolean exclusive, boolean reverse, int... dimensions) {
SDVariable ret = f().cumprod(in, exclusive, reverse, dimensions);
return updateVariableNameAndReference(ret, name);
}
public SDVariable biasAdd(SDVariable input, SDVariable bias) {
return biasAdd(null, input, bias);
}
public SDVariable biasAdd(String name, SDVariable input, SDVariable bias) {
SDVariable ret = f().biasAdd(input, bias);
return updateVariableNameAndReference(ret, name);
}
/**
* @param iX
* @param shape
* @return
*/
public SDVariable reshape(SDVariable iX,
int... shape) {
return reshape(null, iX, shape);
}
/**
* @param x
* @param dimensions
* @return
*/
public SDVariable reverse(SDVariable x, int... dimensions){
return reverse(null, x, dimensions);
}
/**
* @param x
* @param dimensions
* @return
*/
public SDVariable reverse(String name, SDVariable x, int... dimensions){
SDVariable ret = f().reverse(x, dimensions);
return updateVariableNameAndReference(ret, name);
}
public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths, int seqDim, int batchDim) {
SDVariable ret = f().reverseSequence(x, seq_lengths, seqDim, batchDim);
return updateVariableNameAndReference(ret, name);
}
public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths) {
SDVariable ret = f().reverseSequence(x, seq_lengths);
return updateVariableNameAndReference(ret, name);
}
public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths, int seqDim, int batchDim) {
return reverseSequence(null, x, seq_lengths, seqDim, batchDim);
}
public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths){
return reverseSequence(null, x, seq_lengths);
}
public SDVariable sequenceMask(String name, SDVariable lengths, SDVariable maxLen) {
SDVariable ret = f().sequenceMask(lengths, maxLen);
return updateVariableNameAndReference(ret, name);
}
public SDVariable sequenceMask(SDVariable lengths, SDVariable maxLen) {
return sequenceMask(null, lengths, maxLen);
}
public SDVariable sequenceMask(String name, SDVariable lengths, int maxLen) {
SDVariable ret = f().sequenceMask(lengths, maxLen);
return updateVariableNameAndReference(ret, name);
}
public SDVariable sequenceMask(SDVariable lengths, int maxLen) {
return sequenceMask(null, lengths, maxLen);
}
public SDVariable sequenceMask(String name, SDVariable lengths) {
SDVariable ret = f().sequenceMask(lengths);
return updateVariableNameAndReference(ret, name);
}
public SDVariable sequenceMask(SDVariable lengths) {
SDVariable ret = f().sequenceMask(lengths);
return updateVariableNameAndReference(ret, null);
}
public SDVariable assign(SDVariable x, SDVariable y){
return assign(null, x, y);
}
public SDVariable assign(String name, SDVariable x, SDVariable y) {
SDVariable ret = f().assign(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* @param iX
* @return
*/
public SDVariable transpose(SDVariable iX) {
return transpose(null, iX);
}
/**
* @param iX
* @param dimensions
* @return
*/
public SDVariable permute(SDVariable iX, int... dimensions) {
return permute(null, iX, dimensions);
}
/**
* @param x
* @param axis
* @return
*/
public SDVariable rollAxis(SDVariable x, int axis) {
return rollAxis(null, x, axis);
}
/**
* @param dimension
* @param inputs
* @return
*/
public SDVariable concat(int dimension, SDVariable... inputs) {
return concat(null, dimension, inputs);
}
public SDVariable[] moments(SDVariable input, int... axes) {
return moments(null, input, axes);
}
public SDVariable[] moments(String[] name, SDVariable input, int... axes) {
SDVariable[] res = f().moments(input, axes);
return updateVariableNamesAndReferences(res, name);
}
public SDVariable[] normalizeMoments(SDVariable counts, SDVariable means, SDVariable variances, double shift) {
return normalizeMoments(null, counts, means, variances, shift);
}
public SDVariable[] normalizeMoments(String[] name, SDVariable counts, SDVariable means, SDVariable variances,
double shift) {
SDVariable[] res = f().normalizeMoments(counts, means, variances, shift);
return updateVariableNamesAndReferences(res, name);
}
/**
* @param iX
* @param repeat
* @return
*/
public SDVariable tile(SDVariable iX, int[] repeat) { return tile(null, iX, repeat);}
public SDVariable fill(SDVariable shape, double value) {
return fill(null, shape, value);
}
public SDVariable dropout(SDVariable input, double p) {
return dropout(null, input, p);
}
public SDVariable dropout(String name, SDVariable input, double p) {
SDVariable res = f().dropout(input, p);
return updateVariableNameAndReference(res, name);
}
public SDVariable xwPlusB(SDVariable input, SDVariable weights, SDVariable bias) {
return xwPlusB(null, input, weights, bias);
}
public SDVariable xwPlusB(String name, SDVariable input, SDVariable weights, SDVariable bias) {
SDVariable res = f().xwPlusB(input, weights, bias);
return updateVariableNameAndReference(res, name);
}
public SDVariable reluLayer(SDVariable input, SDVariable weights, SDVariable bias) {
return reluLayer(null, input, weights, bias);
}
public SDVariable reluLayer(String name, SDVariable input, SDVariable weights, SDVariable bias) {
SDVariable res = f().reluLayer(input, weights, bias);
return updateVariableNameAndReference(res, name);
}
/**
* @param x
* @param y
* @param transpose
* @return
*/
public SDVariable mmul(SDVariable x, SDVariable y, MMulTranspose transpose) {
return mmul(null, x, y, transpose);
}
/**
* @param x
* @param y
* @return
*/
public SDVariable mmul(SDVariable x, SDVariable y) {
return mmul(null, x, y);
}
/**
* @param x
* @param y
* @param dimensions
* @return
*/
public SDVariable tensorMmul(SDVariable x,
SDVariable y,
int[][] dimensions) {
return tensorMmul(null, x, y, dimensions);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable cosineSimilarity(SDVariable iX, SDVariable i_y, int... dimensions) {
return cosineSimilarity(generateNewVarName(CosineSimilarity.OP_NAME, 0), iX, i_y, dimensions);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable euclideanDistance(SDVariable iX, SDVariable i_y, int... dimensions) {
return euclideanDistance(generateNewVarName(EuclideanDistance.OP_NAME, 0), iX, i_y, dimensions);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable manhattanDistance(SDVariable iX, SDVariable i_y, int... dimensions) {
return manhattanDistance(generateNewVarName(ManhattanDistance.OP_NAME, 0), iX, i_y, dimensions);
}
public SDVariable cosineDistance(SDVariable ix, SDVariable iy, int... dimensions) {
return cosineDistance(null, ix, iy, dimensions);
}
public SDVariable cosineDistance(String name, SDVariable ix, SDVariable iy, int... dimensions) {
SDVariable result = functionFactory.cosineDistance(ix, iy, dimensions);
return updateVariableNameAndReference(result, name);
}
public SDVariable hammingDistance(SDVariable ix, SDVariable iy, int... dimensions) {
return hammingDistance(null, ix, iy, dimensions);
}
public SDVariable hammingDistance(String name, SDVariable ix, SDVariable iy, int... dimensions) {
SDVariable result = functionFactory.hammingDistance(ix, iy, dimensions);
return updateVariableNameAndReference(result, name);
}
public SDVariable jaccardDistance(SDVariable ix, SDVariable iy, int... dimensions) {
return jaccardDistance(null, ix, iy, dimensions);
}
public SDVariable jaccardDistance(String name, SDVariable ix, SDVariable iy, int... dimensions) {
SDVariable result = functionFactory.jaccardDistance(ix, iy, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossBinaryXENT(SDVariable iX, SDVariable i_y, int... dimensions) {
return lossBinaryXENT(generateNewVarName(new LossBinaryXENT().opName(), 0), iX, i_y, dimensions);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossCosineSimilarity(SDVariable iX, SDVariable i_y, int... dimensions) {
return lossCosineSimilarity(generateNewVarName(new LossCosineProximity().opName(), 0), iX, i_y, dimensions);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossHinge(SDVariable iX, SDVariable i_y, int... dimensions) {
return lossHinge(generateNewVarName(new LossHinge().opName(), 0), iX, i_y, dimensions);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossKLD(SDVariable iX, SDVariable i_y, int... dimensions) {
return lossKLD(generateNewVarName(new LossKLD().opName(), 0), iX, i_y, dimensions);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossL1(SDVariable iX, SDVariable i_y, int... dimensions) {
return lossL1(generateNewVarName(new LossL1().opName(), 0), iX, i_y, dimensions);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossL2(SDVariable iX, SDVariable i_y, int... dimensions) {
return lossL2(generateNewVarName(new LossL2().opName(), 0), iX, i_y, dimensions);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossMAE(SDVariable iX, SDVariable i_y, int... dimensions) {
return lossMAE(generateNewVarName(new LossMAE().opName(), 0), iX, i_y, dimensions);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossMSE(SDVariable iX, SDVariable i_y, int... dimensions) {
return lossMSE(generateNewVarName(new LossMSE().opName(), 0), iX, i_y, dimensions);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossMCXENT(SDVariable iX, SDVariable i_y, int... dimensions) {
return lossMCXENT(generateNewVarName(new LossMCXENT().opName(), 0), iX, i_y, dimensions);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossMSLE(SDVariable iX, SDVariable i_y, int... dimensions) {
return lossMSLE(generateNewVarName(new LossMSLE().opName(), 0), iX, i_y, dimensions);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossNegativeLogLikelihood(SDVariable iX, SDVariable i_y, int... dimensions) {
return lossNegativeLogLikelihood(generateNewVarName(new LossNegativeLogLikelihood().opName(), 0), iX, i_y, dimensions);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossPoisson(SDVariable iX, SDVariable i_y, int... dimensions) {
return lossPoisson(generateNewVarName(new LossPoisson().opName(), 0), iX, i_y, dimensions);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossSquaredHinge(SDVariable iX, SDVariable i_y, int... dimensions) {
return lossSquaredHinge(generateNewVarName(new LossSquaredHinge().opName(), 0), iX, i_y, dimensions);
}
/**
* @param name
* @param iX
* @return
*/
public SDVariable gradientBackwardsMarker(String name, SDVariable iX) {
SDVariable result = functionFactory.gradientBackwardsMarker(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable neq(String name, SDVariable iX, double iy) {
SDVariable result = functionFactory.neq(iX, iy);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable eq(String name, SDVariable iX, double iy) {
SDVariable result = functionFactory.eq(iX, iy);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable gte(String name, SDVariable iX, double iy) {
SDVariable result = functionFactory.gte(iX, iy);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable lte(String name, SDVariable iX, double iy) {
SDVariable result = functionFactory.lte(iX, iy);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable gt(String name, SDVariable iX, double iy) {
SDVariable result = functionFactory.gt(iX, iy);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable lt(String name, SDVariable iX, double iy) {
SDVariable result = functionFactory.lt(iX, iy);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable neq(String name, SDVariable iX, SDVariable iy) {
SDVariable result = functionFactory.neq(iX, iy);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable eq(String name, SDVariable iX, SDVariable iy) {
SDVariable result = functionFactory.eq(iX, iy);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable gte(String name, SDVariable iX, SDVariable iy) {
SDVariable result = functionFactory.gte(iX, iy);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable lte(String name, SDVariable iX, SDVariable iy) {
SDVariable result = functionFactory.lte(iX, iy);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable gt(String name, SDVariable iX, SDVariable iy) {
SDVariable result = functionFactory.gt(iX, iy);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable lt(String name, SDVariable iX, SDVariable iy) {
SDVariable result = functionFactory.lt(iX, iy);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable or(String name, SDVariable iX, SDVariable iy) {
SDVariable result = functionFactory.or(iX, iy);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable neg(String name, SDVariable iX) {
SDVariable result = functionFactory.neg(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable isNonDecreasing(SDVariable iX) {
return isNonDecreasing(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable isNonDecreasing(String name, SDVariable iX) {
SDVariable result = functionFactory.isNonDecreasing(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable isStrictlyIncreasing(SDVariable iX) {
return isStrictlyIncreasing(null, iX);
}
/**
* @param iX
* @return
*/
public SDVariable isStrictlyIncreasing(String name, SDVariable iX) {
SDVariable result = functionFactory.isStrictlyIncreasing(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param
* @return
*/
public SDVariable isNumericTensor(SDVariable iX) {
return isNumericTensor(null, iX);
}
/**
* @param
* @return
*/
public SDVariable isNumericTensor(String name, SDVariable iX) {
SDVariable result = functionFactory.isNumericTensor(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable cos(String name, SDVariable iX) {
SDVariable result = functionFactory.cos(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable sin(String name, SDVariable iX) {
SDVariable result = functionFactory.sin(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable tan(String name, SDVariable iX) {
SDVariable result = functionFactory.tan(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable acos(String name, SDVariable iX) {
SDVariable result = functionFactory.acos(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable asin(String name, SDVariable iX) {
SDVariable result = functionFactory.asin(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable atan(String name, SDVariable iX) {
SDVariable result = functionFactory.atan(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable cosh(String name, SDVariable iX) {
SDVariable result = functionFactory.cosh(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable sinh(String name, SDVariable iX) {
SDVariable result = functionFactory.sinh(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable tanh(String name, SDVariable iX) {
SDVariable
result = functionFactory.tanh(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable acosh(String name, SDVariable iX) {
SDVariable result = functionFactory.acosh(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable asinh(String name, SDVariable iX) {
SDVariable result = functionFactory.asinh(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable atanh(String name, SDVariable iX) {
SDVariable result = functionFactory.atanh(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable exp(String name, SDVariable iX) {
SDVariable result = functionFactory.exp(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable expm1(String name, SDVariable iX) {
SDVariable result = functionFactory.expm1(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable rsqrt(String name, SDVariable iX) {
SDVariable result = functionFactory.rsqrt(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable log(String name, SDVariable iX) {
SDVariable result = functionFactory.log(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable log1p(String name, SDVariable iX) {
SDVariable result = functionFactory.log1p(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable isFinite(String name, SDVariable iX) {
SDVariable result = functionFactory.isFinite(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable isInfinite(String name, SDVariable iX) {
SDVariable result = functionFactory.isInfinite(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable isNaN(String name, SDVariable iX) {
SDVariable result = functionFactory.isNaN(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable round(String name, SDVariable iX) {
SDVariable result = functionFactory.round(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param value
* @return
*/
public SDVariable pow(String name, SDVariable iX, double value) {
SDVariable result = functionFactory.pow(iX, value);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable cube(String name, SDVariable iX) {
SDVariable result = functionFactory.cube(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable sqrt(String name, SDVariable iX) {
SDVariable result = functionFactory.sqrt(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable square(String name, SDVariable iX) {
SDVariable result = functionFactory.square(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable floor(String name, SDVariable iX) {
SDVariable result = functionFactory.floor(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable relu(String name, SDVariable iX, double cutoff) {
SDVariable result = functionFactory.relu(iX, cutoff);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable relu6(String name, SDVariable iX, double cutoff) {
SDVariable result = functionFactory.relu6(iX, cutoff);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable softmax(String name, SDVariable iX) {
SDVariable result = functionFactory.softmax(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable softmaxDerivative(String name, SDVariable iX, SDVariable wrt) {
SDVariable result = functionFactory.softmaxDerivative(iX, wrt);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable hardTanh(String name, SDVariable iX) {
SDVariable result = functionFactory.hardTanh(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable hardTanhDerivative(String name, SDVariable iX) {
SDVariable result = functionFactory.hardTanhDerivative(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable sigmoid(String name, SDVariable iX) {
SDVariable result = functionFactory.sigmoid(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable sigmoidDerivative(String name, SDVariable iX, SDVariable wrt) {
SDVariable result = functionFactory
.sigmoidDerivative(iX, wrt);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable sign(String name, SDVariable iX) {
SDVariable result = functionFactory
.sign(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable softsign(String name, SDVariable iX) {
SDVariable result = functionFactory.softsign(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable softsignDerivative(String name, SDVariable iX) {
SDVariable result = functionFactory.softsignDerivative(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable softplus(String name, SDVariable iX) {
SDVariable result = functionFactory.softplus(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable elu(String name, SDVariable iX) {
SDVariable result = functionFactory.elu(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable eluDerivative(String name, SDVariable iX) {
SDVariable result = functionFactory.eluDerivative(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param alpha
* @return
*/
public SDVariable leakyRelu(String name, SDVariable iX, double alpha) {
SDVariable result = functionFactory.leakyRelu(iX, alpha);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param alpha
* @return
*/
public SDVariable leakyReluDerivative(String name, SDVariable iX, double alpha) {
SDVariable result = functionFactory.leakyReluDerivative(iX, alpha);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable mean(String name, SDVariable iX) {
SDVariable result = functionFactory.mean(iX);
return updateVariableNameAndReference(result, name);
}
public SDVariable mean(String name, SDVariable iX, int... dimension) {
SDVariable result = functionFactory.mean(iX, dimension);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param biasCorrected
* @param dimensions
* @return
*/
public SDVariable standardDeviation(String name, SDVariable iX,
boolean biasCorrected,
int... dimensions) {
SDVariable result = functionFactory.std(
iX,
biasCorrected,
dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param biasCorrected
* @param dimensions
* @return
*/
public SDVariable variance(String name, SDVariable iX,
boolean biasCorrected,
int... dimensions) {
SDVariable result = functionFactory.variance(iX,
biasCorrected,
dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param dimensions
* @return
*/
public SDVariable sum(String name, SDVariable iX,
int... dimensions) {
SDVariable result = functionFactory.sum(iX, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param dimensions
* @return
*/
public SDVariable prod(String name, SDVariable iX,
int... dimensions) {
SDVariable result = functionFactory.prod(iX, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param dimensions
* @return
*/
public SDVariable max(String name, SDVariable iX, int... dimensions) {
SDVariable result = functionFactory.max(iX, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param dimensions
* @return
*/
public SDVariable min(String name, SDVariable iX,
int... dimensions) {
SDVariable result = functionFactory.min(iX, dimensions);
return updateVariableNameAndReference(result, name);
}
public SDVariable norm1(String name, SDVariable ix, int... dimensions) {
SDVariable result = f().norm1(ix, dimensions);
return updateVariableNameAndReference(result, name);
}
public SDVariable norm2(String name, SDVariable ix, int... dimensions) {
SDVariable result = f().norm2(ix, dimensions);
return updateVariableNameAndReference(result, name);
}
public SDVariable normmax(String name, SDVariable ix, int... dimensions) {
SDVariable result = f().normmax(ix, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param shape
* @return
*/
public SDVariable reshape(String name, SDVariable iX,
int... shape) {
SDVariable result = functionFactory
.reshape(iX, shape);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @return
*/
public SDVariable transpose(String name, SDVariable iX) {
SDVariable result = functionFactory.transpose(iX);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param dimensions
* @return
*/
public SDVariable permute(String name, SDVariable iX, int... dimensions) {
SDVariable result = functionFactory.permute(iX, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param x
* @param axis
* @return
*/
public SDVariable rollAxis(String name, SDVariable x, int axis) {
SDVariable result = functionFactory.rollAxis(x, axis);
return updateVariableNameAndReference(result, name);
}
/**
* @param shape
* @param value
* @return
*/
public SDVariable fill(String name, SDVariable shape, double value) {
SDVariable result = functionFactory.fill(shape, value);
return updateVariableNameAndReference(result, name);
}
/**
* @param dimension
* @param inputs
* @return
*/
public SDVariable concat(String name, int dimension, SDVariable... inputs) {
SDVariable result = functionFactory.concat(dimension, inputs);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param repeat
* @return
*/
public SDVariable tile(String name, SDVariable iX, int[] repeat) {
SDVariable result = functionFactory.tile(iX, repeat);
return updateVariableNameAndReference(result, name);
}
/**
* @param x
* @param y
* @param transpose
* @return
*/
public SDVariable mmul(String name, SDVariable x, SDVariable y, MMulTranspose transpose) {
SDVariable result = functionFactory.mmul(x, y, transpose);
return updateVariableNameAndReference(result, name);
}
/**
* @param x
* @param y
* @return
*/
public SDVariable mmul(String name, SDVariable x, SDVariable y) {
return mmul(name, x, y, MMulTranspose.allFalse());
}
/**
* @param x
* @param y
* @param dimensions
* @return
*/
public SDVariable tensorMmul(String name,
SDVariable x,
SDVariable y,
int[][] dimensions) {
SDVariable result = functionFactory.tensorMmul(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable cosineSimilarity(String name, SDVariable iX, SDVariable i_y, int... dimensions) {
SDVariable cosim = functionFactory.cosineSimilarity(
iX,
i_y,
dimensions);
return updateVariableNameAndReference(cosim, name);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable euclideanDistance(String name, SDVariable iX, SDVariable i_y, int... dimensions) {
SDVariable result = functionFactory.euclideanDistance(iX, i_y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable manhattanDistance(String name, SDVariable iX, SDVariable i_y, int... dimensions) {
SDVariable result = functionFactory.manhattanDistance(iX, i_y, dimensions);
return updateVariableNameAndReference(result, name);
}
public SDVariable sigmoidCrossEntropyWithLogits(SDVariable logits, SDVariable weights, SDVariable labels,
int reductionMode, double labelSmoothing) {
return sigmoidCrossEntropyWithLogits(null, logits, weights, labels, reductionMode, labelSmoothing);
}
public SDVariable sigmoidCrossEntropyWithLogits(String name, SDVariable logits, SDVariable weights, SDVariable labels,
int reductionMode, double labelSmoothing) {
SDVariable res = f().sigmoidCrossEntropyWithLogits(logits, weights, labels, reductionMode, labelSmoothing);
return updateVariableNameAndReference(res, name);
}
public SDVariable softmaxCrossEntropyWithLogits(SDVariable logits, SDVariable weights, SDVariable labels,
int reductionMode, double labelSmoothing) {
return softmaxCrossEntropyWithLogits(null, logits, weights, labels, reductionMode, labelSmoothing);
}
public SDVariable softmaxCrossEntropyWithLogits(String name, SDVariable logits, SDVariable weights, SDVariable labels,
int reductionMode, double labelSmoothing) {
SDVariable res = f().softmaxCrossEntropyWithLogits(logits, weights, labels, reductionMode, labelSmoothing);
return updateVariableNameAndReference(res, name);
}
public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable inputs,
SDVariable weights) {
return weightedCrossEntropyWithLogits(null, targets, inputs, weights);
}
public SDVariable weightedCrossEntropyWithLogits(String name, SDVariable targets, SDVariable inputs,
SDVariable weights) {
SDVariable res = f().weightedCrossEntropyWithLogits(targets, inputs, weights);
return updateVariableNameAndReference(res, name);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossBinaryXENT(String name, SDVariable iX, SDVariable i_y, int... dimensions) {
SDVariable result = functionFactory.lossBinaryXENT(iX, i_y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossCosineSimilarity(String name, SDVariable iX, SDVariable i_y, int... dimensions) {
SDVariable result = functionFactory.lossCosineSimilarity(iX, i_y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossHinge(String name, SDVariable iX, SDVariable i_y, int... dimensions) {
SDVariable result = functionFactory.lossHinge(iX, i_y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossKLD(String name, SDVariable iX, SDVariable i_y, int... dimensions) {
SDVariable result = functionFactory.lossKLD(iX, i_y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossL1(String name, SDVariable iX, SDVariable i_y, int... dimensions) {
SDVariable result = functionFactory.lossL1(iX, i_y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossL2(String name, SDVariable iX, SDVariable i_y, int... dimensions) {
SDVariable result = functionFactory.lossL2(iX, i_y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossMAE(String name, SDVariable iX, SDVariable i_y, int... dimensions) {
SDVariable result = functionFactory.lossMAE(iX, i_y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossMSE(String name, SDVariable iX, SDVariable i_y, int... dimensions) {
SDVariable result = functionFactory.lossMSE(iX, i_y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossMCXENT(String name, SDVariable iX, SDVariable i_y, int... dimensions) {
SDVariable result = functionFactory.lossMCXENT(iX, i_y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossMSLE(String name, SDVariable iX, SDVariable i_y, int... dimensions) {
SDVariable result = functionFactory.lossMSLE(iX, i_y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossNegativeLogLikelihood(String name, SDVariable iX, SDVariable i_y, int... dimensions) {
SDVariable result = functionFactory.lossNegativeLogLikelihood(iX, i_y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossPoisson(String name, SDVariable iX, SDVariable i_y, int... dimensions) {
SDVariable result = functionFactory.lossPoisson(iX, i_y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param iX
* @param i_y
* @param dimensions
* @return
*/
public SDVariable lossSquaredHinge(String name, SDVariable iX, SDVariable i_y, int... dimensions) {
SDVariable result = functionFactory.lossSquaredHinge(iX, i_y, dimensions);
return updateVariableNameAndReference(result, name);
}
public SDVariable expandDims(SDVariable ix, int axis) {
return expandDims(null, ix, axis);
}
public SDVariable expandDims(String name, SDVariable ix, int axis) {
SDVariable result = f().expandDims(ix, axis);
return updateVariableNameAndReference(result, name);
}
public SDVariable squeeze(SDVariable ix, int axis) {
return squeeze(null, ix, axis);
}
public SDVariable squeeze(String name, SDVariable ix, int axis) {
SDVariable result = f().squeeze(ix, axis);
return updateVariableNameAndReference(result, name);
}
public SDVariable confusionMatrix(SDVariable labels, SDVariable predictions) {
return confusionMatrix((String) null, labels, predictions);
}
public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred) {
SDVariable result = f().confusionMatrix(labels, pred);
return updateVariableNameAndReference(result, name);
}
public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses) {
return confusionMatrix(null, labels, pred, numClasses);
}
public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, Integer numClasses) {
SDVariable result = f().confusionMatrix(labels, pred, numClasses);
return updateVariableNameAndReference(result, name);
}
public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable weights) {
return confusionMatrix(null, labels, pred, weights);
}
public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, SDVariable weights) {
SDVariable result = f().confusionMatrix(labels, pred, weights);
return updateVariableNameAndReference(result, name);
}
public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights) {
return confusionMatrix(null, labels, pred, numClasses, weights);
}
public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights) {
SDVariable result = f().confusionMatrix(labels, pred, numClasses, weights);
return updateVariableNameAndReference(result, name);
}
/**
* @param variable
*/
public void addVariable(SDVariable variable) {
if (variableMap == null)
variableMap = new HashMap<>();
Preconditions.checkState(variable.getSameDiff() == this, "Samediff instance must be the same.");
/**
* Of note here:
* We don't validate base don vertex id
* because more than one input can have the same
* vertex id as a result.
*
* We validate based on variable opName instead
* which takes in to account function names as well
* as input ids
*/
if (variableMap.containsKey(variable.getVarName()) && !variableMap.get(variable.getVarName()).equals(variable)) {
throw new IllegalArgumentException("Variable already found with variable opName " + variable.getVarName());
}
Preconditions.checkState(variable.getSameDiff() == this, "Same diff instance for variable must be the same!");
variableMap.put(variable.getVarName(), variable);
}
/**
* Generate a new variable name
* based on the uniqueness
* of thebase name and arg index
*
* @param baseName the base name to use (use function.opName() where function is a {@link DifferentialFunction}
* @param argIndex the arg index
* @return the new generated name
*/
public String generateNewVarName(String baseName, int argIndex) {
if (getVariable(baseName) == null && argIndex == 0) {
return baseName;
}
//need to find a new name
int count = 1;
String name = baseName + "_" + count + (argIndex > 0 ? ":" + argIndex : "");
while (getVariable(name) != null) {
name = baseName + "_" + (++count) + (argIndex > 0 ? ":" + argIndex : "");
}
if (getVariable(name) != null) {
throw new ND4JIllegalStateException("Converged on already generated variable!");
}
return name;
}
/**
* LSTM unit
*
* @param baseName the base name for outputs
* @param configuration the configuration to use
* @return
*/
public SDVariable lstm(String baseName, LSTMCellConfiguration configuration) {
return new LSTMCell(this, configuration).outputVariables(baseName)[0];
}
/**
* An sru cell
*
* @param configuration the configuration for the sru cell
* @return
*/
public SDVariable sruCell(SRUCellConfiguration configuration) {
return new SRUCell(this, configuration).outputVariables()[0];
}
/**
* Simple recurrent unit
*
* @param configuration the configuration for the sru
* @return
*/
public SDVariable sru(SRUConfiguration configuration) {
return new SRU(this, configuration).outputVariables()[0];
}
/**
* The gru cell
*
* @param configuration teh configuration to use
* @return
*/
public SDVariable gru(GRUCellConfiguration configuration) {
return new GRUCell(this, configuration).outputVariables()[0];
}
/**
* An sru cell
*
* @param baseName the base name to use for the output variables
* @param configuration the configuration for the sru cell
* @return
*/
public SDVariable sruCell(String baseName, SRUCellConfiguration configuration) {
return new SRUCell(this, configuration).outputVariables(baseName)[0];
}
/**
* Simiple recurrent unit
*
* @param baseName the base name to use for output variables
* @param configuration the configuration for the sru
* @return
*/
public SDVariable sru(String baseName, SRUConfiguration configuration) {
return new SRU(this, configuration).outputVariables(baseName)[0];
}
/**
* The gru cell
*
* @param baseName the base name for the gru cell
* @param configuration teh configuration to use
* @return
*/
public SDVariable gru(String baseName, GRUCellConfiguration configuration) {
return new GRUCell(this, configuration).outputVariables(baseName)[0];
}
public SDVariable slice(SDVariable input, int[] begin, int[] size) {
return slice(null, input, begin, size);
}
public SDVariable slice(String name, SDVariable input, int[] begin, int[] size) {
SDVariable ret = f().slice(input, begin, size);
return updateVariableNameAndReference(ret, name);
}
public SDVariable stridedSlice(SDVariable input, int[] begin, int[] end, int[] strides) {
return stridedSlice(null, input, begin, end, strides);
}
public SDVariable stridedSlice(String name, SDVariable input, int[] begin, int[] end, int[] strides) {
return stridedSlice(name, input, begin, end, strides, 0, 0, 0, 0, 0);
}
public SDVariable stridedSlice(SDVariable in, int[] begin, int[] end, int[] strides, int beginMask,
int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
return stridedSlice(null, in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
}
public SDVariable stridedSlice(String name, SDVariable in, int[] begin, int[] end, int[] strides, int beginMask,
int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
SDVariable ret = f().stridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
return updateVariableNameAndReference(ret, name);
}
public SDVariable scatterAdd(String name, SDVariable ref, SDVariable indices, SDVariable updates) {
SDVariable ret = f().scatterAdd(ref, indices, updates);
return updateVariableNameAndReference(ret, name);
}
public SDVariable scatterMul(String name, SDVariable ref, SDVariable indices, SDVariable updates) {
SDVariable ret = f().scatterMul(ref, indices, updates);
return updateVariableNameAndReference(ret, name);
}
public SDVariable scatterSub(String name, SDVariable ref, SDVariable indices, SDVariable updates) {
SDVariable ret = f().scatterSub(ref, indices, updates);
return updateVariableNameAndReference(ret, name);
}
public SDVariable scatterDiv(String name, SDVariable ref, SDVariable indices, SDVariable updates) {
SDVariable ret = f().scatterDiv(ref, indices, updates);
return updateVariableNameAndReference(ret, name);
}
public SDVariable scatterAdd(SDVariable ref, SDVariable indices, SDVariable updates) {
return scatterAdd(null, ref, indices, updates);
}
public SDVariable scatterMul(SDVariable ref, SDVariable indices, SDVariable updates) {
return scatterMul(null, ref, indices, updates);
}
public SDVariable scatterSub(SDVariable ref, SDVariable indices, SDVariable updates) {
return scatterSub(null, ref, indices, updates);
}
public SDVariable scatterDiv(SDVariable ref, SDVariable indices, SDVariable updates) {
return scatterDiv(null, ref, indices, updates);
}
/**
* Generate the variables based on the given input op
* and return the output variable names.
*
* @param function the function to generate the output
* variable names for
* @return the set of names generated for each output of the function.
*/
public SDVariable[] generateOutputVariableForOp(DifferentialFunction function, String baseName) {
//xyz ops only have 1 output
//if there is already a base name defined, use that
if (baseName == null || baseName.isEmpty() && getBaseNameForFunction(function) != null)
baseName = getBaseNameForFunction(function);
if (baseName == null)
baseName = function.opName();
val outputShape = function.calculateOutputShape();
if (outputShape == null || outputShape.isEmpty()) {
if (function instanceof CustomOp) {
CustomOp customOp = (CustomOp) function;
val descriptor = customOp.getDescriptor();
//can't guess number of outputs, variable
if (descriptor == null || descriptor.getNumOutputs() <= 0) {
throw new ND4JIllegalStateException("No output variables found!");
} else {
char ordering = 'c';
if (function.args()[0].getArr() != null) {
ordering = function.args()[0].getArr().ordering();
}
SDVariable[] ret = new SDVariable[descriptor.getNumOutputs()];
//dynamic shapes
for (int i = 0; i < ret.length; i++) {
SDVariable checkGet = getVariable(baseName);
if (checkGet == null) {
checkGet = var(generateNewVarName(baseName, i), null, new ZeroInitScheme(ordering));
} else if (i > 0 && !importedVarName.contains(baseName)) {
//need to find a new name
String newName = generateNewVarName(baseName, i);
checkGet = getVariable(newName);
}
if (checkGet == null) {
String newName = generateNewVarName(baseName, i);
checkGet = var(newName, null, new ZeroInitScheme(ordering));
}
ret[i] = checkGet;
}
return ret;
}
}
//this is for unresolved shapes, we know xyz is always 1 output
else if (function instanceof BaseOp && outputShape.isEmpty()) {
SDVariable[] ret = new SDVariable[1];
SDVariable checkGet = getVariable(baseName);
char ordering = 'c';
if (function.args()[0].getArr() != null) {
ordering = function.args()[0].getArr().ordering();
}
if (checkGet == null) {
checkGet = var(baseName, null, new ZeroInitScheme(ordering));
} else if (!importedVarName.contains(baseName)) {
//need to find a new name
String newName = generateNewVarName(baseName, 0);
checkGet = var(newName, null, new ZeroInitScheme(ordering));
}
if (checkGet == null) {
checkGet = var(baseName, null, new ZeroInitScheme(ordering));
}
ret[0] = checkGet;
return ret;
}
}
char ordering = 'c';
if (function.args()[0].getArr() != null) {
ordering = function.args()[0].getArr().ordering();
}
SDVariable[] ret = new SDVariable[outputShape.size()];
// ownName/baseName will be used to get variables names
val ownName = function.getOwnName();
val rootName = baseName;
for (int i = 0; i < ret.length; i++) {
val shape = outputShape.get(i);
// it should be: rootName:index. i.e.: split:1, split:2, split:3, split:4 etc
baseName = rootName + (i > 0 ? ":" + i : "");
SDVariable checkGet = getVariable(baseName);
if (checkGet == null) {
// obviously - there's no such var, just add it
checkGet = var(baseName, shape, new ZeroInitScheme(ordering));
} else if (shape != null && !shapeAlreadyExistsForVarName(checkGet.getVarName())) {
// var exists, let's update its shape
putShapeForVarName(checkGet.getVarName(), shape);
} else if (shape != null && shapeAlreadyExistsForVarName(checkGet.getVarName())) {
// no-op.
// TODO: maybe we should check shapes equality here?
// it's either var that already exist, or something bad happening
} else if (!importedVarName.contains(baseName)) {
// FIXME: dead end. it's impossible to get here with null as shape
//need to find a new name
int count = 1;
String name = baseName + "_" + count + (i > 0 ? ":" + i : "");
while (getVariable(name) != null) {
count++;
name = baseName + "_" + count + (i > 0 ? ":" + i : "");
}
if (getVariable(name) != null) {
throw new ND4JIllegalStateException("Converged on already generated variable!");
}
checkGet = var(name, shape, new ZeroInitScheme(ordering));
}
if (checkGet == null) {
checkGet = var(baseName + (i > 0 ? ":" + i : ""), shape, new ZeroInitScheme(ordering));
}
ret[i] = checkGet;
}
return ret;
}
/**
* Generate the variables based on the given input op
* and return the output variable names.
*
* @param function the function to generate the output
* variable names for
* @return the set of names generated for each output of the function.
*/
public SDVariable[] generateOutputVariableForOp(DifferentialFunction function) {
return generateOutputVariableForOp(function, function.opName());
}
/**
* Get a function instance
* given the opName
*
* @param functionName the opName of the function
* @return the same diff function instance
* defined for the given opName
*/
public SameDiff getFunction(String functionName) {
return sameDiffFunctionInstances.get(functionName);
}
/**
* u
*
* @return
*/
public INDArray execAndEndResult(List ops) {
List exec = exec(ops);
Op op = (Op) exec.get(exec.size() - 1);
return op.z();
}
/**
* @return
*/
public INDArray execAndEndResult() {
resolveVariablesWith(Collections.emptyMap());
List exec = exec().getRight();
val output = exec.get(exec.size() - 1).outputVariables()[0];
return output.getArr();
}
public INDArray yetAnotherExecMethod(@NonNull Map inputs) {
if (!wasRegistered.get()) {
synchronized (this) {
if (!wasRegistered.get()) {
val bb = asFlatBuffers();
val ptr = new BytePointer(bb);
Nd4j.getExecutioner().registerGraph(this.hashCode(), ptr);
wasRegistered.set(true);
}
}
}
val newMap = new LinkedHashMap();
val keySet = inputs.keySet();
for (val key : keySet) {
val vx = variableMap.get(key);
newMap.put(vx.getVarName(), inputs.get(key));
}
val result = Nd4j.getExecutioner().executeGraph(this.hashCode(), newMap);
if (result.size() == 0)
throw new ND4JIllegalStateException("Execution failed");
val list = new ArrayList(result.values());
return list.get(list.size() - 1);
}
/**
* Executes the list of operations.
* This exec method is for
* only invoking operations
* rather than creating them
*
* @param ops the list of already created ops
* @return the passes in list
*/
public List exec(List ops) {
for (int i = 0; i < ops.size(); i++) {
Op op = (Op) ops.get(i);
Nd4j.getExecutioner().exec(op);
}
return ops;
}
/**
* An interface for representing a conditional statement
*/
public interface SameDiffConditional {
/**
* @param context
* @param body
* @return
*/
SDVariable eval(SameDiff context, SameDiffFunctionDefinition body, SDVariable[] inputVars);
}
public static class DefaultSameDiffConditional implements SameDiffConditional {
@Override
public SDVariable eval(SameDiff context, SameDiff.SameDiffFunctionDefinition body, SDVariable[] inputVars) {
context.defineFunction("eval", body, inputVars);
context.invokeFunctionOn("eval", context);
return new ArrayList<>(context.functionInstancesById.values()).get(context.functionInstancesById.size() - 1).outputVariables()[0];
}
}
/**
* Creates a while statement
*
* @param sameDiffConditional
* @param loopBody
* @return
*/
public While whileStatement(SameDiffConditional sameDiffConditional,
SameDiffFunctionDefinition conditionBody,
SameDiff.SameDiffFunctionDefinition loopBody
, SDVariable[] inputVars) {
return While.builder()
.inputVars(inputVars)
.condition(conditionBody)
.predicate(sameDiffConditional)
.trueBody(loopBody)
.parent(this)
.blockName("while-" + UUID.randomUUID().toString())
.build();
}
/**
* @param conditional
* @param trueBody
* @param falseBody
* @return
*/
public If ifStatement(SameDiffConditional conditional,
SameDiffFunctionDefinition conditionBody,
SameDiffFunctionDefinition trueBody,
SameDiffFunctionDefinition falseBody
, SDVariable[] inputVars) {
return If.builder()
.conditionBody(conditionBody)
.falseBody(falseBody)
.trueBody(trueBody)
.predicate(conditional)
.inputVars(inputVars)
.parent(this)
.blockName("if-" + UUID.randomUUID().toString())
.build();
}
/**
* A function definition for
* samediff
*/
public interface SameDiffFunctionDefinition {
/**
* @param inputs
* @param variableInputs
* @return
*/
SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs);
}
/**
* @param functionName
* @param with
*/
public SDVariable invokeFunctionOn(String functionName, SameDiff with) {
SameDiff instance = sameDiffFunctionInstances.get(functionName);
SDVariable ret = instance.invokeGraphOn(with);
return ret;
}
/**
* @param function
*/
public SameDiff defineFunction(String function, SameDiffFunctionDefinition functionDefinition, SDVariable[] variables) {
if (!sameDiffFunctionInstances.containsKey(function)) {
SameDiff sub = SameDiff.create();
sub.workspace = (workspace);
//setup subgraph
//re execute to populate subgraph
SDVariable[] ret = new SDVariable[variables.length];
for (int i = 0; i < ret.length; i++) {
ret[i] = sub.var(variables[i]);
}
functionDefinition.define(sub, null, ret);
sameDiffFunctionInstances.put(function, sub);
}
return sameDiffFunctionInstances.get(function);
}
/**
* @param function
*/
public void defineFunction(String function, SameDiffFunctionDefinition functionDefinition) {
defineFunction(function, functionDefinition, new LinkedHashMap());
}
/**
* @param function
* @param functionDefinition
* @param inputs
*/
public void defineFunction(String function,
SameDiffFunctionDefinition functionDefinition,
Map inputs) {
if (!sameDiffFunctionInstances.containsKey(function)) {
SameDiff sub = SameDiff.create();
sub.workspace = (workspace);
//setup subgraph
//re execute to populate subgraph
functionDefinition.define(sub, inputs, null);
sameDiffFunctionInstances.put(function, sub);
}
}
/**
* Exec a given function
*
* @param functionName the opName of the function
* to invoke
* @return
*/
public INDArray execAndEndResult(String functionName) {
return sameDiffFunctionInstances.get(functionName).execAndEndResult();
}
/**
* Exec a given function
*
* @param functionName the opName of the function
* to invoke
* @return
*/
public Pair