org.nd4j.autodiff.samediff.SameDiff Maven / Gradle / Ivy
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.autodiff.samediff;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import com.google.common.primitives.Ints;
import com.google.flatbuffers.FlatBufferBuilder;
import com.rits.cloning.Cloner;
import com.rits.cloning.IFastCloner;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.bytedeco.javacpp.BytePointer;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
import org.nd4j.autodiff.functions.FunctionProperties;
import org.nd4j.autodiff.samediff.flow.FlowPath;
import org.nd4j.autodiff.util.cloner.DataBufferFastCloner;
import org.nd4j.autodiff.util.cloner.INDArrayFastCloner;
import org.nd4j.base.Preconditions;
import org.nd4j.graph.*;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.*;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.distances.CosineSimilarity;
import org.nd4j.linalg.api.ops.impl.accum.distances.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.accum.distances.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.controlflow.If;
import org.nd4j.linalg.api.ops.impl.controlflow.While;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.*;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUConfiguration;
import org.nd4j.linalg.api.ops.impl.shape.Eye;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.BaseTensorOp;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayV3;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.temp.ExternalErrorsFunction;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.collection.IntArrayKeyMap;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.lossfunctions.impl.*;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.list.compat.TensorList;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ConstantInitScheme;
import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import java.io.*;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
/**
* SameDiff is the
* entrypoint for
* nd4j's autodiff.
*
* You define a graph symbolically.
*
* That graph accumulates operations.
*
* In order to execute the graph, you run
* {@link #exec()} to get all the operations
* {@link #exec(List)} for an already created set of ops
* {@link #execAndEndResult()} for the end result only
* {@link #execAndEndResult(List)} for a cached set of ops
*/
@AllArgsConstructor
@Builder
@Slf4j
public class SameDiff {
private Map incomingArgsReverse; //Key: DifferentialFunction.getOwnName(). Value: name of SDVariables as inputs to that function
private Map outgoingArgsReverse; //Key: DifferentialFunction.getOwnName(). Value: name of SDVariables as outputs from that function
private Map permuteOrder;
private boolean shouldBootStrap = true;
private Set importedVarName;
//map a function's instance id to a base name, used for propagating variable names
//for output during import
private Map baseNameForFunctionInstanceId;
private DifferentialFunctionFactory functionFactory;
private Map variableMap; //Key: SDVariable name. Value: SDVariable
private Map variableNameToShape; //Key: SDVariable name. Value: shape for that variable
//gradient information
private Map gradients; //Key:
private Map forwardVarForGrad;
private Map variableNameToArr; //Key: name of SDVariable. Value: Array for that variable
//individual index for variable names
private Map> functionsArgsFor; //Key: SDVariable name. Value: all DifferentialFunctions it is an input to
private Map> functionOutputFor; //Key: SDVariable name. Value: DifferentialFunctions this variable is an output for (TODO: Why is this a list? Isn't it *always* length 1?)
private Map lists = new HashMap<>(); // Key - node name; Value - TensorList
// this entity holds runtime information for Switch/Merge/NextIteration etc stuff
private transient ThreadLocal localFlowPath = new ThreadLocal();
// here we save String -> Integer conversion to variables
private transient Map reverseMap = null;
// counter for auto-naming variables
private int variableId = 0;
/**
* For import, many times we have variables
* that map to properties. Most common
* we will have an input to a function that is mapped to an ndarray.
* That ndarray is usually a scalar shape.
*
* That array with a scalar shape can be something like an axis.
*
* We often don't know that array's value till run time.
* This map stores variable names that we should resolve
* from samediff. We use the value of that array
* to update the properties.
*/
private Map> propertiesToResolve;
/**
* A map of own name to
* the properties of the function (things like execution axes etc)
* The valid values can be:
* int
* long
* INDArray
*/
private Map> propertiesForFunction;
private Map> placeHolderMap;
private Map placeHolderOriginalShapes;
private Set placeHolderVarNames;
private MemoryWorkspace workspace;
private Map sameDiffFunctionDefinitionMap;
private Map sameDiffFunctionInstances;
private Set placeHolderFunctions;
private static Cloner cloner = newCloner();
private static Map opMethods;
@Getter
private Map functionInstancesById;
private Table fieldVariableResolutionMapping;
// flag, shows if graph was already registered with libnd4j
private transient AtomicBoolean wasRegistered = new AtomicBoolean(false);
//debug mode variables
@Getter
private boolean debugMode;
private Map opsForResult;
private boolean resolvedVariables = false;
@Getter
@Setter
boolean logExecution = true;
@Getter
private SameDiff parent;
@Getter
private SameDiff child;
static {
opMethods = new HashMap<>();
Method[] methods = SameDiff.class.getDeclaredMethods();
for (Method method : methods) {
if (method.getReturnType().equals(SDVariable.class)) {
opMethods.put(method.getName(), method);
}
}
}
/**
* @return New cloner object. NOTE: INTENDED FOR DEVELOPER USE ONLY
*/
public static Cloner newCloner() {
Cloner cloner = new Cloner();
//Implement custom cloning for INDArrays (default can have problems with off-heap and pointers)
//Sadly: the cloner library does NOT support interfaces here, hence we need to use the actual classes
//cloner.registerFastCloner(INDArray.class, new INDArrayFastCloner()); //Does not work due to interface
IFastCloner fc = new INDArrayFastCloner();
cloner.registerFastCloner(Nd4j.getBackend().getNDArrayClass(), fc);
//Same thing with DataBuffers: off heap -> cloner library chokes on them, but need to know the concrete
// buffer classes, not just the interface
IFastCloner fc2 = new DataBufferFastCloner();
DataBufferFactory d = Nd4j.getDataBufferFactory();
doReg(cloner, fc2, d.intBufferClass());
doReg(cloner, fc2, d.longBufferClass());
doReg(cloner, fc2, d.halfBufferClass());
doReg(cloner, fc2, d.floatBufferClass());
doReg(cloner, fc2, d.doubleBufferClass());
doReg(cloner, fc2, CompressedDataBuffer.class);
return cloner;
}
private static void doReg(Cloner cl, IFastCloner fc, Class c) {
if (c != null)
cl.registerFastCloner(c, fc);
}
/**
* Update the opName for the variable
* with the given vertex id
*
* @param varName the vertex id to update
* @param withName thew new opName
*/
public void updateVariableName(String varName, String withName) {
SDVariable oldVarNameRef = getVariable(varName);
variableMap.remove(oldVarNameRef.getVarName());
val oldVarName = varName;
oldVarNameRef.setVarName(withName);
variableMap.put(withName, oldVarNameRef);
for (val reverseValues : outgoingArgsReverse.entrySet()) {
for (int i = 0; i < reverseValues.getValue().length; i++) {
if (reverseValues.getValue()[i].equals(oldVarName)) {
reverseValues.getValue()[i] = withName;
}
}
}
for (val reverseValues : incomingArgsReverse.entrySet()) {
for (int i = 0; i < reverseValues.getValue().length; i++) {
if (reverseValues.getValue()[i].equals(oldVarName)) {
reverseValues.getValue()[i] = withName;
}
}
}
if (variableNameToArr.containsKey(oldVarName)) {
val arr = variableNameToArr.remove(oldVarName);
variableNameToArr.put(withName, arr);
}
if (variableNameToShape.containsKey(oldVarName)) {
val shape = variableNameToShape.remove(oldVarName);
variableNameToShape.put(withName, shape);
}
if (gradients.containsKey(oldVarName)) {
val grad = gradients.remove(oldVarName);
gradients.put(withName, grad);
}
if (forwardVarForGrad.containsKey(oldVarName)) {
val forwardGrad = forwardVarForGrad.remove(oldVarName);
forwardVarForGrad.put(withName, forwardGrad);
}
if (placeHolderMap.containsKey(oldVarName)) {
val placeholders = placeHolderMap.remove(oldVarName);
placeHolderMap.put(withName, placeholders);
}
if (functionsArgsFor.containsKey(oldVarName)) {
val funcs = functionsArgsFor.remove(oldVarName);
for (val func : funcs) {
if (func instanceof BaseOp) {
BaseOp baseOp = (BaseOp) func;
if (baseOp.getXVertexId() != null && baseOp.getXVertexId().equals(oldVarName)) {
baseOp.setXVertexId(withName);
}
if (baseOp.getYVertexId() != null && baseOp.getYVertexId().equals(oldVarName)) {
baseOp.setYVertexId(withName);
}
if (baseOp.getZVertexId() != null && baseOp.getZVertexId().equals(oldVarName)) {
baseOp.setZVertexId(withName);
}
}
}
functionsArgsFor.put(withName, funcs);
}
if (functionOutputFor.containsKey(oldVarName)) {
val funcs = functionOutputFor.remove(oldVarName);
for (val func : funcs) {
if (func instanceof BaseOp) {
BaseOp baseOp = (BaseOp) func;
if (baseOp.getXVertexId() != null && baseOp.getXVertexId().equals(oldVarName)) {
baseOp.setXVertexId(withName);
}
if (baseOp.getYVertexId() != null && baseOp.getYVertexId().equals(oldVarName)) {
baseOp.setYVertexId(withName);
}
if (baseOp.getZVertexId() != null && baseOp.getZVertexId().equals(oldVarName)) {
baseOp.setZVertexId(withName);
}
}
}
functionOutputFor.put(withName, funcs);
}
variableMap.remove(oldVarName);
}
/**
* Clears debugging state and disables debug mode.
*/
public SameDiff disableDebugging() {
debugMode = false;
return this;
}
/**
* Enables tracing of graphs automatically.
*/
public SameDiff enableDebugMode() {
debugMode = true;
return this;
}
/**
* Returns this samediff instance's {@link DifferentialFunctionFactory}
*
* @return
*/
public DifferentialFunctionFactory f() {
return functionFactory;
}
/**
* @param sameDiff
* @return
*/
public SDVariable invokeGraphOn(SameDiff sameDiff) {
//map the new vertices on to the old ones
Map thisVertexIdToNew = new HashMap<>();
int idx = 1;
for (val var : variables()) {
val clone = cloner.deepCloneDontCloneInstances(var, var.getSameDiff());
val newVar = sameDiff.var(clone);
if (var.getArr() != null) {
sameDiff.associateArrayWithVariable(var.getArr(), newVar);
}
thisVertexIdToNew.put(idx, idx);
clone.setSameDiff(sameDiff);
idx++;
}
val newFunctions = new LinkedHashMap();
for (DifferentialFunction function : functionInstancesById.values()) {
if (function instanceof SDVariable) {
continue;
}
DifferentialFunction clone = cloner.deepCloneDontCloneInstances(
function,
function.getSameDiff());
clone.setSameDiff(sameDiff);
clone.setOwnName(function.getOwnName());
if (sameDiff.functionExists(function.getOwnName()))
sameDiff.putFunctionForId(function.getOwnName(), function);
newFunctions.put(function.getOwnName(), clone);
val argsForFunction = function.args();
val outputsForFunction = function.outputVariables();
//note that these have the same variable names
sameDiff.addArgsFor(argsForFunction, clone);
sameDiff.addOutgoingFor(outputsForFunction, function);
for (val arg : clone.args()) {
arg.setSameDiff(sameDiff);
}
for (val output : clone.outputVariables()) {
output.setSameDiff(sameDiff);
}
sameDiff.functionInstancesById.put(function.getOwnName(), function);
}
return sameDiff.variables().get(sameDiff.variables().size() - 1);
}
/**
* Returns true if the given function id exists
*
* @param id the function id to test for
* @return true if the function id exists, false otherwise
*/
public boolean functionExists(String id) {
return functionInstancesById.containsKey(id);
}
/**
* Get the function by the {@link DifferentialFunction#getOwnName()}
*
* @param id the id of the function
* @return the function for the given id if it exists
*/
public DifferentialFunction getFunctionById(String id) {
if (!functionInstancesById.containsKey(id)) {
throw new ND4JIllegalStateException("No function with id " + id + " found!");
}
return functionInstancesById.get(id);
}
/**
* Put the function for the given id
*
* @param id the id of the function
* @param function the function
*/
public void putFunctionForId(String id, DifferentialFunction function) {
if (functionInstancesById.containsKey(id)) {
throw new ND4JIllegalStateException("Function by id already exists!");
} else if (function instanceof SDVariable) {
throw new ND4JIllegalStateException("Function must not be a variable!");
}
functionInstancesById.put(id, function);
}
/**
* Returns the name(s) of the inputs for the given function
*
* @param function the function to get the inputs for
* @return the input ids for a given function
*/
public String[] getInputsForFunction(DifferentialFunction function) {
if (!incomingArgsReverse.containsKey(function.getOwnName()))
throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName());
return incomingArgsReverse.get(function.getOwnName());
}
/**
* Returns the name(s) of the outputs for the given function
*
* @param function the function to get the outputs for
* @return the outputs ids for a given function
*/
public String[] getOutputsForFunction(DifferentialFunction function) {
return outgoingArgsReverse.get(function.getOwnName());
}
/**
* Get the output variable(s) for the specified differential function
*
* @param function the function reference to get the output variable(s) for
* @return the output variables for the given function
*/
public SDVariable[] getOutputVariablesForFunction(DifferentialFunction function) {
val inputs = getOutputsForFunction(function);
if (inputs == null) {
throw new ND4JIllegalStateException("No inputs found for function " + function);
}
val vars = new SDVariable[inputs.length];
for (int i = 0; i < inputs.length; i++) {
vars[i] = getVariable(inputs[i]);
}
return vars;
}
/**
* Get the input variable(s) for the specified differential function
*
* @param function the function reference to get the input variable(s) for
* @return the input variables for the given function
*/
public SDVariable[] getInputVariablesForFunction(DifferentialFunction function) {
val inputs = getInputsForFunction(function);
if (inputs == null) {
throw new ND4JIllegalStateException("No inputs found for function " + function);
}
val vars = new SDVariable[inputs.length];
for (int i = 0; i < inputs.length; i++) {
vars[i] = getVariable(inputs[i]);
if (vars[i] == null) {
throw new ND4JIllegalStateException("Found null variable at index " + i);
}
}
return vars;
}
/**
* Update the INDArray for the given variable. Note that the array must exist to use this method.
*
* @param varName Name of the variable to update the array for
* @param arr Array to update
* @throws ND4JIllegalStateException when the array does not exist.
* @see #putArrayForVarName(String, INDArray)
* @see #putOrUpdateShapeForVarName(String, long[], boolean)
*/
public void updateArrayForVarName(String varName, INDArray arr) {
if (!variableNameToArr.containsKey(varName)) {
throw new ND4JIllegalStateException("Array for " + varName + " does not exist. Please use putArrayForVertexId instead.");
}
variableNameToArr.put(varName, arr);
}
/**
* Adds an INDArray for a given variable name.
* Use {@link #updateArrayForVarName(String, INDArray)} if the array already exists.
*
* @param varName the vertex id to add
* @param arr the array to add
* @throws ND4JIllegalStateException when the array already exists.
* @see #putOrUpdateShapeForVarName(String, long[], boolean)
*/
public void putArrayForVarName(String varName, INDArray arr) {
if (varName == null)
throw new ND4JIllegalStateException("No null names allowed!");
if (variableNameToArr.containsKey(varName)) {
throw new ND4JIllegalStateException("Array for " + varName + " already exists!");
}
variableNameToArr.put(varName, arr);
}
/**
* Put the array if it does not exist for the given variable name, or update it if it does
* @param varName Variable name
* @param arr Array
*/
public void putOrUpdateArrayForVarName(@NonNull String varName, INDArray arr){
if(variableNameToArr.containsKey(varName)){
updateArrayForVarName(varName, arr);
} else {
putArrayForVarName(varName, arr);
}
}
/**
* Get the shape for the given vertex id.
* Note that if an array is defined, it will use the shape of the array instead.
*
* A shape *and* an array should not be defined at the same time.
* This wastes memory. The internal map used for tracking shapes for particular
* vertex ids should also delete redundant shapes stored to avoid redundant sources of information.
*
* @param varName the vertex id to get the shape for
* @return the shape for the given vertex if if any.
*/
public long[] getShapeForVarName(String varName) {
if (variableNameToArr.containsKey(varName)) {
return variableNameToArr.get(varName).shape();
}
return variableNameToShape.get(varName);
}
/**
* Update a vertex id with the given shape.
* Note that you should use {@link #putShapeForVarName(String, long[])} if you want to add a new shape.
* Update is meant to be an in place replacement of the shape for the vertex id *only*.
*
* @param varName the vertex id to associate
* @param shape the shape to associate with
* @see #putShapeForVarName(String, long[])
* @see #putOrUpdateShapeForVarName(String, long[], boolean)
*/
public void updateShapeForVarName(String varName, long[] shape) {
updateShapeForVarName(varName, shape, false);
}
/**
* Update a vertex id with the given shape.
* Note that you should use {@link #putShapeForVarName(String, long[])} if you want to add a new shape.
* Update is meant to be an in place replacement of the shape for the vertex id *only*.
*
* @param varName the vertex id to associate
* @param shape the shape to associate with
* @param clearArrayOnShapeMismatch boolean to indicate whether to clear the variable on shape mismatch
* @see #putShapeForVarName(String, long[])
* @see #putOrUpdateShapeForVarName(String, long[], boolean)
*/
public void updateShapeForVarName(String varName, long[] shape, boolean clearArrayOnShapeMismatch) {
if (shape == null) {
throw new ND4JIllegalStateException("Null shapes not allowed!");
}
if (variableNameToArr.containsKey(varName) && !Arrays.equals(variableNameToArr.get(varName).shape(), shape)) {
if(clearArrayOnShapeMismatch){
if(log.isTraceEnabled()){
log.trace("Clearing array for variable {}: array shape {}, new shape {}", varName,
Arrays.toString(variableNameToArr.get(varName).shape()), Arrays.toString(shape));
}
variableNameToArr.remove(varName);
} else {
throw new ND4JIllegalStateException("Already found an existing array for variable \"" + varName
+ "\" with shape " + Arrays.toString(variableNameToArr.get(varName).shape())
+ " - attempting to put new array shape " + Arrays.toString(shape));
}
}
for (int i = 0; i < shape.length; i++) {
if (shape[i] < 1) {
addAsPlaceHolder(varName);
placeHolderOriginalShapes.put(varName, shape);
return;
}
}
if(log.isTraceEnabled()){
long[] pShape = variableNameToShape.get(varName);
log.trace("Updated shape for variable \"{}\": previous shape {}, new shape {}", varName,
(pShape == null ? "" : Arrays.toString(pShape)), Arrays.toString(shape));
}
variableNameToShape.put(varName, shape);
}
/**
* Associate a vertex id with the given shape.
*
* @param varName the vertex id to associate
* @param shape the shape to associate with
* @see #putShapeForVarName(String, long[])
* @see #putOrUpdateShapeForVarName(String, long[], boolean)
*/
public void putShapeForVarName(String varName, long[] shape) {
if (shape == null) {
throw new ND4JIllegalStateException("Shape must not be null!");
}
if (variableNameToShape.containsKey(varName)) {
throw new ND4JIllegalStateException("Shape for " + varName + " already exists!");
}
for (int i = 0; i < shape.length; i++) {
if (shape[i] < 1) {
addAsPlaceHolder(varName);
placeHolderOriginalShapes.put(varName, shape);
return;
}
}
variableNameToShape.put(varName, shape);
}
/**
* Put or update the shape for the given variable name. Optionally supports clearing the specified variable's
* INDArray if it's shape does not match the new shape
* @param varName Variable name
* @param shape Shape to put
* @param clearArrayOnShapeMismatch If false: no change to arrays. If true: if an INDArray is defined for the specified
* variable name, it will be removed from the graph (to be later re-generated) if
* its shape does not match the specified shape
*/
public void putOrUpdateShapeForVarName(String varName, @NonNull long[] shape, boolean clearArrayOnShapeMismatch){
if(variableNameToShape.containsKey(varName)){
updateShapeForVarName(varName, shape, clearArrayOnShapeMismatch);
} else {
putShapeForVarName(varName, shape);
}
}
/**
* Returns true if the given vertex id and shape already exist.
*
* @param varName the vertex id
* @return true if the ndarray and vertex id already exist
*/
public boolean shapeAlreadyExistsForVarName(String varName) {
return variableNameToShape.containsKey(varName) || arrayAlreadyExistsForVarName(varName);
}
/**
* Returns true if the given vertex id and {@link INDArray} already exist.
*
* @param varName the vertex id
* @return true if a vertex with the given INDArray exists, and it has an INDArray associated with it
*/
public boolean arrayAlreadyExistsForVarName(String varName) {
return variableNameToArr.containsKey(varName);
}
/**
* Get an {@link INDArray} for a given vertex id, or null if none exists
*
* @param varName Variable name to get the array for
* @return Array, or null if none exists
*/
public INDArray getArrForVarName(String varName) {
return variableNameToArr.get(varName);
}
/**
* Associate the array with the given variable.
*
* @param arr the array to get the variable for
* @param variable the name of the variable to associate the array with
*/
public void associateArrayWithVariable(INDArray arr, @NonNull String variable) {
associateArrayWithVariable(arr, this.getVariable(variable));
}
/**
* Associate the array with the given variable.
*
* @param arr the array to get the variable for
* @param variable the variable to associate the array with
*/
public void associateArrayWithVariable(INDArray arr, SDVariable variable) {
if (variable == null) {
throw new ND4JIllegalArgumentException("Variable must not be null!");
}
if (arr == null) {
throw new ND4JIllegalArgumentException("Array must not be null");
}
variableNameToArr.put(variable.getVarName(), arr);
putOrUpdateShapeForVarName(variable.getVarName(), arr.shape(), true);
// invalidate exec cache
exec_cache = null;
//Also update nested SameDiff instances (such as gradient function)
if(sameDiffFunctionInstances != null && sameDiffFunctionInstances.size() > 0){
for(Map.Entry e : sameDiffFunctionInstances.entrySet()){
SameDiff sd = e.getValue();
if(sd.variableNameToArr != null && sd.variableNameToArr.containsKey(variable.getVarName())){
sd.associateArrayWithVariable(arr, variable);
}
}
}
}
/**
* Associate a {@link SameDiff} namespace as a sub function.
*
* @param name the opName of the function
* @param nameSpace the namespace
*/
public void putSubFunction(String name, SameDiff nameSpace) {
if (sameDiffFunctionInstances.containsKey(name) && sameDiffFunctionInstances.get(name) != nameSpace) {
throw new ND4JIllegalStateException("Unable to replace samediff namespace. Please choose another opName");
}
sameDiffFunctionInstances.put(name, nameSpace);
}
/**
* Return the internal variable map
*
* @return Map of variables by name
*/
public Map variableMap() {
return variableMap;
}
/**
* Invoke an op by opName
*
* @param op the op
* @param x the first input
* @param y the second input
* @return the result variable
*/
public SDVariable invoke(Op op, SDVariable x, SDVariable y) {
if (!opMethods.containsKey(op.opName())) {
throw new ND4JIllegalStateException("Illegal method opName " + op.opName());
}
if (x != null && y != null) {
try {
return (SDVariable) opMethods.get(op.opName()).invoke(this, x, y);
} catch (Exception e) {
}
} else {
try {
return (SDVariable) opMethods.get(op.opName()).invoke(this, x);
} catch (Exception e) {
}
}
throw new ND4JIllegalStateException("Illegal method opName " + op.opName());
}
/**
* The set of defined SameDiff function names. SameDiff function instances should not be confused
* with DifferentialFunction ops; an example of a SameDiff function instance is the gradient "grad" function
*
* @return Set of defined SameDiff function instance names
*/
public Collection definedFunctionNames() {
return this.sameDiffFunctionInstances.keySet();
}
/**
* Returns the number of bytes for the graph. Calculated as sum_i prod(shapeOf(variable[i]))
*
* @return Bytes for all of the arrays in the graph for the current variable shapes
*/
public long memoryForGraph() {
return numElements() * DataTypeUtil.lengthForDtype(Nd4j.dataType());
}
/**
* Invoke an op by opName
*
* @param op the op
* @param x the first input
* @return the result variable
*/
public SDVariable invoke(Op op, SDVariable x) {
return invoke(op, x, null);
}
private SameDiff() {
functionFactory = new DifferentialFunctionFactory(this);
variableMap = new LinkedHashMap<>();
sameDiffFunctionDefinitionMap = new LinkedHashMap<>();
sameDiffFunctionInstances = new LinkedHashMap<>();
gradients = new LinkedHashMap<>();
forwardVarForGrad = new LinkedHashMap<>();
opsForResult = new IntArrayKeyMap<>();
variableNameToArr = new LinkedHashMap<>();
variableNameToShape = new LinkedHashMap<>();
placeHolderMap = new LinkedHashMap<>();
placeHolderVarNames = new LinkedHashSet<>();
placeHolderOriginalShapes = new LinkedHashMap<>();
incomingArgsReverse = new LinkedHashMap<>();
outgoingArgsReverse = new LinkedHashMap<>();
functionInstancesById = new LinkedHashMap<>();
placeHolderFunctions = new LinkedHashSet<>();
functionsArgsFor = new LinkedHashMap<>();
functionOutputFor = new LinkedHashMap<>();
baseNameForFunctionInstanceId = new LinkedHashMap<>();
importedVarName = new LinkedHashSet<>();
permuteOrder = new LinkedHashMap<>();
propertiesToResolve = new LinkedHashMap<>();
propertiesForFunction = new LinkedHashMap<>();
fieldVariableResolutionMapping = HashBasedTable.create();
}
/**
* Adds a property that needs to be resolve for later.
* These variables are typically values that are arrays
* that are named but have an unknown value till execution time.
*
* This is very common for model import.
*
* @param forFunction the function to add the property to resolve for
* @param arrayName the array name
*/
public void addPropertyToResolve(DifferentialFunction forFunction, String arrayName) {
if (!propertiesToResolve.containsKey(forFunction.getOwnName())) {
List newVal = new ArrayList<>();
newVal.add(arrayName);
propertiesToResolve.put(forFunction.getOwnName(), newVal);
} else {
List newVal = propertiesToResolve.get(forFunction.getOwnName());
newVal.add(arrayName);
}
}
/**
* Return the properties to resolve for the given function.
* This is typically used right before execution in model import in
* {@link DifferentialFunction#resolvePropertiesFromSameDiffBeforeExecution()}
*
* @param function the function get the properties to resolve for
* @return the properties to resolve for the given function
*/
public List propertiesToResolveForFunction(DifferentialFunction function) {
if (!propertiesToResolve.containsKey(function.getOwnName()))
return Collections.emptyList();
return propertiesToResolve.get(function.getOwnName());
}
/**
* Returns true if the given function has ndarray properties to resolve.
*
* @param function the function to check
* @return true if the function has yet to be resolved properties
*/
public boolean hasPropertiesToResolve(DifferentialFunction function) {
return propertiesToResolve.containsKey(function.getOwnName());
}
/**
* Get the property for a given function
*
* @param functionInstance the function to get the
* property for
* @param propertyName the name of the property to get
* @param the inferred return type
* @return the property for the given function
*/
public T getPropertyForFunction(DifferentialFunction functionInstance, String propertyName) {
if (!propertiesForFunction.containsKey(functionInstance.getOwnName())) {
return null;
} else {
val map = propertiesForFunction.get(functionInstance.getOwnName());
return (T) map.get(propertyName);
}
}
/**
* Add a property for the given function
*
* @param functionFor the function add a property for
* @param propertyName the property name
* @param property the property value
*/
public void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, INDArray property) {
addPropertyForFunction(functionFor, propertyName, (Object) property);
}
/**
* Add a property for the given function
*
* @param functionFor the function to add the property for
* @param propertyName the name of the property to add the value for
* @param property the property value to add
*/
public void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, long property) {
addPropertyForFunction(functionFor, propertyName, (Object) property);
}
private void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, Object propertyValue) {
if (!propertiesForFunction.containsKey(functionFor.getOwnName())) {
Map fields = new LinkedHashMap<>();
fields.put(propertyName, propertyValue);
propertiesForFunction.put(functionFor.getOwnName(), fields);
} else {
val fieldMap = propertiesForFunction.get(functionFor.getOwnName());
if (fieldMap.containsKey(propertyName)) {
throw new ND4JIllegalStateException("Attempting to override property " + propertyName);
}
fieldMap.put(propertyName, propertyValue);
}
}
/**
* Adds a field name -> variable name mapping for a given function.
* This is used for model import where there is an unresolved variable at the time of calling any
* {@link org.nd4j.imports.graphmapper.GraphMapper#importGraph(File)}
* .
*
* This data structure is typically accessed during {@link DifferentialFunction#resolvePropertiesFromSameDiffBeforeExecution()}
*
* When a function attempts to resolve variables right before execution, there needs to be a way of knowing
* which variable in a samediff graph should map to a function's particular field name
*
* @param function the function to map
* @param fieldName the field name for the function to map
* @param varName the variable name of the array to get from samediff
*/
public void addVariableMappingForField(DifferentialFunction function, String fieldName, String varName) {
fieldVariableResolutionMapping.put(function.getOwnName(), fieldName, varName);
}
/**
* Get the variable name to use
* for resolving a given field
* for a given function during import time.
* This method is u sed during {@link DifferentialFunction#resolvePropertiesFromSameDiffBeforeExecution()}
*
* @param function the function to get the variable name for
* @param fieldName the field name to resolve for
* @return the resolve variable name if any
*/
public String getVarNameForFieldAndFunction(DifferentialFunction function, String fieldName) {
return fieldVariableResolutionMapping.get(function.getOwnName(), fieldName);
}
/**
* Returns true if the variable name is imported
*
* @param variableName the imported variable name
* @return true if the name is imported, false otherwise
*/
public boolean isImportVariable(String variableName) {
return importedVarName.contains(variableName);
}
/**
* Marks a variable name as imported.
* This is used in conjunction with model
* import to ensure immutability
* when referencing graph variables
* mapped from an external source.
*
* @param varName the var name to add.
*/
public void addVarNameForImport(String varName) {
importedVarName.add(varName);
}
/**
* Sets a base name for the function id.
* This is used for when calling {@link #generateOutputVariableForOp(DifferentialFunction, String)}
* for ensuring original names for model import map to current samediff names
* when names are generated.
*
* @param baseName the base name to add
* @param function the function to declare a base name for.
*/
public void setBaseNameForFunctionInstanceId(String baseName, DifferentialFunction function) {
baseNameForFunctionInstanceId.put(function.getOwnName(), baseName);
}
/**
* Returns the base name for the given function
* if any (may return null)
*
* @param function the function to get the base name for
* @return the base name for the given function (if any) based
* on the function's instance id.
*/
public String getBaseNameForFunction(DifferentialFunction function) {
return baseNameForFunctionInstanceId.get(function.getOwnName());
}
/**
* Attempts to insert the {@link DifferentialFunction} reference in to this {@link SameDiff} instance.
* If the given array field with the given index already exists, it will do a reference check to ensure that the 2
* array fields are the same. If not, an exception is thrown.
* If the instances are the same (by semantics, not reference) then it will just return the original instance.
* This is to ensure that instances that are created are unique and reference checked.
*
* @param function the array field to attempt to create
* @return Original instance
*/
public X setupFunction(X function) {
Preconditions.checkNotNull(function, "Passed in function must not be null!");
if (function instanceof SDVariable) {
if (function.getSameDiff() != this) {
function.setSameDiff(this);
}
return function;
}
return function;
}
/**
* Adds outgoing arguments to the graph for the specified DifferentialFunction
* Also checks for input arguments and updates the graph adding an appropriate edge when the full graph is declared.
*
* @param variables Variables - arguments for the specified differential function
* @param function Differential function
*/
public void addOutgoingFor(SDVariable[] variables, DifferentialFunction function) {
String[] varNames = new String[variables.length];
for (int i = 0; i < varNames.length; i++) {
varNames[i] = variables[i].getVarName();
}
addOutgoingFor(varNames, function);
}
/**
* Adds outgoing arguments to the graph for the specified DifferentialFunction
* Also checks for input arguments and updates the graph adding an appropriate edge when the full graph is declared.
*
* @param varNames Name of the variables that are outputs of the specified differential function
* @param function Differential function
*/
public void addOutgoingFor(String[] varNames, DifferentialFunction function) {
if (function.getOwnName() == null)
throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
if (outgoingArgsReverse.containsKey(function.getOwnName())) {
throw new ND4JIllegalStateException("Outgoing arguments already declared for " + function);
}
if (varNames == null)
throw new ND4JIllegalStateException("Var names can not be null!");
for (int i = 0; i < varNames.length; i++) {
if (varNames[i] == null)
throw new ND4JIllegalStateException("Variable name elements can not be null!");
}
outgoingArgsReverse.put(function.getOwnName(), varNames);
for (val resultName : varNames) {
List funcs = functionOutputFor.get(resultName);
if (funcs == null) {
funcs = new ArrayList<>();
functionOutputFor.put(resultName, funcs);
}
funcs.add(function);
}
}
/**
* Adds incoming arguments for the specified differential function to the graph
*
* @param variables Name of the variables that are arguments (inputs) to the specified function
* @param function Function
*/
public void addArgsFor(String[] variables, DifferentialFunction function) {
if (function.getOwnName() == null)
throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
//double check if function contains placeholder args
for (val varName : variables) {
if (isPlaceHolder(varName)) {
placeHolderFunctions.add(function.getOwnName());
}
}
incomingArgsReverse.put(function.getOwnName(), variables);
for (val variableName : variables) {
List funcs = functionsArgsFor.get(variableName);
if (funcs == null) {
funcs = new ArrayList<>();
functionsArgsFor.put(variableName, funcs);
}
funcs.add(function);
}
}
/**
* Adds incoming arguments for the specified differential function to the graph
*
* @param variables variables that are arguments (inputs) to the specified function
* @param function Function
*/
public void addArgsFor(SDVariable[] variables, DifferentialFunction function) {
String[] varNames = new String[variables.length];
for (int i = 0; i < varNames.length; i++) {
if (variables[i] == null)
throw new ND4JIllegalStateException("Found null variable at index " + i);
varNames[i] = variables[i].getVarName();
}
addArgsFor(varNames, function);
}
/**
* Get the differential function (if any) that this variable is the output for
*
* @param variableName Name of the variable
* @return The differential function that this variable is an output of, or null if it is not the output of a function
*/
public DifferentialFunction getVariableOutputFunction(String variableName) {
List list = functionOutputFor.get(variableName);
if (list == null) {
return null;
}
return list.get(0);
}
/**
* Return a list of differential functions (if any) that this variable is the input argument for
*
* @param variableName Name of the variable
* @return The differential functions that this variable is an input argument for, or null if it is not the input to any function
*/
public List getVariableArgOfFunctions(String variableName) {
return functionsArgsFor.get(variableName);
}
/**
* Returns true if this function already has defined arguments
*
* @param function the function to check
* @return true if the function has args, false otherwise
*/
public boolean hasArgs(DifferentialFunction function) {
String[] vertexIdArgs = incomingArgsReverse.get(function.getOwnName());
return vertexIdArgs != null && vertexIdArgs.length > 0;
}
/**
* Get an array of differential functions that have been defined for this SameDiff instance
* @return Array of differential functions
*/
public DifferentialFunction[] functions() {
val ret = functionInstancesById.values();
return ret.toArray(new DifferentialFunction[ret.size()]);
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + (variableMap != null ? variableMap.hashCode() : 0);
return result;
}
/**
* Create a new SameDiff instance from an existing instance.
* Note that state (variables and functions) is shared between the two SameDiff instance
*
* @param originalSameDiff Original SameDiff instance
* @return Copy
*/
public static SameDiff create(SameDiff originalSameDiff) {
SameDiff ret = SameDiff.builder()
.variableMap(originalSameDiff.variableMap)
.sameDiffFunctionInstances(originalSameDiff.sameDiffFunctionInstances)
.build();
//ensuring proper sameDiff reference
DifferentialFunctionFactory differentialFunctionFactory = new DifferentialFunctionFactory(ret);
ret.functionFactory = differentialFunctionFactory;
return ret;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SameDiff sameDiff = (SameDiff) o;
if (variableMap != null ? !variableMap.equals(sameDiff.variableMap) : sameDiff.variableMap != null)
return false;
if (sameDiffFunctionDefinitionMap != null ? !sameDiffFunctionDefinitionMap.equals(sameDiff.sameDiffFunctionDefinitionMap) : sameDiff.sameDiffFunctionDefinitionMap != null)
return false;
return sameDiffFunctionInstances != null ? sameDiffFunctionInstances.equals(sameDiff.sameDiffFunctionInstances) : sameDiff.sameDiffFunctionInstances == null;
}
/**
* Create a new (empty) SameDiff instance without any functions or variables
* @return New SameDiff instance
*/
public static SameDiff create() {
return new SameDiff();
}
/**
* Evaluate the given inputs based on the current graph
*
* @param inputs the inputs to evaluate
* @return
*/
public INDArray[] eval(Map inputs) {
SameDiff execPipeline = dup();
List opExecAction = execPipeline.exec().getRight();
if (opExecAction.isEmpty())
throw new IllegalStateException("No ops found to execute.");
INDArray[] ret = new INDArray[opExecAction.size()];
for (int i = 0; i < ret.length; i++) {
val varName = opExecAction.get(i).outputVariables()[0].getVarName();
ret[i] = execPipeline.getArrForVarName(varName);
}
return ret;
}
/**
* Clone/duplicate the SameDiff instance, including arrays etc. The returned SameDiff instance should have no
* shared state with the original instance
* @return The cloned SameDiff instance
*/
public SameDiff dup() {
Cloner cloner = newCloner();
val clone = cloner.deepClone(this);
return clone;
}
/**
* Count the number of elements in all arrays, according to {@link SDVariable#getShape()}
* @return Number of array elements for all variables
*/
public long numElements() {
long ret = 0;
for (SDVariable variable : variables()) {
ret += ArrayUtil.prod(variable.getShape());
}
return ret;
}
private void initWorkspace() {
workspace = Nd4j.getWorkspaceManager().createNewWorkspace(
WorkspaceConfiguration.builder()
.initialSize(memoryForGraph())
.policyAllocation(AllocationPolicy.OVERALLOCATE)
.policyLearning(LearningPolicy.FIRST_LOOP)
.build());
Nd4j.getWorkspaceManager().setWorkspaceForCurrentThread(workspace);
}
/**
* The list of all variables in the graph
*
* @return All variables in the graph
*/
public List variables() {
return new ArrayList<>(variableMap.values());
}
/**
* Create a new variable with the specified shape, with all values initialized to 1.0
*
* @param name the name of the variable to create
* @param shape the shape of the array to be created
* @return the created variable
*/
public SDVariable one(String name, int[] shape) {
return var(name, ArrayUtil.toLongArray(shape), new ConstantInitScheme('f', 1.0));
}
/**
* Create a new variable with the specified shape, with all values initialized to 1.0
*
* @param name the name of the variable to create
* @param shape the shape of the array to be created
* @return the created variable
*/
public SDVariable one(String name, long[] shape) {
return var(name, shape, new ConstantInitScheme('f', 1.0));
}
/**
* Return a variable of all 1s, with the same shape as the input variable. Note that this is dynamic:
* if the input shape changes in later execution, the returned variable's shape will also be updated
*
* @param input Input SDVariable
* @return A new SDVariable with the same (dynamic) shape as the input
*/
public SDVariable onesLike(SDVariable input) {
return onesLike(null, input);
}
/**
* Return a variable of all 1s, with the same shape as the input variable. Note that this is dynamic:
* if the input shape changes in later execution, the returned variable's shape will also be updated
*
* @param name Name of the new SDVariable
* @param input Input SDVariable
* @return A new SDVariable with the same (dynamic) shape as the input
*/
public SDVariable onesLike(String name, SDVariable input) {
SDVariable ret = f().onesLike(name, input);
return updateVariableNameAndReference(ret, name);
}
/**
* Create a new variable with the specified shape, with all values initialized to 0
*
* @param name the name of the variable to create
* @param shape the shape of the array to be created
* @return the created variable
*/
public SDVariable zero(String name, long[] shape) {
return var(name, shape, new ZeroInitScheme());
}
/**
* Create a new variable with the specified shape, with all values initialized to 0
*
* @param name the name of the variable to create
* @param shape the shape of the array to be created
* @return the created variable
*/
public SDVariable zero(String name, int[] shape) {
return var(name, ArrayUtil.toLongArray(shape), new ZeroInitScheme());
}
/**
* Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic:
* if the input shape changes in later execution, the returned variable's shape will also be updated
*
* @param input Input SDVariable
* @return A new SDVariable with the same (dynamic) shape as the input
*/
public SDVariable zerosLike(SDVariable input) {
return zerosLike(null, input);
}
/**
* Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic:
* if the input shape changes in later execution, the returned variable's shape will also be updated
*
* @param name Name of the new SDVariable
* @param input Input SDVariable
* @return A new SDVariable with the same (dynamic) shape as the input
*/
public SDVariable zerosLike(String name, SDVariable input) {
SDVariable ret = f().zerosLike(name, input);
return updateVariableNameAndReference(ret, name);
}
/**
* Return a variable of given shape in which all values have a given constant value.
*
* @param value constant to set for each value
* @param shape shape of the variable as long array
* @return A new SDVariable of provided shape with constant value.
*/
public SDVariable constant(SDVariable value, long... shape) {
return constant(null, value, shape);
}
/**
* Return a variable of given shape in which all values have a given constant value.
*
* @param name Name of the new SDVariable
* @param value constant to set for each value
* @param shape shape of the variable as long array
* @return A new SDVariable of provided shape with constant value.
*/
public SDVariable constant(String name, SDVariable value, long... shape) {
SDVariable ret = f().constant(value, shape);
return updateVariableNameAndReference(ret, name);
}
/**
* Create a new 1d array with values evenly spaced between values 'start' and 'stop'
* For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
*
* @param start Start value
* @param stop Stop value
* @param number Number of values to generate
* @return SDVariable with linearly spaced elements
*/
public SDVariable linspace(double start, double stop, long number) {
return linspace(null, start, stop, number);
}
/**
* Create a new 1d array with values evenly spaced between values 'start' and 'stop'
* For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
*
* @param name Name of the new variable
* @param start Start value
* @param stop Stop value
* @param number Number of values to generate
* @return SDVariable with linearly spaced elements
*/
public SDVariable linspace(String name, double start, double stop, long number) {
SDVariable ret = f().linspace(start, stop, number);
return updateVariableNameAndReference(ret, name);
}
/**
* Create a new variable with a 1d array, where the values start at {@code from} and increment by {@code step}
* up to (but not including) limit.
* For example, {@code range(1.0, 3.0, 0.5)} will return {@code [1.0, 1.5, 2.0, 2.5]}
* @param from Initial/smallest value
* @param to Largest value (exclusive)
* @param step Step size
* @return 1D SDVariable with the specified values
*/
public SDVariable range(double from, double to, double step){
return range(null, from, to, step);
}
/**
* Create a new variable with a 1d array, where the values start at {@code from} and increment by {@code step}
* up to (but not including) limit.
* For example, {@code range(1.0, 3.0, 0.5)} will return {@code [1.0, 1.5, 2.0, 2.5]}
* @param name Name of the new variable
* @param from Initial/smallest value
* @param to Largest value (exclusive)
* @param step Step size
* @return 1D SDVariable with the specified values
*/
public SDVariable range(String name, double from, double to, double step){
SDVariable ret = f().range(from, to, step);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #meshgrid(List, SDVariable...)
*/
public SDVariable[] meshgrid(SDVariable... inputs){
return meshgrid(null, inputs);
}
/**
* Broadcast the 1D input variables onto an n-dimensional grid.
* The resulting variable can be used for example for evaluating functions at all locations on a grid.
* Example:
*
* {@code input1 = [1, 2, 3]
* input2 = [4, 5, 6]
* SDVariable[] out = meshgrid(input1, input2)
* out[0]:
* [ 1, 2, 3]
* [ 1, 2, 3]
* [ 1, 2, 3]
*
* out[1]:
* [ 4, 4, 4]
* [ 5, 5, 5]
* [ 6, 6, 6]}
*
*
* @param names List of names for the output variables. Must have exactly N names for N input arrays
* @param inputs N x 1D input variables
* @return an array of exactly N SDVariables (for N inputs), of rank N
*/
public SDVariable[] meshgrid(List names, SDVariable... inputs){
return meshgrid(names, true, inputs);
}
/**
* @see #meshgrid(List, SDVariable...)
*/
public SDVariable[] meshgrid(List names, boolean cartesian, SDVariable... inputs){
Preconditions.checkState(names == null || names.size() == inputs.length,
"Got %s names but %s inputs", (names == null ? 0 : names.size()), inputs.length);
SDVariable[] ret = f().meshgrid(cartesian, inputs);
for( int i=0; i
* Any array will be generated with all zeros for the values
*
* @param name the name of the variable
* @param shape the shape of the variable
* @return the created variable
*/
public SDVariable var(String name, long... shape) {
Preconditions.checkNotNull(shape != null, "Invalid shape: shape may not be null");
return var(name, shape, new ZeroInitScheme());
}
/**
* Creates a {@link SDVariable} with the given shape and name
* Any array will be generated with all zeros for the values
*
* @param name the name of the variable
* @param shape the shape of the variable
* @return the created variable
*/
public SDVariable var(String name, int... shape) {
Preconditions.checkNotNull(shape != null, "Invalid shape: shape may not be null");
return var(name, ArrayUtil.toLongArray(shape), new ZeroInitScheme());
}
/**
* Initialize a {@link SDVariable} reference tying this variable to this samediff instance.
*
* {@link NDArraySupplierInitScheme} is used to ensure that if the array is allocated anywhere
* and {@link SameDiff} instance to exist as a copy of the variable.
*
* @param arr
* @return
*/
public SDVariable var(final SDVariable arr) {
if (variableMap.containsKey(arr.getVarName()) && variableMap.get(arr.getVarName()).getArr() != null)
return variableMap.get(arr.getVarName());
if (arr.getVarName() == null || arr.getVarName().length() < 1)
throw new IllegalArgumentException("Name for variable must be defined");
if (arr == null)
throw new IllegalArgumentException("Array for " + arr.getVarName() + " must not be null");
if (workspace == null)
initWorkspace();
final SDVariable ret = SDVariable.builder()
.sameDiff(this)
.shape(arr.getShape())
.varName(arr.getVarName())
.weightInitScheme(new NDArraySupplierInitScheme(new NDArraySupplierInitScheme.NDArraySupplier() {
@Override
/**
* Pre allocate the array if it doesn't already exist.
* The reason we do this is to avoid race conditions with
* {@link #allocate()}
*/
public INDArray getArr() {
if (arr.getArr() == null) {
INDArray retArr = arr.getWeightInitScheme().create(arr.getShape());
associateArrayWithVariable(retArr, arr);
}
return arr.getArr();
}
}))
.build();
variableMap.put(arr.getVarName(), ret);
return ret;
}
private String getNewVarName() {
String varName = "sd_var_" + String.valueOf(variableId);
while (variableMap.containsKey(varName)) {
variableId++;
varName = "sd_var_" + String.valueOf(variableId);
}
return varName;
}
/**
* Creates a {@link SDVariable} with the specified shape and a generated name
* Any array will be generated with all zeros for the values
*
* @param shape the shape of the variable
* @return the created variable
*/
public SDVariable var(int... shape) {
return var(getNewVarName(), shape);
}
/**
* Creates a {@link SDVariable} with the specified shape and a generated name
* Any array will be generated with all zeros for the values
*
* @param shape the shape of the variable
* @return the created variable
*/
public SDVariable var(long... shape) {
return var(getNewVarName(), shape);
}
/**
* Creates a {@link SDVariable} with the specified shape and a generated name. The associated array will
* then be generated using the specified weight initialization scheme
*
* @param weightInitScheme The weight initialization scheme to use when generating an INDArray
* @param shape the shape of the variable
* @return the created variable
*/
public SDVariable var(WeightInitScheme weightInitScheme, long... shape) {
return var(getNewVarName(), shape, weightInitScheme);
}
/**
* Create an {@link SDVariable} with a generated name, and assocate the specified array with it
* @param arr Array to associate with the new variable
* @return New SDVariable
* @see #var(String, INDArray)
*/
public SDVariable var(INDArray arr) {
return var(getNewVarName(), arr);
}
/**
* Create an {@link SDVariable} with the specified name, and assocate the specified array with it
* @param arr Array to associate with the new variable
* @return New SDVariable with the specified name and array
*/
public SDVariable var(String name, INDArray arr) {
if (variableMap.containsKey(name) && variableMap.get(name).getArr() != null)
throw new IllegalArgumentException("Another variable with the name " + name +
" already exists.");
if (name == null || name.length() < 1)
name = getNewVarName();
if (arr == null)
throw new IllegalArgumentException("Array for " + name + " must not be null");
if (workspace == null)
initWorkspace();
val arrRef = arr.migrate();
SDVariable ret = SDVariable.builder()
.sameDiff(this)
.shape(arr.shape())
.varName(name)
.weightInitScheme(new NDArraySupplierInitScheme(new NDArraySupplierInitScheme.NDArraySupplier() {
@Override
/**
* Return array
*/
public INDArray getArr() {
return arrRef;
}
}))
.build();
associateArrayWithVariable(arr, ret);
if (ArrayUtil.prod(arr.shape()) == 1)
ret.setScalarValue(arr.getDouble(0));
addVariable(ret);
if (getShapeForVarName(name) == null)
putShapeForVarName(name, arr.shape());
//ensure there is a reference to the array in the integer index
//this is used later for op creation
variableMap.put(name, ret);
return ret;
}
/**
* Generate a square identity matrix with the specified number of rows.
*
* @param rows Number of rows (and columns)
* @return SDVariable with an identity matrix array
*/
public SDVariable eye(int rows) {
return eye(rows, rows);
}
/**
* Generate an identity matrix with the specified number of rows and columns.
*
* @param rows Number of rows
*/
public SDVariable eye(String name, int rows) {
return eye(name, rows, rows);
}
/**
* @see #eye(String, int, int)
*/
public SDVariable eye(int rows, int cols) {
return eye(null, rows, cols);
}
/**
* Generate an identity matrix with the specified number of rows and columns
* Example:
*
* {@code SDVariable eye = eye(3,2)
* eye:
* [ 1, 0]
* [ 0, 1]
* [ 0, 0]}
*
*
* @param name Name of the new SDVariable
* @param rows Number of rows
* @param cols Number of columns
* @return SDVaribable identity matrix
*/
public SDVariable eye(String name, int rows, int cols) {
return eye(name, rows, cols, null);
}
/**
* see {@link #eye(String, int, int, int...)}
*/
public SDVariable eye(int rows, int cols, int... batchDimension) {
return eye(null, rows, cols, batchDimension);
}
/**
* Generate an identity matrix with the specified number of rows and columns, with optional leading dims
* Example:
* batchShape: [3,3]
* numRows: 2
* numCols: 4
* returns a tensor of shape (3, 3, 2, 4) that consists of 3 * 3 batches of (2,4)-shaped identity matrices:
* 1 0 0 0
* 0 1 0 0
*
* @param rows Number of rows
* @param cols Number of columns
* @param batchDimension Batch dimensions. May be null
*/
public SDVariable eye(String name, int rows, int cols, int... batchDimension) {
SDVariable eye = new Eye(this, rows, cols, batchDimension).outputVariables()[0];
return updateVariableNameAndReference(eye, name);
}
/**
* As per {@link #eye(String, int, int, int...)} bit with the number of rows/columns specified as scalar SDVariables,
* and the batch dimension specified as a 1D SDVariable
*/
public SDVariable eye(String name, SDVariable rows, SDVariable cols, SDVariable batchDimension){
SDVariable eye = new Eye(this, rows, cols, batchDimension).outputVariable();
return updateVariableNameAndReference(eye, name);
}
/**
* As per {@link #eye(int, int, int...)} bit with the number of rows/columns specified as scalar SDVariables,
* and the batch dimension specified as a 1D SDVariable
*/
public SDVariable eye(SDVariable rows, SDVariable cols, SDVariable batchDimension){
return eye(null, rows, cols, batchDimension);
}
/**
* As per {@link #eye(String, int, int)} bit with the number of rows/columns specified as scalar SDVariables
*/
public SDVariable eye(String name, SDVariable rows, SDVariable cols){
SDVariable eye = new Eye(this, rows, cols).outputVariables()[0];
return updateVariableNameAndReference(eye, name);
}
/**
* As per {@link #eye(int, int)} bit with the number of rows/columns specified as scalar SDVariables
*/
public SDVariable eye(SDVariable rows, SDVariable cols){
SDVariable eye = new Eye(this, rows, cols).outputVariables()[0];
return updateVariableNameAndReference(eye, null);
}
/**
* As per {@link #eye(String, int)} but with the number of rows specified as a scalar SDVariable
*/
public SDVariable eye(String name, SDVariable rows){
SDVariable eye = new Eye(this, rows).outputVariables()[0];
return updateVariableNameAndReference(eye, name);
}
/**
* As per {@link #eye(int)} but with the number of rows specified as a scalar SDVariable
*/
public SDVariable eye(SDVariable rows){
SDVariable eye = new Eye(this, rows).outputVariables()[0];
return updateVariableNameAndReference(eye, null);
}
/**
* Remove an argument for a function. Note that if this function does not contain the argument, it will just be a no op.
*
* @param varName the variable name to remove
* @param function the function to remove the argument from
*/
public void removeArgFromFunction(String varName, DifferentialFunction function) {
val args = function.args();
for (int i = 0; i < args.length; i++) {
if (args[i].getVarName().equals(varName)) {
/**
* Since we are removing the variable reference
* from the arguments we need to update both
* the reverse and forward arguments.
*/
val reverseArgs = incomingArgsReverse.get(function.getOwnName());
incomingArgsReverse.remove(function.getOwnName());
val newArgs = new ArrayList(args.length - 1);
for (int arg = 0; arg < args.length; arg++) {
if (!reverseArgs[arg].equals(varName)) {
newArgs.add(reverseArgs[arg]);
}
}
val newArgsArr = newArgs.toArray(new String[newArgs.size()]);
incomingArgsReverse.put(function.getOwnName(), newArgsArr);
//no further need to scan
break;
}
}
}
/**
* Get the variable based on the opName
*
* @param name the opName of the variable
* @return the variabel instance if there is one
*/
public SDVariable getVariable(String name) {
return variableMap.get(name);
}
/**
* Get the gradient for the given vertex id
*
* @param varName the vertex id
* @return the gradient for this variable or null
*/
public SDVariable getGradForVariable(String varName) {
//TODO 2018/06/26 - Review this?
//Gradients are being placed in the inner "grad" function SameDiff instance, but not the outer one
// should they be synced and we just use the map in this instance?
if (gradients.containsKey(varName)) {
return gradients.get(varName);
} else if(sameDiffFunctionInstances.containsKey("grad") && sameDiffFunctionInstances.get("grad").gradients.containsKey(varName)){
return sameDiffFunctionInstances.get("grad").gradients.get(varName);
}
return null;
}
/**
* Assign a SDVariable to represent the gradient of the SDVariable with the specified name
*
* @param variableName the variable name to assign the gradient variable for
* @param variable the gradient variable
*/
public void setGradientForVariableName(String variableName, SDVariable variable) {
if (variable == null) {
throw new ND4JIllegalStateException("Unable to set null gradient for variable name " + variableName);
}
gradients.put(variableName, variable);
}
/**
* Get the forward variable for gradient based on the gradient's vertex id
*
* @param vertexId the vertex id
* @return the gradient for the variable or null
*/
public SDVariable getForwardVariableForVertexId(int vertexId) {
return forwardVarForGrad.get(vertexId);
}
/**
* @param varName
* @param forwardVariable
*/
public void setForwardVariableForVarName(String varName, SDVariable forwardVariable) {
forwardVarForGrad.put(varName, forwardVariable);
}
/**
* Get the gradient for the variable with the specified variable name.
* Note that in order to run this function, {@link #execBackwards()} must be executed first.
* All gradient functions are obtained from the results of the execBackwards call.
*
* @param varName the variable name to get the gradient variable for.
* @return The gradient variable for the specified variable
*/
public SDVariable grad(String varName) {
if (!sameDiffFunctionInstances.containsKey("grad")) {
throw new IllegalStateException("Unable to obtain gradient. Please run execBackwards() first.");
}
SameDiff grad = getFunction("grad");
SDVariable var = grad.getVariable(varName);
return getFunction("grad").getGradForVariable(var.getVarName());
}
/**
* @see #randomUniform(String, double, double, SDVariable)
*/
public SDVariable randomUniform(double min, double max, SDVariable shape){
return randomUniform(null, min, max, shape);
}
/**
* Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution,
* U(min,max)
* See {@link #randomUniform(double, double, long...)} for the equivalent function where the shape is
* specified as a long[] instead
*
* @param name Name of the new SDVariable
* @param min Minimum value
* @param max Maximum value. Must satisfy max >= min
* @param shape Shape of the new random SDVariable, as a 1D array
* @return New SDVariable
*/
public SDVariable randomUniform(String name, double min, double max, SDVariable shape){
SDVariable ret = f().randomUniform(min, max, shape);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #randomUniform(String, double, double, long...)
*/
public SDVariable randomUniform(double min, double max, long... shape){
return randomUniform(null, min, max, shape);
}
/**
* Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution,
* U(min,max)
* See {@link #randomUniform(double, double, long...)} for the equivalent function where the shape is
* specified as a SDVariable instead
*
* @param name Name of the new SDVariable
* @param min Minimum value
* @param max Maximum value. Must satisfy max >= min
* @param shape Shape of the new random SDVariable
* @return New SDVariable
*/
public SDVariable randomUniform(String name, double min, double max, long... shape){
SDVariable ret = f().randomUniform(min, max, shape);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #randomNormal(String, double, double, SDVariable)
*/
public SDVariable randomNormal(double mean, double stddev, SDVariable shape){
return randomNormal(null, mean, stddev, shape);
}
/**
* Generate a new random SDVariable, where values are randomly sampled according to a Gaussian (normal) distribution,
* N(mean, stdev)
* See {@link #randomNormal(String, double, double, long...)} for the equivalent function where the shape is
* specified as a long[] instead
*
* @param name Name of the new SDVariable
* @param mean Mean value for the random array
* @param stddev Standard deviation for the random array
* @param shape Shape of the new random SDVariable, as a 1D array
* @return New SDVariable
*/
public SDVariable randomNormal(String name, double mean, double stddev, SDVariable shape){
SDVariable ret = f().randomNormal(mean, stddev, shape);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #randomNormal(String, double, double, long...)
*/
public SDVariable randomNormal(double mean, double stddev, long... shape){
return randomNormal(null, mean, stddev, shape);
}
/**
* Generate a new random SDVariable, where values are randomly sampled according to a Gaussian (normal) distribution,
* N(mean, stdev)
* See {@link #randomNormal(String, double, double, SDVariable)} for the equivalent function where the shape is
* specified as a long[] instead
*
* @param name Name of the new SDVariable
* @param mean Mean value for the random array
* @param stddev Standard deviation for the random array
* @param shape Shape of the new random SDVariable
* @return New SDVariable
*/
public SDVariable randomNormal(String name, double mean, double stddev, long... shape){
SDVariable ret = f().randomNormal(mean, stddev, shape);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #randomLogNormal(String, double, double, long...)
*/
public SDVariable randomLogNormal(double mean, double stddev, long... shape){
return randomLogNormal(null, mean, stddev, shape);
}
/**
* Generate a new random SDVariable, where values are randomly sampled according to a Log Normal distribution,
* i.e., {@code log(x) ~ N(mean, stdev)}
*
* @param name Name of the new SDVariable
* @param mean Mean value for the random array
* @param stddev Standard deviation for the random array
* @param shape Shape of the new random SDVariable
* @return New SDVariable
*/
public SDVariable randomLogNormal(String name, double mean, double stddev, long... shape){
SDVariable ret = f().randomLogNormal(mean, stddev, shape);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #randomNormalTruncated(String, double, double, long...)
*/
public SDVariable randomNormalTruncated(double mean, double stddev, long... shape){
return randomNormalTruncated(null, mean, stddev, shape);
}
/**
* Generate a new random SDVariable, where values are randomly sampled according to a Gaussian (normal) distribution,
* N(mean, stdev). However, any values more than 1 standard deviation from the mean are dropped and re-sampled
*
* @param name Name of the new SDVariable
* @param mean Mean value for the random array
* @param stddev Standard deviation for the random array
* @param shape Shape of the new random SDVariable
* @return New SDVariable
*/
public SDVariable randomNormalTruncated(String name, double mean, double stddev, long... shape){
SDVariable ret = f().randomNormalTruncated(mean, stddev, shape);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #randomBernoulli(String, double, SDVariable)
*/
public SDVariable randomBernoulli(double p, SDVariable shape){
return randomBernoulli(null, p, shape);
}
/**
* Generate a new random SDVariable, where values are randomly sampled according to a Bernoulli distribution,
* with the specified probability. Array values will have value 1 with probability P and value 0 with probability
* 1-P.
* See {@link #randomBernoulli(String, double, long...)} for the equivalent function where the shape is
* specified as a long[] instead
*
* @param name Name of the new SDVariable
* @param p Probability of value 1
* @param shape Shape of the new random SDVariable, as a 1D array
* @return New SDVariable
*/
public SDVariable randomBernoulli(String name, double p, SDVariable shape){
SDVariable ret = f().randomBernoulli(p, shape);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #randomBernoulli(String, double, long...)
*/
public SDVariable randomBernoulli(double p, long... shape){
return randomBernoulli(null, p, shape);
}
/**
* Generate a new random SDVariable, where values are randomly sampled according to a Bernoulli distribution,
* with the specified probability. Array values will have value 1 with probability P and value 0 with probability
* 1-P.
* See {@link #randomBernoulli(String, double, SDVariable)} for the equivalent function where the shape is
* specified as a SDVarible instead
*
* @param name Name of the new SDVariable
* @param p Probability of value 1
* @param shape Shape of the new random SDVariable, as a 1D array
* @return New SDVariable
*/
public SDVariable randomBernoulli(String name, double p, long... shape){
SDVariable ret = f().randomBernoulli(p, shape);
return updateVariableNameAndReference(ret, name);
}
/**
* Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution,
* with the specified number of trials and probability.
*
* @param nTrials Number of trials parameter for the binomial distribution
* @param p Probability of success for each trial
* @param shape Shape of the new random SDVariable, as a 1D array
* @return New SDVariable
*/
public SDVariable randomBinomial(int nTrials, double p, long... shape){
return randomBinomial(null, nTrials, p, shape);
}
/**
* Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution,
* with the specified number of trials and probability.
*
* @param name Name of the new SDVariable
* @param nTrials Number of trials parameter for the binomial distribution
* @param p Probability of success for each trial
* @param shape Shape of the new random SDVariable, as a 1D array
* @return New SDVariable
*/
public SDVariable randomBinomial(String name, int nTrials, double p, long... shape){
SDVariable ret = f().randomBinomial(nTrials, p, shape);
return updateVariableNameAndReference(ret, name);
}
/**
* Generate a new random SDVariable, where values are randomly sampled according to a exponential distribution:
* P(x) = lambda * exp(-lambda * x)
*
* @param lambda Must be > 0
* @param shape Shape of the output
* @return new SDVariable
*/
public SDVariable randomExponential(double lambda, SDVariable shape) {
return randomExponential(null, lambda, shape);
}
/**
* Generate a new random SDVariable, where values are randomly sampled according to a exponential distribution:
* P(x) = lambda * exp(-lambda * x)
*
* @param name Name of the output variable
* @param lambda Must be > 0
* @param shape Shape of the new variable
* @return new SDVaribale
*/
public SDVariable randomExponential(String name, double lambda, SDVariable shape) {
SDVariable ret = f().randomExponential(lambda, shape);
return updateVariableNameAndReference(ret, name);
}
/**
* 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format.
*
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
* @param scale Scale to upsample in both H and W dimensions
* @return Upsampled input
*/
public SDVariable upsampling2d(SDVariable input, int scale) {
return upsampling2d(null, input, true, scale, scale);
}
/**
* 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format.
*
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
* @param scale Scale to upsample in both H and W dimensions
* @return Upsampled input
*/
public SDVariable upsampling2d(String name, SDVariable input, int scale) {
return upsampling2d(name, input, true, scale, scale);
}
/**
* 2D Convolution layer operation - Upsampling 2d
*
* @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width])
* or NHWC format (shape [minibatch, height, width, channels])
* @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format
* @param scaleH Scale to upsample in height dimension
* @param scaleW Scale to upsample in width dimension
* @return Upsampled input
*/
public SDVariable upsampling2d(SDVariable input, boolean nchw, int scaleH, int scaleW) {
return upsampling2d(null, input, nchw, scaleH, scaleW);
}
/**
* 2D Convolution layer operation - Upsampling 2d
*
* @param input Input, in NCHW format
* @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format
* @param scaleH Scale to upsample in height dimension
* @param scaleW Scale to upsample in width dimension
* @return Upsampled input
*/
public SDVariable upsampling2d(String name, SDVariable input, boolean nchw, int scaleH, int scaleW) {
SDVariable ret = f().upsampling2d(input, nchw, scaleH, scaleW);
return updateVariableNameAndReference(ret, name);
}
/**
* 2D Convolution layer operation - average pooling 2d
*
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param pooling2DConfig the configuration for
* @return Result after applying average pooling on the input
*/
public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
return avgPooling2d(null, input, pooling2DConfig);
}
/**
* 2D Convolution layer operation - average pooling 2d
*
* @param name name of the operation in SameDiff
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param pooling2DConfig the configuration
* @return Result after applying average pooling on the input
*/
public SDVariable avgPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) {
SDVariable ret = f().avgPooling2d(input, pooling2DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* 2D Convolution layer operation - max pooling 2d
*
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param pooling2DConfig the configuration
* @return Result after applying max pooling on the input
*/
public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
return maxPooling2d(null, input, pooling2DConfig);
}
/**
* 2D Convolution layer operation - max pooling 2d
*
* @param name name of the operation in SameDiff
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param pooling2DConfig the configuration
* @return Result after applying max pooling on the input
*/
public SDVariable maxPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) {
SDVariable ret = f().maxPooling2d(input, pooling2DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* 3D convolution layer operation - average pooling 3d
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param pooling3DConfig the configuration
* @return Result after applying average pooling on the input
*/
public SDVariable avgPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) {
return avgPooling3d(null, input, pooling3DConfig);
}
/**
* 3D convolution layer operation - average pooling 3d
*
* @param name name of the operation in SameDiff
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param pooling3DConfig the configuration
* @return Result after applying average pooling on the input
*/
public SDVariable avgPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) {
SDVariable ret = f().avgPooling3d(input, pooling3DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* 3D convolution layer operation - max pooling 3d operation.
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param pooling3DConfig the configuration
* @return Result after applying max pooling on the input
*/
public SDVariable maxPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) {
return maxPooling3d(null, input, pooling3DConfig);
}
/**
* 3D convolution layer operation - max pooling 3d operation.
*
* @param name name of the operation in SameDiff
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param pooling3DConfig the configuration
* @return Result after applying max pooling on the input
*/
public SDVariable maxPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) {
SDVariable ret = f().maxPooling3d(input, pooling3DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* 1D Convolution layer operation - Conv1d
*
* @param input the input array/activations for the conv1d op
* @param weights weights for conv1d op
* @param conv1DConfig the configuration
* @return
*/
public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
return conv1d(null, input, weights, conv1DConfig);
}
/**
* Conv1d operation.
*
* @param name name of the operation in SameDiff
* @param input the inputs to conv1d
* @param weights weights for conv1d op
* @param conv1DConfig the configuration
* @return
*/
public SDVariable conv1d(String name, SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
SDVariable ret = f().conv1d(input, weights, conv1DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* 2D convolution layer operation - local response normalization
*
* @param inputs the inputs to lrn
* @param lrnConfig the configuration
* @return
*/
public SDVariable localResponseNormalization(SDVariable inputs, LocalResponseNormalizationConfig lrnConfig) {
return localResponseNormalization(null, inputs, lrnConfig);
}
/**
* 2D convolution layer operation - local response normalization
*
* @param name name of the operation in SameDiff
* @param input the inputs to lrn
* @param lrnConfig the configuration
* @return
*/
public SDVariable localResponseNormalization(String name, SDVariable input,
LocalResponseNormalizationConfig lrnConfig) {
SDVariable ret = f().localResponseNormalization(input, lrnConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* 2D Convolution operation (without bias)
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param weights Weights for the convolution operation. 4 dimensions.
* If layer input data is in NCHW format, weights should have format [outputChannels, inputChannels, kernelHeight, kernelWidth].
* If layer input data is in NHWC format, weight should have format [kernelHeight, kernelWidth, inputChannels, outputChannels]
* @param config Conv2DConfig configuration
* @return result of conv2d op
*/
public SDVariable conv2d(SDVariable layerInput, SDVariable weights, Conv2DConfig config) {
return conv2d(layerInput, weights, null, config);
}
/**
* 2D Convolution operation with optional bias
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param weights Weights for the convolution operation. 4 dimensions.
* If layer input data is in NCHW format, weights should have format [outputChannels, inputChannels, kernelHeight, kernelWidth].
* If layer input data is in NHWC format, weight should have format [kernelHeight, kernelWidth, inputChannels, outputChannels]
* @param bias Optional 1D bias array with shape [outputChannels]. May be null.
* @param config Conv2DConfig configuration
* @return result of conv2d op
*/
public SDVariable conv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, Conv2DConfig config) {
SDVariable[] arr = new SDVariable[bias == null ? 2 : 3];
arr[0] = layerInput;
arr[1] = weights;
if (bias != null)
arr[2] = bias;
return conv2d(arr, config);
}
/**
* 2D Convolution operation with optional bias
*
* @param inputs an array with either 2 elements (layerInput, weights) or 3 elements (layerInput, weights, bias) as
* described in {@link #conv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)}
* @param config Conv2DConfig configuration
* @return result of convolution 2d operation
*/
public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig config) {
return conv2d(null, inputs, config);
}
/**
* 2D Convolution operation with optional bias
*
* @param name Name of the output SDVariable
* @param inputs an array with either 2 elements (layerInput, weights) or 3 elements (layerInput, weights, bias) as
* described in {@link #conv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)}
* @param config Conv2DConfig configuration
* @return result of convolution 2d operation
*/
public SDVariable conv2d(String name, SDVariable[] inputs, Conv2DConfig config) {
SDVariable ret = f().conv2d(inputs, config);
return updateVariableNameAndReference(ret, name);
}
/**
* Depth-wise 2D convolution operation without bias
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param depthWeights depth-wise conv 2D weights. 4 dimensions.
* If layer input data is in NCHW format, weights should have format [outputChannels, inputChannels, kernelHeight, kernelWidth].
* If layer input data is in NHWC format, weight should have format [kernelHeight, kernelWidth, inputChannels, outputChannels]
* @param config Conv2DConfig configuration
* @return result of conv2d op
*/
public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, Conv2DConfig config) {
return depthWiseConv2d(layerInput, depthWeights, null, config);
}
/**
* Depth-wise 2D convolution operation with optional bias
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param depthWeights depth-wise conv 2D weights. 4 dimensions.
* If layer input data is in NCHW format, weights should have format [outputChannels, inputChannels, kernelHeight, kernelWidth].
* If layer input data is in NHWC format, weight should have format [kernelHeight, kernelWidth, inputChannels, outputChannels]
* @param bias Optional 1D bias array with shape [outputChannels]. May be null.
* @param config Conv2DConfig configuration
* @return result of depthwise conv2d op
*/
public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable bias, Conv2DConfig config) {
SDVariable[] arr = new SDVariable[bias == null ? 2 : 3];
arr[0] = layerInput;
arr[1] = depthWeights;
if (bias != null)
arr[2] = bias;
return depthWiseConv2d(arr, config);
}
/**
* Depth-wise convolution 2D operation.
*
* @param inputs the inputs to depth-wise conv2d. An array with either 2 elements (layerInput, depthWeights)
* or 3 elements (layerInput, depthWeights, bias) as described in
* {@link #depthWiseConv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)}
* @param depthConv2DConfig the configuration
* @return result of depthwise conv2d op
*/
public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
return depthWiseConv2d(null, inputs, depthConv2DConfig);
}
/**
* Depth-wise convolution 2D operation.
*
* @param name name of the output variable
* @param inputs the inputs to depth-wise conv2d. An array with either 2 elements (layerInput, depthWeights)
* or 3 elements (layerInput, depthWeights, bias) as described in
* {@link #depthWiseConv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)}
* @param depthConv2DConfig the configuration
* @return result of depthwise conv2d op
*/
public SDVariable depthWiseConv2d(String name, SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
SDVariable ret = f().depthWiseConv2d(inputs, depthConv2DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* Separable 2D convolution operation without bias
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param depthWeights Depth weights, rank 4.
* If layer input is in NCHW format, depth weights should have format [outputChannels, depthMultiplier, kernelHeight, kernelWidth].
* If layer input is in NHWC format, depth weights should have format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
* @param pointWeights Point weights, rank 4.
* If layer input is in NCHW format, point weights should have format [outputChannels, inputChannels*depthMultiplier, 1, 1].
* If layer input is in NHWC format, point weights should have format [1, 1, inputChannels*depthMultiplier, outputChannels]
* @param config Conv2DConfig configuration
* @return result of separable convolution 2d operation
*/
public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights,
Conv2DConfig config) {
return separableConv2d(layerInput, depthWeights, pointWeights, null, config);
}
/**
* Separable 2D convolution operation with optional bias
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param depthWeights Depth weights, rank 4.
* If layer input is in NCHW format, depth weights should have format [outputChannels, depthMultiplier, kernelHeight, kernelWidth].
* If layer input is in NHWC format, depth weights should have format [kernelHeight, kernelWidth, inputChannels, depthMultiplier]
* @param pointWeights Point weights, rank 4.
* If layer input is in NCHW format, point weights should have format [outputChannels, inputChannels*depthMultiplier, 1, 1].
* If layer input is in NHWC format, point weights should have format [1, 1, inputChannels*depthMultiplier, outputChannels]
* @param bias Optional bias, rank 1 with shape [outputChannels]. May be null.
* @param config Conv2DConfig configuration
* @return result of separable convolution 2d operation
*/
public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights,
SDVariable bias, Conv2DConfig config) {
SDVariable[] arr = new SDVariable[bias == null ? 3 : 4];
arr[0] = layerInput;
arr[1] = depthWeights;
arr[2] = pointWeights;
if (bias != null)
arr[3] = bias;
return sconv2d(arr, config);
}
/**
* Separable 2D convolution operation with/without optional bias
*
* @param inputs the inputs to separable conv2 operation. Should be length 3 (layerInput, depthWeights, pointWeights)
* or length 4 (layerInput, depthWeights, pointWeights, bias) as described in {@link #separableConv2d(SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}
* @param conv2DConfig the configuration
* @return result of separable convolution 2d operation
*/
public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
return sconv2d(null, inputs, conv2DConfig);
}
/**
* Separable 2D convolution operation with/without optional bias
*
* @param name name of the output variable
* @param inputs the inputs to separable conv2 operation. Should be length 3 (layerInput, depthWeights, pointWeights)
* or length 4 (layerInput, depthWeights, pointWeights, bias) as described in {@link #separableConv2d(SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}
* @param conv2DConfig the configuration
* @return result of separable convolution 2d operation
*/
public SDVariable sconv2d(String name, SDVariable[] inputs, Conv2DConfig conv2DConfig) {
SDVariable ret = f().sconv2d(inputs, conv2DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* 2D deconvolution operation without bias
*
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param weights Weights for the 2d deconvolution operation. 4 dimensions.
* If layer input data is in NCHW format, weights should have format [inputChannels, outputChannels, kernelHeight, kernelWidth].
* If layer input data is in NHWC format, weight should have format [kernelHeight, kernelWidth, outputChannels, inputChannels]
* @param deconv2DConfig DeConv2DConfig configuration
* @return result of deconv2d op
*/
public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, DeConv2DConfig deconv2DConfig) {
return deconv2d(layerInput, weights, null, deconv2DConfig);
}
/**
* 2D deconvolution operation with optional bias
*
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param weights Weights for the 2d deconvolution operation. 4 dimensions.
* If layer input data is in NCHW format, weights should have format [inputChannels, outputChannels, kernelHeight, kernelWidth].
* If layer input data is in NHWC format, weight should have format [kernelHeight, kernelWidth, outputChannels, inputChannels]
* @param bias Optional 1D bias array with shape [outputChannels]. May be null.
* @param deconv2DConfig DeConv2DConfig configuration
* @return result of deconv2d op
*/
public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, DeConv2DConfig deconv2DConfig) {
SDVariable[] arr = new SDVariable[bias == null ? 2 : 3];
arr[0] = layerInput;
arr[1] = weights;
if (bias != null)
arr[2] = bias;
return deconv2d(arr, deconv2DConfig);
}
/**
* 2D deconvolution operation with or without optional bias
*
* @param inputs Inputs to the deconvolution 2d operation - input array of length 2 (layerInput, weights)
* or length 3 (layerInput, weights, bias) as described in {@link #deconv2d(SDVariable[], DeConv2DConfig)}
* @param deconv2DConfig the configuration
* @return result of deconv2d op
*/
public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
return deconv2d(null, inputs, deconv2DConfig);
}
/**
* 2D deconvolution operation with or without optional bias
*
* @param name Name of the output variable
* @param inputs Inputs to the deconvolution 2d operation - input array of length 2 (layerInput, weights)
* or length 3 (layerInput, weights, bias) as described in {@link #deconv2d(SDVariable[], DeConv2DConfig)}
* @param deconv2DConfig the configuration
* @return result of deconv2d op
*/
public SDVariable deconv2d(String name, SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
SDVariable ret = f().deconv2d(inputs, deconv2DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* Convolution 3D operation without bias
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param weights Weights for conv3d. Rank 5.
* If input data is in NCDHW fomat, weights should have shape [outputChannels, inputChannels, kernelDepth, kernelHeight, kernelWidth].
* If input data is in NDHWC fomat, weights should have shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
* @param conv3DConfig the configuration
* @return Conv3d output variable
*/
public SDVariable conv3d(SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) {
return conv3d(null, input, weights, null, conv3DConfig);
}
/**
* Convolution 3D operation with optional bias
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param weights Weights for conv3d. Rank 5.
* If input data is in NCDHW fomat, weights should have shape [outputChannels, inputChannels, kernelDepth, kernelHeight, kernelWidth].
* If input data is in NDHWC fomat, weights should have shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
* @param bias Optional 1D bias array with shape [outputChannels]. May be null.
* @param conv3DConfig the configuration
* @return Conv3d output variable
*/
public SDVariable conv3d(SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) {
return conv3d(null, input, weights, bias, conv3DConfig);
}
/**
* Convolution 3D operation without bias
*
* @param name Name of the output variable
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param weights Weights for conv3d. Rank 5.
* If input data is in NCDHW fomat, weights should have shape [outputChannels, inputChannels, kernelDepth, kernelHeight, kernelWidth].
* If input data is in NDHWC fomat, weights should have shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
* @param conv3DConfig the configuration
* @return Conv3d output variable
*/
public SDVariable conv3d(String name, SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) {
return conv3d(name, input, weights, null, conv3DConfig);
}
/**
* Convolution 3D operation with optional bias
*
* @param name Name of the output variable
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels])
* @param weights Weights for conv3d. Rank 5.
* If input data is in NCDHW fomat, weights should have shape [outputChannels, inputChannels, kernelDepth, kernelHeight, kernelWidth].
* If input data is in NDHWC fomat, weights should have shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels].
* @param bias Optional 1D bias array with shape [outputChannels]. May be null.
* @param conv3DConfig the configuration
* @return Conv3d output variable
*/
public SDVariable conv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) {
SDVariable[] args;
if (bias == null) {
args = new SDVariable[]{input, weights};
} else {
args = new SDVariable[]{input, weights, bias};
}
SDVariable ret = f().conv3d(args, conv3DConfig);
return updateVariableNameAndReference(ret, name);
}
/**
* Batch norm operation.
*/
public SDVariable batchNorm(SDVariable input, SDVariable mean,
SDVariable variance, SDVariable gamma,
SDVariable beta,
boolean applyGamma, boolean applyBeta, double epsilon) {
return batchNorm(null, input, mean, variance, gamma, beta, applyGamma, applyBeta, epsilon);
}
/**
* Batch norm operation.
*/
public SDVariable batchNorm(String name, SDVariable input, SDVariable mean,
SDVariable variance, SDVariable gamma,
SDVariable beta,
boolean applyGamma, boolean applyBeta, double epsilon) {
SDVariable res = f().batchNorm(input, mean, variance, gamma, beta, applyGamma, applyBeta, epsilon);
return updateVariableNameAndReference(res, name);
}
/**
* im2col operation for use in 2D convolution operations. Outputs a 6d array with shape
* [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
*
* @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width]
* @param config Convolution configuration for the im2col operation
* @return Im2Col output variable
*/
public SDVariable im2Col(SDVariable in, Conv2DConfig config) {
return im2Col(null, in, config);
}
/**
* im2col operation for use in 2D convolution operations. Outputs a 6d array with shape
* [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
*
* @param name Name of the output variable
* @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width]
* @param config Convolution configuration for the im2col operation
* @return Im2Col output variable
*/
public SDVariable im2Col(String name, SDVariable in, Conv2DConfig config) {
SDVariable ret = f().im2Col(in, config);
return updateVariableNameAndReference(ret, name);
}
/**
* col2im operation for use in 2D convolution operations. Outputs a 4d array with shape
* [minibatch, inputChannels, height, width]
*
* @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
* @param config Convolution configuration for the col2im operation
* @return Col2Im output variable
*/
public SDVariable col2Im(SDVariable in, Conv2DConfig config) {
return col2Im(null, in, config);
}
/**
* col2im operation for use in 2D convolution operations. Outputs a 4d array with shape
* [minibatch, inputChannels, height, width]
*
* @param name Name of the output variable
* @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
* @param config Convolution configuration for the col2im operation
* @return Col2Im output variable
*/
public SDVariable col2Im(String name, SDVariable in, Conv2DConfig config) {
SDVariable ret = f().col2Im(in, config);
return updateVariableNameAndReference(ret, name);
}
/**
* Create a new scalar (rank 0) SDVariable with the specified value
* @param name Name of the SDVariable
* @param value Value to initialize the variable with
* @return SDVariable
*/
public SDVariable scalar(String name, double value) {
return var(name, Nd4j.scalar(value));
}
/**
* Greater than or equals operation: elementwise x >= y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param x Input array
* @param y Double value argument to use in operation
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable gte(SDVariable x, double y) {
return gte(null, x, y);
}
/**
* Greater than or equals operation: elementwise x >= y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param name Name of the output variable
* @param x Input array
* @param y Double value argument to use in operation
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable gte(String name, SDVariable x, double y) {
SDVariable result = functionFactory.gte(x, y);
return updateVariableNameAndReference(result, name);
}
/**
* Less than or equals operation: elementwise x <= y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param x Input array
* @param y Double value argument to use in operation
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable lte(SDVariable x, double y) {
return lte(null, x, y);
}
/**
* Less than or equals operation: elementwise x <= y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param name Name of the output variable
* @param x Input array
* @param y Double value argument to use in operation
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable lte(String name, SDVariable x, double y) {
SDVariable result = functionFactory.lte(x, y);
return updateVariableNameAndReference(result, name);
}
/**
* Greater than operation: elementwise x > y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param x Input array
* @param y Double value argument to use in operation
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable gt(SDVariable x, double y) {
return gt(null, x, y);
}
/**
* Greater than operation: elementwise x > y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param name Name of the output variable
* @param x Input array
* @param y Double value argument to use in operation
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable gt(String name, SDVariable x, double y) {
SDVariable result = functionFactory.gt(x, y);
return updateVariableNameAndReference(result, name);
}
/**
* Less than operation: elementwise x < y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param x Input array
* @param y Double value argument to use in operation
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable lt(SDVariable x, double y) {
return lt(null, x, y);
}
/**
* Less than operation: elementwise x < y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param name Name of the output variable
* @param x Input array
* @param y Double value argument to use in operation
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable lt(String name, SDVariable x, double y) {
SDVariable result = functionFactory.lt(x, y);
return updateVariableNameAndReference(result, name);
}
/**
* Not equals operation: elementwise x != y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param x Input array
* @param y Double value argument to use in operation
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable neq(SDVariable x, double y) {
return neq(null, x, y);
}
/**
* Not equals operation: elementwise x != y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param name Name of the output variable
* @param x Input array
* @param y Double value argument to use in operation
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable neq(String name, SDVariable x, double y) {
SDVariable result = functionFactory.neq(x, y);
return updateVariableNameAndReference(result, name);
}
/**
* Equals operation: elementwise x == y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param x Input array
* @param y Double value argument to use in operation
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable eq(SDVariable x, double y) {
return eq(null, x, y);
}
/**
* Equals operation: elementwise x == y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param name Name of the output variable
* @param x Input array
* @param y Double value argument to use in operation
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable eq(String name, SDVariable x, double y) {
SDVariable result = functionFactory.eq(x, y);
return updateVariableNameAndReference(result, name);
}
/**
* Greater than or equal to operation: elementwise x >= y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
*
* @param x Input 1
* @param y Input 2
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable gte(SDVariable x, SDVariable y) {
return gte(null, x, y);
}
/**
* Greater than or equal to operation: elementwise x >= y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
*
* @param name Name of the output variable
* @param x Input 1
* @param y Input 2
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable gte(String name, SDVariable x, SDVariable y) {
SDVariable result = functionFactory.gte(x, y);
return updateVariableNameAndReference(result, name);
}
/**
* Less than or equal to operation: elementwise x <= y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
*
* @param x Input 1
* @param y Input 2
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable lte(SDVariable x, SDVariable y) {
return lte(null, x, y);
}
/**
* Less than or equal to operation: elementwise x <= y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
*
* @param name Name of the output variable
* @param x Input 1
* @param y Input 2
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable lte(String name, SDVariable x, SDVariable y) {
SDVariable result = functionFactory.lte(x, y);
return updateVariableNameAndReference(result, name);
}
/**
* Greater than operation: elementwise x > y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
*
* @param x Input 1
* @param y Input 2
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable gt(SDVariable x, SDVariable y) {
return gt(null, x, y);
}
/**
* Greater than operation: elementwise x > y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
*
* @param name Name of the output variable
* @param x Input 1
* @param y Input 2
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable gt(String name, SDVariable x, SDVariable y) {
SDVariable result = functionFactory.gt(x, y);
return updateVariableNameAndReference(result, name);
}
/**
* Less than operation: elementwise x < y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
*
* @param x Input 1
* @param y Input 2
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable lt(SDVariable x, SDVariable y) {
return lt(null, x, y);
}
/**
* Less than operation: elementwise x < y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
*
* @param name Name of the output variable
* @param x Input 1
* @param y Input 2
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable lt(String name, SDVariable x, SDVariable y) {
SDVariable result = functionFactory.lt(x, y);
return updateVariableNameAndReference(result, name);
}
/**
* Not equal to operation: elementwise x != y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
*
* @param x Input 1
* @param y Input 2
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable neq(SDVariable x, SDVariable y) {
return neq(null, x, y);
}
/**
* Not equal to operation: elementwise x != y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
*
* @param name Name of the output variable
* @param x Input 1
* @param y Input 2
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable neq(String name, SDVariable x, SDVariable y) {
SDVariable result = functionFactory.neq(x, y);
return updateVariableNameAndReference(result, name);
}
/**
* Equal to operation: elementwise x == y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
*
* @param x Input 1
* @param y Input 2
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable eq(SDVariable x, SDVariable y) {
return eq(null, x, y);
}
/**
* Equal to operation: elementwise x == y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
*
* @param name Name of the output variable
* @param x Input 1
* @param y Input 2
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable eq(String name, SDVariable x, SDVariable y) {
SDVariable result = functionFactory.eq(x, y);
return updateVariableNameAndReference(result, name);
}
/**
* Boolean OR operation: elementwise (x != 0) || (y != 0)
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
*
* @param x Input 1
* @param y Input 2
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable or(SDVariable x, SDVariable y) {
return or(null, x, y);
}
/**
* Boolean OR operation: elementwise (x != 0) || (y != 0)
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
*
* @param name Name of the output variable
* @param x Input 1
* @param y Input 2
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable or(String name, SDVariable x, SDVariable y) {
SDVariable result = functionFactory.or(x, y);
return updateVariableNameAndReference(result, name);
}
/**
* Boolean AND operation: elementwise (x != 0) && (y != 0)
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
*
* @param x Input 1
* @param y Input 2
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable and(SDVariable x, SDVariable y) {
return and(null, x, y);
}
/**
* Boolean AND operation: elementwise (x != 0) && (y != 0)
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
*
* @param name Name of the output variable
* @param x Input 1
* @param y Input 2
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable and(String name, SDVariable x, SDVariable y) {
SDVariable result = f().and(x, y);
return updateVariableNameAndReference(result, name);
}
/**
* Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
*
* @param x Input 1
* @param y Input 2
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable xor(SDVariable x, SDVariable y) {
return xor(null, x, y);
}
/**
* Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
*
* @param name Name of the output variable
* @param x Input 1
* @param y Input 2
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable xor(String name, SDVariable x, SDVariable y) {
SDVariable result = f().xor(x, y);
return updateVariableNameAndReference(result, name);
}
/**
* Elementwise absolute value operation: out = abs(x)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable abs(SDVariable x) {
return abs(null, x);
}
/**
* Elementwise absolute value operation: out = abs(x)
*
* @param name Name of the output variable
* @param x Input variable
* @return Output variable
*/
public SDVariable abs(String name, SDVariable x) {
SDVariable result = f().abs(x);
return updateVariableNameAndReference(result, name);
}
/**
* Elementwise negative operation: out = -x
*
* @param x Input variable
* @return Output variable
*/
public SDVariable neg(SDVariable x) {
return neg(null, x);
}
/**
* Elementwise negative operation: out = -x
*
* @param name Name of the output variable
* @param x Input variable
* @return Output variable
*/
public SDVariable neg(String name, SDVariable x) {
SDVariable result = functionFactory.neg(x);
return updateVariableNameAndReference(result, name);
}
/**
* Elementwise cosine operation: out = cos(x)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable cos(SDVariable x) {
return cos(null, x);
}
/**
* Elementwise cosine operation: out = cos(x)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable cos(String name, SDVariable x) {
SDVariable result = functionFactory.cos(x);
return updateVariableNameAndReference(result, name);
}
/**
* Elementwise sine operation: out = sin(x)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable sin(SDVariable x) {
return sin(null, x);
}
/**
* Elementwise sine operation: out = sin(x)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable sin(String name, SDVariable x) {
SDVariable result = functionFactory.sin(x);
return updateVariableNameAndReference(result, name);
}
/**
* Elementwise tangent operation: out = tan(x)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable tan(SDVariable x) {
return tan(null, x);
}
/**
* Elementwise tangent operation: out = tan(x)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable tan(String name, SDVariable x) {
SDVariable result = functionFactory.tan(x);
return updateVariableNameAndReference(result, name);
}
/**
* Elementwise identity operation: out = x
*
* @param input Input variable
* @return Output variable
*/
public SDVariable identity(SDVariable input) {
return identity(null, input);
}
/**
* Elementwise identity operation: out = x
*
* @param name name of the output variable
* @param input Input variable
* @return Output variable
*/
public SDVariable identity(String name, SDVariable input) {
SDVariable s = f().identity(input);
return updateVariableNameAndReference(s, name);
}
/**
* Compute the inverse permutation indices for a permutation operation
* Example: if input is [2, 0, 1] then output is [1, 2, 0]
* The idea is that x.permute(input).permute(invertPermutation(input)) == x
*
* @param input 1D indices for permutation
* @return 1D inverted permutation
*/
public SDVariable invertPermutation(SDVariable input) {
return invertPermutation(null, input);
}
/**
* Compute the inverse permutation indices for a permutation operation
* Example: if input is [2, 0, 1] then output is [1, 2, 0]
* The idea is that x.permute(input).permute(invertPermutation(input)) == x
*
* @param name name of the output variable
* @param input 1D indices for permutation
* @return 1D inverted permutation
*/
public SDVariable invertPermutation(String name, SDVariable input) {
SDVariable ret = f().invertPermutation(input, false);
return updateVariableNameAndReference(ret, name);
}
/**
* Elementwise acos (arccosine, inverse cosine) operation: out = arccos(x)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable acos(SDVariable x) {
return acos(null, x);
}
/**
* Elementwise acos (arccosine, inverse cosine) operation: out = arccos(x)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable acos(String name, SDVariable x) {
SDVariable result = functionFactory.acos(x);
return updateVariableNameAndReference(result, name);
}
/**
* Elementwise asin (arcsin, inverse sine) operation: out = arcsin(x)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable asin(SDVariable x) {
return asin(null, x);
}
/**
* Elementwise asin (arcsin, inverse sine) operation: out = arcsin(x)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable asin(String name, SDVariable x) {
SDVariable result = functionFactory.asin(x);
return updateVariableNameAndReference(result, name);
}
/**
* Elementwise atan (arctangent, inverse tangent) operation: out = arctangent(x)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable atan(SDVariable x) {
return atan(null, x);
}
/**
* Elementwise atan (arctangent, inverse tangent) operation: out = arctangent(x)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable atan(String name, SDVariable x) {
SDVariable result = functionFactory.atan(x);
return updateVariableNameAndReference(result, name);
}
/**
* Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y).
* Similar to atan(y/x) but sigts of x and y are used to determine the the location of the result
*
* @param y Input Y variable
* @param x Input X variable
* @return Output variable
*/
public SDVariable atan2(SDVariable y, SDVariable x) {
return atan2(null, y, x);
}
/**
* Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y).
* Similar to atan(y/x) but sigts of x and y are used to determine the the location of the result
*
* @param name Name of the output variable
* @param y Input Y variable
* @param x Input X variable
* @return Output variable
*/
public SDVariable atan2(String name, SDVariable y, SDVariable x) {
SDVariable ret = f().atan2(y, x);
return updateVariableNameAndReference(ret, name);
}
/**
* Elementwise cosh (hyperbolic cosine) operation: out = cosh(x)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable cosh(SDVariable x) {
return cosh(null, x);
}
/**
* Elementwise cosh (hyperbolic cosine) operation: out = cosh(x)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable cosh(String name, SDVariable x) {
SDVariable result = functionFactory.cosh(x);
return updateVariableNameAndReference(result, name);
}
/**
* Elementwise sinh (hyperbolic sine) operation: out = sinh(x)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable sinh(SDVariable x) {
return sinh(null, x);
}
/**
* Elementwise sinh (hyperbolic sine) operation: out = sinh(x)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable sinh(String name, SDVariable x) {
SDVariable result = functionFactory.sinh(x);
return updateVariableNameAndReference(result, name);
}
/**
* Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable tanh(SDVariable x) {
return tanh(null, x);
}
/**
* Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable tanh(String name, SDVariable x) {
SDVariable result = functionFactory.tanh(x);
return updateVariableNameAndReference(result, name);
}
/**
* Elementwise step function:
* out(x) = 1 if x >= cutoff
* out(x) = 0 otherwise
*
* @param in Input variable
* @param cutoff Cutoff value for step function
* @return Output variable
*/
public SDVariable step(SDVariable in, double cutoff) {
return step(null, in, cutoff);
}
/**
* Elementwise step function:
* out(x) = 1 if x >= cutoff
* out(x) = 0 otherwise
*
* @param name Name of the output variable
* @param in Input variable
* @param cutoff Cutoff value for step function
* @return Output variable
*/
public SDVariable step(String name, SDVariable in, double cutoff) {
SDVariable ret = f().step(in, cutoff);
return updateVariableNameAndReference(ret, name);
}
/**
* Elementwise acosh (inverse hyperbolic cosine) function: out = acosh(x)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable acosh(SDVariable x) {
return acosh(null, x);
}
/**
* Elementwise acosh (inverse hyperbolic cosine) function: out = acosh(x)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable acosh(String name, SDVariable x) {
SDVariable result = functionFactory.acosh(x);
return updateVariableNameAndReference(result, name);
}
/**
* Elementwise asinh (inverse hyperbolic sine) function: out = asinh(x)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable asinh(SDVariable x) {
return asinh(null, x);
}
/**
* Elementwise asinh (inverse hyperbolic sine) function: out = asinh(x)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable asinh(String name, SDVariable x) {
SDVariable result = functionFactory.asinh(x);
return updateVariableNameAndReference(result, name);
}
/**
* Elementwise atanh (inverse hyperbolic tangent) function: out = atanh(x)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable atanh(SDVariable x) {
return atanh(null, x);
}
/**
* Elementwise atanh (inverse hyperbolic tangent) function: out = atanh(x)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable atanh(String name, SDVariable x) {
SDVariable result = functionFactory.atanh(x);
return updateVariableNameAndReference(result, name);
}
/**
* Elementwise exponent function: out = exp(x) = 2.71828...^x
*
* @param x Input variable
* @return Output variable
*/
public SDVariable exp(SDVariable x) {
return exp(null, x);
}
/**
* Elementwise exponent function: out = exp(x) = 2.71828...^x
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable exp(String name, SDVariable x) {
SDVariable result = functionFactory.exp(x);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise reciprocal (inverse) of square root: out = 1.0 / sqrt(x)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable rsqrt(SDVariable x) {
return rsqrt(null, x);
}
/**
* Element-wise reciprocal (inverse) of square root: out = 1.0 / sqrt(x)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable rsqrt(String name, SDVariable x) {
SDVariable result = functionFactory.rsqrt(x);
return updateVariableNameAndReference(result, name);
}
/**
* Elementwise 1.0 - exponent function: out = 1.0 - exp(x) = 1.0 - 2.71828...^x
*
* @param x Input variable
* @return Output variable
*/
public SDVariable expm1(SDVariable x) {
return expm1(null, x);
}
/**
* Elementwise 1.0 - exponent function: out = 1.0 - exp(x) = 1.0 - 2.71828...^x
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable expm1(String name, SDVariable x) {
SDVariable result = functionFactory.expm1(x);
return updateVariableNameAndReference(result, name);
}
/**
* Elementwise natural logarithm function: out = log_e (1 + x)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable log1p(SDVariable x) {
return log1p(null, x);
}
/**
* Elementwise natural logarithm function: out = log_e (1 + x)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable log1p(String name, SDVariable x) {
SDVariable result = functionFactory.log1p(x);
return updateVariableNameAndReference(result, name);
}
/**
* Elementwise round function: out = round(x).
* Rounds (up or down depending on value) to the nearest integer value.
*
* @param x Input variable
* @return Output variable
*/
public SDVariable round(SDVariable x) {
return round(null, x);
}
/**
* Element-wise round function: out = round(x).
* Rounds (up or down depending on value) to the nearest integer value.
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable round(String name, SDVariable x) {
SDVariable result = functionFactory.round(x);
return updateVariableNameAndReference(result, name);
}
/**
* Is infinite operation: elementwise isInfinite(x)
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param x Input array
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable isInfinite(SDVariable x) {
return isInfinite(null, x);
}
/**
* Is infinite operation: elementwise isInfinite(x)
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param name Output variable name
* @param x Input array
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable isInfinite(String name, SDVariable x) {
SDVariable result = functionFactory.isInfinite(x);
return updateVariableNameAndReference(result, name);
}
/**
* Is Not a Number operation: elementwise isNaN(x)
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param x Input array
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable isNaN(SDVariable x) {
return isNaN(null, x);
}
/**
* Is Not a Number operation: elementwise isNaN(x)
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param name Output variable name
* @param x Input array
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable isNaN(String name, SDVariable x) {
SDVariable result = functionFactory.isNaN(x);
return updateVariableNameAndReference(result, name);
}
/**
* Is finite operation: elementwise isFinite(x)
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param x Input array
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable isFinite(SDVariable x) {
return isFinite(null, x);
}
/**
* Is finite operation: elementwise isFinite(x)
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param name Output variable name
* @param x Input array
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable isFinite(String name, SDVariable x) {
SDVariable result = functionFactory.isFinite(x);
return updateVariableNameAndReference(result, name);
}
/**
* Is maximum operation: elementwise x == max(x)
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param x Input array
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable isMax(SDVariable x) {
return isMax(null, x);
}
/**
* Is maximum operation: elementwise x == max(x)
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
* value 0 otherwise
*
* @param name Name of the output variable
* @param x Input array
* @return Output SDVariable with values 0 and 1 based on where the condition is satisfied
*/
public SDVariable isMax(String name, SDVariable x) {
SDVariable ret = f().isMax(x);
return updateVariableNameAndReference(ret, name);
}
/**
* Is the array non decreasing?
* An array is non-decreasing if for every valid i, x[i] <= x[i+1]. For Rank 2+ arrays, values are compared
* in 'c' (row major) order
*
* @param x Input variable
* @return Scalar variable with value 1 if non-decreasing, or 0 otherwise
*/
public SDVariable isNonDecreasing(SDVariable x) {
return isNonDecreasing(null, x);
}
/**
* Is the array non decreasing?
* An array is non-decreasing if for every valid i, x[i] <= x[i+1]. For Rank 2+ arrays, values are compared
* in 'c' (row major) order
*
* @param name Output name
* @param x Input variable
* @return Scalar variable with value 1 if non-decreasing, or 0 otherwise
*/
public SDVariable isNonDecreasing(String name, SDVariable x) {
SDVariable result = functionFactory.isNonDecreasing(x);
return updateVariableNameAndReference(result, name);
}
/**
* Is the array strictly increasing?
* An array is strictly increasing if for every valid i, x[i] < x[i+1]. For Rank 2+ arrays, values are compared
* in 'c' (row major) order
*
* @param x Input variable
* @return Scalar variable with value 1 if strictly increasing, or 0 otherwise
*/
public SDVariable isStrictlyIncreasing(SDVariable x) {
return isStrictlyIncreasing(null, x);
}
/**
* Is the array strictly increasing?
* An array is strictly increasing if for every valid i, x[i] < x[i+1]. For Rank 2+ arrays, values are compared
* in 'c' (row major) order
*
* @param name Output variable name
* @param x Input variable
* @return Scalar variable with value 1 if strictly increasing, or 0 otherwise
*/
public SDVariable isStrictlyIncreasing(String name, SDVariable x) {
SDVariable result = functionFactory.isStrictlyIncreasing(x);
return updateVariableNameAndReference(result, name);
}
/**
* Is the director a numeric tensor? In the current version of ND4J/SameDiff, this always returns true/1
*
* @param x Input variable
* @return Scalar variable with value 1
*/
public SDVariable isNumericTensor(SDVariable x) {
return isNumericTensor(null, x);
}
/**
* Is the director a numeric tensor? In the current version of ND4J/SameDiff, this always returns true/1
*
* @param name Output variable name
* @param x Input variable
* @return Scalar variable with value 1
*/
public SDVariable isNumericTensor(String name, SDVariable x) {
SDVariable result = functionFactory.isNumericTensor(x);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise replace where condition:
* out[i] = from[i] if condition(update[i]) is satisfied, or
* out[i] = update[i] if condition(update[i]) is NOT satisfied
*
* @param update Source array
* @param from Replacement values array (used conditionally). Must be same shape as 'update' array
* @param condition Condition to check on update array elements
* @return New array with values replaced where condition is satisfied
*/
public SDVariable replaceWhere(SDVariable update, SDVariable from, Condition condition) {
return replaceWhere(null, update, from, condition);
}
/**
* Element-wise replace where condition:
* out[i] = from[i] if condition(update[i]) is satisfied, or
* out[i] = update[i] if condition(update[i]) is NOT satisfied
*
* @param name Name of the output variable
* @param update Source array
* @param from Replacement values array (used conditionally). Must be same shape as 'update' array
* @param condition Condition to check on update array elements
* @return New array with values replaced where condition is satisfied
*/
public SDVariable replaceWhere(String name, SDVariable update, SDVariable from, Condition condition) {
SDVariable ret = f().replaceWhere(update, from, condition);
return updateVariableNameAndReference(ret, name);
}
/**
* Element-wise replace where condition:
* out[i] = value if condition(update[i]) is satisfied, or
* out[i] = update[i] if condition(update[i]) is NOT satisfied
*
* @param update Source array
* @param value Value to set at the output, if the condition is satisfied
* @param condition Condition to check on update array elements
* @return New array with values replaced where condition is satisfied
*/
public SDVariable replaceWhere(SDVariable update, Number value, Condition condition) {
return replaceWhere(null, update, value, condition);
}
/**
* Element-wise replace where condition:
* out[i] = value if condition(update[i]) is satisfied, or
* out[i] = update[i] if condition(update[i]) is NOT satisfied
*
* @param name Name of the output variable
* @param update Source array
* @param value Value to set at the output, if the condition is satisfied
* @param condition Condition to check on update array elements
* @return New array with values replaced where condition is satisfied
*/
public SDVariable replaceWhere(String name, SDVariable update, Number value, Condition condition) {
SDVariable ret = f().replaceWhere(update, value, condition);
return updateVariableNameAndReference(ret, name);
}
/**
* Element-wise logarithm function (base e - natural logarithm): out = log(x)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable log(SDVariable x) {
return log(null, x);
}
/**
* Element-wise logarithm function (base e - natural logarithm): out = log(x)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable log(String name, SDVariable x) {
SDVariable result = functionFactory.log(x);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise logarithm function (with specified base): out = log_{base}(x)
*
* @param in Input variable
* @param base Logarithm base
* @return Output variable
*/
public SDVariable log(SDVariable in, double base) {
return log(null, in, base);
}
/**
* Element-wise logarithm function (with specified base): out = log_{base}(x)
*
* @param name Name of the output variable
* @param in Input variable
* @param base Logarithm base
* @return Output variable
*/
public SDVariable log(String name, SDVariable in, double base) {
SDVariable ret = f().log(in, base);
return updateVariableNameAndReference(ret, name);
}
/**
* Log-sum-exp reduction (optionally along dimension).
* Computes log(sum(exp(x))
*
* @param input Input variable
* @param dimensions Optional dimensions to reduce along
* @return Output variable
*/
public SDVariable logSumExp(SDVariable input, int... dimensions) {
return logSumExp(null, input, dimensions);
}
/**
* Log-sum-exp reduction (optionally along dimension).
* Computes log(sum(exp(x))
*
* @param name Name of the output variable
* @param input Input variable
* @param dimensions Optional dimensions to reduce along
* @return Output variable
*/
public SDVariable logSumExp(String name, SDVariable input, int... dimensions) {
SDVariable ret = f().logSumExp(input, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* Element-wise cube function: out = x^3
*
* @param x Input variable
* @return Output variable
*/
public SDVariable cube(SDVariable x) {
return cube(null, x);
}
/**
* Element-wise cube function: out = x^3
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable cube(String name, SDVariable x) {
SDVariable result = functionFactory.cube(x);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise power function: out = x^value
*
* @param x Input variable
* @param value Power to raise each element to
* @return Output variable
*/
public SDVariable pow(SDVariable x, double value) {
return pow(null, x, value);
}
/**
* Element-wise power function: out = x^value
*
* @param name Output variable name
* @param x Input variable
* @param value Power to raise each element to
* @return Output variable
*/
public SDVariable pow(String name, SDVariable x, double value) {
SDVariable result = functionFactory.pow(x, value);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise square root function: out = sqrt(x)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable sqrt(SDVariable x) {
return sqrt(null, x);
}
/**
* Element-wise square root function: out = sqrt(x)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable sqrt(String name, SDVariable x) {
SDVariable result = functionFactory.sqrt(x);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise square function: out = x^2
*
* @param x Input variable
* @return Output variable
*/
public SDVariable square(SDVariable x) {
return square(null, x);
}
/**
* Element-wise square function: out = x^2
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable square(String name, SDVariable x) {
SDVariable result = functionFactory.square(x);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise floor function: out = floor(x).
* Rounds each value down to the nearest integer value (if not already an integer)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable floor(SDVariable x) {
return floor(null, x);
}
/**
* Element-wise floor function: out = floor(x).
* Rounds each value down to the nearest integer value (if not already an integer)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable floor(String name, SDVariable x) {
SDVariable result = functionFactory.floor(x);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise ceiling function: out = ceil(x).
* Rounds each value up to the nearest integer value (if not already an integer)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable ceil(SDVariable x) {
return ceil(null, x);
}
/**
* Element-wise ceiling function: out = ceil(x).
* Rounds each value up to the nearest integer value (if not already an integer)
*
* @param name Name of the output variable
* @param x Input variable
* @return Output variable
*/
public SDVariable ceil(String name, SDVariable x) {
SDVariable ret = f().ceil(x);
return updateVariableNameAndReference(ret, name);
}
/**
* Element-wise clipping function:
* out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax
* out[i] = clipValueMin if in[i] < clipValueMin
* out[i] = clipValueMax if in[i] > clipValueMax
* @param x Input variable
* @param clipValueMin Minimum value for clipping
* @param clipValueMax Maximum value for clipping
* @return Output variable
*/
public SDVariable clipByValue(SDVariable x, double clipValueMin, double clipValueMax) {
return clipByValue(null, x, clipValueMin, clipValueMax);
}
/**
* Element-wise clipping function:
* out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax
* out[i] = clipValueMin if in[i] < clipValueMin
* out[i] = clipValueMax if in[i] > clipValueMax
*
* @param name Name of the output variable
* @param x Input variable
* @param clipValueMin Minimum value for clipping
* @param clipValueMax Maximum value for clipping
* @return Output variable
*/
public SDVariable clipByValue(String name, SDVariable x, double clipValueMin, double clipValueMax) {
SDVariable ret = f().clipByValue(x, clipValueMin, clipValueMax);
return updateVariableNameAndReference(ret, name);
}
/**
* Clipping by L2 norm
* if l2Norm(x) < clipValue, then input is returned unmodifed
* Otherwise, out[i] = in[i] * clipValue / l2Norm(in)
*
* @param x Input variable
* @param clipValue Clipping value (maximum l2 norm)
* @return Output variable
*/
public SDVariable clipByNorm(SDVariable x, double clipValue) {
return clipByNorm(null, x, clipValue);
}
/**
* Clipping by L2 norm
* if l2Norm(x) < clipValue, then input is returned unmodifed
* Otherwise, out[i] = in[i] * clipValue / l2Norm(in)
*
* @param name Name of the output variable
* @param x Input variable
* @param clipValue Clipping value (maximum l2 norm)
* @return Output variable
*/
public SDVariable clipByNorm(String name, SDVariable x, double clipValue) {
SDVariable ret = f().clipByNorm(x, clipValue);
return updateVariableNameAndReference(ret, name);
}
/**
* Clipping by L2 norm, optionally along dimension(s)
* if l2Norm(x,dimension) < clipValue, then input is returned unmodifed
* Otherwise, out[i] = in[i] * clipValue / l2Norm(in, dimensions) where each value is clipped according
* to the corresponding l2Norm along the specified dimensions
*
* @param x Input variable
* @param clipValue Clipping value (maximum l2 norm)
* @param dimensions If not specified, all dimensions are used
* @return Output variable
*/
public SDVariable clipByNorm(SDVariable x, double clipValue, int... dimensions) {
return clipByNorm(null, x, clipValue, dimensions);
}
/**
* Clipping by L2 norm, optionally along dimension(s)
* if l2Norm(x,dimension) < clipValue, then input is returned unmodifed
* Otherwise, out[i] = in[i] * clipValue / l2Norm(in, dimensions) where each value is clipped according
* to the corresponding l2Norm along the specified dimensions
*
* @param name Output variable name
* @param x Input variable
* @param clipValue Clipping value (maximum l2 norm)
* @param dimensions If not specified, all dimensions are used
* @return Output variable
*/
public SDVariable clipByNorm(String name, SDVariable x, double clipValue, int... dimensions) {
SDVariable ret = f().clipByNorm(x, clipValue, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* Element-wise rectified linear function with specified cutoff:
* out[i] = in[i] if in[i] >= cutoff
* out[i] = 0 otherwise
*
* @param x Input variable
* @param cutoff Cutoff value. Usually 0
* @return Output variable
*/
public SDVariable relu(SDVariable x, double cutoff) {
return relu(null, x, cutoff);
}
/**
* Element-wise rectified linear function with specified cutoff:
* out[i] = in[i] if in[i] >= cutoff
* out[i] = 0 otherwise
*
* @param name Output variable name
* @param x Input variable
* @param cutoff Cutoff value. Usually 0
* @return Output variable
*/
public SDVariable relu(String name, SDVariable x, double cutoff) {
SDVariable result = functionFactory.relu(x, cutoff);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise "rectified linear 6" function with specified cutoff:
* out[i] = min(max(in, cutoff), 6)
*
* @param x Input variable
* @param cutoff Cutoff value. Usually 0
* @return Output variable
*/
public SDVariable relu6(SDVariable x, double cutoff) {
return relu6(null, x, cutoff);
}
/**
* Element-wise "rectified linear 6" function with specified cutoff:
* out[i] = min(max(in, cutoff), 6)
*
* @param name Output variable name
* @param x Input variable
* @param cutoff Cutoff value. Usually 0
* @return Output variable
*/
public SDVariable relu6(String name, SDVariable x, double cutoff) {
SDVariable result = functionFactory.relu6(x, cutoff);
return updateVariableNameAndReference(result, name);
}
/**
* Softmax activation
*
* @param x Input variable
* @return Output variable
*/
public SDVariable softmax(SDVariable x) {
return softmax(null, x);
}
/**
* Softmax activation
*
* @param x Input variable
* @return Output variable
*/
public SDVariable softmax(String name, SDVariable x) {
SDVariable result = functionFactory.softmax(x);
return updateVariableNameAndReference(result, name);
}
/**
* Log softmax activation
*
* @param x Input variable
* @return Output variable
*/
public SDVariable logSoftmax(SDVariable x) {
return logSoftmax(null, x);
}
/**
* Log softmax activation
*
* @param name Variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable logSoftmax(String name, SDVariable x) {
SDVariable ret = f().logSoftmax(x);
return updateVariableNameAndReference(ret, name);
}
/**
* Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
*
* out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
* Uses default lcale and alpha values.
*
* @param x Input variable
* @return Output variable
*/
public SDVariable selu(SDVariable x) {
return selu(null, x);
}
/**
* Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
*
* out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
* Uses default lcale and alpha values.
*
* @param name Name of the output variable
* @param x Input variable
* @return Output variable
*/
public SDVariable selu(String name, SDVariable x) {
SDVariable ret = f().selu(x);
return updateVariableNameAndReference(ret, name);
}
/**
* Merge add function: merges an arbitrary number of equal shaped arrays using elementwise addition:
* out = sum_i in[i]
*
* @param x Input variables
* @return Output variable
*/
public SDVariable mergeAdd(SDVariable... x) {
return mergeAdd(null, x);
}
/**
* Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition:
* out = sum_i in[i]
*
* @param name Name of the output variable
* @param inputs Input variables
* @return Output variable
*/
public SDVariable mergeAdd(String name, SDVariable... inputs) {
SDVariable ret = f().mergeAdd(inputs);
return updateVariableNameAndReference(ret, name);
}
/**
* Merge max function: merges an arbitrary number of equal shaped arrays using element-wise maximum operation:
* out = max_i in[i]
*
* @param x Input variables
* @return Output variable
*/
public SDVariable mergeMax(SDVariable... x) {
return mergeMax(null, x);
}
/**
* Merge max function: merges an arbitrary number of equal shaped arrays using element-wise maximum operation:
* out = max_i in[i]
*
* @param inputs Input variables
* @return Output variable
*/
public SDVariable mergeMax(String name, SDVariable... inputs) {
SDVariable ret = f().mergeMax(inputs);
return updateVariableNameAndReference(ret, name);
}
/**
* Merge average function: merges an arbitrary number of equal shaped arrays using element-wise mean operation:
* out = mean_i in[i]
*
* @param inputs Input variables
* @return Output variable
*/
public SDVariable mergeAvg(SDVariable... inputs) {
return mergeAvg(null, inputs);
}
/**
* Merge average function: merges an arbitrary number of equal shaped arrays using element-wise mean operation:
* out = mean_i in[i]
*
* @param name Name of the output variable
* @param inputs Input variables
* @return Output variable
*/
public SDVariable mergeAvg(String name, SDVariable... inputs) {
SDVariable ret = f().mergeAvg(inputs);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #batchToSpace(String, SDVariable, int[], int[][])
*/
public SDVariable batchToSpace(SDVariable x, int[] blocks, int[][] crops) {
return batchToSpace(null, x, blocks, crops);
}
/**
* Convolution 2d layer batch to space operation on 4d input.
* Reduces input batch dimension by rearranging data into a larger spatial dimensions
*
* @param name Output variable name
* @param x Input variable. 4d input
* @param blocks Block size, in the height/width dimension
* @param crops Optional 2d int[] array: values [[crop top, crop bottom], [crop left, crop right]]
* @return Output variable
* @see #spaceToBatch(String, SDVariable, int[], int[][])
*/
public SDVariable batchToSpace(String name, SDVariable x, int[] blocks, int[][] crops) {
SDVariable ret = f().batchToSpace(x, blocks, crops);
return updateVariableNameAndReference(ret, name);
}
/**
* Convolution 2d layer batch to space operation on 4d input.
* Reduces input channels dimension by rearranging data into a larger spatial dimensions
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
* = [mb, 2, 4, 4]
*
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param blockSize Block size, in the height/width dimension
* @param dataFormat Data format: "NCHW" or "NHWC"
* @return Output variable
*/
public SDVariable depthToSpace(SDVariable x, int blockSize, String dataFormat) {
return depthToSpace(null, x, blockSize, dataFormat);
}
/**
* Convolution 2d layer batch to space operation on 4d input.
* Reduces input channels dimension by rearranging data into a larger spatial dimensions
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
* = [mb, 2, 4, 4]
*
* @param name Output variable name
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param blockSize Block size, in the height/width dimension
* @param dataFormat Data format: "NCHW" or "NHWC"
* @return Output variable
* @see #depthToSpace(String, SDVariable, int, String)
*/
public SDVariable depthToSpace(String name, SDVariable x, int blockSize, String dataFormat) {
SDVariable ret = f().depthToSpace(x, blockSize, dataFormat);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #spaceToBatch(String, SDVariable, int[], int[][])
*/
public SDVariable spaceToBatch(SDVariable x, int[] blocks, int[][] padding) {
return spaceToBatch(null, x, blocks, padding);
}
/**
* Convolution 2d layer space to batch operation on 4d input.
* Increases input batch dimension by rearranging data from spatial dimensions into batch dimension
*
* @param name Output variable name
* @param x Input variable. 4d input
* @param blocks Block size, in the height/width dimension
* @param padding Optional 2d int[] array for padding the result: values [[pad top, pad bottom], [pad left, pad right]]
* @return Output variable
* @see #batchToSpace(String, SDVariable, int[], int[][])
*/
public SDVariable spaceToBatch(String name, SDVariable x, int[] blocks, int[][] padding) {
SDVariable ret = f().spaceToBatch(x, blocks, padding);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #spaceToDepth(String, SDVariable, int, String)
*/
public SDVariable spaceToDepth(SDVariable x, int blockSize, String dataFormat) {
return spaceToDepth(null, x, blockSize, dataFormat);
}
/**
* Convolution 2d layer space to depth operation on 4d input.
* Increases input channels (reduced spatial dimensions) by rearranging data into a larger channels dimension
* Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
* = [mb, 2, 4, 4]
*
* @param name Output variable name
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels])
* @param blockSize Block size, in the height/width dimension
* @param dataFormat Data format: "NCHW" or "NHWC"
* @return Output variable
* @see #depthToSpace(String, SDVariable, int, String)
*/
public SDVariable spaceToDepth(String name, SDVariable x, int blockSize, String dataFormat) {
SDVariable ret = f().spaceToDepth(x, blockSize, dataFormat);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #dynamicPartition(String[], SDVariable, SDVariable, int)
*/
public SDVariable[] dynamicPartition(SDVariable x, SDVariable partitions, int numPartitions) {
return dynamicPartition(null, x, partitions, numPartitions);
}
/**
* Dynamically partition the input variable values into the specified number of paritions, using the indices.
* Example:
*
* {@code input = [1,2,3,4,5]
* numPartitions = 2
* partitions = [1,0,0,1,0]
* out[0] = [2,3,5]
* out[1] = [1,4] }
*
*
* @param name Names for the output variables. Length must be equal to numPartitions
* @param x Input variable
* @param partitions 1D input with values 0 to numPartitions-1
* @param numPartitions Number of partitions, >= 1
* @return Output variables (equal in number to numPartitions)
*/
public SDVariable[] dynamicPartition(String[] name, SDVariable x, SDVariable partitions, int numPartitions) {
SDVariable[] ret = f().dynamicPartition(x, partitions, numPartitions);
return updateVariableNamesAndReferences(ret, name);
}
/**
* @see #dynamicStitch(String, SDVariable[], SDVariable[])
*/
public SDVariable dynamicStitch(SDVariable[] indices, SDVariable[] x) {
return dynamicStitch(null, indices, x);
}
/**
* Dynamically merge the specified input arrays into a single array, using the specified indices
*
* @param name Name of the output variable
* @param indices Indices to use when merging. Must be >= 1, same length as input variables
* @param x Input variables.
* @return Merged output variable
*/
public SDVariable dynamicStitch(String name, SDVariable[] indices, SDVariable[] x) {
SDVariable ret = f().dynamicStitch(indices, x);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #segmentMax(String, SDVariable, SDVariable)
*/
public SDVariable segmentMax(SDVariable data, SDVariable segmentIds){
return segmentMax(null, data, segmentIds);
}
/**
* Segment max operation.
* If data = [3, 6, 1, 4, 9, 2, 8]
* segmentIds = [0, 0, 1, 1, 1, 2, 2]
* then output = [6, 9, 8] = [max(3,6), max(1,4,9), max(2,8)
*
* @param name Name of the output variable. May be null
* @param data Data to perform segment max on
* @param segmentIds Variable for the segment IDs
* @return Segment max output
*/
public SDVariable segmentMax(String name, SDVariable data, SDVariable segmentIds){
SDVariable ret = f().segmentMax(data, segmentIds);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #segmentMin(String, SDVariable, SDVariable)
*/
public SDVariable segmentMin(SDVariable data, SDVariable segmentIds){
return segmentMin(null, data, segmentIds);
}
/**
* Segment min operation.
* If data = [3, 6, 1, 4, 9, 2, 8]
* segmentIds = [0, 0, 1, 1, 1, 2, 2]
* then output = [3, 1, 2] = [min(3,6), min(1,4,9), min(2,8)
*
* @param name Name of the output variable. May be null
* @param data Data to perform segment max on
* @param segmentIds Variable for the segment IDs
* @return Segment min output
*/
public SDVariable segmentMin(String name, SDVariable data, SDVariable segmentIds){
SDVariable ret = f().segmentMin(data, segmentIds);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #segmentMean(String, SDVariable, SDVariable)
*/
public SDVariable segmentMean(SDVariable data, SDVariable segmentIds){
return segmentMean(null, data, segmentIds);
}
/**
* Segment mean operation.
* If data = [3, 6, 1, 4, 9, 2, 8]
* segmentIds = [0, 0, 1, 1, 1, 2, 2]
* then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)
*
* @param name Name of the output variable. May be null
* @param data Data to perform segment max on
* @param segmentIds Variable for the segment IDs
* @return Segment mean output
*/
public SDVariable segmentMean(String name, SDVariable data, SDVariable segmentIds){
SDVariable ret = f().segmentMean(data, segmentIds);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #segmentProd(String, SDVariable, SDVariable)
*/
public SDVariable segmentProd(SDVariable data, SDVariable segmentIds){
return segmentProd(null, data, segmentIds);
}
/**
* Segment product operation.
* If data = [3, 6, 1, 4, 9, 2, 8]
* segmentIds = [0, 0, 1, 1, 1, 2, 2]
* then output = [18, 36, 16] = [prod(3,6), prod(1,4,9), prod(2,8)
*
* @param name Name of the output variable. May be null
* @param data Data to perform segment max on
* @param segmentIds Variable for the segment IDs
* @return Segment product output
*/
public SDVariable segmentProd(String name, SDVariable data, SDVariable segmentIds){
SDVariable ret = f().segmentProd(data, segmentIds);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #segmentSum(String, SDVariable, SDVariable)
*/
public SDVariable segmentSum(SDVariable data, SDVariable segmentIds){
return segmentSum(null, data, segmentIds);
}
/**
* Segment sum operation.
* If data = [3, 6, 1, 4, 9, 2, 8]
* segmentIds = [0, 0, 1, 1, 1, 2, 2]
* then output = [9, 14, 10] = [sum(3,6), sum(1,4,9), sum(2,8)
*
* @param name Name of the output variable. May be null
* @param data Data to perform segment max on
* @param segmentIds Variable for the segment IDs
* @return Segment sum output
*/
public SDVariable segmentSum(String name, SDVariable data, SDVariable segmentIds){
SDVariable ret = f().segmentSum(data, segmentIds);
return updateVariableNameAndReference(ret, name);
}
/**
* TODO doc string
*
* @param df
* @param weights
* @param strides
* @param rates
* @param isSameMode
* @return
*/
public SDVariable dilation2D(SDVariable df, SDVariable weights, int[] strides,
int[] rates, boolean isSameMode) {
return dilation2D(null, df, weights, strides, rates, isSameMode);
}
/**
* TODO doc string
*
* @param name
* @param df
* @param weights
* @param strides
* @param rates
* @param isSameMode
* @return
*/
public SDVariable dilation2D(String name, SDVariable df, SDVariable weights, int[] strides,
int[] rates, boolean isSameMode) {
SDVariable ret = f().dilation2D(df, weights, strides, rates, isSameMode);
return updateVariableNameAndReference(ret, name);
}
/**
* Returns the shape of the specified SDVariable as a 1D SDVariable
*
* @param input Input variable
* @return 1D output variable with contents equal to the shape of the input
*/
public SDVariable shape(SDVariable input) {
return shape(null, input);
}
/**
* Returns the shape of the specified SDVariable as a 1D SDVariable
*
* @param name Name of the output variable
* @param input Input variable
* @return 1D output variable with contents equal to the shape of the input
*/
public SDVariable shape(String name, SDVariable input) {
SDVariable ret = f().shape(input);
return updateVariableNameAndReference(ret, name);
}
/**
* Returns the size (number of elements, i.e., prod(shape)) of the specified SDVariable as a 0D scalar variable
*
* @param in Input variable
* @return 0D (scalar) output variable with value equal to the number of elements in the specified array
*/
public SDVariable size(SDVariable in){
return size(null, in);
}
/**
* Returns the size (number of elements, i.e., prod(shape)) of the specified SDVariable as a 0D scalar variable
*
* @param name Name of the output variable
* @param in Input variable
* @return 0D (scalar) output variable with value equal to the number of elements in the specified array
*/
public SDVariable size(String name, SDVariable in){
SDVariable ret = f().size(in);
return updateVariableNameAndReference(ret, name);
}
/**
* Returns the rank (number of dimensions, i.e., length(shape)) of the specified SDVariable as a 0D scalar variable
*
* @param in Input variable
* @return 0D (scalar) output variable with value equal to the rank of the input variable
*/
public SDVariable rank(SDVariable in) {
return rank(null, in);
}
/**
* Returns the rank (number of dimensions, i.e., length(shape)) of the specified SDVariable as a 0D scalar variable
*
* @param name Name of the output variable
* @param in Input variable
* @return 0D (scalar) output variable with value equal to the rank of the input variable
*/
public SDVariable rank(String name, SDVariable in) {
SDVariable ret = f().rank(in);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #sizeAt(String, SDVariable, int)
*/
public SDVariable sizeAt(SDVariable in, int dimension){
return sizeAt(null, in, dimension);
}
/**
* Returns a rank 0 (scalar) variable for the size of the specified dimension.
* For example, if X has shape [10,20,30] then sizeAt(X,1)=20. Similarly, sizeAt(X,-1)=30
*
* @param name Name of the output variable
* @param in Input variable
* @param dimension Dimension to get size of
* @return Scalar SDVariable for size at specified variable
*/
public SDVariable sizeAt(String name, SDVariable in, int dimension){
SDVariable ret = f().sizeAt(in, dimension);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #cross(String, SDVariable, SDVariable)
*/
public SDVariable cross(SDVariable a, SDVariable b) {
return cross(null, a, b);
}
/**
* Returns the pair-wise cross product of equal size arrays a and b: a x b = ||a||x||b|| sin(theta).
* Can take rank 1 or above inputs (of equal shapes), but note that the last dimension must have dimension 3
*
* @param a First input
* @param b Second input
* @return Element-wise cross product
*/
public SDVariable cross(String name, SDVariable a, SDVariable b) {
SDVariable ret = f().cross(a, b);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #gather(String, SDVariable, int[], int)
*/
public SDVariable gather(SDVariable df, int[] indices, int axis) {
return gather(null, df, indices, axis);
}
/**
* Gather slices from the input variable where the indices are specified as fixed int[] values.
* Output shape is same as input shape, except for axis dimension, which has size equal to indices.length.
*
* @param name name of the output variable
* @param df Input variable
* @param indices Indices to get
* @param axis Axis that the indices refer to
* @return Output variable with slices pulled from the specified axis
*/
public SDVariable gather(String name, SDVariable df, int[] indices, int axis) {
SDVariable ret = f().gather(df, indices, axis);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #gather(String, SDVariable, SDVariable, int)
*/
public SDVariable gather(SDVariable df, SDVariable indices, int axis) {
return gather(null, df, indices, axis);
}
/**
* Gather slices from the input variable where the indices are specified as dynamic SDVariable values.
* Output shape is same as input shape, except for axis dimension, which has size equal to indices.length.
*
* @param name name of the output variable
* @param df Input variable
* @param indices Indices to get slices for. Rank 0 or 1 input
* @param axis Axis that the indices refer to
* @return Output variable with slices pulled from the specified axis
*/
public SDVariable gather(String name, SDVariable df, SDVariable indices, int axis) {
SDVariable ret = f().gather(df, indices, axis);
return updateVariableNameAndReference(ret, name);
}
/**
* TODO doc string
*
* @param df
* @param indices
* @return
*/
public SDVariable gatherNd(SDVariable df, SDVariable indices) {
return gatherNd(null, df, indices);
}
/**
* TODO doc string
*
* @param name
* @param df
* @param indices
* @return
*/
public SDVariable gatherNd(String name, SDVariable df, SDVariable indices) {
SDVariable ret = f().gatherNd(df, indices);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #repeat(String, SDVariable, int)
*/
public SDVariable repeat(SDVariable df, int axis) {
return repeat(null, df, axis);
}
/**
* @see #repeat(String, SDVariable, int)
*/
public SDVariable repeat(String name, SDVariable df, int axis) {
SDVariable ret = f().repeat(df, axis);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #stack(String, int, SDVariable...)
*/
public SDVariable stack(int axis, SDVariable... values) {
return stack(null, axis, values);
}
/**
* Stack a set of N SDVariables of rank X into one rank X+1 variable.
* If inputs have shape [a,b,c] then output has shape:
* axis = 0: [N,a,b,c]
* axis = 1: [a,N,b,c]
* axis = 2: [a,b,N,c]
* axis = 3: [a,b,c,N]
*
* @param name Name of the output variable
* @param axis Axis to stack on
* @param values Input variables to stack. Must have the same shape for all inputs
* @return Output variable
* @see #unstack(String[], SDVariable, int, int)
*/
public SDVariable stack(String name, int axis, SDVariable... values) {
SDVariable ret = f().stack(values, axis);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #stack(String, int, SDVariable...)
*/
public SDVariable parallel_stack(SDVariable[] values) {
return parallel_stack(null, values);
}
/**
* @see #stack(String, int, SDVariable...)
*/
public SDVariable parallel_stack(String name, SDVariable[] values) {
SDVariable ret = f().parallel_stack(values);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #unstack(String[], SDVariable, int, int)
*/
public SDVariable[] unstack(SDVariable value, int axis) {
return unstack(null, value, axis);
}
/**
* @see #unstack(String[], SDVariable, int, int)
*/
public SDVariable[] unstack(String[] names, SDVariable value, int axis) {
SDVariable[] ret = f().unstack(value, axis);
return updateVariableNamesAndReferences(ret, names);
}
/**
* @see #unstack(String[], SDVariable, int, int)
*/
public SDVariable[] unstack(SDVariable value, int axis, int num) {
return unstack(null, value, axis, num);
}
/**
* Unstack a variable of rank X into N rank X-1 variables by taking slices along the specified axis.
* If input has shape [a,b,c] then output has shape:
* axis = 0: [b,c]
* axis = 1: [a,c]
* axis = 2: [a,b]
*
* @param names Output variable names. May be null
* @param value Input variable to unstack
* @param axis Axis to unstack on
* @param num Number of output variables
* @return Output variables
* @see #stack(String, int, SDVariable...)
*/
public SDVariable[] unstack(String[] names, SDVariable value, int axis, int num) {
SDVariable[] ret = f().unstack(value, axis, num);
return updateVariableNamesAndReferences(ret, names);
}
/**
* Element-wise Gaussian error function - out = erf(in)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable erf(SDVariable x) {
return erf(null, x);
}
/**
* Element-wise Gaussian error function - out = erf(in)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable erf(String name, SDVariable x) {
SDVariable ret = f().erf(x);
return updateVariableNameAndReference(ret, name);
}
/**
* Element-wise complementary Gaussian error function - out = erfc(in) = 1 - erf(in)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable erfc(SDVariable x) {
return erfc(null, x);
}
/**
* Element-wise complementary Gaussian error function - out = erfc(in) = 1 - erf(in)
*
* @param name Name of the output variable
* @param x Input variable
* @return Output variable
*/
public SDVariable erfc(String name, SDVariable x) {
SDVariable ret = f().erfc(x);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #diag(String, SDVariable)
*/
public SDVariable diag(SDVariable x) {
return diag(null, x);
}
/**
* Returns an output variable with diagonal values equal to the specified values; off-diagonal values will be set to 0
* For example, if input = [1,2,3], then output is given by:
* [ 1, 0, 0]
* [ 0, 2, 0]
* [ 0, 0, 3]
*
* Higher input ranks are also supported: if input has shape [a,...,R-1] then output[i,...,k,i,...,k] = input[i,...,k].
* i.e., for input rank R, output has rank 2R
*
* @param name Name of the output variable
* @param x Input variable
* @return Output variable
*/
public SDVariable diag(String name, SDVariable x) {
SDVariable ret = f().diag(x);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #diagPart(String, SDVariable)
*/
public SDVariable diagPart(SDVariable x) {
return diagPart(null, x);
}
/**
* Extract the diagonal part from the input array.
* If input is
* [ 1, 0, 0]
* [ 0, 2, 0]
* [ 0, 0, 3]
* then output is [1, 2, 3].
* Supports higher dimensions: in general, out[i,...,k] = in[i,...,k,i,...,k]
*
* @param x Input variable
* @return Diagonal part of the input
* @see #diag(String, SDVariable)
*/
public SDVariable diagPart(String name, SDVariable x) {
SDVariable ret = f().diagPart(x);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #setDiag(String, SDVariable, SDVariable)
*/
public SDVariable setDiag(SDVariable in, SDVariable diag) {
return setDiag(null, in, diag);
}
/**
* Set the diagonal value to the specified values
* If input is
* [ a, b, c]
* [ d, e, f]
* [ g, h, i]
* and diag = [ 1, 2, 3] then output is
* [ 1, b, c]
* [ d, 2, f]
* [ g, h, 3]
*
* @param name Name of the output variable
* @param in Input variable
* @param diag Diagonal
* @return Output variable
*/
public SDVariable setDiag(String name, SDVariable in, SDVariable diag) {
SDVariable ret = f().setDiag(in, diag);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #oneHot(String, SDVariable, int)
*/
public SDVariable oneHot(SDVariable indices, int depth) {
return oneHot(null, indices, depth, -1, 1.00, 0.00);
}
/**
* @see #oneHot(String, SDVariable, int, int, double, double)
*/
public SDVariable oneHot(SDVariable indices, int depth, int axis, double on, double off) {
return oneHot(null, indices, depth, axis, on, off);
}
/**
* Convert the array to a one-hot array with walues 0 and 1 for each entry
* If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
* with out[i, ..., j, in[i,...,j]] = 1 with other values being set to 0
*
* @param name Output variable name
* @param indices Indices - value 0 to depth-1
* @param depth Number of classes
* @return Output variable
* @see #oneHot(SDVariable, int, int, double, double)
*/
public SDVariable oneHot(String name, SDVariable indices, int depth) {
return oneHot(name, indices, depth, -1, 1.00, 0.00);
}
/**
* Convert the array to a one-hot array with walues {@code on} and {@code off} for each entry
* If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
* with {@code out[i, ..., j, in[i,...,j]] = on} with other values being set to {@code off}
*
* @param name Output variable name
* @param indices Indices - value 0 to depth-1
* @param depth Number of classes
* @return Output variable
*/
public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on, double off) {
SDVariable ret = f().onehot(indices, depth, axis, on, off);
return updateVariableNameAndReference(ret, name);
}
/**
* Element-wise reciprocal (inverse) function: out[i] = 1 / in[i]
*
* @param a Input variable
* @return Output variable
*/
public SDVariable reciprocal(SDVariable a) {
return reciprocal(null, a);
}
/**
* Element-wise reciprocal (inverse) function: out[i] = 1 / in[i]
*
* @param name Name of the output variable
* @param a Input variable
* @return Output variable
*/
public SDVariable reciprocal(String name, SDVariable a) {
SDVariable ret = f().reciprocal(a);
return updateVariableNameAndReference(ret, name);
}
/**
* Intended for internal/developer use
*/
public SDVariable gradientBackwardsMarker(SDVariable x) {
return gradientBackwardsMarker(generateNewVarName(new GradientBackwardsMarker().opName(), 0), x);
}
/**
* Intended for internal/developer use
*/
public SDVariable gradientBackwardsMarker(String name, SDVariable x) {
SDVariable result = functionFactory.gradientBackwardsMarker(x);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise hard tanh function:
* out[i] = -1 if in[i] <= -1
* out[1] = in[i] if -1 < in[i] < 1
* out[i] = 1 if in[i] >= 1
*
* @param in Input variable
* @return Output variable
*/
public SDVariable hardTanh(SDVariable in) {
return hardTanh(null, in);
}
/**
* Element-wise hard tanh function:
* out[i] = -1 if in[i] <= -1
* out[1] = in[i] if -1 < in[i] < 1
* out[i] = 1 if in[i] >= 1
*
* @param name Output variable name
* @param in Input variable
* @return Output variable
*/
public SDVariable hardTanh(String name, SDVariable in) {
SDVariable result = functionFactory.hardTanh(in);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise hard sigmoid function:
* out[i] = 0 if in[i] <= -2.5
* out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5
* out[i] = 1 if in[i] >= 2.5
*
* @param in Input variable
* @return Output variable
*/
public SDVariable hardSigmoid(SDVariable in) {
return hardSigmoid(null, in);
}
/**
* Element-wise hard sigmoid function:
* out[i] = 0 if in[i] <= -2.5
* out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5
* out[i] = 1 if in[i] >= 2.5
*
* @param name Name of the output variable
* @param in Input variable
* @return Output variable
*/
public SDVariable hardSigmoid(String name, SDVariable in) {
SDVariable ret = f().hardSigmoid(in);
return updateVariableNameAndReference(ret, name);
}
/**
* Derivative (dOut/dIn) of the element-wise hard Tanh function - {@link #hardTanh(SDVariable)}
*
* @param x Input
* @return Output variable
*/
public SDVariable hardTanhDerivative(SDVariable x) {
return hardTanhDerivative(null, x);
}
/**
* Derivative (dOut/dIn) of the element-wise hard Tanh function - {@link #hardTanh(SDVariable)}
*
* @param name Output variable name
* @param x Input
* @return Output variable
*/
public SDVariable hardTanhDerivative(String name, SDVariable x) {
SDVariable result = functionFactory.hardTanhDerivative(x);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise sigmoid function: out[i] = 1.0/(1+exp(-in[i]))
*
* @param x Input Variable
* @return Output variable
*/
public SDVariable sigmoid(SDVariable x) {
return sigmoid(null, x);
}
/**
* Element-wise sigmoid function: out[i] = 1.0/(1+exp(-in[i]))
*
* @param name Output variable name
* @param x Input Variable
* @return Output variable
*/
public SDVariable sigmoid(String name, SDVariable x) {
SDVariable result = functionFactory.sigmoid(x);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut
*
* @param x Input Variable
* @param wrt Gradient at the output - dL/dOut. Must have same shape as the input
* @return Output variable
*/
public SDVariable sigmoidDerivative(SDVariable x, SDVariable wrt) {
return sigmoidDerivative(null, x, wrt);
}
/**
* Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut
*
* @param name Output variable name
* @param x Input Variable
* @param wrt Gradient at the output - dL/dOut. Must have same shape as the input
* @return Output variable
*/
public SDVariable sigmoidDerivative(String name, SDVariable x, SDVariable wrt) {
SDVariable result = functionFactory
.sigmoidDerivative(x, wrt);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise sigmoid function: out[i] = log(sigmoid(in[i]))
*
* @param x Input Variable
* @return Output variable
*/
public SDVariable logSigmoid(SDVariable x) {
return logSigmoid(null, x);
}
/**
* Element-wise sigmoid function: out[i] = log(sigmoid(in[i]))
*
* @param name Name of the output variable
* @param x Input Variable
* @return Output variable
*/
public SDVariable logSigmoid(String name, SDVariable x) {
SDVariable ret = f().logSigmoid(x);
return updateVariableNameAndReference(ret, name);
}
/**
* Element-wise sign (signum) function:
* out = -1 if in < 0
* out = 0 if in = 0
* out = 1 if in > 0
*
* @param x Input variable
* @return Output variable
*/
public SDVariable sign(SDVariable x) {
return sign(null, x);
}
/**
* Element-wise sign (signum) function:
* out = -1 if in < 0
* out = 0 if in = 0
* out = 1 if in > 0
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable sign(String name, SDVariable x) {
SDVariable result = functionFactory.sign(x);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise softsign function: out = x / (abs(x) + 1)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable softsign(SDVariable x) {
return softsign(null, x);
}
/**
* Element-wise softsign function: out = x / (abs(x) + 1)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable softsign(String name, SDVariable x) {
SDVariable result = functionFactory.softsign(x);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise derivative (dOut/dIn) of the softsign function {@link #softsign(SDVariable)}
*
* @param x Input variable
* @return Output varible
*/
public SDVariable softsignDerivative(SDVariable x) {
return softsignDerivative(null, x);
}
/**
* Element-wise derivative (dOut/dIn) of the softsign function {@link #softsign(SDVariable)}
*
* @param name Output variable name
* @param x Input variable
* @return Output varible
*/
public SDVariable softsignDerivative(String name, SDVariable x) {
SDVariable result = functionFactory.softsignDerivative(x);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise softplus function: out = log(exp(x) + 1)
*
* @param x Input variable
* @return Output variable
*/
public SDVariable softplus(SDVariable x) {
return softplus(null, x);
}
/**
* Element-wise softplus function: out = log(exp(x) + 1)
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable softplus(String name, SDVariable x) {
SDVariable result = functionFactory.softplus(x);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
* See: https://arxiv.org/abs/1710.05941
*
* @param x Input variable
* @return Output variable
*/
public SDVariable swish(SDVariable x) {
return swish(null, x);
}
/**
* Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
* See: https://arxiv.org/abs/1710.05941
*
* @param name Name of the output variable
* @param x Input variable
* @return Output variable
*/
public SDVariable swish(String name, SDVariable x) {
SDVariable ret = f().swish(x);
return updateVariableNameAndReference(ret, name);
}
/**
* Element-wise exponential linear unit (ELU) function:
* out = x if x > 0
* out = a * (exp(x) - 1) if x <= 0
* with constant a = 1.0
*
* See: http://arxiv.org/abs/1511.07289
*
* @param x Input variable
* @return Output variable
*/
public SDVariable elu(SDVariable x) {
return elu(null, x);
}
/**
* Element-wise exponential linear unit (ELU) function:
* out = x if x > 0
* out = a * (exp(x) - 1) if x <= 0
* with constant a = 1.0
*
* See: http://arxiv.org/abs/1511.07289
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable elu(String name, SDVariable x) {
SDVariable result = functionFactory.elu(x);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise derivative exponential linear unit (ELU) function, dOut/dIn given input.
* {@link #elu(SDVariable)}
*
* @param x Input variable
* @return Output variable
*/
public SDVariable eluDerivative(SDVariable x) {
return eluDerivative(null, x);
}
/**
* Element-wise derivative exponential linear unit (ELU) function, dOut/dIn given input.
* {@link #elu(SDVariable)}
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable eluDerivative(String name, SDVariable x) {
SDVariable result = functionFactory.eluDerivative(x);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise leaky ReLU function:
* out = x if x >= 0.0
* out = alpha * x if x < cutoff
* Alpha value is most commonly set to 0.01
*
* @param x Input variable
* @param alpha Cutoff - usually 0.0
* @return Output variable
*/
public SDVariable leakyRelu(SDVariable x, double alpha) {
return leakyRelu(null, x, alpha);
}
/**
* Element-wise leaky ReLU function:
* out = x if x >= 0.0
* out = alpha * x if x < cutoff
* Alpha value is most commonly set to 0.01
*
* @param x Input variable
* @param alpha Cutoff - usually 0.0
* @return Output variable
*/
public SDVariable leakyRelu(String name, SDVariable x, double alpha) {
SDVariable result = functionFactory.leakyRelu(x, alpha);
return updateVariableNameAndReference(result, name);
}
/**
* Leaky ReLU derivative: dOut/dIn given input.
* See {@link #leakyRelu(String, SDVariable, double)}
*
* @param x Input variable
* @param alpha Alpha value
* @return Output variable
*/
public SDVariable leakyReluDerivative(String name, SDVariable x, double alpha) {
SDVariable result = functionFactory.leakyReluDerivative(x, alpha);
return updateVariableNameAndReference(result, name);
}
/**
* Full array mean reduction operation
* @param x Input variable
* @return Output variable - scalar
*/
public SDVariable mean(SDVariable x) {
return mean(null, x);
}
/**
* Mean (average) array reduction operation, optionally along specified dimensions
*
* @param x Input variable
* @param dimension Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable mean(SDVariable x, int... dimension) {
return mean(null, x, dimension);
}
/**
* Mean (average) array reduction operation, optionally along specified dimensions
*
* @param name Output variable name
* @param x Input variable
* @param dimension Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable mean(String name, SDVariable x, int... dimension) {
return mean(name, x, false, dimension);
}
/**
* Mean (average) array reduction operation, optionally along specified dimensions
* Note that if keepDims = true, the output variable has the same rank as the input variable,
* with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
* the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
* keepDims = true: [a,1,c]
* keepDims = false: [a,c]
*
* @param name Output variable name
* @param x Input variable
* @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions
* @param dimension Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable mean(String name, SDVariable x, boolean keepDims, int... dimension) {
SDVariable result = functionFactory.mean(x, keepDims, dimension);
return updateVariableNameAndReference(result, name);
}
/**
* @see #standardDeviation(String, SDVariable, boolean, int...)
*/
public SDVariable standardDeviation(SDVariable x, boolean biasCorrected, int... dimensions) {
return standardDeviation(null, x, biasCorrected, dimensions);
}
/**
* Stardard deviation array reduction operation, optionally along specified dimensions
*
* @param name Output variable name
* @param x Input variable
* @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev)
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Output variable: reduced array of rank (input rank - num dimensions)
*/
public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected, int... dimensions) {
return standardDeviation(name, x, biasCorrected, false, dimensions);
}
/**
* Stardard deviation array reduction operation, optionally along specified dimensions
* Note that if keepDims = true, the output variable has the same rank as the input variable,
* with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
* the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
* keepDims = true: [a,1,c]
* keepDims = false: [a,c]
*
* @param x Input variable
* @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev)
* @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Output variable: reduced array of rank (input rank - num dimensions)
*/
public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected, boolean keepDims, int... dimensions) {
SDVariable result = functionFactory.std(x, biasCorrected, keepDims, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @see #variance(String, SDVariable, boolean, int...)
*/
public SDVariable variance(SDVariable x, boolean biasCorrected, int... dimensions) {
return variance(null, x, biasCorrected, dimensions);
}
/**
* Variance array reduction operation, optionally along specified dimensions
*
* @param name Output variable name
* @param x Input variable
* @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance)
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Output variable: reduced array of rank (input rank - num dimensions)
*/
public SDVariable variance(String name, SDVariable x, boolean biasCorrected, int... dimensions) {
return variance(name, x, biasCorrected, false, dimensions);
}
/**
* Variance array reduction operation, optionally along specified dimensions
* Note that if keepDims = true, the output variable has the same rank as the input variable,
* with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
* the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
* keepDims = true: [a,1,c]
* keepDims = false: [a,c]
*
* @param name Output variable name
* @param x Input variable
* @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance)
* @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Output variable: reduced array of rank (input rank - num dimensions)
*/
public SDVariable variance(String name, SDVariable x, boolean biasCorrected, boolean keepDims, int... dimensions) {
SDVariable result = functionFactory.variance(x, biasCorrected, keepDims, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* Entropy reduction: -sum(x * log(x))
*
* @param in Input variable
* @param dimensions Dimensions to reduce on (null/empty for full array)
* @return Output variable
*/
public SDVariable entropy(SDVariable in, int... dimensions) {
return entropy(null, in, dimensions);
}
/**
* Entropy reduction: -sum(x * log(x))
*
* @param name Name of the output variable
* @param in Input variable
* @param dimensions Dimensions to reduce on (null/empty for full array)
* @return Output variable: reduced array of rank (input rank - num dimensions)
*/
public SDVariable entropy(String name, SDVariable in, int... dimensions) {
SDVariable ret = f().entropy(in, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* Log entropy reduction: log(-sum(x * log(x)))
*
* @param in Input variable
* @param dimensions Dimensions to reduce on (null for full array)
* @return Output variable: reduced array of rank (input rank - num dimensions)
*/
public SDVariable logEntropy(SDVariable in, int... dimensions) {
return logEntropy(null, in, dimensions);
}
/**
* Log entropy reduction: log(-sum(x * log(x)))
*
* @param name Name of the output variable
* @param in Input variable
* @param dimensions Dimensions to reduce on (null for full array)
* @return Output variable: reduced array of rank (input rank - num dimensions)
*/
public SDVariable logEntropy(String name, SDVariable in, int... dimensions) {
SDVariable ret = f().logEntropy(in, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* Shannon Entropy reduction: -sum(x * log2(x))
*
* @param in Input variable
* @param dimensions Dimensions to reduce on (null/empty for full array)
* @return Output variable
*/
public SDVariable shannonEntropy(SDVariable in, int... dimensions) {
return shannonEntropy(null, in, dimensions);
}
/**
* Shannon Entropy reduction: -sum(x * log2(x))
*
* @param name Name of the output variable
* @param in Input variable
* @param dimensions Dimensions to reduce on (null/empty for full array)
* @return Output variable: reduced array of rank (input rank - num dimensions)
*/
public SDVariable shannonEntropy(String name, SDVariable in, int... dimensions) {
SDVariable ret = f().shannonEntropy(in, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* Sum array reduction operation, optionally along specified dimensions
*
* @param x Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Output variable: reduced array of rank (input rank - num dimensions)
*/
public SDVariable sum(SDVariable x, int... dimensions) {
return sum(null, x, dimensions);
}
/**
* Sum array reduction operation, optionally along specified dimensions
*
* @param x Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or
* of rank (input rank) if keepdims = true
*/
public SDVariable sum(String name, SDVariable x, int... dimensions) {
return sum(name, x, false, dimensions);
}
/**
* Sum array reduction operation, optionally along specified dimensions.
* Note that if keepDims = true, the output variable has the same rank as the input variable,
* with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
* the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
* keepDims = true: [a,1,c]
* keepDims = false: [a,c]
*
* @param name Output variable name
* @param x Input variable
* @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or
* of rank (input rank) if keepdims = true
*/
public SDVariable sum(String name, SDVariable x, boolean keepDims, int... dimensions) {
SDVariable result = functionFactory.sum(x, keepDims, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @see #sum(String, SDVariable, boolean, int...)
*/
public SDVariable sum(SDVariable x, boolean keepDims, int... dimensions) {
return sum(null, x, keepDims, dimensions);
}
/**
* Product array reduction operation, optionally along specified dimensions
*
* @param x Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Output variable: reduced array of rank (input rank - num dimensions)
*/
public SDVariable prod(SDVariable x, int... dimensions) {
return prod(null, x, dimensions);
}
/**
* Product array reduction operation, optionally along specified dimensions
*
* @param name Output variable name
* @param x Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Output variable: reduced array of rank (input rank - num dimensions)
*/
public SDVariable prod(String name, SDVariable x, int... dimensions) {
return prod(name, x, false, dimensions);
}
/**
* Product array reduction operation, optionally along specified dimensions
* Note that if keepDims = true, the output variable has the same rank as the input variable,
* with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
* the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
* keepDims = true: [a,1,c]
* keepDims = false: [a,c]
*
* @param name Output variable name
* @param x Input variable
* @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Output variable: reduced array of rank (input rank - num dimensions)
*/
public SDVariable prod(String name, SDVariable x, boolean keepDims, int... dimensions) {
SDVariable result = functionFactory.prod(x, keepDims, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise scalar maximum operation: out = max(in, value)
*
* @param in Input variable
* @param value Scalar value to compare
* @return Output variable
*/
public SDVariable scalarMax(SDVariable in, Number value) {
return scalarMax(null, in, value);
}
/**
* Element-wise scalar maximum operation: out = max(in, value)
*
* @param name Name of the output variable
* @param in Input variable
* @param value Scalar value to compare
* @return Output variable
*/
public SDVariable scalarMax(String name, SDVariable in, Number value) {
SDVariable ret = f().scalarMax(in, value);
return updateVariableNameAndReference(ret, name);
}
/**
* Element-wise scalar minimum operation: out = min(in, value)
*
* @param in Input variable
* @param value Scalar value to compare
* @return Output variable
*/
public SDVariable scalarMin(SDVariable in, Number value) {
return scalarMin(null, in, value);
}
/**
* Element-wise scalar minimum operation: out = min(in, value)
*
* @param name Name of the output variable
* @param in Input variable
* @param value Scalar value to compare
* @return Output variable
*/
public SDVariable scalarMin(String name, SDVariable in, Number value) {
SDVariable ret = f().scalarMin(in, value);
return updateVariableNameAndReference(ret, name);
}
/**
* Element-wise scalar floor modulus operation: out = floorMod(in, value).
* i.e., returns the remainder after division by 'value'
*
* @param in Input variable
* @param value Scalar value to compare
* @return Output variable
*/
public SDVariable scalarFloorMod(SDVariable in, Number value) {
return scalarFloorMod(null, in, value);
}
/**
* Element-wise scalar floor modulus operation: out = floorMod(in, value).
* i.e., returns the remainder after division by 'value'
*
* @param name Name of the output variable
* @param in Input variable
* @param value Scalar value to compare
* @return Output variable
*/
public SDVariable scalarFloorMod(String name, SDVariable in, Number value) {
SDVariable ret = f().scalarFloorMod(in, value);
return updateVariableNameAndReference(ret, name);
}
/**
* Return an array with equal shape to the input, but all elements set to value 'set'
*
* @param in Input variable
* @param set Value to set
* @return Output variable
*/
public SDVariable scalarSet(SDVariable in, Number set) {
return scalarSet(null, in, set);
}
/**
* Return an arary with equal shape to the input, but all elements set to value 'set'
*
* @param name Name of the output variable
* @param in Input variable
* @param set Value to set
* @return Output variable
*/
public SDVariable scalarSet(String name, SDVariable in, Number set) {
SDVariable ret = f().scalarSet(in, set);
return updateVariableNameAndReference(ret, name);
}
/**
* Max array reduction operation, optionally along specified dimensions
*
* @param x Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable max(SDVariable x, int... dimensions) {
return max(null, x, dimensions);
}
/**
* Max array reduction operation, optionally along specified dimensions
*
* @param name Output variable name
* @param x Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable max(String name, SDVariable x, int... dimensions) {
return max(name, x, false, dimensions);
}
/**
* Max array reduction operation, optionally along specified dimensions
* Note that if keepDims = true, the output variable has the same rank as the input variable,
* with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
* the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
* keepDims = true: [a,1,c]
* keepDims = false: [a,c]
*
* @param name Output variable name
* @param x Input variable
* @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable max(String name, SDVariable x, boolean keepDims, int... dimensions) {
SDVariable result = functionFactory.max(x, keepDims, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise maximum operation: out[i] = max(first[i], second[i])
* Supports broadcasting
*
* @param first First input array
* @param second Second input array
* @return Output variable
*/
public SDVariable max(SDVariable first, SDVariable second) {
return max(null, first, second);
}
/**
* Element-wise maximum operation: out[i] = max(first[i], second[i])
* Supports broadcasting
*
* @param name Name of the output variable
* @param first First input array
* @param second Second input array
* @return Output variable
*/
public SDVariable max(String name, SDVariable first, SDVariable second) {
SDVariable result = f().max(first, second);
return updateVariableNameAndReference(result, name);
}
/**
* Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
*
* @param in Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable amax(SDVariable in, int... dimensions) {
return amax(null, in, dimensions);
}
/**
* Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
*
* @param name Name of the output variable
* @param in Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable amax(String name, SDVariable in, int... dimensions) {
SDVariable ret = f().amax(in, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x))
*
* @param in Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable amin(SDVariable in, int... dimensions) {
return amin(null, in, dimensions);
}
/**
* Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x))
*
* @param name Name of the output variable
* @param in Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable amin(String name, SDVariable in, int... dimensions) {
SDVariable ret = f().amin(in, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x))
*
* @param in Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable amean(SDVariable in, int... dimensions) {
return amean(null, in, dimensions);
}
/**
* Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x))
*
* @param name Name of the output variable
* @param in Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable amean(String name, SDVariable in, int... dimensions) {
SDVariable ret = f().amean(in, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x))
*
* @param in Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable asum(SDVariable in, int... dimensions) {
return asum(null, in, dimensions);
}
/**
* Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x))
*
* @param name Name of the output variable
* @param in Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable asum(String name, SDVariable in, int... dimensions) {
SDVariable ret = f().asum(in, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0)
*
* @param input Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable countZero(SDVariable input, int... dimensions) {
return countZero(null, input, dimensions);
}
/**
* Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0)
*
* @param name Name of the output variable
* @param input Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable countZero(String name, SDVariable input, int... dimensions) {
SDVariable res = f().countZero(input, dimensions);
return updateVariableNameAndReference(res, name);
}
/**
* Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x))
*
* @param input Input variable
* @return Reduced array of rank 0 (scalar)
*/
public SDVariable zeroFraction(SDVariable input) {
return zeroFraction(null, input);
}
/**
* Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x))
*
* @param name Name of the output variable
* @param input Input variable
* @return Reduced array of rank 0 (scalar)
*/
public SDVariable zeroFraction(String name, SDVariable input) {
SDVariable res = f().zeroFraction(input);
return updateVariableNameAndReference(res, name);
}
/**
* Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0)
*
* @param input Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable countNonZero(SDVariable input, int... dimensions) {
return countNonZero(null, input, dimensions);
}
/**
* Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0)
*
* @param name Name of the output variable
* @param input Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable countNonZero(String name, SDVariable input, int... dimensions) {
SDVariable res = f().countNonZero(input, dimensions);
return updateVariableNameAndReference(res, name);
}
/**
* Minimum array reduction operation, optionally along specified dimensions. out = min(in)
*
* @param x Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable min(SDVariable x, int... dimensions) {
return min(null, x, dimensions);
}
/**
* Minimum array reduction operation, optionally along specified dimensions. out = min(in)
*
* @param name Output variable name
* @param x Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable min(String name, SDVariable x, int... dimensions) {
return min(name, x, false, dimensions);
}
/**
* Minimum array reduction operation, optionally along specified dimensions. out = min(in)
* Note that if keepDims = true, the output variable has the same rank as the input variable,
* with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
* the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
* keepDims = true: [a,1,c]
* keepDims = false: [a,c]
*
* @param name Output variable name
* @param x Input variable
* @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable min(String name, SDVariable x, boolean keepDims, int... dimensions) {
SDVariable result = functionFactory.min(x, keepDims, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* Element-wise minimum operation: out[i] = min(first[i], second[i])
* Supports broadcasting
*
* @param first First input array
* @param second Second input array
* @return Output variable
*/
public SDVariable min(SDVariable first, SDVariable second) {
return min(null, first, second);
}
/**
* Element-wise minimum operation: out[i] = min(first[i], second[i])
* Supports broadcasting
*
* @param name Name of the output variable
* @param first First input array
* @param second Second input array
* @return Output variable
*/
public SDVariable min(String name, SDVariable first, SDVariable second) {
SDVariable result = f().min(first, second);
return updateVariableNameAndReference(result, name);
}
/**
* Argmax array reduction operation, optionally along specified dimensions.
* Output values are the index of the maximum value of each slice along the specified dimension
*
* @param in Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable argmax(SDVariable in, int... dimensions) {
return argmax(null, in, false, dimensions);
}
/**
* @see #argmax(String, SDVariable, boolean, int...)
*/
public SDVariable argmax(SDVariable in, boolean keepDims, int... dimensions) {
return argmax(null, in, keepDims, dimensions);
}
/**
* Argmax array reduction operation, optionally along specified dimensions.
* Output values are the index of the maximum value of each slice along the specified dimension
*
* @param in Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable argmax(String name, SDVariable in, int... dimensions) {
return argmax(name, in, false, dimensions);
}
/**
* Argmax array reduction operation, optionally along specified dimensions.
* Output values are the index of the maximum value of each slice along the specified dimension.
*
* Note that if keepDims = true, the output variable has the same rank as the input variable,
* with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
* the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
* keepDims = true: [a,1,c]
* keepDims = false: [a,c]
*
* @param name Name of the output variable
* @param in Input variable
* @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or
* of rank (input rank) if keepdims = true
*/
public SDVariable argmax(String name, SDVariable in, boolean keepDims, int... dimensions) {
SDVariable ret = f().argmax(in, keepDims, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* Argmin array reduction operation, optionally along specified dimensions.
* Output values are the index of the minimum value of each slice along the specified dimension
*
* @param in Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable argmin(SDVariable in, int... dimensions) {
return argmin(null, in, dimensions);
}
/**
* @see #argmin(String, SDVariable, boolean, int...)
*/
public SDVariable argmin(SDVariable in, boolean keepDims, int... dimensions) {
return argmin(null, in, keepDims, dimensions);
}
/**
* Argmin array reduction operation, optionally along specified dimensions.
* Output values are the index of the minimum value of each slice along the specified dimension
*
* @param in Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable argmin(String name, SDVariable in, int... dimensions) {
return argmin(name, in, false, dimensions);
}
/**
* Argmin array reduction operation, optionally along specified dimensions.
* Output values are the index of the minimum value of each slice along the specified dimension.
*
* Note that if keepDims = true, the output variable has the same rank as the input variable,
* with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
* the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
* keepDims = true: [a,1,c]
* keepDims = false: [a,c]
*
* @param name Name of the output variable
* @param in Input variable
* @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or
* of rank (input rank) if keepdims = true
*/
public SDVariable argmin(String name, SDVariable in, boolean keepDims, int... dimensions) {
SDVariable ret = f().argmin(in, keepDims, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* Index of the max absolute value: argmax(abs(in))
* @see #argmax(SDVariable, int...)
*/
public SDVariable iamax(SDVariable in, int... dimensions) {
return iamax(null, in, dimensions);
}
/**
* Index of the max absolute value: argmax(abs(in))
* @see #argmax(String, SDVariable, boolean, int...)
*/
public SDVariable iamax(SDVariable in, boolean keepDims, int... dimensions) {
return iamax(null, in, keepDims, dimensions);
}
/**
* Index of the max absolute value: argmax(abs(in))
* @see #argmax(String, SDVariable, boolean, int...)
*/
public SDVariable iamax(String name, SDVariable in, int... dimensions) {
return iamax(name, in, false, dimensions);
}
/**
* Index of the max absolute value: argmax(abs(in))
* @see #argmax(String, SDVariable, boolean, int...)
*/
public SDVariable iamax(String name, SDVariable in, boolean keepDims, int... dimensions) {
SDVariable ret = f().iamax(in, keepDims, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* Index of the min absolute value: argmin(abs(in))
* @see #argmin(String, SDVariable, boolean, int...)
*/
public SDVariable iamin(SDVariable in, int... dimensions) {
return iamin(null, in, dimensions);
}
/**
* Index of the min absolute value: argmin(abs(in))
* @see #argmin(String, SDVariable, boolean, int...)
*/
public SDVariable iamin(SDVariable in, boolean keepDims, int... dimensions) {
return iamin(null, in, keepDims, dimensions);
}
/**
* Index of the min absolute value: argmin(abs(in))
* @see #argmin(String, SDVariable, boolean, int...)
*/
public SDVariable iamin(String name, SDVariable in, int... dimensions) {
return iamin(name, in, false, dimensions);
}
/**
* Index of the min absolute value: argmin(abs(in))
* @see #argmin(String, SDVariable, boolean, int...)
*/
public SDVariable iamin(String name, SDVariable in, boolean keepDims, int... dimensions) {
SDVariable ret = f().iamin(in, keepDims, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #firstIndex(String, SDVariable, Condition, int...)
*/
public SDVariable firstIndex(SDVariable in, Condition condition, int... dimensions) {
return firstIndex(null, in, condition, dimensions);
}
/**
* @see #firstIndex(String, SDVariable, Condition, boolean, int...)
*/
public SDVariable firstIndex(SDVariable in, Condition condition, boolean keepDims, int... dimensions){
return firstIndex(null, in, condition, keepDims, dimensions);
}
/**
* First index reduction operation.
* Returns a variable that contains the index of the first element that matches the specified condition (for each
* slice along the specified dimensions)
*
* @param name Name of the output variable
* @param in Input variable
* @param condition Condition to check on input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable firstIndex(String name, SDVariable in, Condition condition, int... dimensions) {
return firstIndex(name, in, condition, false, dimensions);
}
/**
* First index reduction operation.
* Returns a variable that contains the index of the first element that matches the specified condition (for each
* slice along the specified dimensions)
* Note that if keepDims = true, the output variable has the same rank as the input variable,
* with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
* the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
* keepDims = true: [a,1,c]
* keepDims = false: [a,c]
*
* @param name Name of the output variable
* @param in Input variable
* @param condition Condition to check on input variable
* @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable firstIndex(String name, SDVariable in, Condition condition, boolean keepDims, int... dimensions) {
SDVariable ret = f().firstIndex(in, condition, keepDims, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #lastIndex(String, SDVariable, Condition, int...)
*/
public SDVariable lastIndex(SDVariable in, Condition condition, int... dimensions) {
return lastIndex(null, in, condition, dimensions);
}
/**
* @see #lastIndex(String, SDVariable, Condition, boolean, int...)
*/
public SDVariable lastIndex(SDVariable in, Condition condition, boolean keepDims, int... dimensions){
return lastIndex(null, in, condition, keepDims, dimensions);
}
/**
* Last index reduction operation.
* Returns a variable that contains the index of the last element that matches the specified condition (for each
* slice along the specified dimensions)
*
* @param name Name of the output variable
* @param in Input variable
* @param condition Condition to check on input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable lastIndex(String name, SDVariable in, Condition condition, int... dimensions) {
return lastIndex(name, in, condition, false, dimensions);
}
/**
* Last index reduction operation.
* Returns a variable that contains the index of the last element that matches the specified condition (for each
* slice along the specified dimensions)
* Note that if keepDims = true, the output variable has the same rank as the input variable,
* with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
* the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
* keepDims = true: [a,1,c]
* keepDims = false: [a,c]
*
* @param name Name of the output variable
* @param in Input variable
* @param condition Condition to check on input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Reduced array of rank (input rank - num dimensions)
*/
public SDVariable lastIndex(String name, SDVariable in, Condition condition, boolean keepDims, int... dimensions){
SDVariable ret = f().lastIndex(in, condition, keepDims, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* Returns a count of the number of elements that satisfy the condition
* @param in Input
* @param condition Condition
* @return Number of elements that the condition is satisfied for
*/
public SDVariable matchConditionCount(SDVariable in, Condition condition) {
return matchConditionCount(null, in, condition);
}
/**
* Returns a count of the number of elements that satisfy the condition
* @param name Name of the output variable
* @param in Input
* @param condition Condition
* @return Number of elements that the condition is satisfied for
*/
public SDVariable matchConditionCount(String name, SDVariable in, Condition condition) {
return matchConditionCount(name, in, condition, false);
}
/**
* Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
* Note that if keepDims = true, the output variable has the same rank as the input variable,
* with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
* the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
* keepDims = true: [a,1,c]
* keepDims = false: [a,c]
*
* @param name Name of the output variable
* @param in Input variable
* @param condition Condition
* @param keepDim If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Number of elements that the condition is satisfied for
*/
public SDVariable matchConditionCount(String name, SDVariable in, Condition condition, boolean keepDim, int... dimensions) {
SDVariable ret = f().matchConditionCount(in, condition, keepDim, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* Returns a boolean mask of equal shape to the input, where the condition is satisfied - value 1 where satisfied, 0 otherwise
*
* @param in Input variable
* @param condition Condition
* @return Boolean mask mariable
*/
public SDVariable matchCondition(SDVariable in, Condition condition) {
return matchCondition(null, in, condition);
}
/**
* Returns a boolean mask of equal shape to the input, where the condition is satisfied - value 1 where satisfied, 0 otherwise
*
* @param in Input
* @param condition Condition
* @return Boolean mask
*/
public SDVariable matchCondition(String name, SDVariable in, Condition condition){
SDVariable ret = f().matchCondition(in, condition);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #cumsum(String, SDVariable, SDVariable, boolean, boolean)
*/
public SDVariable cumsum(SDVariable in, boolean exclusive, boolean reverse, int... axis) {
return cumsum(null, in, exclusive, reverse, axis);
}
/**
* Cumulative sum operation.
* For input: [ a, b, c], output is:
* exclusize=false, reverse=false: [a, a+b, a+b+c]
* exclusive=true, reverse=false, [0, a, a+b]
* exclusive=false, reverse=true: [a+b+c, b+c, c]
* exclusive=true, reverse=true: [b+c, c, 0]
*
* @param name Name of the output variable
* @param in Input variable
* @param axis Scalar axis argument for dimension to perform cumululative sum operations along
* @param exclusive If true: exclude the first value
* @param reverse If true: reverse the direction of the accumulation
* @return Output variable
*/
public SDVariable cumsum(String name, SDVariable in, boolean exclusive, boolean reverse, int... axis) {
SDVariable ret = f().cumsum(in, exclusive, reverse, axis);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #cumprod(String, SDVariable, SDVariable, boolean, boolean)
*/
public SDVariable cumprod(SDVariable in, boolean exclusive, boolean reverse, int... axis) {
return cumprod(null, in, exclusive, reverse, axis);
}
/**
* Cumulative product operation.
* For input: [ a, b, c], output is:
* exclusize=false, reverse=false: [a, a*b, a*b*c]
* exclusive=true, reverse=false, [0, a, a*b]
* exclusive=false, reverse=true: [a*b*c, b*c, c]
* exclusive=true, reverse=true: [b*c, c, 0]
*
* @param name Name of the output variable
* @param in Input variable
* @param axis Scalar axis argument for dimension to perform cumululative sum operations along
* @param exclusive If true: exclude the first value
* @param reverse If true: reverse the direction of the accumulation
* @return Output variable
*/
public SDVariable cumprod(String name, SDVariable in, boolean exclusive, boolean reverse, int... axis) {
SDVariable ret = f().cumprod(in, exclusive, reverse, axis);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #biasAdd(String, SDVariable, SDVariable)
*/
public SDVariable biasAdd(SDVariable input, SDVariable bias) {
return biasAdd(null, input, bias);
}
/**
* Bias addition operation: a special case of addition, typically used with CNN 4D activations and a 1D bias vector
* @param name Name of the output variable
* @param input 4d input variable
* @param bias 1d bias
* @return Output variable
*/
public SDVariable biasAdd(String name, SDVariable input, SDVariable bias) {
SDVariable ret = f().biasAdd(input, bias);
return updateVariableNameAndReference(ret, name);
}
/**
* Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
* input, but with the specified shape.
* Note that prod(shape) must match length(input) == prod(input.shape)
*
* @param x Input variable
* @param shape New shape for variable
* @return Output variable
* @see #reshape(SDVariable, SDVariable)
*/
public SDVariable reshape(SDVariable x, long... shape) {
return reshape(null, x, shape);
}
/**
* Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
* input, but with the specified shape.
* Note that prod(shape) must match length(input) == prod(input.shape)
*
* @param name Output variable name
* @param x Input variable
* @param shape New shape for variable
* @return Output variable
* @see #reshape(SDVariable, SDVariable)
*/
public SDVariable reshape(String name, SDVariable x, long... shape) {
SDVariable result = functionFactory .reshape(x, shape);
return updateVariableNameAndReference(result, name);
}
/**
* Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
* input, but with the specified shape.
* Note that prod(shape) must match length(input) == prod(input.shape)
*
* @param x Input variable
* @param shape New shape for variable
* @return Output variable
* @see #reshape(SDVariable, SDVariable)
*/
public SDVariable reshape(SDVariable x, int... shape) {
return reshape(null, x, shape);
}
/**
* Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
* input, but with the specified shape.
* Note that prod(shape) must match length(input) == prod(input.shape)
*
* @param name Output variable name
* @param x Input variable
* @param shape New shape for variable
* @return Output variable
* @see #reshape(SDVariable, SDVariable)
*/
public SDVariable reshape(String name, SDVariable x, int... shape) {
SDVariable result = functionFactory .reshape(x, shape);
return updateVariableNameAndReference(result, name);
}
/**
* Reshape the input variable to the specified (dynamic) shape. The output variable will have the same values as the
* input, but with the specified shape.
* Note that prod(shape) must match length(input) == prod(input.shape)
*
* @param x Input variable
* @param shape New shape for variable
* @return Output variable
* @see #reshape(SDVariable, int[])
*/
public SDVariable reshape(SDVariable x, SDVariable shape) {
return reshape(null, x, shape);
}
/**
* Reshape the input variable to the specified (dynamic) shape. The output variable will have the same values as the
* input, but with the specified shape.
* Note that prod(shape) must match length(input) == prod(input.shape)
*
* @param name Output variable name
* @param x Input variable
* @param shape New shape for variable
* @return Output variable
* @see #reshape(SDVariable, int[])
*/
public SDVariable reshape(String name, SDVariable x,SDVariable shape) {
SDVariable result = functionFactory.reshape(x, shape);
return updateVariableNameAndReference(result, name);
}
/**
* @see #reverse(String, SDVariable, int...)
*/
public SDVariable reverse(SDVariable x, int... dimensions) {
return reverse(null, x, dimensions);
}
/**
* Reverse the values of an array for the specified dimensions
* If input is:
* [ 1, 2, 3]
* [ 4, 5, 6]
* then
* reverse(in, 0):
* [3, 2, 1]
* [6, 5, 4]
* reverse(in, 0):
* [4, 5, 6]
* [1, 2 3]
*
* @param x Input variable
* @param dimensions Dimensions
* @return Output variable
*/
public SDVariable reverse(String name, SDVariable x, int... dimensions) {
SDVariable ret = f().reverse(x, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed
*
* @param name Name of the output variable
* @param x Input variable
* @param seq_lengths Length of the sequences
* @param seqDim Sequence dimension
* @param batchDim Batch dimension
* @return Reversed sequences
*/
public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths, int seqDim, int batchDim) {
SDVariable ret = f().reverseSequence(x, seq_lengths, seqDim, batchDim);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #reverseSequence(String, SDVariable, SDVariable, int, int)
*/
public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths) {
SDVariable ret = f().reverseSequence(x, seq_lengths);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #reverseSequence(String, SDVariable, SDVariable, int, int)
*/
public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths, int seqDim, int batchDim) {
return reverseSequence(null, x, seq_lengths, seqDim, batchDim);
}
/**
* @see #reverseSequence(String, SDVariable, SDVariable, int, int)
*/
public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths) {
return reverseSequence(null, x, seq_lengths);
}
/**
* Generate a sequence mask (with values 0 or 1) based on the specified lengths
* Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
*
* @param name Name of the output variable
* @param lengths Lengths of the sequences
* @param maxLen Maximum sequence length
* @return Output variable
*/
public SDVariable sequenceMask(String name, SDVariable lengths, SDVariable maxLen) {
SDVariable ret = f().sequenceMask(lengths, maxLen);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #sequenceMask(String, SDVariable, SDVariable)
*/
public SDVariable sequenceMask(SDVariable lengths, SDVariable maxLen) {
return sequenceMask(null, lengths, maxLen);
}
/**
* @see #sequenceMask(String, SDVariable, SDVariable)
*/
public SDVariable sequenceMask(String name, SDVariable lengths, int maxLen) {
SDVariable ret = f().sequenceMask(lengths, maxLen);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #sequenceMask(String, SDVariable, SDVariable)
*/
public SDVariable sequenceMask(SDVariable lengths, int maxLen) {
return sequenceMask(null, lengths, maxLen);
}
/**
* @see #sequenceMask(String, SDVariable, SDVariable)
*/
public SDVariable sequenceMask(String name, SDVariable lengths) {
SDVariable ret = f().sequenceMask(lengths);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #sequenceMask(String, SDVariable, SDVariable)
*/
public SDVariable sequenceMask(SDVariable lengths) {
SDVariable ret = f().sequenceMask(lengths);
return updateVariableNameAndReference(ret, null);
}
/**
* @see #expandDims(String, SDVariable, int)
*/
public SDVariable expandDims(SDVariable x, int axis) {
return expandDims(null, x, axis);
}
/**
* Reshape the input by adding a 1 at the specified location.
* For example, if input has shape [a, b], then output shape is:
* axis = 0: [1, a, b]
* axis = 1: [a, 1, b]
* axis = 2: [a, b, 1]
*
* @param name Name of the output variable
* @param x Input variable
* @param axis Axis to expand
* @return Output variable
* @see #squeeze(String, SDVariable, int)
*/
public SDVariable expandDims(String name, SDVariable x, int axis) {
SDVariable result = f().expandDims(x, axis);
return updateVariableNameAndReference(result, name);
}
/**
* @see #squeeze(String, SDVariable, int)
*/
public SDVariable squeeze(SDVariable x, int axis) {
return squeeze(null, x, axis);
}
/**
* Remove a single dimension of size 1.
* For example, if input has shape [a,b,1,c] then squeeze(input, 2) returns an array of shape [a,b,c]
*
* @param name Name of the output variable
* @param x Input variable
* @param axis Size 1 dimension to remove
* @return Output variable
*/
public SDVariable squeeze(String name, SDVariable x, int axis) {
SDVariable result = f().squeeze(x, axis);
return updateVariableNameAndReference(result, name);
}
/**
* Assign/copy op: out = x.assign(y). Supports broadcasting
*
* @param x Input variable x
* @param y Input variable y
* @return Output variable
*/
public SDVariable assign(SDVariable x, SDVariable y) {
return assign(null, x, y);
}
/**
* Assign/copy op: out = x.assign(y). Supports broadcasting
*
* @param name Name of the output variable
* @param x Input variable x
* @param y Input variable y
* @return Output variable
*/
public SDVariable assign(String name, SDVariable x, SDVariable y) {
SDVariable ret = f().assign(x, y);
return updateVariableNameAndReference(ret, name);
}
/**
* Return an array with equal shape to the input, but all elements set to 'value'
*
* @param in Input variable
* @param value Value to set
* @return Output variable
*/
public SDVariable assign(SDVariable in, Number value) {
return assign(null, in, value);
}
/**
* Return an array with equal shape to the input, but all elements set to 'value'
*
* @param name Name of the output variable
* @param in Input variable
* @param value Value to set
* @return Output variable
*/
public SDVariable assign(String name, SDVariable in, Number value) {
SDVariable ret = f().assign(in, value);
return updateVariableNameAndReference(ret, name);
}
/**
* Matrix transpose operation: If input has shape [a,b] output has shape [b,a]
*
* @param x Input variable
* @return Output variable (transposed input)
*/
public SDVariable transpose(SDVariable x) {
return transpose(null, x);
}
/**
* Matrix transpose operation: If input has shape [a,b] output has shape [b,a]
*
* @param name Output variable name
* @param x Input variable
* @return Output variable (transposed input)
*/
public SDVariable transpose(String name, SDVariable x) {
SDVariable result = functionFactory.transpose(x);
return updateVariableNameAndReference(result, name);
}
/**
* Array permutation operation: permute the dimensions according to the specified permutation indices.
* Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
*
* @param x Input variable
* @return Output variable (permuted input)
*/
public SDVariable permute(SDVariable x, int... dimensions) {
return permute(null, x, dimensions);
}
/**
* Array permutation operation: permute the dimensions according to the specified permutation indices.
* Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
*
* @param name Output variable name
* @param x Input variable
* @return Output variable (permuted input)
*/
public SDVariable permute(String name, SDVariable x, int... dimensions) {
SDVariable result = functionFactory.permute(x, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param x
* @param axis
* @return
*/
public SDVariable rollAxis(SDVariable x, int axis) {
return rollAxis(null, x, axis);
}
/**
* @param x
* @param axis
* @return
*/
public SDVariable rollAxis(String name, SDVariable x, int axis) {
SDVariable result = functionFactory.rollAxis(x, axis);
return updateVariableNameAndReference(result, name);
}
/**
* @see #concat(String, int, SDVariable...)
*/
public SDVariable concat(int dimension, SDVariable... inputs) {
return concat(null, dimension, inputs);
}
/**
* Concatenate a set of inputs along the specified dimension.
* Note that inputs must have identical rank and identical dimensions, other than the dimension to stack on.
* For example, if 2 inputs have shape [a, x, c] and [a, y, c] and dimension = 1, then the output has shape [a, x+y, c]
*
* @param name Name of the output variable
* @param dimension Dimension to concatenate on
* @param inputs Input variables
* @return Output variable
* @see #stack(String, int, SDVariable...)
*/
public SDVariable concat(String name, int dimension, SDVariable... inputs) {
SDVariable result = functionFactory.concat(dimension, inputs);
return updateVariableNameAndReference(result, name);
}
/**
* @see #moments(String[], SDVariable, int...)
*/
public SDVariable[] moments(SDVariable input, int... axes) {
return moments(null, input, axes);
}
/**
* Calculate the mean and (population) variance for the input variable, for the specified axis
*
* @param name Name of the output variables. Can be null; if non-null, must be length 2
* @param input Input to calculate moments for
* @param axes Dimensions to perform calculation over
* @return Mean and variance variables
*/
public SDVariable[] moments(String[] name, SDVariable input, int... axes) {
SDVariable[] res = f().moments(input, axes);
return updateVariableNamesAndReferences(res, name);
}
/**
* @see #normalizeMoments(String[], SDVariable, SDVariable, SDVariable, double)
*/
public SDVariable[] normalizeMoments(SDVariable counts, SDVariable means, SDVariable variances, double shift) {
return normalizeMoments(null, counts, means, variances, shift);
}
/**
* Calculate the mean and variance from the sufficient statistics
*
* @param name Name of the output variables. Can be null; if non-null, must be length 2
* @param counts Rank 0 (scalar) value with the total number of values used to calculate the sufficient statistics
* @param means Mean-value sufficient statistics: this is the SUM of all data values
* @param variances Variaance sufficient statistics: this is the squared sum of all data values
* @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for numerical stability)
* @return Output variables: mean and population variance
*/
public SDVariable[] normalizeMoments(String[] name, SDVariable counts, SDVariable means, SDVariable variances,
double shift) {
SDVariable[] res = f().normalizeMoments(counts, means, variances, shift);
return updateVariableNamesAndReferences(res, name);
}
/**
* @see #matrixDeterminant(String, SDVariable)
*/
public SDVariable matrixDeterminant(SDVariable in){
return matrixDeterminant(null, in);
}
/**
* Matrix determinant op. For 2D input, this returns the standard matrix determinant.
* For higher dimensional input with shape [..., m, m] the matrix determinant is returned for each
* shape [m,m] sub-matrix.
* @param name Name of the output variable
* @param in Input
* @return Matrix determinant variable
*/
public SDVariable matrixDeterminant(String name, SDVariable in){
SDVariable ret = f().matrixDeterminant(in);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #matrixInverse(String, SDVariable)
*/
public SDVariable matrixInverse(SDVariable in){
return matrixInverse(null, in);
}
/**
* Matrix inverse op. For 2D input, this returns the standard matrix inverse.
* For higher dimensional input with shape [..., m, m] the matrix inverse is returned for each
* shape [m,m] sub-matrix.
* @param name Name of the output variable
* @param in Input
* @return Matrix inverse variable
*/
public SDVariable matrixInverse(String name, SDVariable in){
SDVariable ret = f().matrixInverse(in);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #confusionMatrix(String, SDVariable, SDVariable)
*/
public SDVariable confusionMatrix(SDVariable labels, SDVariable predictions) {
return confusionMatrix((String) null, labels, predictions);
}
/**
* Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
* which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
* For example, if labels = [0, 1, 1] and predicted = [0, 2, 1] then output is:
* [1, 0, 0]
* [0, 1, 1]
* [0, 0, 0]
*
* @param name Name of the output variable
* @param labels Labels - 1D array of integer values representing label values
* @param pred Predictions - 1D array of integer values representing predictions. Same length as labels
* @return Output variable (2D, shape [numClasses, numClasses})
*/
public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred) {
SDVariable result = f().confusionMatrix(labels, pred);
return updateVariableNameAndReference(result, name);
}
/**
* @see #confusionMatrix(String, SDVariable, SDVariable, Integer)
*/
public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses) {
return confusionMatrix(null, labels, pred, numClasses);
}
/**
* Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
* which are represented as integer values.
* For example, if labels = [0, 1, 1], predicted = [0, 2, 1], and numClasses=4 then output is:
* [1, 0, 0, 0]
* [0, 1, 1, 0]
* [0, 0, 0, 0]
* [0, 0, 0, 0]
*
* @param name Name of the output variable
* @param labels Labels - 1D array of integer values representing label values
* @param pred Predictions - 1D array of integer values representing predictions. Same length as labels
* @param numClasses Number of classes
* @return Output variable (2D, shape [numClasses, numClasses})
*/
public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, Integer numClasses) {
SDVariable result = f().confusionMatrix(labels, pred, numClasses);
return updateVariableNameAndReference(result, name);
}
/**
* @see #confusionMatrix(String, SDVariable, SDVariable, SDVariable)
*/
public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable weights) {
return confusionMatrix(null, labels, pred, weights);
}
/**
* Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
* which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
* For example, if labels = [0, 1, 1], predicted = [0, 2, 1] and weights = [1, 2, 3]
* [1, 0, 0]
* [0, 3, 2]
* [0, 0, 0]
*
* @param name Name of the output variable
* @param labels Labels - 1D array of integer values representing label values
* @param pred Predictions - 1D array of integer values representing predictions. Same length as labels
* @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of
* each prediction. Must be same length as both labels and predictions arrays
* @return Output variable (2D, shape [numClasses, numClasses})
*/
public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, SDVariable weights) {
SDVariable result = f().confusionMatrix(labels, pred, weights);
return updateVariableNameAndReference(result, name);
}
/**
* @see #confusionMatrix(String, SDVariable, SDVariable, Integer, SDVariable)
*/
public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights) {
return confusionMatrix(null, labels, pred, numClasses, weights);
}
/**
* Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
* which are represented as integer values.
* For example, if labels = [0, 1, 1], predicted = [0, 2, 1], numClasses = 4, and weights = [1, 2, 3]
* [1, 0, 0, 0]
* [0, 3, 2, 0]
* [0, 0, 0, 0]
* [0, 0, 0, 0]
*
* @param name Name of the output variable
* @param labels Labels - 1D array of integer values representing label values
* @param pred Predictions - 1D array of integer values representing predictions. Same length as labels
* @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of
* each prediction. Must be same length as both labels and predictions arrays
* @return Output variable (2D, shape [numClasses, numClasses})
*/
public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights) {
SDVariable result = f().confusionMatrix(labels, pred, numClasses, weights);
return updateVariableNameAndReference(result, name);
}
/**
* @see #tile(String, SDVariable, int[])
*/
public SDVariable tile(SDVariable x, int[] repeat) {
return tile(null, x, repeat);
}
/**
* Repeat (tile) the input tensor the specified number of times.
* For example, if input is
* [1, 2]
* [3, 4]
* and repeat is [2, 3]
* then output is
* [1, 2, 1, 2, 1, 2]
* [3, 4, 3, 4, 3, 4]
* [1, 2, 1, 2, 1, 2]
* [3, 4, 3, 4, 3, 4]
*
*
* @param name Output variable name
* @param x Input variable
* @param repeat Number of times to repeat in each axis. Must have length equal to the rank of the input array
* @return Output variable
*/
public SDVariable tile(String name, SDVariable x, int[] repeat) {
SDVariable result = functionFactory.tile(x, repeat);
return updateVariableNameAndReference(result, name);
}
/**
* Generate an output variable with the specified (dynamic) shape with all elements set to the specified value
*
* @param shape Shape: must be a 1D array/variable
* @param value Value to set all elements to
* @return Output variable
*/
public SDVariable fill(SDVariable shape, double value) {
return fill(null, shape, value);
}
/**
* Generate an output variable with the specified (dynamic) shape with all elements set to the specified value
*
* @param name Name of the output variable
* @param shape Shape: must be a 1D array/variable
* @param value Value to set all elements to
* @return Output variable
*/
public SDVariable fill(String name, SDVariable shape, double value) {
SDVariable result = functionFactory.fill(shape, value);
return updateVariableNameAndReference(result, name);
}
/**
*
* @param input Input
* @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p)
* @return
*/
public SDVariable dropout(SDVariable input, double inputRetainProbability) {
return dropout(null, input, inputRetainProbability);
}
/**
*
* @param input Input
* @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p)
* @return
*/
public SDVariable dropout(String name, SDVariable input, double inputRetainProbability) {
SDVariable res = f().dropout(input, inputRetainProbability);
return updateVariableNameAndReference(res, name);
}
/**
* @see #linear(String, SDVariable, SDVariable, SDVariable)
*/
public SDVariable linear(SDVariable input, SDVariable weights, SDVariable bias) {
return linear(null, input, weights, bias);
}
/**
* Linear layer operation: out = mmul(in,w) + bias
* Note that bias array is optional
*
* @param name Name of the output variable
* @param input Input data
* @param weights Weights variable
* @param bias Optional bias variable (may be null)
* @return Output variable
*/
public SDVariable linear(String name, SDVariable input, SDVariable weights, SDVariable bias) {
SDVariable res = f().xwPlusB(input, weights, bias);
return updateVariableNameAndReference(res, name);
}
/**
* @see #reluLayer(String, SDVariable, SDVariable, SDVariable)
*/
public SDVariable reluLayer(SDVariable input, SDVariable weights, SDVariable bias) {
return reluLayer(null, input, weights, bias);
}
/**
* ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)
* Note that bias array is optional
*
* @param name Name of the output variable
* @param input Input data
* @param weights Weights variable
* @param bias Optional bias variable (may be null)
* @return Output variable
*/
public SDVariable reluLayer(String name, SDVariable input, SDVariable weights, SDVariable bias) {
SDVariable res = f().reluLayer(input, weights, bias);
return updateVariableNameAndReference(res, name);
}
/**
* Matrix multiplication: out = mmul(x,y)
* Supports specifying a {@link MMulTranspose} argument to perform operation such as mmul(a^T, b), etc.
*
* @param x First input variable
* @param y Second input variable
* @param transpose Transpose arguments
* @return Output variable
*/
public SDVariable mmul(SDVariable x, SDVariable y, MMulTranspose transpose) {
return mmul(null, x, y, transpose);
}
/**
* Matrix multiplication: out = mmul(x,y)
*
* @param x First input variable
* @param y Second input variable
* @return Output variable
*/
public SDVariable mmul(SDVariable x, SDVariable y) {
return mmul(null, x, y);
}
/**
* Matrix multiplication: out = mmul(x,y)
* Supports specifying a {@link MMulTranspose} argument to perform operation such as mmul(a^T, b), etc.
*
* @param name Output variable name
* @param x First input variable
* @param y Second input variable
* @param transpose Transpose arguments
* @return Output variable
*/
public SDVariable mmul(String name, SDVariable x, SDVariable y, MMulTranspose transpose) {
SDVariable result = functionFactory.mmul(x, y, transpose);
return updateVariableNameAndReference(result, name);
}
/**
* Matrix multiplication: out = mmul(x,y)
*
* @param name Output variable name
* @param x First input variable
* @param y Second input variable
* @return Output variable
*/
public SDVariable mmul(String name, SDVariable x, SDVariable y) {
return mmul(name, x, y, MMulTranspose.allFalse());
}
/**
* Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
* length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
* respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.
* Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).
*
*
* The result of this operation will be a batch of multiplied matrices. The
* result has the same length as both input batches and each output matrix is of shape (M, K).
*
* @param matricesA First array of input matrices, all of shape (M, N) or (N, M)
* @param matricesB Second array of input matrices, all of shape (N, K) or (K, N)
* @param transposeA whether first batch of matrices is transposed.
* @param transposeB whether second batch of matrices is transposed.
* @param names names for all provided SDVariables
*
* @return Array of multiplied SDVariables of shape (M, K)
*/
public SDVariable[] batchMmul(String[] names, SDVariable[] matricesA, SDVariable[] matricesB,
boolean transposeA, boolean transposeB) {
SDVariable[] result = functionFactory.batchMmul(matricesA, matricesB, transposeA, transposeB);
return updateVariableNamesAndReferences(result, names);
}
/**
* Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
* length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
* respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.
* Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).
*
*
* The result of this operation will be a batch of multiplied matrices. The
* result has the same length as both input batches and each output matrix is of shape (M, K).
*
* @param matricesA First array of input matrices, all of shape (M, N) or (N, M)
* @param matricesB Second array of input matrices, all of shape (N, K) or (K, N)
* @param transposeA whether first batch of matrices is transposed.
* @param transposeB whether second batch of matrices is transposed.
*
* @return Array of multiplied SDVariables of shape (M, K)
*/
public SDVariable[] batchMmul(SDVariable[] matricesA, SDVariable[] matricesB,
boolean transposeA, boolean transposeB) {
return batchMmul(null, matricesA, matricesB, transposeA, transposeB);
}
/**
* Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
* length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
* respectively. The result of this operation will be a batch of multiplied matrices. The
* result has the same length as both input batches and each output matrix is of shape (M, K).
*
* @param matricesA First array of input matrices, all of shape (M, N)
* @param matricesB Second array of input matrices, all of shape (N, K)
* @return Array of multiplied SDVariables of shape (M, K)
*/
public SDVariable[] batchMmul(SDVariable[] matricesA, SDVariable[] matricesB) {
return batchMmul(null, matricesA, matricesB, false, false);
}
/**
* @param x
* @param y
* @param dimensions
* @return
*/
public SDVariable tensorMmul(SDVariable x,
SDVariable y,
int[][] dimensions) {
return tensorMmul(null, x, y, dimensions);
}
/**
* TODO doc string
*
* @param x
* @param y
* @param dimensions
* @return
*/
public SDVariable dot(SDVariable x, SDVariable y, int... dimensions) {
return dot(null, x, y, dimensions);
}
/**
* TODO doc string
*
* @param name
* @param x
* @param y
* @param dimensions
* @return
*/
public SDVariable dot(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable ret = f().dot(x, y, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
* out = sum_i abs(x[i])
*
* @param name Output variable name
* @param x Input variable
* @param dimensions dimensions to reduce over
* @return Output variable
*/
public SDVariable norm1(String name, SDVariable x, int... dimensions) {
return norm1(name, x, false, dimensions);
}
/**
* Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
* out = sum_i abs(x[i])
* Note that if keepDims = true, the output variable has the same rank as the input variable,
* with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
* the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
* keepDims = true: [a,1,c]
* keepDims = false: [a,c]
*
* @param name Output variable name
* @param x Input variable
* @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions
* @param dimensions dimensions to reduce over
* @return Output variable
*/
public SDVariable norm1(String name, SDVariable x, boolean keepDims, int... dimensions) {
SDVariable result = f().norm1(x, keepDims, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
* out = sqrt(sum_i x[i]^2)
*
* @param name Output variable name
* @param x Input variable
* @param dimensions dimensions to reduce over
* @return Output variable
*/
public SDVariable norm2(String name, SDVariable x, int... dimensions) {
return norm2(name, x, false, dimensions);
}
/**
* Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
* out = sqrt(sum_i x[i]^2)
* Note that if keepDims = true, the output variable has the same rank as the input variable,
* with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
* the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
* keepDims = true: [a,1,c]
* keepDims = false: [a,c]
*
* @param name Output variable name
* @param x Input variable
* @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions
* @param dimensions dimensions to reduce over
* @return Output variable
*/
public SDVariable norm2(String name, SDVariable x, boolean keepDims, int... dimensions) {
SDVariable result = f().norm2(x, keepDims, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* Squared L2 norm: see {@link #norm2(String, SDVariable, int...)}
*/
public SDVariable squaredNorm(SDVariable x, int... dimensions) {
return squaredNorm(null, x, false, dimensions);
}
/**
* Squared L2 norm: see {@link #norm2(String, SDVariable, int...)}
*/
public SDVariable squaredNorm(String name, SDVariable x, int... dimensions) {
return squaredNorm(name, x, false, dimensions);
}
/**
* Squared L2 norm: see {@link #norm2(String, SDVariable, boolean, int...)}
*/
public SDVariable squaredNorm(SDVariable x, boolean keepDims, int... dimensions) {
return squaredNorm(null, x, keepDims, dimensions);
}
/**
* Squared L2 norm: see {@link #norm2(String, SDVariable, boolean, int...)}
*/
public SDVariable squaredNorm(String name, SDVariable x, boolean keepDims, int... dimensions) {
SDVariable result = f().squaredNorm(x, keepDims, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the
* specified dimensions
*
* @param name Output variable name
* @param x Input variable
* @param dimensions dimensions to reduce over
* @return Output variable
*/
public SDVariable normmax(String name, SDVariable x, int... dimensions) {
return normmax(name, x, false, dimensions);
}
/**
* Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the
* specified dimensions:
* out = max(abs(x[i]))
* Note that if keepDims = true, the output variable has the same rank as the input variable,
* with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
* the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
* keepDims = true: [a,1,c]
* keepDims = false: [a,c]
*
* @param name Output variable name
* @param x Input variable
* @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions
* @param dimensions dimensions to reduce over
* @return Output variable
*/
public SDVariable normmax(String name, SDVariable x, boolean keepDims, int... dimensions) {
SDVariable result = f().normmax(x, keepDims, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @see #cosineSimilarity(String, SDVariable, SDVariable, int...)
*/
public SDVariable cosineSimilarity(SDVariable x, SDVariable y, int... dimensions) {
return cosineSimilarity(generateNewVarName(CosineSimilarity.OP_NAME, 0), x, y, dimensions);
}
/**
* Cosine similarity pairwise reduction operation. The output contains the cosine similarity for each tensor/subset
* along the specified dimensions:
* out = (sum_i x[i] * y[i]) / ( sqrt(sum_i x[i]^2) * sqrt(sum_i y[i]^2)
*
* @param x Input variable x
* @param y Input variable y
* @param dimensions Dimensions to calculate cosine similarity over
* @return Output variable
*/
public SDVariable cosineSimilarity(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable cosim = functionFactory.cosineSimilarity(x, y, dimensions);
return updateVariableNameAndReference(cosim, name);
}
/**
* @see #euclideanDistance(String, SDVariable, SDVariable, int...)
*/
public SDVariable euclideanDistance(SDVariable x, SDVariable y, int... dimensions) {
return euclideanDistance(generateNewVarName(EuclideanDistance.OP_NAME, 0), x, y, dimensions);
}
/**
* Euclidean distance (l2 norm, l2 distance) reduction operation. The output contains the Euclidean distance for each
* tensor/subset along the specified dimensions:
* out = sqrt( sum_i (x[i] - y[i])^2 )
*
* @param x Input variable x
* @param y Input variable y
* @param dimensions Dimensions to calculate cosine similarity over
* @return Output variable
*/
public SDVariable euclideanDistance(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable result = functionFactory.euclideanDistance(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @see #manhattanDistance(String, SDVariable, SDVariable, int...)
*/
public SDVariable manhattanDistance(SDVariable x, SDVariable y, int... dimensions) {
return manhattanDistance(generateNewVarName(ManhattanDistance.OP_NAME, 0), x, y, dimensions);
}
/**
* Manhattan distance (l1 norm, l1 distance) reduction operation. The output contains the Manhattan distance for each
* tensor/subset along the specified dimensions:
* out = sum_i abs(x[i]-y[i])
*
* @param name Name of the output variable
* @param x Input variable x
* @param y Input variable y
* @param dimensions Dimensions to calculate cosine similarity over
* @return Output variable
*/
public SDVariable manhattanDistance(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable result = functionFactory.manhattanDistance(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @see #cosineDistance(String, SDVariable, SDVariable, int...)
*/
public SDVariable cosineDistance(SDVariable x, SDVariable y, int... dimensions) {
return cosineDistance(null, x, y, dimensions);
}
/**
* Cosine distance reduction operation. The output contains the cosine distance for each
* tensor/subset along the specified dimensions:
* out = 1.0 - cosineSimilarity(x,y)
* See {@link #cosineSimilarity(String, SDVariable, SDVariable, int...)}
*
* @param name Name of the output variable
* @param x Input variable x
* @param y Input variable y
* @param dimensions Dimensions to calculate cosine similarity over
* @return Output variable
*/
public SDVariable cosineDistance(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable result = functionFactory.cosineDistance(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @see #hammingDistance(String, SDVariable, SDVariable, int...)
*/
public SDVariable hammingDistance(SDVariable x, SDVariable y, int... dimensions) {
return hammingDistance(null, x, y, dimensions);
}
/**
* Hamming distance reduction operation. The output contains the cosine distance for each
* tensor/subset along the specified dimensions:
* out = count( x[i] != y[i] )
*
* @param name Name of the output variable
* @param x Input variable x
* @param y Input variable y
* @param dimensions Dimensions to calculate cosine similarity over
* @return Output variable
*/
public SDVariable hammingDistance(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable result = functionFactory.hammingDistance(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* Jaccard similarity reduction operation. The output contains the Jaccard distance for each
* tensor along the the specified dimensions.
*
* @param x Input variable x
* @param y Input variable y
* @param dimensions Dimensions to calculate Jaccard similarity over
* @return Output variable
*/
public SDVariable jaccardDistance(SDVariable x, SDVariable y, int... dimensions) {
return jaccardDistance(null, x, y, dimensions);
}
/**
* Jaccard similarity reduction operation. The output contains the Jaccard distance for each
* tensor along the the specified dimensions.
*
* @param name Name of the output variable
* @param x Input variable x
* @param y Input variable y
* @param dimensions Dimensions to calculate Jaccard similarity over
* @return Output variable
*/
public SDVariable jaccardDistance(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable result = functionFactory.jaccardDistance(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* Binary cross entropy loss.
*
* @param x Input variable x
* @param y Input variable y
* @param dimensions Reduction dimensions
* @return Output variable
*/
public SDVariable lossBinaryXENT(SDVariable x, SDVariable y, int... dimensions) {
return lossBinaryXENT(generateNewVarName(new LossBinaryXENT().opName(), 0), x, y, dimensions);
}
/**
* TODO doc string
*
* @param x Input variable x
* @param y Input variable y
* @param dimensions Reduction dimensions
* @return Output variable
*/
public SDVariable lossCosineSimilarity(SDVariable x, SDVariable y, int... dimensions) {
return lossCosineSimilarity(generateNewVarName(new LossCosineProximity().opName(), 0), x, y, dimensions);
}
// TODO: document all losses
/**
* Hinge loss
*
* @param x Input variable x
* @param y Input variable y
* @param dimensions Reduction dimensions
* @return Output variable
*/
public SDVariable lossHinge(SDVariable x, SDVariable y, int... dimensions) {
return lossHinge(generateNewVarName(new LossHinge().opName(), 0), x, y, dimensions);
}
/**
* Kullback-Leibler divergence loss
*
* @param x Input variable x
* @param y Input variable y
* @param dimensions Reduction dimensions
* @return Output variable
*/
public SDVariable lossKLD(SDVariable x, SDVariable y, int... dimensions) {
return lossKLD(generateNewVarName(new LossKLD().opName(), 0), x, y, dimensions);
}
/**
* L1 loss
*
* @param x Input variable x
* @param y Input variable y
* @param dimensions Reduction dimensions
* @return Output variable
*/
public SDVariable lossL1(SDVariable x, SDVariable y, int... dimensions) {
return lossL1(generateNewVarName(new LossL1().opName(), 0), x, y, dimensions);
}
/**
* L2 loss
*
* @param x Input variable x
* @param y Input variable y
* @param dimensions Reduction dimensions
* @return Output variable
*/
public SDVariable lossL2(SDVariable x, SDVariable y, int... dimensions) {
return lossL2(generateNewVarName(new LossL2().opName(), 0), x, y, dimensions);
}
/**
* Mean absolute error loss
*
* @param x Input variable x
* @param y Input variable y
* @param dimensions Reduction dimensions
* @return Output variable
*/
public SDVariable lossMAE(SDVariable x, SDVariable y, int... dimensions) {
return lossMAE(generateNewVarName(new LossMAE().opName(), 0), x, y, dimensions);
}
/**
* Mean squared error loss
*
* @param x Input variable x
* @param y Input variable y
* @param dimensions Reduction dimensions
* @return Output variable
*/
public SDVariable lossMSE(SDVariable x, SDVariable y, int... dimensions) {
return lossMSE(generateNewVarName(new LossMSE().opName(), 0), x, y, dimensions);
}
/**
* @param x Input variable x
* @param y Input variable y
* @param dimensions Reduction dimensions
* @return Output variable
*/
public SDVariable lossMCXENT(SDVariable x, SDVariable y, int... dimensions) {
return lossMCXENT(generateNewVarName(new LossMCXENT().opName(), 0), x, y, dimensions);
}
/**
* @param x Input variable x
* @param y Input variable y
* @param dimensions Reduction dimensions
* @return Output variable
*/
public SDVariable lossMSLE(SDVariable x, SDVariable y, int... dimensions) {
return lossMSLE(generateNewVarName(new LossMSLE().opName(), 0), x, y, dimensions);
}
/**
* @param x Input variable x
* @param y Input variable y
* @param dimensions Reduction dimensions
* @return Output variable
*/
public SDVariable lossNegativeLogLikelihood(SDVariable x, SDVariable y, int... dimensions) {
return lossNegativeLogLikelihood(generateNewVarName(new LossNegativeLogLikelihood().opName(), 0),
x, y, dimensions);
}
/**
* @param x Input variable x
* @param y Input variable y
* @param dimensions Reduction dimensions
* @return Output variable
*/
public SDVariable lossPoisson(SDVariable x, SDVariable y, int... dimensions) {
return lossPoisson(generateNewVarName(new LossPoisson().opName(), 0), x, y, dimensions);
}
/**
* @param x Input variable x
* @param y Input variable y
* @param dimensions Reduction dimensions
* @return Output variable
*/
public SDVariable lossSquaredHinge(SDVariable x, SDVariable y, int... dimensions) {
return lossSquaredHinge(generateNewVarName(new LossSquaredHinge().opName(), 0), x, y, dimensions);
}
/**
* @param x
* @return
*/
public SDVariable softmaxDerivative(String name, SDVariable x, SDVariable wrt) {
return softmaxDerivative(name, x, wrt, null);
}
public SDVariable softmaxDerivative(String name, SDVariable x, SDVariable wrt, Integer dimension) {
SDVariable result = functionFactory.softmaxDerivative(x, wrt, dimension);
return updateVariableNameAndReference(result, name);
}
/**
* @param x Input variable x
* @param y Input variable y
* @param dimensions dimensions
* @return Output variable
*/
public SDVariable tensorMmul(String name,
SDVariable x,
SDVariable y,
int[][] dimensions) {
SDVariable result = functionFactory.tensorMmul(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* TODO
*
* @param logits
* @param weights
* @param labels
* @param reductionMode
* @param labelSmoothing
* @return
*/
public SDVariable sigmoidCrossEntropyWithLogits(SDVariable logits, SDVariable weights, SDVariable labels,
int reductionMode, double labelSmoothing) {
return sigmoidCrossEntropyWithLogits(null, logits, weights, labels, reductionMode, labelSmoothing);
}
/**
* TODO
*
* @param name
* @param logits
* @param weights
* @param labels
* @param reductionMode
* @param labelSmoothing
* @return
*/
public SDVariable sigmoidCrossEntropyWithLogits(String name, SDVariable logits, SDVariable weights, SDVariable labels,
int reductionMode, double labelSmoothing) {
SDVariable res = f().sigmoidCrossEntropyWithLogits(logits, weights, labels, reductionMode, labelSmoothing);
return updateVariableNameAndReference(res, name);
}
/**
* TODO
*
* @param logits
* @param weights
* @param labels
* @param reductionMode
* @param labelSmoothing
* @return
*/
public SDVariable softmaxCrossEntropyWithLogits(SDVariable logits, SDVariable weights, SDVariable labels,
int reductionMode, double labelSmoothing) {
return softmaxCrossEntropyWithLogits(null, logits, weights, labels, reductionMode, labelSmoothing);
}
/**
* TODO
*
* @param name
* @param logits
* @param weights
* @param labels
* @param reductionMode
* @param labelSmoothing
* @return
*/
public SDVariable softmaxCrossEntropyWithLogits(String name, SDVariable logits, SDVariable weights, SDVariable labels,
int reductionMode, double labelSmoothing) {
SDVariable res = f().softmaxCrossEntropyWithLogits(logits, weights, labels, reductionMode, labelSmoothing);
return updateVariableNameAndReference(res, name);
}
/**
* TODO
*
* @param targets
* @param inputs
* @param weights
* @return
*/
public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable inputs,
SDVariable weights) {
return weightedCrossEntropyWithLogits(null, targets, inputs, weights);
}
/**
* TODO
*
* @param name
* @param targets
* @param inputs
* @param weights
* @return
*/
public SDVariable weightedCrossEntropyWithLogits(String name, SDVariable targets, SDVariable inputs,
SDVariable weights) {
SDVariable res = f().weightedCrossEntropyWithLogits(targets, inputs, weights);
return updateVariableNameAndReference(res, name);
}
/**
* @param x
* @param y
* @param dimensions
* @return
*/
public SDVariable lossBinaryXENT(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable result = functionFactory.lossBinaryXENT(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param x
* @param y
* @param dimensions
* @return
*/
public SDVariable lossCosineSimilarity(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable result = functionFactory.lossCosineSimilarity(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param x
* @param y
* @param dimensions
* @return
*/
public SDVariable lossHinge(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable result = functionFactory.lossHinge(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param x
* @param y
* @param dimensions
* @return
*/
public SDVariable lossKLD(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable result = functionFactory.lossKLD(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param x
* @param y
* @param dimensions
* @return
*/
public SDVariable lossL1(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable result = functionFactory.lossL1(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param x
* @param y
* @param dimensions
* @return
*/
public SDVariable lossL2(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable result = functionFactory.lossL2(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param x
* @param y
* @param dimensions
* @return
*/
public SDVariable lossMAE(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable result = functionFactory.lossMAE(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param x
* @param y
* @param dimensions
* @return
*/
public SDVariable lossMSE(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable result = functionFactory.lossMSE(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param x
* @param y
* @param dimensions
* @return
*/
public SDVariable lossMCXENT(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable result = functionFactory.lossMCXENT(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param x
* @param y
* @param dimensions
* @return
*/
public SDVariable lossMSLE(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable result = functionFactory.lossMSLE(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param x
* @param y
* @param dimensions
* @return
*/
public SDVariable lossNegativeLogLikelihood(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable result = functionFactory.lossNegativeLogLikelihood(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param x
* @param y
* @param dimensions
* @return
*/
public SDVariable lossPoisson(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable result = functionFactory.lossPoisson(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* @param x
* @param y
* @param dimensions
* @return
*/
public SDVariable lossSquaredHinge(String name, SDVariable x, SDVariable y, int... dimensions) {
SDVariable result = functionFactory.lossSquaredHinge(x, y, dimensions);
return updateVariableNameAndReference(result, name);
}
/**
* Add the specified variable to this SameDiff instance
* @param variable Variable to add
*/
public void addVariable(SDVariable variable) {
if (variableMap == null)
variableMap = new HashMap<>();
Preconditions.checkState(variable.getSameDiff() == this, "Samediff instance must be the same.");
/**
* Of note here:
* We don't validate based on vertex id because more than one input can have the same
* vertex id as a result.
*
* We validate based on variable opName instead which takes in to account function names as well
* as input ids
*/
if (variableMap.containsKey(variable.getVarName()) && !variableMap.get(variable.getVarName()).equals(variable)) {
throw new IllegalArgumentException("Variable already found with variable opName " + variable.getVarName());
}
Preconditions.checkState(variable.getSameDiff() == this, "Same diff instance for variable must be the same!");
variableMap.put(variable.getVarName(), variable);
}
/**
* Generate a new variable name based on the uniqueness of the base name and arg index
* For example, if baseName = "X" will return:
* "X" if "X" does not already exist, or "X:argIndex" if argIndex > 0
* "X_1" if "X" already exists, or "X_1:argIndex" if argIndex > 0
* "X_2" if "X" and "X_1" already exists, or "X_2:argIndex" if argIndex > 0
* And so on, until an unused name is found
*
* @param baseName the base name to use (use function.opName() where function is a {@link DifferentialFunction}
* @param argIndex the arg index
* @return the new generated name
*/
public String generateNewVarName(String baseName, int argIndex) {
if (getVariable(baseName) == null && argIndex == 0) {
return baseName;
}
//need to find a new name
int count = 0;
String name = baseName + (count == 0 ? "" : "_" + count) + (argIndex > 0 ? ":" + argIndex : "");
while (getVariable(name) != null) {
name = baseName + "_" + (++count) + (argIndex > 0 ? ":" + argIndex : "");
}
if (getVariable(name) != null) {
throw new ND4JIllegalStateException("Converged on already generated variable!");
}
return name;
}
/**
* LSTM unit
*
* @param baseName the base name for outputs
* @param configuration the configuration to use
* @return
*/
public SDVariable lstm(String baseName, LSTMCellConfiguration configuration) {
return new LSTMCell(this, configuration).outputVariables(baseName)[0];
}
/**
* An sru cell
*
* @param configuration the configuration for the sru cell
* @return
*/
public SDVariable sruCell(SRUCellConfiguration configuration) {
return new SRUCell(this, configuration).outputVariables()[0];
}
/**
* Simple recurrent unit
*
* @param configuration the configuration for the sru
* @return
*/
public SDVariable sru(SRUConfiguration configuration) {
return new SRU(this, configuration).outputVariables()[0];
}
/**
* The gru cell
*
* @param configuration teh configuration to use
* @return
*/
public SDVariable gru(GRUCellConfiguration configuration) {
return new GRUCell(this, configuration).outputVariables()[0];
}
/**
* An sru cell
*
* @param baseName the base name to use for the output variables
* @param configuration the configuration for the sru cell
* @return
*/
public SDVariable sruCell(String baseName, SRUCellConfiguration configuration) {
return new SRUCell(this, configuration).outputVariables(baseName)[0];
}
/**
* Simiple recurrent unit
*
* @param baseName the base name to use for output variables
* @param configuration the configuration for the sru
* @return
*/
public SDVariable sru(String baseName, SRUConfiguration configuration) {
return new SRU(this, configuration).outputVariables(baseName)[0];
}
/**
* The gru cell
*
* @param baseName the base name for the gru cell
* @param configuration teh configuration to use
* @return
*/
public SDVariable gru(String baseName, GRUCellConfiguration configuration) {
return new GRUCell(this, configuration).outputVariables(baseName)[0];
}
/**
* @see #slice(String, SDVariable, int[], int[])
*/
public SDVariable slice(SDVariable input, int[] begin, int[] size) {
return slice(null, input, begin, size);
}
/**
* Get a subset of the specified input, by specifying the first element and the size of the array.
* For example, if input is:
* [a, b, c]
* [d, e, f]
* then slice(input, begin=[0,1], size=[2,1] will return:
* [b]
* [e]
*
* Note that for each dimension i, begin[i] + size[i] <= input.size(i)
*
* @param name Output variable name
* @param input Variable to get subset of
* @param begin Beginning index. Must be same length as rank of input array
* @param size Size of the output array. Must be same length as rank of input array
* @return Subset of the input
*/
public SDVariable slice(String name, SDVariable input, int[] begin, int[] size) {
SDVariable ret = f().slice(input, begin, size);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #stridedSlice(String, SDVariable, long[], long[], long[])
*/
public SDVariable stridedSlice(SDVariable input, int[] begin, int[] end, int[] strides) {
return stridedSlice(null, input, begin, end, strides);
}
/**
* @see #stridedSlice(String, SDVariable, long[], long[], long[])
*/
public SDVariable stridedSlice(String name, SDVariable input, int[] begin, int[] end, int[] strides) {
return stridedSlice(name, input, begin, end, strides, 0, 0, 0, 0, 0);
}
/**
* @see #stridedSlice(String, SDVariable, long[], long[], long[])
*/
public SDVariable stridedSlice(SDVariable input, long[] begin, long[] end, long[] strides) {
return stridedSlice(null, input, begin, end, strides);
}
/**
* Get a subset of the specified input, by specifying the first element, last element, and the strides.
* For example, if input is:
* [a, b, c]
* [d, e, f]
* [g, h, i]
* then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1]) will return:
* [b, c]
* [h, i]
*
*
* @param name Output variable name
* @param input Variable to get subset of
* @param begin Beginning index. Must be same length as rank of input array
* @param end End index. Must be same length as the rank of the array
* @param strides Stride ("step size") for each dimension. Must be same length as the rank of the array. For example,
* stride of 2 means take every second element.
* @return Subset of the input
*/
public SDVariable stridedSlice(String name, SDVariable input, long[] begin, long[] end, long[] strides) {
return stridedSlice(name, input, begin, end, strides, 0, 0, 0, 0, 0);
}
/**
* Get a subset of the specified input, by specifying the first element, last element, and the strides.
* Operates as described in {@link #stridedSlice(SDVariable, long[], long[], long[])} with some extra mask arrays
* as described below.
*
* @param name Output variable name
* @param in Variable to get subset of
* @param begin Beginning index
* @param end End index
* @param strides Stride ("step size") for each dimension. For example,
* stride of 2 means take every second element.
* @param beginMask Bit mask: If the ith bit is set to 1, then the value in the begin long[] is ignored,
* and a value of 0 is used instead for the beginning index for that dimension
* @param endMask Bit mask: If the ith bit is set to 1, then the value in the end long[] is ignored,
* and a value of size(i)-1 is used instead for the end index for that dimension
* @param ellipsisMask Bit mask: only one non-zero value is allowed here. If a non-zero value is set, then other
* dimensions are inserted as required at the specified position
* @param newAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and
* a size 1 dimension is inserted at this point
* @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and
* a size 1 dimension is removed at this point. Note that begin/end/stride values must
* result in a size 1 output for these dimensions
* @return A subset of the input array
*/
public SDVariable stridedSlice(String name, SDVariable in, long[] begin, long[] end, long[] strides, int beginMask,
int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
SDVariable ret = f().stridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #stridedSlice(String, SDVariable, long[], long[], long[], int, int, int, int, int)
*/
public SDVariable stridedSlice(SDVariable in, int[] begin, int[] end, int[] strides, int beginMask,
int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
return stridedSlice(null, in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
}
/**
* @see #stridedSlice(String, SDVariable, long[], long[], long[], int, int, int, int, int)
*/
public SDVariable stridedSlice(String name, SDVariable in, int[] begin, int[] end, int[] strides, int beginMask,
int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
SDVariable ret = f().stridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #stridedSlice(String, SDVariable, long[], long[], long[], int, int, int, int, int)
*/
public SDVariable stridedSlice(SDVariable in, long[] begin, long[] end, long[] strides, int beginMask,
int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
return stridedSlice(null, in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
}
/**
* @see #scatterAdd(String, SDVariable, SDVariable, SDVariable)
*/
public SDVariable scatterAdd(SDVariable ref, SDVariable indices, SDVariable updates) {
return scatterAdd(null, ref, indices, updates);
}
/**
* Scatter addition operation.
* If indices is rank 0 (a scalar), then out[index, ...] += updates[...]
* If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] += updates[i, ...]
* If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] += updates[i, ..., k, ...]
* Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
*
* @param name Name of the output variable
* @param ref Initial/source variable
* @param indices Indices array
* @param updates Updates to add to the initial/source array
* @return The updated variable
*/
public SDVariable scatterAdd(String name, SDVariable ref, SDVariable indices, SDVariable updates) {
SDVariable ret = f().scatterAdd(ref, indices, updates);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #scatterMul(String, SDVariable, SDVariable, SDVariable)
*/
public SDVariable scatterMul(SDVariable ref, SDVariable indices, SDVariable updates) {
return scatterMul(null, ref, indices, updates);
}
/**
* Scatter multiplication operation.
* If indices is rank 0 (a scalar), then out[index, ...] *= updates[...]
* If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] *= updates[i, ...]
* If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] *= updates[i, ..., k, ...]
* Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
*
* @param name Name of the output variable
* @param ref Initial/source variable
* @param indices Indices array
* @param updates Updates to add to the initial/source array
* @return The updated variable
*/
public SDVariable scatterMul(String name, SDVariable ref, SDVariable indices, SDVariable updates) {
SDVariable ret = f().scatterMul(ref, indices, updates);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #scatterSub(String, SDVariable, SDVariable, SDVariable)
*/
public SDVariable scatterSub(SDVariable ref, SDVariable indices, SDVariable updates) {
return scatterSub(null, ref, indices, updates);
}
/**
* Scatter subtraction operation.
* If indices is rank 0 (a scalar), then out[index, ...] -= updates[...]
* If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] -= updates[i, ...]
* If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] -= updates[i, ..., k, ...]
* Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
*
* @param name Name of the output variable
* @param ref Initial/source variable
* @param indices Indices array
* @param updates Updates to add to the initial/source array
* @return The updated variable
*/
public SDVariable scatterSub(String name, SDVariable ref, SDVariable indices, SDVariable updates) {
SDVariable ret = f().scatterSub(ref, indices, updates);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #scatterDiv(String, SDVariable, SDVariable, SDVariable)
*/
public SDVariable scatterDiv(SDVariable ref, SDVariable indices, SDVariable updates) {
return scatterDiv(null, ref, indices, updates);
}
/**
* Scatter division operation.
* If indices is rank 0 (a scalar), then out[index, ...] /= updates[...]
* If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] /= updates[i, ...]
* If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] /= updates[i, ..., k, ...]
* Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
*
* @param name Name of the output variable
* @param ref Initial/source variable
* @param indices Indices array
* @param updates Updates to add to the initial/source array
* @return The updated variable
*/
public SDVariable scatterDiv(String name, SDVariable ref, SDVariable indices, SDVariable updates) {
SDVariable ret = f().scatterDiv(ref, indices, updates);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #scatterMax(String, SDVariable, SDVariable, SDVariable)
*/
public SDVariable scatterMax(SDVariable ref, SDVariable indices, SDVariable updates) {
return scatterMax(null, ref, indices, updates);
}
/**
* Scatter max operation.
* If indices is rank 0 (a scalar), then out[index, ...] = max(updates[...], in[index,...])
* If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = max(updates[i,...], in[indices[i],...])
* If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = max(updates[i, ..., k, ...], in[indices[i], ..., indices[k], ...]
* Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
*
* @param name Name of the output variable
* @param ref Initial/source variable
* @param indices Indices array
* @param updates Updates to add to the initial/source array
* @return The updated variable
*/
public SDVariable scatterMax(String name, SDVariable ref, SDVariable indices, SDVariable updates) {
SDVariable ret = f().scatterMax(ref, indices, updates);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #scatterMin(String, SDVariable, SDVariable, SDVariable)
*/
public SDVariable scatterMin(SDVariable ref, SDVariable indices, SDVariable updates) {
return scatterMin(null, ref, indices, updates);
}
/**
* Scatter min operation.
* If indices is rank 0 (a scalar), then out[index, ...] = min(updates[...], in[index,...])
* If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = min(updates[i,...], in[indices[i],...])
* If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = min(updates[i, ..., k, ...], in[indices[i], ..., indices[k], ...]
* Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
*
* @param name Name of the output variable
* @param ref Initial/source variable
* @param indices Indices array
* @param updates Updates to add to the initial/source array
* @return The updated variable
*/
public SDVariable scatterMin(String name, SDVariable ref, SDVariable indices, SDVariable updates) {
SDVariable ret = f().scatterMin(ref, indices, updates);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #scatterUpdate(String, SDVariable, SDVariable, SDVariable)
*/
public SDVariable scatterUpdate(SDVariable ref, SDVariable indices, SDVariable updates) {
return scatterUpdate(null, ref, indices, updates);
}
/**
* Scatter update operation.
* If indices is rank 0 (a scalar), then out[index, ...] = updates[...]
* If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = updates[i, ...]
* If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = updates[i, ..., k, ...]
* Note that if multiple indices refer to the same location, the output at those locations is undefined - different
* updates may occur in different orders
*
* @param name Name of the output variable
* @param ref Initial/source variable
* @param indices Indices array
* @param updates Updates to add to the initial/source array
* @return The updated variable
*/
public SDVariable scatterUpdate(String name, SDVariable ref, SDVariable indices, SDVariable updates) {
SDVariable ret = f().scatterUpdate(ref, indices, updates);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #trace(String, SDVariable)
*/
public SDVariable trace(SDVariable in){
return trace(null, in);
}
/**
* Matrix trace operation
* For rank 2 matrices, the output is a scalar vith the trace - i.e., sum of the main diagonal.
* For higher rank inputs, output[a,b,c] = trace(in[a,b,c,:,:])
*
* @param name Name of the output variable. May be null.
* @param in Input variable
* @return Trace
*/
public SDVariable trace(String name, SDVariable in){
SDVariable ret = f().trace(in);
return updateVariableNameAndReference(ret, name);
}
/**
* Generate the variables based on the given input op and return the output variable names.
*
* @param function the function to generate the output
* variable names for
* @return the set of names generated for each output of the function.
*/
public SDVariable[] generateOutputVariableForOp(DifferentialFunction function, String baseName) {
//xyz ops only have 1 output
//if there is already a base name defined, use that
if (baseName == null || baseName.isEmpty() && getBaseNameForFunction(function) != null)
baseName = getBaseNameForFunction(function);
if (baseName == null)
baseName = function.opName();
val outputShape = function.calculateOutputShape();
if (outputShape == null || outputShape.isEmpty()) {
if (function instanceof CustomOp) {
CustomOp customOp = (CustomOp) function;
//can't guess number of outputs, variable
int num_outputs = function.getNumOutputs(); //Use this in preference - if set. Descriptor might specify 2, but it can sometimes be 2+
if (num_outputs <= 0) {
val descriptor = customOp.getDescriptor();
if (descriptor != null) {
num_outputs = descriptor.getNumOutputs();
}
if (num_outputs <= 0) {
throw new ND4UnresolvedOutputVariables("Could not determine number of output variables for op "
+ function.getOwnName() + " - " + function.getClass().getSimpleName() + ". Ops can override" +
" getNumOutputs() to specify number of outputs if required");
}
}
char ordering = 'c';
SDVariable[] args = function.args();
if (args != null && args.length > 0 && args[0].getArr() != null) { //Args may be null or length 0 for some ops, like eye
ordering = function.args()[0].getArr().ordering();
}
SDVariable[] ret = new SDVariable[num_outputs];
//dynamic shapes
//When importing from TF: convention seem to be names like "unstack", "unstack:1", "unstack:2", ...
//TODO validate this!
for (int i = 0; i < ret.length; i++) {
SDVariable var = (i == 0 ? getVariable(baseName) : getVariable(baseName + ":" + i));
if (var == null) {
//Generate new variable name if one with the specified name doesn't exist
var = var(generateNewVarName(baseName, i), null, new ZeroInitScheme(ordering));
}
var.setOutputIndex(i);
var.setCreator(function);
ret[i] = var;
}
//Update the internal state: outgoing variables for function
if (getOutputsForFunction(function) == null)
addOutgoingFor(ret, function);
return ret;
}
//this is for unresolved shapes, we know xyz is always 1 output
else if (function instanceof BaseOp && outputShape.isEmpty()) {
SDVariable[] ret = new SDVariable[1];
SDVariable checkGet = getVariable(baseName);
char ordering = 'c';
SDVariable[] args = function.args();
if (args != null && args.length > 0 && function.args()[0].getArr() != null) { //Args may be null or length 0 for some ops, like eye
ordering = function.args()[0].getArr().ordering();
}
if (checkGet == null) {
checkGet = var(baseName, null, new ZeroInitScheme(ordering));
} else if (!importedVarName.contains(baseName)) {
//need to find a new name
String newName = generateNewVarName(baseName, 0);
checkGet = var(newName, null, new ZeroInitScheme(ordering));
}
if (checkGet == null) {
checkGet = var(baseName, null, new ZeroInitScheme(ordering));
}
checkGet.setOutputIndex(0);
checkGet.setCreator(function);
ret[0] = checkGet;
//Update the internal state: outgoing variables for function
if (getOutputsForFunction(function) == null)
addOutgoingFor(ret, function);
return ret;
}
}
char ordering = 'c';
if (function.args() != null && function.args().length > 0 && function.args()[0].getArr() != null) {
ordering = function.args()[0].getArr().ordering();
}
SDVariable[] ret = new SDVariable[outputShape.size()];
// ownName/baseName will be used to get variables names
val ownName = function.getOwnName();
val rootName = baseName;
for (int i = 0; i < ret.length; i++) {
val shape = outputShape.get(i);
// it should be: rootName:index. i.e.: split:1, split:2, split:3, split:4 etc
baseName = rootName + (i > 0 ? ":" + i : "");
SDVariable checkGet = getVariable(baseName);
if (checkGet == null) {
// obviously - there's no such var, just add it
checkGet = var(baseName, shape, new ZeroInitScheme(ordering));
} else if (shape != null && !shapeAlreadyExistsForVarName(checkGet.getVarName())) {
// var exists, let's update its shape
putShapeForVarName(checkGet.getVarName(), shape);
} else if (shape != null && shapeAlreadyExistsForVarName(checkGet.getVarName())) {
// no-op.
// TODO: maybe we should check shapes equality here?
// it's either var that already exist, or something bad happening
} else if (!importedVarName.contains(baseName)) {
// FIXME: dead end. it's impossible to get here with null as shape
//need to find a new name
int count = 1;
String name = baseName + "_" + count + (i > 0 ? ":" + i : "");
while (getVariable(name) != null) {
count++;
name = baseName + "_" + count + (i > 0 ? ":" + i : "");
}
if (getVariable(name) != null) {
throw new ND4JIllegalStateException("Converged on already generated variable!");
}
checkGet = var(name, shape, new ZeroInitScheme(ordering));
}
if (checkGet == null) {
checkGet = var(baseName + (i > 0 ? ":" + i : ""), shape, new ZeroInitScheme(ordering));
}
checkGet.setOutputIndex(i);
checkGet.setCreator(function);
ret[i] = checkGet;
}
return ret;
}
/**
* Generate the variables based on the given input op
* and return the output variable names.
*
* @param function the function to generate the output
* variable names for
* @return the set of names generated for each output of the function.
*/
public SDVariable[] generateOutputVariableForOp(DifferentialFunction function) {
return generateOutputVariableForOp(function, function.opName());
}
/**
* Get a SameDiff function instance given the name of the function
*
* @param functionName the name of the function
* @return the same diff function instance defined for the given name
*/
public SameDiff getFunction(String functionName) {
return sameDiffFunctionInstances.get(functionName);
}
/**
* Execute the specified ops and return the output of the last one
*
* @param ops Ops to execute
* @return Output (or first output) of the final op in the list, after execution
*/
public INDArray execAndEndResult(List ops) {
List exec = exec(ops);
Op op = (Op) exec.get(exec.size() - 1);
return op.z();
}
/**
* Execute the graph using the current arrays/state and return the array for the final variable in the graph.
* After execution, the arrays for other variables can be obtained using {@link #getArrForVarName(String)}
* or {@link SDVariable#getArr()}
*
* Note: If the final operation has multiple output variables, use {@link #execAndEndResults()} instead
*
* @return The output of the final operation in the graph after execution
*/
public INDArray execAndEndResult() {
List exec = exec().getRight();
val finalOp = exec.get(exec.size() - 1);
val output = finalOp.outputVariables();
if (output.length > 1) {
throw new ND4JIllegalStateException(finalOp.opName() + " has multiple outputs. Use execAndEndResults instead.");
}
return output[0].getArr();
}
/**
* Execute the graph using the current arrays/state and return the array(s) for the final variable in the graph.
* After execution, the arrays for other variables can be obtained using {@link #getArrForVarName(String)}
* or {@link SDVariable#getArr()}
*
* @return The outputs of the final operation in the graph, after execution
*/
public INDArray[] execAndEndResults() {
List exec = exec().getRight();
val finalOp = exec.get(exec.size() - 1);
val output = finalOp.outputVariables();
INDArray outArrays[] = new INDArray[output.length];
for (int i = 0; i < outArrays.length; i++) {
outArrays[i] = output[i].getArr();
}
return outArrays;
}
/**
* Execute the graph using the current arrays/state and return the (first, and possibly only) array for the specified
* variable in the graph.
* After execution, the arrays for other variables can be obtained using {@link #getArrForVarName(String)}
* or {@link SDVariable#getArr()}
*
* @return The output of the final operation in the graph
*/
public INDArray execAndEndResult(int outputIndex) {
List exec = exec().getRight();
val output = exec.get(exec.size() - 1).outputVariables()[outputIndex];
return output.getArr();
}
public INDArray yetAnotherExecMethod(@NonNull Map inputs) {
if (!wasRegistered.get()) {
synchronized (this) {
if (!wasRegistered.get()) {
val bb = asFlatBuffers();
val ptr = new BytePointer(bb);
Nd4j.getExecutioner().registerGraph(this.hashCode(), ptr);
wasRegistered.set(true);
}
}
}
val newMap = new LinkedHashMap();
val keySet = inputs.keySet();
for (val key : keySet) {
val vx = variableMap.get(key);
newMap.put(vx.getVarName(), inputs.get(key));
}
val result = Nd4j.getExecutioner().executeGraph(this.hashCode(), newMap, this.reverseMap);
if (result.size() == 0)
throw new ND4JIllegalStateException("Execution failed");
val list = new ArrayList(result.values());
return list.get(list.size() - 1);
}
/**
* Executes the list of operations.
* This exec method is for only invoking operations rather than creating them
*
* @param ops the list of already created ops
* @return the passes in list
*/
public List exec(List ops) {
for (int i = 0; i < ops.size(); i++) {
Op op = (Op) ops.get(i);
Nd4j.getExecutioner().exec(op);
}
return ops;
}
public TensorList getListByName(@NonNull String name) {
return lists.get(name);
}
public void putListByName(@NonNull String name, TensorList list) {
lists.put(name, list);
}
/**
* Creates a while statement
*
* @param sameDiffConditional
* @param loopBody
* @return
*/
public While whileStatement(SameDiffConditional sameDiffConditional,
SameDiffFunctionDefinition conditionBody,
SameDiffFunctionDefinition loopBody
, SDVariable[] inputVars) {
return While.builder()
.inputVars(inputVars)
.condition(conditionBody)
.predicate(sameDiffConditional)
.trueBody(loopBody)
.parent(this)
.blockName("while-" + UUID.randomUUID().toString())
.build();
}
/**
* @param conditional
* @param trueBody
* @param falseBody
* @return
*/
public If ifStatement(SameDiffConditional conditional,
SameDiffFunctionDefinition conditionBody,
SameDiffFunctionDefinition trueBody,
SameDiffFunctionDefinition falseBody
, SDVariable[] inputVars) {
return If.builder()
.conditionBody(conditionBody)
.falseBody(falseBody)
.trueBody(trueBody)
.predicate(conditional)
.inputVars(inputVars)
.parent(this)
.blockName("if-" + UUID.randomUUID().toString())
.build();
}
public TensorArrayV3 tensorArray() {
return new TensorArrayV3(this);
}
/**
* @param functionName
* @param with
*/
public SDVariable invokeFunctionOn(String functionName, SameDiff with) {
SameDiff instance = sameDiffFunctionInstances.get(functionName);
SDVariable ret = instance.invokeGraphOn(with);
return ret;
}
/**
* @param function
*/
public SameDiff defineFunction(String function, SameDiffFunctionDefinition functionDefinition, SDVariable[] variables) {
if (!sameDiffFunctionInstances.containsKey(function)) {
SameDiff sub = SameDiff.create();
sub.workspace = (workspace);
this.child = sub;
sub.parent = this;
//setup subgraph
//re execute to populate subgraph
SDVariable[] ret = new SDVariable[variables.length];
for (int i = 0; i < ret.length; i++) {
ret[i] = sub.var(variables[i]);
}
sub.inputs = ret;
sub.outputs = functionDefinition.define(sub, null, ret);
sameDiffFunctionInstances.put(function, sub);
}
this.child = null;
return sameDiffFunctionInstances.get(function);
}
/**
* @param function
*/
public void defineFunction(String function, SameDiffFunctionDefinition functionDefinition) {
defineFunction(function, functionDefinition, new LinkedHashMap());
}
/**
* @param function
* @param functionDefinition
* @param inputs
*/
public void defineFunction(String function,
SameDiffFunctionDefinition functionDefinition,
Map inputs) {
if (!sameDiffFunctionInstances.containsKey(function)) {
SameDiff sub = SameDiff.create();
sub.workspace = (workspace);
//setup subgraph
//re execute to populate subgraph
functionDefinition.define(sub, inputs, null);
sameDiffFunctionInstances.put(function, sub);
}
}
/**
* Exec a given SameDiff function instance
*
* @param functionName the name of the SameDiff function instance to invoke
* @return Output of the final variable after execution
*/
public INDArray execAndEndResult(String functionName) {
return sameDiffFunctionInstances.get(functionName).execAndEndResult();
}
/**
* Execute the specified SameDiff function instance
*
* @param functionName the name of the SameDiff function instance to invoke
* @return
*/
public Pair