All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.samediff.SameDiff Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.autodiff.samediff;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import com.google.common.primitives.Ints;
import com.google.flatbuffers.FlatBufferBuilder;
import com.rits.cloning.Cloner;
import com.rits.cloning.IFastCloner;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.bytedeco.javacpp.BytePointer;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
import org.nd4j.autodiff.functions.FunctionProperties;
import org.nd4j.autodiff.samediff.flow.FlowPath;
import org.nd4j.autodiff.util.cloner.DataBufferFastCloner;
import org.nd4j.autodiff.util.cloner.INDArrayFastCloner;
import org.nd4j.base.Preconditions;
import org.nd4j.graph.*;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.*;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.distances.CosineSimilarity;
import org.nd4j.linalg.api.ops.impl.accum.distances.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.accum.distances.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.controlflow.If;
import org.nd4j.linalg.api.ops.impl.controlflow.While;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.*;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUConfiguration;
import org.nd4j.linalg.api.ops.impl.shape.Eye;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.BaseTensorOp;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayV3;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.temp.ExternalErrorsFunction;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.collection.IntArrayKeyMap;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.lossfunctions.impl.*;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.list.compat.TensorList;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ConstantInitScheme;
import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;

import java.io.*;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * SameDiff is the
 * entrypoint for
 * nd4j's autodiff.
 * 

* You define a graph symbolically. *

* That graph accumulates operations. *

* In order to execute the graph, you run * {@link #exec()} to get all the operations * {@link #exec(List)} for an already created set of ops * {@link #execAndEndResult()} for the end result only * {@link #execAndEndResult(List)} for a cached set of ops */ @AllArgsConstructor @Builder @Slf4j public class SameDiff { private Map incomingArgsReverse; //Key: DifferentialFunction.getOwnName(). Value: name of SDVariables as inputs to that function private Map outgoingArgsReverse; //Key: DifferentialFunction.getOwnName(). Value: name of SDVariables as outputs from that function private Map permuteOrder; private boolean shouldBootStrap = true; private Set importedVarName; //map a function's instance id to a base name, used for propagating variable names //for output during import private Map baseNameForFunctionInstanceId; private DifferentialFunctionFactory functionFactory; private Map variableMap; //Key: SDVariable name. Value: SDVariable private Map variableNameToShape; //Key: SDVariable name. Value: shape for that variable //gradient information private Map gradients; //Key: private Map forwardVarForGrad; private Map variableNameToArr; //Key: name of SDVariable. Value: Array for that variable //individual index for variable names private Map> functionsArgsFor; //Key: SDVariable name. Value: all DifferentialFunctions it is an input to private Map> functionOutputFor; //Key: SDVariable name. Value: DifferentialFunctions this variable is an output for (TODO: Why is this a list? Isn't it *always* length 1?) private Map lists = new HashMap<>(); // Key - node name; Value - TensorList // this entity holds runtime information for Switch/Merge/NextIteration etc stuff private transient ThreadLocal localFlowPath = new ThreadLocal(); // here we save String -> Integer conversion to variables private transient Map reverseMap = null; // counter for auto-naming variables private int variableId = 0; /** * For import, many times we have variables * that map to properties. Most common * we will have an input to a function that is mapped to an ndarray. * That ndarray is usually a scalar shape. *

* That array with a scalar shape can be something like an axis. *

* We often don't know that array's value till run time. * This map stores variable names that we should resolve * from samediff. We use the value of that array * to update the properties. */ private Map> propertiesToResolve; /** * A map of own name to * the properties of the function (things like execution axes etc) * The valid values can be: * int * long * INDArray */ private Map> propertiesForFunction; private Map> placeHolderMap; private Map placeHolderOriginalShapes; private Set placeHolderVarNames; private MemoryWorkspace workspace; private Map sameDiffFunctionDefinitionMap; private Map sameDiffFunctionInstances; private Set placeHolderFunctions; private static Cloner cloner = newCloner(); private static Map opMethods; @Getter private Map functionInstancesById; private Table fieldVariableResolutionMapping; // flag, shows if graph was already registered with libnd4j private transient AtomicBoolean wasRegistered = new AtomicBoolean(false); //debug mode variables @Getter private boolean debugMode; private Map opsForResult; private boolean resolvedVariables = false; @Getter @Setter boolean logExecution = true; @Getter private SameDiff parent; @Getter private SameDiff child; static { opMethods = new HashMap<>(); Method[] methods = SameDiff.class.getDeclaredMethods(); for (Method method : methods) { if (method.getReturnType().equals(SDVariable.class)) { opMethods.put(method.getName(), method); } } } /** * @return New cloner object. NOTE: INTENDED FOR DEVELOPER USE ONLY */ public static Cloner newCloner() { Cloner cloner = new Cloner(); //Implement custom cloning for INDArrays (default can have problems with off-heap and pointers) //Sadly: the cloner library does NOT support interfaces here, hence we need to use the actual classes //cloner.registerFastCloner(INDArray.class, new INDArrayFastCloner()); //Does not work due to interface IFastCloner fc = new INDArrayFastCloner(); cloner.registerFastCloner(Nd4j.getBackend().getNDArrayClass(), fc); //Same thing with DataBuffers: off heap -> cloner library chokes on them, but need to know the concrete // buffer classes, not just the interface IFastCloner fc2 = new DataBufferFastCloner(); DataBufferFactory d = Nd4j.getDataBufferFactory(); doReg(cloner, fc2, d.intBufferClass()); doReg(cloner, fc2, d.longBufferClass()); doReg(cloner, fc2, d.halfBufferClass()); doReg(cloner, fc2, d.floatBufferClass()); doReg(cloner, fc2, d.doubleBufferClass()); doReg(cloner, fc2, CompressedDataBuffer.class); return cloner; } private static void doReg(Cloner cl, IFastCloner fc, Class c) { if (c != null) cl.registerFastCloner(c, fc); } /** * Update the opName for the variable * with the given vertex id * * @param varName the vertex id to update * @param withName thew new opName */ public void updateVariableName(String varName, String withName) { SDVariable oldVarNameRef = getVariable(varName); variableMap.remove(oldVarNameRef.getVarName()); val oldVarName = varName; oldVarNameRef.setVarName(withName); variableMap.put(withName, oldVarNameRef); for (val reverseValues : outgoingArgsReverse.entrySet()) { for (int i = 0; i < reverseValues.getValue().length; i++) { if (reverseValues.getValue()[i].equals(oldVarName)) { reverseValues.getValue()[i] = withName; } } } for (val reverseValues : incomingArgsReverse.entrySet()) { for (int i = 0; i < reverseValues.getValue().length; i++) { if (reverseValues.getValue()[i].equals(oldVarName)) { reverseValues.getValue()[i] = withName; } } } if (variableNameToArr.containsKey(oldVarName)) { val arr = variableNameToArr.remove(oldVarName); variableNameToArr.put(withName, arr); } if (variableNameToShape.containsKey(oldVarName)) { val shape = variableNameToShape.remove(oldVarName); variableNameToShape.put(withName, shape); } if (gradients.containsKey(oldVarName)) { val grad = gradients.remove(oldVarName); gradients.put(withName, grad); } if (forwardVarForGrad.containsKey(oldVarName)) { val forwardGrad = forwardVarForGrad.remove(oldVarName); forwardVarForGrad.put(withName, forwardGrad); } if (placeHolderMap.containsKey(oldVarName)) { val placeholders = placeHolderMap.remove(oldVarName); placeHolderMap.put(withName, placeholders); } if (functionsArgsFor.containsKey(oldVarName)) { val funcs = functionsArgsFor.remove(oldVarName); for (val func : funcs) { if (func instanceof BaseOp) { BaseOp baseOp = (BaseOp) func; if (baseOp.getXVertexId() != null && baseOp.getXVertexId().equals(oldVarName)) { baseOp.setXVertexId(withName); } if (baseOp.getYVertexId() != null && baseOp.getYVertexId().equals(oldVarName)) { baseOp.setYVertexId(withName); } if (baseOp.getZVertexId() != null && baseOp.getZVertexId().equals(oldVarName)) { baseOp.setZVertexId(withName); } } } functionsArgsFor.put(withName, funcs); } if (functionOutputFor.containsKey(oldVarName)) { val funcs = functionOutputFor.remove(oldVarName); for (val func : funcs) { if (func instanceof BaseOp) { BaseOp baseOp = (BaseOp) func; if (baseOp.getXVertexId() != null && baseOp.getXVertexId().equals(oldVarName)) { baseOp.setXVertexId(withName); } if (baseOp.getYVertexId() != null && baseOp.getYVertexId().equals(oldVarName)) { baseOp.setYVertexId(withName); } if (baseOp.getZVertexId() != null && baseOp.getZVertexId().equals(oldVarName)) { baseOp.setZVertexId(withName); } } } functionOutputFor.put(withName, funcs); } variableMap.remove(oldVarName); } /** * Clears debugging state and disables debug mode. */ public SameDiff disableDebugging() { debugMode = false; return this; } /** * Enables tracing of graphs automatically. */ public SameDiff enableDebugMode() { debugMode = true; return this; } /** * Returns this samediff instance's {@link DifferentialFunctionFactory} * * @return */ public DifferentialFunctionFactory f() { return functionFactory; } /** * @param sameDiff * @return */ public SDVariable invokeGraphOn(SameDiff sameDiff) { //map the new vertices on to the old ones Map thisVertexIdToNew = new HashMap<>(); int idx = 1; for (val var : variables()) { val clone = cloner.deepCloneDontCloneInstances(var, var.getSameDiff()); val newVar = sameDiff.var(clone); if (var.getArr() != null) { sameDiff.associateArrayWithVariable(var.getArr(), newVar); } thisVertexIdToNew.put(idx, idx); clone.setSameDiff(sameDiff); idx++; } val newFunctions = new LinkedHashMap(); for (DifferentialFunction function : functionInstancesById.values()) { if (function instanceof SDVariable) { continue; } DifferentialFunction clone = cloner.deepCloneDontCloneInstances( function, function.getSameDiff()); clone.setSameDiff(sameDiff); clone.setOwnName(function.getOwnName()); if (sameDiff.functionExists(function.getOwnName())) sameDiff.putFunctionForId(function.getOwnName(), function); newFunctions.put(function.getOwnName(), clone); val argsForFunction = function.args(); val outputsForFunction = function.outputVariables(); //note that these have the same variable names sameDiff.addArgsFor(argsForFunction, clone); sameDiff.addOutgoingFor(outputsForFunction, function); for (val arg : clone.args()) { arg.setSameDiff(sameDiff); } for (val output : clone.outputVariables()) { output.setSameDiff(sameDiff); } sameDiff.functionInstancesById.put(function.getOwnName(), function); } return sameDiff.variables().get(sameDiff.variables().size() - 1); } /** * Returns true if the given function id exists * * @param id the function id to test for * @return true if the function id exists, false otherwise */ public boolean functionExists(String id) { return functionInstancesById.containsKey(id); } /** * Get the function by the {@link DifferentialFunction#getOwnName()} * * @param id the id of the function * @return the function for the given id if it exists */ public DifferentialFunction getFunctionById(String id) { if (!functionInstancesById.containsKey(id)) { throw new ND4JIllegalStateException("No function with id " + id + " found!"); } return functionInstancesById.get(id); } /** * Put the function for the given id * * @param id the id of the function * @param function the function */ public void putFunctionForId(String id, DifferentialFunction function) { if (functionInstancesById.containsKey(id)) { throw new ND4JIllegalStateException("Function by id already exists!"); } else if (function instanceof SDVariable) { throw new ND4JIllegalStateException("Function must not be a variable!"); } functionInstancesById.put(id, function); } /** * Returns the name(s) of the inputs for the given function * * @param function the function to get the inputs for * @return the input ids for a given function */ public String[] getInputsForFunction(DifferentialFunction function) { if (!incomingArgsReverse.containsKey(function.getOwnName())) throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName()); return incomingArgsReverse.get(function.getOwnName()); } /** * Returns the name(s) of the outputs for the given function * * @param function the function to get the outputs for * @return the outputs ids for a given function */ public String[] getOutputsForFunction(DifferentialFunction function) { return outgoingArgsReverse.get(function.getOwnName()); } /** * Get the output variable(s) for the specified differential function * * @param function the function reference to get the output variable(s) for * @return the output variables for the given function */ public SDVariable[] getOutputVariablesForFunction(DifferentialFunction function) { val inputs = getOutputsForFunction(function); if (inputs == null) { throw new ND4JIllegalStateException("No inputs found for function " + function); } val vars = new SDVariable[inputs.length]; for (int i = 0; i < inputs.length; i++) { vars[i] = getVariable(inputs[i]); } return vars; } /** * Get the input variable(s) for the specified differential function * * @param function the function reference to get the input variable(s) for * @return the input variables for the given function */ public SDVariable[] getInputVariablesForFunction(DifferentialFunction function) { val inputs = getInputsForFunction(function); if (inputs == null) { throw new ND4JIllegalStateException("No inputs found for function " + function); } val vars = new SDVariable[inputs.length]; for (int i = 0; i < inputs.length; i++) { vars[i] = getVariable(inputs[i]); if (vars[i] == null) { throw new ND4JIllegalStateException("Found null variable at index " + i); } } return vars; } /** * Update the INDArray for the given variable. Note that the array must exist to use this method. * * @param varName Name of the variable to update the array for * @param arr Array to update * @throws ND4JIllegalStateException when the array does not exist. * @see #putArrayForVarName(String, INDArray) * @see #putOrUpdateShapeForVarName(String, long[], boolean) */ public void updateArrayForVarName(String varName, INDArray arr) { if (!variableNameToArr.containsKey(varName)) { throw new ND4JIllegalStateException("Array for " + varName + " does not exist. Please use putArrayForVertexId instead."); } variableNameToArr.put(varName, arr); } /** * Adds an INDArray for a given variable name. * Use {@link #updateArrayForVarName(String, INDArray)} if the array already exists. * * @param varName the vertex id to add * @param arr the array to add * @throws ND4JIllegalStateException when the array already exists. * @see #putOrUpdateShapeForVarName(String, long[], boolean) */ public void putArrayForVarName(String varName, INDArray arr) { if (varName == null) throw new ND4JIllegalStateException("No null names allowed!"); if (variableNameToArr.containsKey(varName)) { throw new ND4JIllegalStateException("Array for " + varName + " already exists!"); } variableNameToArr.put(varName, arr); } /** * Put the array if it does not exist for the given variable name, or update it if it does * @param varName Variable name * @param arr Array */ public void putOrUpdateArrayForVarName(@NonNull String varName, INDArray arr){ if(variableNameToArr.containsKey(varName)){ updateArrayForVarName(varName, arr); } else { putArrayForVarName(varName, arr); } } /** * Get the shape for the given vertex id. * Note that if an array is defined, it will use the shape of the array instead. *

* A shape *and* an array should not be defined at the same time. * This wastes memory. The internal map used for tracking shapes for particular * vertex ids should also delete redundant shapes stored to avoid redundant sources of information. * * @param varName the vertex id to get the shape for * @return the shape for the given vertex if if any. */ public long[] getShapeForVarName(String varName) { if (variableNameToArr.containsKey(varName)) { return variableNameToArr.get(varName).shape(); } return variableNameToShape.get(varName); } /** * Update a vertex id with the given shape.
* Note that you should use {@link #putShapeForVarName(String, long[])} if you want to add a new shape. * Update is meant to be an in place replacement of the shape for the vertex id *only*. * * @param varName the vertex id to associate * @param shape the shape to associate with * @see #putShapeForVarName(String, long[]) * @see #putOrUpdateShapeForVarName(String, long[], boolean) */ public void updateShapeForVarName(String varName, long[] shape) { updateShapeForVarName(varName, shape, false); } /** * Update a vertex id with the given shape.
* Note that you should use {@link #putShapeForVarName(String, long[])} if you want to add a new shape. * Update is meant to be an in place replacement of the shape for the vertex id *only*. * * @param varName the vertex id to associate * @param shape the shape to associate with * @param clearArrayOnShapeMismatch boolean to indicate whether to clear the variable on shape mismatch * @see #putShapeForVarName(String, long[]) * @see #putOrUpdateShapeForVarName(String, long[], boolean) */ public void updateShapeForVarName(String varName, long[] shape, boolean clearArrayOnShapeMismatch) { if (shape == null) { throw new ND4JIllegalStateException("Null shapes not allowed!"); } if (variableNameToArr.containsKey(varName) && !Arrays.equals(variableNameToArr.get(varName).shape(), shape)) { if(clearArrayOnShapeMismatch){ if(log.isTraceEnabled()){ log.trace("Clearing array for variable {}: array shape {}, new shape {}", varName, Arrays.toString(variableNameToArr.get(varName).shape()), Arrays.toString(shape)); } variableNameToArr.remove(varName); } else { throw new ND4JIllegalStateException("Already found an existing array for variable \"" + varName + "\" with shape " + Arrays.toString(variableNameToArr.get(varName).shape()) + " - attempting to put new array shape " + Arrays.toString(shape)); } } for (int i = 0; i < shape.length; i++) { if (shape[i] < 1) { addAsPlaceHolder(varName); placeHolderOriginalShapes.put(varName, shape); return; } } if(log.isTraceEnabled()){ long[] pShape = variableNameToShape.get(varName); log.trace("Updated shape for variable \"{}\": previous shape {}, new shape {}", varName, (pShape == null ? "" : Arrays.toString(pShape)), Arrays.toString(shape)); } variableNameToShape.put(varName, shape); } /** * Associate a vertex id with the given shape. * * @param varName the vertex id to associate * @param shape the shape to associate with * @see #putShapeForVarName(String, long[]) * @see #putOrUpdateShapeForVarName(String, long[], boolean) */ public void putShapeForVarName(String varName, long[] shape) { if (shape == null) { throw new ND4JIllegalStateException("Shape must not be null!"); } if (variableNameToShape.containsKey(varName)) { throw new ND4JIllegalStateException("Shape for " + varName + " already exists!"); } for (int i = 0; i < shape.length; i++) { if (shape[i] < 1) { addAsPlaceHolder(varName); placeHolderOriginalShapes.put(varName, shape); return; } } variableNameToShape.put(varName, shape); } /** * Put or update the shape for the given variable name. Optionally supports clearing the specified variable's * INDArray if it's shape does not match the new shape * @param varName Variable name * @param shape Shape to put * @param clearArrayOnShapeMismatch If false: no change to arrays. If true: if an INDArray is defined for the specified * variable name, it will be removed from the graph (to be later re-generated) if * its shape does not match the specified shape */ public void putOrUpdateShapeForVarName(String varName, @NonNull long[] shape, boolean clearArrayOnShapeMismatch){ if(variableNameToShape.containsKey(varName)){ updateShapeForVarName(varName, shape, clearArrayOnShapeMismatch); } else { putShapeForVarName(varName, shape); } } /** * Returns true if the given vertex id and shape already exist. * * @param varName the vertex id * @return true if the ndarray and vertex id already exist */ public boolean shapeAlreadyExistsForVarName(String varName) { return variableNameToShape.containsKey(varName) || arrayAlreadyExistsForVarName(varName); } /** * Returns true if the given vertex id and {@link INDArray} already exist. * * @param varName the vertex id * @return true if a vertex with the given INDArray exists, and it has an INDArray associated with it */ public boolean arrayAlreadyExistsForVarName(String varName) { return variableNameToArr.containsKey(varName); } /** * Get an {@link INDArray} for a given vertex id, or null if none exists * * @param varName Variable name to get the array for * @return Array, or null if none exists */ public INDArray getArrForVarName(String varName) { return variableNameToArr.get(varName); } /** * Associate the array with the given variable. * * @param arr the array to get the variable for * @param variable the name of the variable to associate the array with */ public void associateArrayWithVariable(INDArray arr, @NonNull String variable) { associateArrayWithVariable(arr, this.getVariable(variable)); } /** * Associate the array with the given variable. * * @param arr the array to get the variable for * @param variable the variable to associate the array with */ public void associateArrayWithVariable(INDArray arr, SDVariable variable) { if (variable == null) { throw new ND4JIllegalArgumentException("Variable must not be null!"); } if (arr == null) { throw new ND4JIllegalArgumentException("Array must not be null"); } variableNameToArr.put(variable.getVarName(), arr); putOrUpdateShapeForVarName(variable.getVarName(), arr.shape(), true); // invalidate exec cache exec_cache = null; //Also update nested SameDiff instances (such as gradient function) if(sameDiffFunctionInstances != null && sameDiffFunctionInstances.size() > 0){ for(Map.Entry e : sameDiffFunctionInstances.entrySet()){ SameDiff sd = e.getValue(); if(sd.variableNameToArr != null && sd.variableNameToArr.containsKey(variable.getVarName())){ sd.associateArrayWithVariable(arr, variable); } } } } /** * Associate a {@link SameDiff} namespace as a sub function. * * @param name the opName of the function * @param nameSpace the namespace */ public void putSubFunction(String name, SameDiff nameSpace) { if (sameDiffFunctionInstances.containsKey(name) && sameDiffFunctionInstances.get(name) != nameSpace) { throw new ND4JIllegalStateException("Unable to replace samediff namespace. Please choose another opName"); } sameDiffFunctionInstances.put(name, nameSpace); } /** * Return the internal variable map * * @return Map of variables by name */ public Map variableMap() { return variableMap; } /** * Invoke an op by opName * * @param op the op * @param x the first input * @param y the second input * @return the result variable */ public SDVariable invoke(Op op, SDVariable x, SDVariable y) { if (!opMethods.containsKey(op.opName())) { throw new ND4JIllegalStateException("Illegal method opName " + op.opName()); } if (x != null && y != null) { try { return (SDVariable) opMethods.get(op.opName()).invoke(this, x, y); } catch (Exception e) { } } else { try { return (SDVariable) opMethods.get(op.opName()).invoke(this, x); } catch (Exception e) { } } throw new ND4JIllegalStateException("Illegal method opName " + op.opName()); } /** * The set of defined SameDiff function names. SameDiff function instances should not be confused * with DifferentialFunction ops; an example of a SameDiff function instance is the gradient "grad" function * * @return Set of defined SameDiff function instance names */ public Collection definedFunctionNames() { return this.sameDiffFunctionInstances.keySet(); } /** * Returns the number of bytes for the graph. Calculated as sum_i prod(shapeOf(variable[i])) * * @return Bytes for all of the arrays in the graph for the current variable shapes */ public long memoryForGraph() { return numElements() * DataTypeUtil.lengthForDtype(Nd4j.dataType()); } /** * Invoke an op by opName * * @param op the op * @param x the first input * @return the result variable */ public SDVariable invoke(Op op, SDVariable x) { return invoke(op, x, null); } private SameDiff() { functionFactory = new DifferentialFunctionFactory(this); variableMap = new LinkedHashMap<>(); sameDiffFunctionDefinitionMap = new LinkedHashMap<>(); sameDiffFunctionInstances = new LinkedHashMap<>(); gradients = new LinkedHashMap<>(); forwardVarForGrad = new LinkedHashMap<>(); opsForResult = new IntArrayKeyMap<>(); variableNameToArr = new LinkedHashMap<>(); variableNameToShape = new LinkedHashMap<>(); placeHolderMap = new LinkedHashMap<>(); placeHolderVarNames = new LinkedHashSet<>(); placeHolderOriginalShapes = new LinkedHashMap<>(); incomingArgsReverse = new LinkedHashMap<>(); outgoingArgsReverse = new LinkedHashMap<>(); functionInstancesById = new LinkedHashMap<>(); placeHolderFunctions = new LinkedHashSet<>(); functionsArgsFor = new LinkedHashMap<>(); functionOutputFor = new LinkedHashMap<>(); baseNameForFunctionInstanceId = new LinkedHashMap<>(); importedVarName = new LinkedHashSet<>(); permuteOrder = new LinkedHashMap<>(); propertiesToResolve = new LinkedHashMap<>(); propertiesForFunction = new LinkedHashMap<>(); fieldVariableResolutionMapping = HashBasedTable.create(); } /** * Adds a property that needs to be resolve for later. * These variables are typically values that are arrays * that are named but have an unknown value till execution time. *

* This is very common for model import. * * @param forFunction the function to add the property to resolve for * @param arrayName the array name */ public void addPropertyToResolve(DifferentialFunction forFunction, String arrayName) { if (!propertiesToResolve.containsKey(forFunction.getOwnName())) { List newVal = new ArrayList<>(); newVal.add(arrayName); propertiesToResolve.put(forFunction.getOwnName(), newVal); } else { List newVal = propertiesToResolve.get(forFunction.getOwnName()); newVal.add(arrayName); } } /** * Return the properties to resolve for the given function. * This is typically used right before execution in model import in * {@link DifferentialFunction#resolvePropertiesFromSameDiffBeforeExecution()} * * @param function the function get the properties to resolve for * @return the properties to resolve for the given function */ public List propertiesToResolveForFunction(DifferentialFunction function) { if (!propertiesToResolve.containsKey(function.getOwnName())) return Collections.emptyList(); return propertiesToResolve.get(function.getOwnName()); } /** * Returns true if the given function has ndarray properties to resolve. * * @param function the function to check * @return true if the function has yet to be resolved properties */ public boolean hasPropertiesToResolve(DifferentialFunction function) { return propertiesToResolve.containsKey(function.getOwnName()); } /** * Get the property for a given function * * @param functionInstance the function to get the * property for * @param propertyName the name of the property to get * @param the inferred return type * @return the property for the given function */ public T getPropertyForFunction(DifferentialFunction functionInstance, String propertyName) { if (!propertiesForFunction.containsKey(functionInstance.getOwnName())) { return null; } else { val map = propertiesForFunction.get(functionInstance.getOwnName()); return (T) map.get(propertyName); } } /** * Add a property for the given function * * @param functionFor the function add a property for * @param propertyName the property name * @param property the property value */ public void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, INDArray property) { addPropertyForFunction(functionFor, propertyName, (Object) property); } /** * Add a property for the given function * * @param functionFor the function to add the property for * @param propertyName the name of the property to add the value for * @param property the property value to add */ public void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, long property) { addPropertyForFunction(functionFor, propertyName, (Object) property); } private void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, Object propertyValue) { if (!propertiesForFunction.containsKey(functionFor.getOwnName())) { Map fields = new LinkedHashMap<>(); fields.put(propertyName, propertyValue); propertiesForFunction.put(functionFor.getOwnName(), fields); } else { val fieldMap = propertiesForFunction.get(functionFor.getOwnName()); if (fieldMap.containsKey(propertyName)) { throw new ND4JIllegalStateException("Attempting to override property " + propertyName); } fieldMap.put(propertyName, propertyValue); } } /** * Adds a field name -> variable name mapping for a given function.
* This is used for model import where there is an unresolved variable at the time of calling any * {@link org.nd4j.imports.graphmapper.GraphMapper#importGraph(File)} * . *

* This data structure is typically accessed during {@link DifferentialFunction#resolvePropertiesFromSameDiffBeforeExecution()} *

* When a function attempts to resolve variables right before execution, there needs to be a way of knowing * which variable in a samediff graph should map to a function's particular field name * * @param function the function to map * @param fieldName the field name for the function to map * @param varName the variable name of the array to get from samediff */ public void addVariableMappingForField(DifferentialFunction function, String fieldName, String varName) { fieldVariableResolutionMapping.put(function.getOwnName(), fieldName, varName); } /** * Get the variable name to use * for resolving a given field * for a given function during import time. * This method is u sed during {@link DifferentialFunction#resolvePropertiesFromSameDiffBeforeExecution()} * * @param function the function to get the variable name for * @param fieldName the field name to resolve for * @return the resolve variable name if any */ public String getVarNameForFieldAndFunction(DifferentialFunction function, String fieldName) { return fieldVariableResolutionMapping.get(function.getOwnName(), fieldName); } /** * Returns true if the variable name is imported * * @param variableName the imported variable name * @return true if the name is imported, false otherwise */ public boolean isImportVariable(String variableName) { return importedVarName.contains(variableName); } /** * Marks a variable name as imported. * This is used in conjunction with model * import to ensure immutability * when referencing graph variables * mapped from an external source. * * @param varName the var name to add. */ public void addVarNameForImport(String varName) { importedVarName.add(varName); } /** * Sets a base name for the function id. * This is used for when calling {@link #generateOutputVariableForOp(DifferentialFunction, String)} * for ensuring original names for model import map to current samediff names * when names are generated. * * @param baseName the base name to add * @param function the function to declare a base name for. */ public void setBaseNameForFunctionInstanceId(String baseName, DifferentialFunction function) { baseNameForFunctionInstanceId.put(function.getOwnName(), baseName); } /** * Returns the base name for the given function * if any (may return null) * * @param function the function to get the base name for * @return the base name for the given function (if any) based * on the function's instance id. */ public String getBaseNameForFunction(DifferentialFunction function) { return baseNameForFunctionInstanceId.get(function.getOwnName()); } /** * Attempts to insert the {@link DifferentialFunction} reference in to this {@link SameDiff} instance. * If the given array field with the given index already exists, it will do a reference check to ensure that the 2 * array fields are the same. If not, an exception is thrown.
* If the instances are the same (by semantics, not reference) then it will just return the original instance. * This is to ensure that instances that are created are unique and reference checked. * * @param function the array field to attempt to create * @return Original instance */ public X setupFunction(X function) { Preconditions.checkNotNull(function, "Passed in function must not be null!"); if (function instanceof SDVariable) { if (function.getSameDiff() != this) { function.setSameDiff(this); } return function; } return function; } /** * Adds outgoing arguments to the graph for the specified DifferentialFunction * Also checks for input arguments and updates the graph adding an appropriate edge when the full graph is declared. * * @param variables Variables - arguments for the specified differential function * @param function Differential function */ public void addOutgoingFor(SDVariable[] variables, DifferentialFunction function) { String[] varNames = new String[variables.length]; for (int i = 0; i < varNames.length; i++) { varNames[i] = variables[i].getVarName(); } addOutgoingFor(varNames, function); } /** * Adds outgoing arguments to the graph for the specified DifferentialFunction * Also checks for input arguments and updates the graph adding an appropriate edge when the full graph is declared. * * @param varNames Name of the variables that are outputs of the specified differential function * @param function Differential function */ public void addOutgoingFor(String[] varNames, DifferentialFunction function) { if (function.getOwnName() == null) throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly"); if (outgoingArgsReverse.containsKey(function.getOwnName())) { throw new ND4JIllegalStateException("Outgoing arguments already declared for " + function); } if (varNames == null) throw new ND4JIllegalStateException("Var names can not be null!"); for (int i = 0; i < varNames.length; i++) { if (varNames[i] == null) throw new ND4JIllegalStateException("Variable name elements can not be null!"); } outgoingArgsReverse.put(function.getOwnName(), varNames); for (val resultName : varNames) { List funcs = functionOutputFor.get(resultName); if (funcs == null) { funcs = new ArrayList<>(); functionOutputFor.put(resultName, funcs); } funcs.add(function); } } /** * Adds incoming arguments for the specified differential function to the graph * * @param variables Name of the variables that are arguments (inputs) to the specified function * @param function Function */ public void addArgsFor(String[] variables, DifferentialFunction function) { if (function.getOwnName() == null) throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly"); //double check if function contains placeholder args for (val varName : variables) { if (isPlaceHolder(varName)) { placeHolderFunctions.add(function.getOwnName()); } } incomingArgsReverse.put(function.getOwnName(), variables); for (val variableName : variables) { List funcs = functionsArgsFor.get(variableName); if (funcs == null) { funcs = new ArrayList<>(); functionsArgsFor.put(variableName, funcs); } funcs.add(function); } } /** * Adds incoming arguments for the specified differential function to the graph * * @param variables variables that are arguments (inputs) to the specified function * @param function Function */ public void addArgsFor(SDVariable[] variables, DifferentialFunction function) { String[] varNames = new String[variables.length]; for (int i = 0; i < varNames.length; i++) { if (variables[i] == null) throw new ND4JIllegalStateException("Found null variable at index " + i); varNames[i] = variables[i].getVarName(); } addArgsFor(varNames, function); } /** * Get the differential function (if any) that this variable is the output for * * @param variableName Name of the variable * @return The differential function that this variable is an output of, or null if it is not the output of a function */ public DifferentialFunction getVariableOutputFunction(String variableName) { List list = functionOutputFor.get(variableName); if (list == null) { return null; } return list.get(0); } /** * Return a list of differential functions (if any) that this variable is the input argument for * * @param variableName Name of the variable * @return The differential functions that this variable is an input argument for, or null if it is not the input to any function */ public List getVariableArgOfFunctions(String variableName) { return functionsArgsFor.get(variableName); } /** * Returns true if this function already has defined arguments * * @param function the function to check * @return true if the function has args, false otherwise */ public boolean hasArgs(DifferentialFunction function) { String[] vertexIdArgs = incomingArgsReverse.get(function.getOwnName()); return vertexIdArgs != null && vertexIdArgs.length > 0; } /** * Get an array of differential functions that have been defined for this SameDiff instance * @return Array of differential functions */ public DifferentialFunction[] functions() { val ret = functionInstancesById.values(); return ret.toArray(new DifferentialFunction[ret.size()]); } @Override public int hashCode() { int result = super.hashCode(); result = 31 * result + (variableMap != null ? variableMap.hashCode() : 0); return result; } /** * Create a new SameDiff instance from an existing instance. * Note that state (variables and functions) is shared between the two SameDiff instance * * @param originalSameDiff Original SameDiff instance * @return Copy */ public static SameDiff create(SameDiff originalSameDiff) { SameDiff ret = SameDiff.builder() .variableMap(originalSameDiff.variableMap) .sameDiffFunctionInstances(originalSameDiff.sameDiffFunctionInstances) .build(); //ensuring proper sameDiff reference DifferentialFunctionFactory differentialFunctionFactory = new DifferentialFunctionFactory(ret); ret.functionFactory = differentialFunctionFactory; return ret; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; SameDiff sameDiff = (SameDiff) o; if (variableMap != null ? !variableMap.equals(sameDiff.variableMap) : sameDiff.variableMap != null) return false; if (sameDiffFunctionDefinitionMap != null ? !sameDiffFunctionDefinitionMap.equals(sameDiff.sameDiffFunctionDefinitionMap) : sameDiff.sameDiffFunctionDefinitionMap != null) return false; return sameDiffFunctionInstances != null ? sameDiffFunctionInstances.equals(sameDiff.sameDiffFunctionInstances) : sameDiff.sameDiffFunctionInstances == null; } /** * Create a new (empty) SameDiff instance without any functions or variables * @return New SameDiff instance */ public static SameDiff create() { return new SameDiff(); } /** * Evaluate the given inputs based on the current graph * * @param inputs the inputs to evaluate * @return */ public INDArray[] eval(Map inputs) { SameDiff execPipeline = dup(); List opExecAction = execPipeline.exec().getRight(); if (opExecAction.isEmpty()) throw new IllegalStateException("No ops found to execute."); INDArray[] ret = new INDArray[opExecAction.size()]; for (int i = 0; i < ret.length; i++) { val varName = opExecAction.get(i).outputVariables()[0].getVarName(); ret[i] = execPipeline.getArrForVarName(varName); } return ret; } /** * Clone/duplicate the SameDiff instance, including arrays etc. The returned SameDiff instance should have no * shared state with the original instance * @return The cloned SameDiff instance */ public SameDiff dup() { Cloner cloner = newCloner(); val clone = cloner.deepClone(this); return clone; } /** * Count the number of elements in all arrays, according to {@link SDVariable#getShape()} * @return Number of array elements for all variables */ public long numElements() { long ret = 0; for (SDVariable variable : variables()) { ret += ArrayUtil.prod(variable.getShape()); } return ret; } private void initWorkspace() { workspace = Nd4j.getWorkspaceManager().createNewWorkspace( WorkspaceConfiguration.builder() .initialSize(memoryForGraph()) .policyAllocation(AllocationPolicy.OVERALLOCATE) .policyLearning(LearningPolicy.FIRST_LOOP) .build()); Nd4j.getWorkspaceManager().setWorkspaceForCurrentThread(workspace); } /** * The list of all variables in the graph * * @return All variables in the graph */ public List variables() { return new ArrayList<>(variableMap.values()); } /** * Create a new variable with the specified shape, with all values initialized to 1.0 * * @param name the name of the variable to create * @param shape the shape of the array to be created * @return the created variable */ public SDVariable one(String name, int[] shape) { return var(name, ArrayUtil.toLongArray(shape), new ConstantInitScheme('f', 1.0)); } /** * Create a new variable with the specified shape, with all values initialized to 1.0 * * @param name the name of the variable to create * @param shape the shape of the array to be created * @return the created variable */ public SDVariable one(String name, long[] shape) { return var(name, shape, new ConstantInitScheme('f', 1.0)); } /** * Return a variable of all 1s, with the same shape as the input variable. Note that this is dynamic: * if the input shape changes in later execution, the returned variable's shape will also be updated * * @param input Input SDVariable * @return A new SDVariable with the same (dynamic) shape as the input */ public SDVariable onesLike(SDVariable input) { return onesLike(null, input); } /** * Return a variable of all 1s, with the same shape as the input variable. Note that this is dynamic: * if the input shape changes in later execution, the returned variable's shape will also be updated * * @param name Name of the new SDVariable * @param input Input SDVariable * @return A new SDVariable with the same (dynamic) shape as the input */ public SDVariable onesLike(String name, SDVariable input) { SDVariable ret = f().onesLike(name, input); return updateVariableNameAndReference(ret, name); } /** * Create a new variable with the specified shape, with all values initialized to 0 * * @param name the name of the variable to create * @param shape the shape of the array to be created * @return the created variable */ public SDVariable zero(String name, long[] shape) { return var(name, shape, new ZeroInitScheme()); } /** * Create a new variable with the specified shape, with all values initialized to 0 * * @param name the name of the variable to create * @param shape the shape of the array to be created * @return the created variable */ public SDVariable zero(String name, int[] shape) { return var(name, ArrayUtil.toLongArray(shape), new ZeroInitScheme()); } /** * Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic: * if the input shape changes in later execution, the returned variable's shape will also be updated * * @param input Input SDVariable * @return A new SDVariable with the same (dynamic) shape as the input */ public SDVariable zerosLike(SDVariable input) { return zerosLike(null, input); } /** * Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic: * if the input shape changes in later execution, the returned variable's shape will also be updated * * @param name Name of the new SDVariable * @param input Input SDVariable * @return A new SDVariable with the same (dynamic) shape as the input */ public SDVariable zerosLike(String name, SDVariable input) { SDVariable ret = f().zerosLike(name, input); return updateVariableNameAndReference(ret, name); } /** * Return a variable of given shape in which all values have a given constant value. * * @param value constant to set for each value * @param shape shape of the variable as long array * @return A new SDVariable of provided shape with constant value. */ public SDVariable constant(SDVariable value, long... shape) { return constant(null, value, shape); } /** * Return a variable of given shape in which all values have a given constant value. * * @param name Name of the new SDVariable * @param value constant to set for each value * @param shape shape of the variable as long array * @return A new SDVariable of provided shape with constant value. */ public SDVariable constant(String name, SDVariable value, long... shape) { SDVariable ret = f().constant(value, shape); return updateVariableNameAndReference(ret, name); } /** * Create a new 1d array with values evenly spaced between values 'start' and 'stop' * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0] * * @param start Start value * @param stop Stop value * @param number Number of values to generate * @return SDVariable with linearly spaced elements */ public SDVariable linspace(double start, double stop, long number) { return linspace(null, start, stop, number); } /** * Create a new 1d array with values evenly spaced between values 'start' and 'stop' * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0] * * @param name Name of the new variable * @param start Start value * @param stop Stop value * @param number Number of values to generate * @return SDVariable with linearly spaced elements */ public SDVariable linspace(String name, double start, double stop, long number) { SDVariable ret = f().linspace(start, stop, number); return updateVariableNameAndReference(ret, name); } /** * Create a new variable with a 1d array, where the values start at {@code from} and increment by {@code step} * up to (but not including) limit.
* For example, {@code range(1.0, 3.0, 0.5)} will return {@code [1.0, 1.5, 2.0, 2.5]} * @param from Initial/smallest value * @param to Largest value (exclusive) * @param step Step size * @return 1D SDVariable with the specified values */ public SDVariable range(double from, double to, double step){ return range(null, from, to, step); } /** * Create a new variable with a 1d array, where the values start at {@code from} and increment by {@code step} * up to (but not including) limit.
* For example, {@code range(1.0, 3.0, 0.5)} will return {@code [1.0, 1.5, 2.0, 2.5]} * @param name Name of the new variable * @param from Initial/smallest value * @param to Largest value (exclusive) * @param step Step size * @return 1D SDVariable with the specified values */ public SDVariable range(String name, double from, double to, double step){ SDVariable ret = f().range(from, to, step); return updateVariableNameAndReference(ret, name); } /** * @see #meshgrid(List, SDVariable...) */ public SDVariable[] meshgrid(SDVariable... inputs){ return meshgrid(null, inputs); } /** * Broadcast the 1D input variables onto an n-dimensional grid.
* The resulting variable can be used for example for evaluating functions at all locations on a grid.
* Example:
*

     * {@code input1 = [1, 2, 3]
     * input2 = [4, 5, 6]
     * SDVariable[] out = meshgrid(input1, input2)
     * out[0]:
     * [ 1, 2, 3]
     * [ 1, 2, 3]
     * [ 1, 2, 3]
     *
     * out[1]:
     * [ 4, 4, 4]
     * [ 5, 5, 5]
     * [ 6, 6, 6]}
     * 
*
* @param names List of names for the output variables. Must have exactly N names for N input arrays * @param inputs N x 1D input variables * @return an array of exactly N SDVariables (for N inputs), of rank N */ public SDVariable[] meshgrid(List names, SDVariable... inputs){ return meshgrid(names, true, inputs); } /** * @see #meshgrid(List, SDVariable...) */ public SDVariable[] meshgrid(List names, boolean cartesian, SDVariable... inputs){ Preconditions.checkState(names == null || names.size() == inputs.length, "Got %s names but %s inputs", (names == null ? 0 : names.size()), inputs.length); SDVariable[] ret = f().meshgrid(cartesian, inputs); for( int i=0; i * Any array will be generated with all zeros for the values * * @param name the name of the variable * @param shape the shape of the variable * @return the created variable */ public SDVariable var(String name, long... shape) { Preconditions.checkNotNull(shape != null, "Invalid shape: shape may not be null"); return var(name, shape, new ZeroInitScheme()); } /** * Creates a {@link SDVariable} with the given shape and name
* Any array will be generated with all zeros for the values * * @param name the name of the variable * @param shape the shape of the variable * @return the created variable */ public SDVariable var(String name, int... shape) { Preconditions.checkNotNull(shape != null, "Invalid shape: shape may not be null"); return var(name, ArrayUtil.toLongArray(shape), new ZeroInitScheme()); } /** * Initialize a {@link SDVariable} reference tying this variable to this samediff instance. *

* {@link NDArraySupplierInitScheme} is used to ensure that if the array is allocated anywhere * and {@link SameDiff} instance to exist as a copy of the variable. * * @param arr * @return */ public SDVariable var(final SDVariable arr) { if (variableMap.containsKey(arr.getVarName()) && variableMap.get(arr.getVarName()).getArr() != null) return variableMap.get(arr.getVarName()); if (arr.getVarName() == null || arr.getVarName().length() < 1) throw new IllegalArgumentException("Name for variable must be defined"); if (arr == null) throw new IllegalArgumentException("Array for " + arr.getVarName() + " must not be null"); if (workspace == null) initWorkspace(); final SDVariable ret = SDVariable.builder() .sameDiff(this) .shape(arr.getShape()) .varName(arr.getVarName()) .weightInitScheme(new NDArraySupplierInitScheme(new NDArraySupplierInitScheme.NDArraySupplier() { @Override /** * Pre allocate the array if it doesn't already exist. * The reason we do this is to avoid race conditions with * {@link #allocate()} */ public INDArray getArr() { if (arr.getArr() == null) { INDArray retArr = arr.getWeightInitScheme().create(arr.getShape()); associateArrayWithVariable(retArr, arr); } return arr.getArr(); } })) .build(); variableMap.put(arr.getVarName(), ret); return ret; } private String getNewVarName() { String varName = "sd_var_" + String.valueOf(variableId); while (variableMap.containsKey(varName)) { variableId++; varName = "sd_var_" + String.valueOf(variableId); } return varName; } /** * Creates a {@link SDVariable} with the specified shape and a generated name
* Any array will be generated with all zeros for the values * * @param shape the shape of the variable * @return the created variable */ public SDVariable var(int... shape) { return var(getNewVarName(), shape); } /** * Creates a {@link SDVariable} with the specified shape and a generated name
* Any array will be generated with all zeros for the values * * @param shape the shape of the variable * @return the created variable */ public SDVariable var(long... shape) { return var(getNewVarName(), shape); } /** * Creates a {@link SDVariable} with the specified shape and a generated name. The associated array will * then be generated using the specified weight initialization scheme * * @param weightInitScheme The weight initialization scheme to use when generating an INDArray * @param shape the shape of the variable * @return the created variable */ public SDVariable var(WeightInitScheme weightInitScheme, long... shape) { return var(getNewVarName(), shape, weightInitScheme); } /** * Create an {@link SDVariable} with a generated name, and assocate the specified array with it * @param arr Array to associate with the new variable * @return New SDVariable * @see #var(String, INDArray) */ public SDVariable var(INDArray arr) { return var(getNewVarName(), arr); } /** * Create an {@link SDVariable} with the specified name, and assocate the specified array with it * @param arr Array to associate with the new variable * @return New SDVariable with the specified name and array */ public SDVariable var(String name, INDArray arr) { if (variableMap.containsKey(name) && variableMap.get(name).getArr() != null) throw new IllegalArgumentException("Another variable with the name " + name + " already exists."); if (name == null || name.length() < 1) name = getNewVarName(); if (arr == null) throw new IllegalArgumentException("Array for " + name + " must not be null"); if (workspace == null) initWorkspace(); val arrRef = arr.migrate(); SDVariable ret = SDVariable.builder() .sameDiff(this) .shape(arr.shape()) .varName(name) .weightInitScheme(new NDArraySupplierInitScheme(new NDArraySupplierInitScheme.NDArraySupplier() { @Override /** * Return array */ public INDArray getArr() { return arrRef; } })) .build(); associateArrayWithVariable(arr, ret); if (ArrayUtil.prod(arr.shape()) == 1) ret.setScalarValue(arr.getDouble(0)); addVariable(ret); if (getShapeForVarName(name) == null) putShapeForVarName(name, arr.shape()); //ensure there is a reference to the array in the integer index //this is used later for op creation variableMap.put(name, ret); return ret; } /** * Generate a square identity matrix with the specified number of rows. * * @param rows Number of rows (and columns) * @return SDVariable with an identity matrix array */ public SDVariable eye(int rows) { return eye(rows, rows); } /** * Generate an identity matrix with the specified number of rows and columns. * * @param rows Number of rows */ public SDVariable eye(String name, int rows) { return eye(name, rows, rows); } /** * @see #eye(String, int, int) */ public SDVariable eye(int rows, int cols) { return eye(null, rows, cols); } /** * Generate an identity matrix with the specified number of rows and columns * Example:
*

     * {@code SDVariable eye = eye(3,2)
     * eye:
     * [ 1, 0]
     * [ 0, 1]
     * [ 0, 0]}
     * 
* * @param name Name of the new SDVariable * @param rows Number of rows * @param cols Number of columns * @return SDVaribable identity matrix */ public SDVariable eye(String name, int rows, int cols) { return eye(name, rows, cols, null); } /** * see {@link #eye(String, int, int, int...)} */ public SDVariable eye(int rows, int cols, int... batchDimension) { return eye(null, rows, cols, batchDimension); } /** * Generate an identity matrix with the specified number of rows and columns, with optional leading dims
* Example:
* batchShape: [3,3]
* numRows: 2
* numCols: 4
* returns a tensor of shape (3, 3, 2, 4) that consists of 3 * 3 batches of (2,4)-shaped identity matrices:
* 1 0 0 0
* 0 1 0 0
* * @param rows Number of rows * @param cols Number of columns * @param batchDimension Batch dimensions. May be null */ public SDVariable eye(String name, int rows, int cols, int... batchDimension) { SDVariable eye = new Eye(this, rows, cols, batchDimension).outputVariables()[0]; return updateVariableNameAndReference(eye, name); } /** * As per {@link #eye(String, int, int, int...)} bit with the number of rows/columns specified as scalar SDVariables, * and the batch dimension specified as a 1D SDVariable */ public SDVariable eye(String name, SDVariable rows, SDVariable cols, SDVariable batchDimension){ SDVariable eye = new Eye(this, rows, cols, batchDimension).outputVariable(); return updateVariableNameAndReference(eye, name); } /** * As per {@link #eye(int, int, int...)} bit with the number of rows/columns specified as scalar SDVariables, * and the batch dimension specified as a 1D SDVariable */ public SDVariable eye(SDVariable rows, SDVariable cols, SDVariable batchDimension){ return eye(null, rows, cols, batchDimension); } /** * As per {@link #eye(String, int, int)} bit with the number of rows/columns specified as scalar SDVariables */ public SDVariable eye(String name, SDVariable rows, SDVariable cols){ SDVariable eye = new Eye(this, rows, cols).outputVariables()[0]; return updateVariableNameAndReference(eye, name); } /** * As per {@link #eye(int, int)} bit with the number of rows/columns specified as scalar SDVariables */ public SDVariable eye(SDVariable rows, SDVariable cols){ SDVariable eye = new Eye(this, rows, cols).outputVariables()[0]; return updateVariableNameAndReference(eye, null); } /** * As per {@link #eye(String, int)} but with the number of rows specified as a scalar SDVariable */ public SDVariable eye(String name, SDVariable rows){ SDVariable eye = new Eye(this, rows).outputVariables()[0]; return updateVariableNameAndReference(eye, name); } /** * As per {@link #eye(int)} but with the number of rows specified as a scalar SDVariable */ public SDVariable eye(SDVariable rows){ SDVariable eye = new Eye(this, rows).outputVariables()[0]; return updateVariableNameAndReference(eye, null); } /** * Remove an argument for a function. Note that if this function does not contain the argument, it will just be a no op. * * @param varName the variable name to remove * @param function the function to remove the argument from */ public void removeArgFromFunction(String varName, DifferentialFunction function) { val args = function.args(); for (int i = 0; i < args.length; i++) { if (args[i].getVarName().equals(varName)) { /** * Since we are removing the variable reference * from the arguments we need to update both * the reverse and forward arguments. */ val reverseArgs = incomingArgsReverse.get(function.getOwnName()); incomingArgsReverse.remove(function.getOwnName()); val newArgs = new ArrayList(args.length - 1); for (int arg = 0; arg < args.length; arg++) { if (!reverseArgs[arg].equals(varName)) { newArgs.add(reverseArgs[arg]); } } val newArgsArr = newArgs.toArray(new String[newArgs.size()]); incomingArgsReverse.put(function.getOwnName(), newArgsArr); //no further need to scan break; } } } /** * Get the variable based on the opName * * @param name the opName of the variable * @return the variabel instance if there is one */ public SDVariable getVariable(String name) { return variableMap.get(name); } /** * Get the gradient for the given vertex id * * @param varName the vertex id * @return the gradient for this variable or null */ public SDVariable getGradForVariable(String varName) { //TODO 2018/06/26 - Review this? //Gradients are being placed in the inner "grad" function SameDiff instance, but not the outer one // should they be synced and we just use the map in this instance? if (gradients.containsKey(varName)) { return gradients.get(varName); } else if(sameDiffFunctionInstances.containsKey("grad") && sameDiffFunctionInstances.get("grad").gradients.containsKey(varName)){ return sameDiffFunctionInstances.get("grad").gradients.get(varName); } return null; } /** * Assign a SDVariable to represent the gradient of the SDVariable with the specified name * * @param variableName the variable name to assign the gradient variable for * @param variable the gradient variable */ public void setGradientForVariableName(String variableName, SDVariable variable) { if (variable == null) { throw new ND4JIllegalStateException("Unable to set null gradient for variable name " + variableName); } gradients.put(variableName, variable); } /** * Get the forward variable for gradient based on the gradient's vertex id * * @param vertexId the vertex id * @return the gradient for the variable or null */ public SDVariable getForwardVariableForVertexId(int vertexId) { return forwardVarForGrad.get(vertexId); } /** * @param varName * @param forwardVariable */ public void setForwardVariableForVarName(String varName, SDVariable forwardVariable) { forwardVarForGrad.put(varName, forwardVariable); } /** * Get the gradient for the variable with the specified variable name. * Note that in order to run this function, {@link #execBackwards()} must be executed first. * All gradient functions are obtained from the results of the execBackwards call. * * @param varName the variable name to get the gradient variable for. * @return The gradient variable for the specified variable */ public SDVariable grad(String varName) { if (!sameDiffFunctionInstances.containsKey("grad")) { throw new IllegalStateException("Unable to obtain gradient. Please run execBackwards() first."); } SameDiff grad = getFunction("grad"); SDVariable var = grad.getVariable(varName); return getFunction("grad").getGradForVariable(var.getVarName()); } /** * @see #randomUniform(String, double, double, SDVariable) */ public SDVariable randomUniform(double min, double max, SDVariable shape){ return randomUniform(null, min, max, shape); } /** * Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution, * U(min,max)
* See {@link #randomUniform(double, double, long...)} for the equivalent function where the shape is * specified as a long[] instead * * @param name Name of the new SDVariable * @param min Minimum value * @param max Maximum value. Must satisfy max >= min * @param shape Shape of the new random SDVariable, as a 1D array * @return New SDVariable */ public SDVariable randomUniform(String name, double min, double max, SDVariable shape){ SDVariable ret = f().randomUniform(min, max, shape); return updateVariableNameAndReference(ret, name); } /** * @see #randomUniform(String, double, double, long...) */ public SDVariable randomUniform(double min, double max, long... shape){ return randomUniform(null, min, max, shape); } /** * Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution, * U(min,max)
* See {@link #randomUniform(double, double, long...)} for the equivalent function where the shape is * specified as a SDVariable instead * * @param name Name of the new SDVariable * @param min Minimum value * @param max Maximum value. Must satisfy max >= min * @param shape Shape of the new random SDVariable * @return New SDVariable */ public SDVariable randomUniform(String name, double min, double max, long... shape){ SDVariable ret = f().randomUniform(min, max, shape); return updateVariableNameAndReference(ret, name); } /** * @see #randomNormal(String, double, double, SDVariable) */ public SDVariable randomNormal(double mean, double stddev, SDVariable shape){ return randomNormal(null, mean, stddev, shape); } /** * Generate a new random SDVariable, where values are randomly sampled according to a Gaussian (normal) distribution, * N(mean, stdev)
* See {@link #randomNormal(String, double, double, long...)} for the equivalent function where the shape is * specified as a long[] instead * * @param name Name of the new SDVariable * @param mean Mean value for the random array * @param stddev Standard deviation for the random array * @param shape Shape of the new random SDVariable, as a 1D array * @return New SDVariable */ public SDVariable randomNormal(String name, double mean, double stddev, SDVariable shape){ SDVariable ret = f().randomNormal(mean, stddev, shape); return updateVariableNameAndReference(ret, name); } /** * @see #randomNormal(String, double, double, long...) */ public SDVariable randomNormal(double mean, double stddev, long... shape){ return randomNormal(null, mean, stddev, shape); } /** * Generate a new random SDVariable, where values are randomly sampled according to a Gaussian (normal) distribution, * N(mean, stdev)
* See {@link #randomNormal(String, double, double, SDVariable)} for the equivalent function where the shape is * specified as a long[] instead * * @param name Name of the new SDVariable * @param mean Mean value for the random array * @param stddev Standard deviation for the random array * @param shape Shape of the new random SDVariable * @return New SDVariable */ public SDVariable randomNormal(String name, double mean, double stddev, long... shape){ SDVariable ret = f().randomNormal(mean, stddev, shape); return updateVariableNameAndReference(ret, name); } /** * @see #randomLogNormal(String, double, double, long...) */ public SDVariable randomLogNormal(double mean, double stddev, long... shape){ return randomLogNormal(null, mean, stddev, shape); } /** * Generate a new random SDVariable, where values are randomly sampled according to a Log Normal distribution, * i.e., {@code log(x) ~ N(mean, stdev)}
* * @param name Name of the new SDVariable * @param mean Mean value for the random array * @param stddev Standard deviation for the random array * @param shape Shape of the new random SDVariable * @return New SDVariable */ public SDVariable randomLogNormal(String name, double mean, double stddev, long... shape){ SDVariable ret = f().randomLogNormal(mean, stddev, shape); return updateVariableNameAndReference(ret, name); } /** * @see #randomNormalTruncated(String, double, double, long...) */ public SDVariable randomNormalTruncated(double mean, double stddev, long... shape){ return randomNormalTruncated(null, mean, stddev, shape); } /** * Generate a new random SDVariable, where values are randomly sampled according to a Gaussian (normal) distribution, * N(mean, stdev). However, any values more than 1 standard deviation from the mean are dropped and re-sampled
* * @param name Name of the new SDVariable * @param mean Mean value for the random array * @param stddev Standard deviation for the random array * @param shape Shape of the new random SDVariable * @return New SDVariable */ public SDVariable randomNormalTruncated(String name, double mean, double stddev, long... shape){ SDVariable ret = f().randomNormalTruncated(mean, stddev, shape); return updateVariableNameAndReference(ret, name); } /** * @see #randomBernoulli(String, double, SDVariable) */ public SDVariable randomBernoulli(double p, SDVariable shape){ return randomBernoulli(null, p, shape); } /** * Generate a new random SDVariable, where values are randomly sampled according to a Bernoulli distribution, * with the specified probability. Array values will have value 1 with probability P and value 0 with probability * 1-P.
* See {@link #randomBernoulli(String, double, long...)} for the equivalent function where the shape is * specified as a long[] instead * * @param name Name of the new SDVariable * @param p Probability of value 1 * @param shape Shape of the new random SDVariable, as a 1D array * @return New SDVariable */ public SDVariable randomBernoulli(String name, double p, SDVariable shape){ SDVariable ret = f().randomBernoulli(p, shape); return updateVariableNameAndReference(ret, name); } /** * @see #randomBernoulli(String, double, long...) */ public SDVariable randomBernoulli(double p, long... shape){ return randomBernoulli(null, p, shape); } /** * Generate a new random SDVariable, where values are randomly sampled according to a Bernoulli distribution, * with the specified probability. Array values will have value 1 with probability P and value 0 with probability * 1-P.
* See {@link #randomBernoulli(String, double, SDVariable)} for the equivalent function where the shape is * specified as a SDVarible instead * * @param name Name of the new SDVariable * @param p Probability of value 1 * @param shape Shape of the new random SDVariable, as a 1D array * @return New SDVariable */ public SDVariable randomBernoulli(String name, double p, long... shape){ SDVariable ret = f().randomBernoulli(p, shape); return updateVariableNameAndReference(ret, name); } /** * Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution, * with the specified number of trials and probability. * * @param nTrials Number of trials parameter for the binomial distribution * @param p Probability of success for each trial * @param shape Shape of the new random SDVariable, as a 1D array * @return New SDVariable */ public SDVariable randomBinomial(int nTrials, double p, long... shape){ return randomBinomial(null, nTrials, p, shape); } /** * Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution, * with the specified number of trials and probability. * * @param name Name of the new SDVariable * @param nTrials Number of trials parameter for the binomial distribution * @param p Probability of success for each trial * @param shape Shape of the new random SDVariable, as a 1D array * @return New SDVariable */ public SDVariable randomBinomial(String name, int nTrials, double p, long... shape){ SDVariable ret = f().randomBinomial(nTrials, p, shape); return updateVariableNameAndReference(ret, name); } /** * Generate a new random SDVariable, where values are randomly sampled according to a exponential distribution: * P(x) = lambda * exp(-lambda * x) * * @param lambda Must be > 0 * @param shape Shape of the output * @return new SDVariable */ public SDVariable randomExponential(double lambda, SDVariable shape) { return randomExponential(null, lambda, shape); } /** * Generate a new random SDVariable, where values are randomly sampled according to a exponential distribution: * P(x) = lambda * exp(-lambda * x) * * @param name Name of the output variable * @param lambda Must be > 0 * @param shape Shape of the new variable * @return new SDVaribale */ public SDVariable randomExponential(String name, double lambda, SDVariable shape) { SDVariable ret = f().randomExponential(lambda, shape); return updateVariableNameAndReference(ret, name); } /** * 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format. * * @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) * @param scale Scale to upsample in both H and W dimensions * @return Upsampled input */ public SDVariable upsampling2d(SDVariable input, int scale) { return upsampling2d(null, input, true, scale, scale); } /** * 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format. * * @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) * @param scale Scale to upsample in both H and W dimensions * @return Upsampled input */ public SDVariable upsampling2d(String name, SDVariable input, int scale) { return upsampling2d(name, input, true, scale, scale); } /** * 2D Convolution layer operation - Upsampling 2d * * @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) * or NHWC format (shape [minibatch, height, width, channels]) * @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format * @param scaleH Scale to upsample in height dimension * @param scaleW Scale to upsample in width dimension * @return Upsampled input */ public SDVariable upsampling2d(SDVariable input, boolean nchw, int scaleH, int scaleW) { return upsampling2d(null, input, nchw, scaleH, scaleW); } /** * 2D Convolution layer operation - Upsampling 2d * * @param input Input, in NCHW format * @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format * @param scaleH Scale to upsample in height dimension * @param scaleW Scale to upsample in width dimension * @return Upsampled input */ public SDVariable upsampling2d(String name, SDVariable input, boolean nchw, int scaleH, int scaleW) { SDVariable ret = f().upsampling2d(input, nchw, scaleH, scaleW); return updateVariableNameAndReference(ret, name); } /** * 2D Convolution layer operation - average pooling 2d * * @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param pooling2DConfig the configuration for * @return Result after applying average pooling on the input */ public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { return avgPooling2d(null, input, pooling2DConfig); } /** * 2D Convolution layer operation - average pooling 2d * * @param name name of the operation in SameDiff * @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param pooling2DConfig the configuration * @return Result after applying average pooling on the input */ public SDVariable avgPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) { SDVariable ret = f().avgPooling2d(input, pooling2DConfig); return updateVariableNameAndReference(ret, name); } /** * 2D Convolution layer operation - max pooling 2d * * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param pooling2DConfig the configuration * @return Result after applying max pooling on the input */ public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { return maxPooling2d(null, input, pooling2DConfig); } /** * 2D Convolution layer operation - max pooling 2d * * @param name name of the operation in SameDiff * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param pooling2DConfig the configuration * @return Result after applying max pooling on the input */ public SDVariable maxPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) { SDVariable ret = f().maxPooling2d(input, pooling2DConfig); return updateVariableNameAndReference(ret, name); } /** * 3D convolution layer operation - average pooling 3d * * @param input the input to average pooling 3d operation - 5d activations in NCDHW format * (shape [minibatch, channels, depth, height, width]) or NDHWC format * (shape [minibatch, depth, height, width, channels]) * @param pooling3DConfig the configuration * @return Result after applying average pooling on the input */ public SDVariable avgPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) { return avgPooling3d(null, input, pooling3DConfig); } /** * 3D convolution layer operation - average pooling 3d * * @param name name of the operation in SameDiff * @param input the input to average pooling 3d operation - 5d activations in NCDHW format * (shape [minibatch, channels, depth, height, width]) or NDHWC format * (shape [minibatch, depth, height, width, channels]) * @param pooling3DConfig the configuration * @return Result after applying average pooling on the input */ public SDVariable avgPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) { SDVariable ret = f().avgPooling3d(input, pooling3DConfig); return updateVariableNameAndReference(ret, name); } /** * 3D convolution layer operation - max pooling 3d operation. * * @param input the input to average pooling 3d operation - 5d activations in NCDHW format * (shape [minibatch, channels, depth, height, width]) or NDHWC format * (shape [minibatch, depth, height, width, channels]) * @param pooling3DConfig the configuration * @return Result after applying max pooling on the input */ public SDVariable maxPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) { return maxPooling3d(null, input, pooling3DConfig); } /** * 3D convolution layer operation - max pooling 3d operation. * * @param name name of the operation in SameDiff * @param input the input to average pooling 3d operation - 5d activations in NCDHW format * (shape [minibatch, channels, depth, height, width]) or NDHWC format * (shape [minibatch, depth, height, width, channels]) * @param pooling3DConfig the configuration * @return Result after applying max pooling on the input */ public SDVariable maxPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) { SDVariable ret = f().maxPooling3d(input, pooling3DConfig); return updateVariableNameAndReference(ret, name); } /** * 1D Convolution layer operation - Conv1d * * @param input the input array/activations for the conv1d op * @param weights weights for conv1d op * @param conv1DConfig the configuration * @return */ public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) { return conv1d(null, input, weights, conv1DConfig); } /** * Conv1d operation. * * @param name name of the operation in SameDiff * @param input the inputs to conv1d * @param weights weights for conv1d op * @param conv1DConfig the configuration * @return */ public SDVariable conv1d(String name, SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) { SDVariable ret = f().conv1d(input, weights, conv1DConfig); return updateVariableNameAndReference(ret, name); } /** * 2D convolution layer operation - local response normalization * * @param inputs the inputs to lrn * @param lrnConfig the configuration * @return */ public SDVariable localResponseNormalization(SDVariable inputs, LocalResponseNormalizationConfig lrnConfig) { return localResponseNormalization(null, inputs, lrnConfig); } /** * 2D convolution layer operation - local response normalization * * @param name name of the operation in SameDiff * @param input the inputs to lrn * @param lrnConfig the configuration * @return */ public SDVariable localResponseNormalization(String name, SDVariable input, LocalResponseNormalizationConfig lrnConfig) { SDVariable ret = f().localResponseNormalization(input, lrnConfig); return updateVariableNameAndReference(ret, name); } /** * 2D Convolution operation (without bias) * * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param weights Weights for the convolution operation. 4 dimensions. * If layer input data is in NCHW format, weights should have format [outputChannels, inputChannels, kernelHeight, kernelWidth]. * If layer input data is in NHWC format, weight should have format [kernelHeight, kernelWidth, inputChannels, outputChannels] * @param config Conv2DConfig configuration * @return result of conv2d op */ public SDVariable conv2d(SDVariable layerInput, SDVariable weights, Conv2DConfig config) { return conv2d(layerInput, weights, null, config); } /** * 2D Convolution operation with optional bias * * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param weights Weights for the convolution operation. 4 dimensions. * If layer input data is in NCHW format, weights should have format [outputChannels, inputChannels, kernelHeight, kernelWidth]. * If layer input data is in NHWC format, weight should have format [kernelHeight, kernelWidth, inputChannels, outputChannels] * @param bias Optional 1D bias array with shape [outputChannels]. May be null. * @param config Conv2DConfig configuration * @return result of conv2d op */ public SDVariable conv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, Conv2DConfig config) { SDVariable[] arr = new SDVariable[bias == null ? 2 : 3]; arr[0] = layerInput; arr[1] = weights; if (bias != null) arr[2] = bias; return conv2d(arr, config); } /** * 2D Convolution operation with optional bias * * @param inputs an array with either 2 elements (layerInput, weights) or 3 elements (layerInput, weights, bias) as * described in {@link #conv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)} * @param config Conv2DConfig configuration * @return result of convolution 2d operation */ public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig config) { return conv2d(null, inputs, config); } /** * 2D Convolution operation with optional bias * * @param name Name of the output SDVariable * @param inputs an array with either 2 elements (layerInput, weights) or 3 elements (layerInput, weights, bias) as * described in {@link #conv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)} * @param config Conv2DConfig configuration * @return result of convolution 2d operation */ public SDVariable conv2d(String name, SDVariable[] inputs, Conv2DConfig config) { SDVariable ret = f().conv2d(inputs, config); return updateVariableNameAndReference(ret, name); } /** * Depth-wise 2D convolution operation without bias * * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param depthWeights depth-wise conv 2D weights. 4 dimensions. * If layer input data is in NCHW format, weights should have format [outputChannels, inputChannels, kernelHeight, kernelWidth]. * If layer input data is in NHWC format, weight should have format [kernelHeight, kernelWidth, inputChannels, outputChannels] * @param config Conv2DConfig configuration * @return result of conv2d op */ public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, Conv2DConfig config) { return depthWiseConv2d(layerInput, depthWeights, null, config); } /** * Depth-wise 2D convolution operation with optional bias * * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param depthWeights depth-wise conv 2D weights. 4 dimensions. * If layer input data is in NCHW format, weights should have format [outputChannels, inputChannels, kernelHeight, kernelWidth]. * If layer input data is in NHWC format, weight should have format [kernelHeight, kernelWidth, inputChannels, outputChannels] * @param bias Optional 1D bias array with shape [outputChannels]. May be null. * @param config Conv2DConfig configuration * @return result of depthwise conv2d op */ public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable bias, Conv2DConfig config) { SDVariable[] arr = new SDVariable[bias == null ? 2 : 3]; arr[0] = layerInput; arr[1] = depthWeights; if (bias != null) arr[2] = bias; return depthWiseConv2d(arr, config); } /** * Depth-wise convolution 2D operation. * * @param inputs the inputs to depth-wise conv2d. An array with either 2 elements (layerInput, depthWeights) * or 3 elements (layerInput, depthWeights, bias) as described in * {@link #depthWiseConv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)} * @param depthConv2DConfig the configuration * @return result of depthwise conv2d op */ public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) { return depthWiseConv2d(null, inputs, depthConv2DConfig); } /** * Depth-wise convolution 2D operation. * * @param name name of the output variable * @param inputs the inputs to depth-wise conv2d. An array with either 2 elements (layerInput, depthWeights) * or 3 elements (layerInput, depthWeights, bias) as described in * {@link #depthWiseConv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)} * @param depthConv2DConfig the configuration * @return result of depthwise conv2d op */ public SDVariable depthWiseConv2d(String name, SDVariable[] inputs, Conv2DConfig depthConv2DConfig) { SDVariable ret = f().depthWiseConv2d(inputs, depthConv2DConfig); return updateVariableNameAndReference(ret, name); } /** * Separable 2D convolution operation without bias * * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param depthWeights Depth weights, rank 4. * If layer input is in NCHW format, depth weights should have format [outputChannels, depthMultiplier, kernelHeight, kernelWidth]. * If layer input is in NHWC format, depth weights should have format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] * @param pointWeights Point weights, rank 4. * If layer input is in NCHW format, point weights should have format [outputChannels, inputChannels*depthMultiplier, 1, 1]. * If layer input is in NHWC format, point weights should have format [1, 1, inputChannels*depthMultiplier, outputChannels] * @param config Conv2DConfig configuration * @return result of separable convolution 2d operation */ public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights, Conv2DConfig config) { return separableConv2d(layerInput, depthWeights, pointWeights, null, config); } /** * Separable 2D convolution operation with optional bias * * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param depthWeights Depth weights, rank 4. * If layer input is in NCHW format, depth weights should have format [outputChannels, depthMultiplier, kernelHeight, kernelWidth]. * If layer input is in NHWC format, depth weights should have format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] * @param pointWeights Point weights, rank 4. * If layer input is in NCHW format, point weights should have format [outputChannels, inputChannels*depthMultiplier, 1, 1]. * If layer input is in NHWC format, point weights should have format [1, 1, inputChannels*depthMultiplier, outputChannels] * @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. * @param config Conv2DConfig configuration * @return result of separable convolution 2d operation */ public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights, SDVariable bias, Conv2DConfig config) { SDVariable[] arr = new SDVariable[bias == null ? 3 : 4]; arr[0] = layerInput; arr[1] = depthWeights; arr[2] = pointWeights; if (bias != null) arr[3] = bias; return sconv2d(arr, config); } /** * Separable 2D convolution operation with/without optional bias * * @param inputs the inputs to separable conv2 operation. Should be length 3 (layerInput, depthWeights, pointWeights) * or length 4 (layerInput, depthWeights, pointWeights, bias) as described in {@link #separableConv2d(SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)} * @param conv2DConfig the configuration * @return result of separable convolution 2d operation */ public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { return sconv2d(null, inputs, conv2DConfig); } /** * Separable 2D convolution operation with/without optional bias * * @param name name of the output variable * @param inputs the inputs to separable conv2 operation. Should be length 3 (layerInput, depthWeights, pointWeights) * or length 4 (layerInput, depthWeights, pointWeights, bias) as described in {@link #separableConv2d(SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)} * @param conv2DConfig the configuration * @return result of separable convolution 2d operation */ public SDVariable sconv2d(String name, SDVariable[] inputs, Conv2DConfig conv2DConfig) { SDVariable ret = f().sconv2d(inputs, conv2DConfig); return updateVariableNameAndReference(ret, name); } /** * 2D deconvolution operation without bias * * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param weights Weights for the 2d deconvolution operation. 4 dimensions. * If layer input data is in NCHW format, weights should have format [inputChannels, outputChannels, kernelHeight, kernelWidth]. * If layer input data is in NHWC format, weight should have format [kernelHeight, kernelWidth, outputChannels, inputChannels] * @param deconv2DConfig DeConv2DConfig configuration * @return result of deconv2d op */ public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, DeConv2DConfig deconv2DConfig) { return deconv2d(layerInput, weights, null, deconv2DConfig); } /** * 2D deconvolution operation with optional bias * * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param weights Weights for the 2d deconvolution operation. 4 dimensions. * If layer input data is in NCHW format, weights should have format [inputChannels, outputChannels, kernelHeight, kernelWidth]. * If layer input data is in NHWC format, weight should have format [kernelHeight, kernelWidth, outputChannels, inputChannels] * @param bias Optional 1D bias array with shape [outputChannels]. May be null. * @param deconv2DConfig DeConv2DConfig configuration * @return result of deconv2d op */ public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, DeConv2DConfig deconv2DConfig) { SDVariable[] arr = new SDVariable[bias == null ? 2 : 3]; arr[0] = layerInput; arr[1] = weights; if (bias != null) arr[2] = bias; return deconv2d(arr, deconv2DConfig); } /** * 2D deconvolution operation with or without optional bias * * @param inputs Inputs to the deconvolution 2d operation - input array of length 2 (layerInput, weights) * or length 3 (layerInput, weights, bias) as described in {@link #deconv2d(SDVariable[], DeConv2DConfig)} * @param deconv2DConfig the configuration * @return result of deconv2d op */ public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { return deconv2d(null, inputs, deconv2DConfig); } /** * 2D deconvolution operation with or without optional bias * * @param name Name of the output variable * @param inputs Inputs to the deconvolution 2d operation - input array of length 2 (layerInput, weights) * or length 3 (layerInput, weights, bias) as described in {@link #deconv2d(SDVariable[], DeConv2DConfig)} * @param deconv2DConfig the configuration * @return result of deconv2d op */ public SDVariable deconv2d(String name, SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { SDVariable ret = f().deconv2d(inputs, deconv2DConfig); return updateVariableNameAndReference(ret, name); } /** * Convolution 3D operation without bias * * @param input the input to average pooling 3d operation - 5d activations in NCDHW format * (shape [minibatch, channels, depth, height, width]) or NDHWC format * (shape [minibatch, depth, height, width, channels]) * @param weights Weights for conv3d. Rank 5. * If input data is in NCDHW fomat, weights should have shape [outputChannels, inputChannels, kernelDepth, kernelHeight, kernelWidth]. * If input data is in NDHWC fomat, weights should have shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. * @param conv3DConfig the configuration * @return Conv3d output variable */ public SDVariable conv3d(SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) { return conv3d(null, input, weights, null, conv3DConfig); } /** * Convolution 3D operation with optional bias * * @param input the input to average pooling 3d operation - 5d activations in NCDHW format * (shape [minibatch, channels, depth, height, width]) or NDHWC format * (shape [minibatch, depth, height, width, channels]) * @param weights Weights for conv3d. Rank 5. * If input data is in NCDHW fomat, weights should have shape [outputChannels, inputChannels, kernelDepth, kernelHeight, kernelWidth]. * If input data is in NDHWC fomat, weights should have shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. * @param bias Optional 1D bias array with shape [outputChannels]. May be null. * @param conv3DConfig the configuration * @return Conv3d output variable */ public SDVariable conv3d(SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) { return conv3d(null, input, weights, bias, conv3DConfig); } /** * Convolution 3D operation without bias * * @param name Name of the output variable * @param input the input to average pooling 3d operation - 5d activations in NCDHW format * (shape [minibatch, channels, depth, height, width]) or NDHWC format * (shape [minibatch, depth, height, width, channels]) * @param weights Weights for conv3d. Rank 5. * If input data is in NCDHW fomat, weights should have shape [outputChannels, inputChannels, kernelDepth, kernelHeight, kernelWidth]. * If input data is in NDHWC fomat, weights should have shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. * @param conv3DConfig the configuration * @return Conv3d output variable */ public SDVariable conv3d(String name, SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) { return conv3d(name, input, weights, null, conv3DConfig); } /** * Convolution 3D operation with optional bias * * @param name Name of the output variable * @param input the input to average pooling 3d operation - 5d activations in NCDHW format * (shape [minibatch, channels, depth, height, width]) or NDHWC format * (shape [minibatch, depth, height, width, channels]) * @param weights Weights for conv3d. Rank 5. * If input data is in NCDHW fomat, weights should have shape [outputChannels, inputChannels, kernelDepth, kernelHeight, kernelWidth]. * If input data is in NDHWC fomat, weights should have shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. * @param bias Optional 1D bias array with shape [outputChannels]. May be null. * @param conv3DConfig the configuration * @return Conv3d output variable */ public SDVariable conv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) { SDVariable[] args; if (bias == null) { args = new SDVariable[]{input, weights}; } else { args = new SDVariable[]{input, weights, bias}; } SDVariable ret = f().conv3d(args, conv3DConfig); return updateVariableNameAndReference(ret, name); } /** * Batch norm operation. */ public SDVariable batchNorm(SDVariable input, SDVariable mean, SDVariable variance, SDVariable gamma, SDVariable beta, boolean applyGamma, boolean applyBeta, double epsilon) { return batchNorm(null, input, mean, variance, gamma, beta, applyGamma, applyBeta, epsilon); } /** * Batch norm operation. */ public SDVariable batchNorm(String name, SDVariable input, SDVariable mean, SDVariable variance, SDVariable gamma, SDVariable beta, boolean applyGamma, boolean applyBeta, double epsilon) { SDVariable res = f().batchNorm(input, mean, variance, gamma, beta, applyGamma, applyBeta, epsilon); return updateVariableNameAndReference(res, name); } /** * im2col operation for use in 2D convolution operations. Outputs a 6d array with shape * [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] * * @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width] * @param config Convolution configuration for the im2col operation * @return Im2Col output variable */ public SDVariable im2Col(SDVariable in, Conv2DConfig config) { return im2Col(null, in, config); } /** * im2col operation for use in 2D convolution operations. Outputs a 6d array with shape * [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] * * @param name Name of the output variable * @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width] * @param config Convolution configuration for the im2col operation * @return Im2Col output variable */ public SDVariable im2Col(String name, SDVariable in, Conv2DConfig config) { SDVariable ret = f().im2Col(in, config); return updateVariableNameAndReference(ret, name); } /** * col2im operation for use in 2D convolution operations. Outputs a 4d array with shape * [minibatch, inputChannels, height, width] * * @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] * @param config Convolution configuration for the col2im operation * @return Col2Im output variable */ public SDVariable col2Im(SDVariable in, Conv2DConfig config) { return col2Im(null, in, config); } /** * col2im operation for use in 2D convolution operations. Outputs a 4d array with shape * [minibatch, inputChannels, height, width] * * @param name Name of the output variable * @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] * @param config Convolution configuration for the col2im operation * @return Col2Im output variable */ public SDVariable col2Im(String name, SDVariable in, Conv2DConfig config) { SDVariable ret = f().col2Im(in, config); return updateVariableNameAndReference(ret, name); } /** * Create a new scalar (rank 0) SDVariable with the specified value * @param name Name of the SDVariable * @param value Value to initialize the variable with * @return SDVariable */ public SDVariable scalar(String name, double value) { return var(name, Nd4j.scalar(value)); } /** * Greater than or equals operation: elementwise x >= y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param x Input array * @param y Double value argument to use in operation * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable gte(SDVariable x, double y) { return gte(null, x, y); } /** * Greater than or equals operation: elementwise x >= y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param name Name of the output variable * @param x Input array * @param y Double value argument to use in operation * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable gte(String name, SDVariable x, double y) { SDVariable result = functionFactory.gte(x, y); return updateVariableNameAndReference(result, name); } /** * Less than or equals operation: elementwise x <= y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param x Input array * @param y Double value argument to use in operation * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable lte(SDVariable x, double y) { return lte(null, x, y); } /** * Less than or equals operation: elementwise x <= y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param name Name of the output variable * @param x Input array * @param y Double value argument to use in operation * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable lte(String name, SDVariable x, double y) { SDVariable result = functionFactory.lte(x, y); return updateVariableNameAndReference(result, name); } /** * Greater than operation: elementwise x > y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param x Input array * @param y Double value argument to use in operation * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable gt(SDVariable x, double y) { return gt(null, x, y); } /** * Greater than operation: elementwise x > y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param name Name of the output variable * @param x Input array * @param y Double value argument to use in operation * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable gt(String name, SDVariable x, double y) { SDVariable result = functionFactory.gt(x, y); return updateVariableNameAndReference(result, name); } /** * Less than operation: elementwise x < y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param x Input array * @param y Double value argument to use in operation * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable lt(SDVariable x, double y) { return lt(null, x, y); } /** * Less than operation: elementwise x < y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param name Name of the output variable * @param x Input array * @param y Double value argument to use in operation * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable lt(String name, SDVariable x, double y) { SDVariable result = functionFactory.lt(x, y); return updateVariableNameAndReference(result, name); } /** * Not equals operation: elementwise x != y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param x Input array * @param y Double value argument to use in operation * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable neq(SDVariable x, double y) { return neq(null, x, y); } /** * Not equals operation: elementwise x != y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param name Name of the output variable * @param x Input array * @param y Double value argument to use in operation * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable neq(String name, SDVariable x, double y) { SDVariable result = functionFactory.neq(x, y); return updateVariableNameAndReference(result, name); } /** * Equals operation: elementwise x == y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param x Input array * @param y Double value argument to use in operation * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable eq(SDVariable x, double y) { return eq(null, x, y); } /** * Equals operation: elementwise x == y
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param name Name of the output variable * @param x Input array * @param y Double value argument to use in operation * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable eq(String name, SDVariable x, double y) { SDVariable result = functionFactory.eq(x, y); return updateVariableNameAndReference(result, name); } /** * Greater than or equal to operation: elementwise x >= y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param x Input 1 * @param y Input 2 * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable gte(SDVariable x, SDVariable y) { return gte(null, x, y); } /** * Greater than or equal to operation: elementwise x >= y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param name Name of the output variable * @param x Input 1 * @param y Input 2 * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable gte(String name, SDVariable x, SDVariable y) { SDVariable result = functionFactory.gte(x, y); return updateVariableNameAndReference(result, name); } /** * Less than or equal to operation: elementwise x <= y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param x Input 1 * @param y Input 2 * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable lte(SDVariable x, SDVariable y) { return lte(null, x, y); } /** * Less than or equal to operation: elementwise x <= y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param name Name of the output variable * @param x Input 1 * @param y Input 2 * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable lte(String name, SDVariable x, SDVariable y) { SDVariable result = functionFactory.lte(x, y); return updateVariableNameAndReference(result, name); } /** * Greater than operation: elementwise x > y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param x Input 1 * @param y Input 2 * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable gt(SDVariable x, SDVariable y) { return gt(null, x, y); } /** * Greater than operation: elementwise x > y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param name Name of the output variable * @param x Input 1 * @param y Input 2 * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable gt(String name, SDVariable x, SDVariable y) { SDVariable result = functionFactory.gt(x, y); return updateVariableNameAndReference(result, name); } /** * Less than operation: elementwise x < y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param x Input 1 * @param y Input 2 * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable lt(SDVariable x, SDVariable y) { return lt(null, x, y); } /** * Less than operation: elementwise x < y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param name Name of the output variable * @param x Input 1 * @param y Input 2 * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable lt(String name, SDVariable x, SDVariable y) { SDVariable result = functionFactory.lt(x, y); return updateVariableNameAndReference(result, name); } /** * Not equal to operation: elementwise x != y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param x Input 1 * @param y Input 2 * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable neq(SDVariable x, SDVariable y) { return neq(null, x, y); } /** * Not equal to operation: elementwise x != y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param name Name of the output variable * @param x Input 1 * @param y Input 2 * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable neq(String name, SDVariable x, SDVariable y) { SDVariable result = functionFactory.neq(x, y); return updateVariableNameAndReference(result, name); } /** * Equal to operation: elementwise x == y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param x Input 1 * @param y Input 2 * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable eq(SDVariable x, SDVariable y) { return eq(null, x, y); } /** * Equal to operation: elementwise x == y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param name Name of the output variable * @param x Input 1 * @param y Input 2 * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable eq(String name, SDVariable x, SDVariable y) { SDVariable result = functionFactory.eq(x, y); return updateVariableNameAndReference(result, name); } /** * Boolean OR operation: elementwise (x != 0) || (y != 0)
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param x Input 1 * @param y Input 2 * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable or(SDVariable x, SDVariable y) { return or(null, x, y); } /** * Boolean OR operation: elementwise (x != 0) || (y != 0)
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param name Name of the output variable * @param x Input 1 * @param y Input 2 * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable or(String name, SDVariable x, SDVariable y) { SDVariable result = functionFactory.or(x, y); return updateVariableNameAndReference(result, name); } /** * Boolean AND operation: elementwise (x != 0) && (y != 0)
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param x Input 1 * @param y Input 2 * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable and(SDVariable x, SDVariable y) { return and(null, x, y); } /** * Boolean AND operation: elementwise (x != 0) && (y != 0)
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param name Name of the output variable * @param x Input 1 * @param y Input 2 * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable and(String name, SDVariable x, SDVariable y) { SDVariable result = f().and(x, y); return updateVariableNameAndReference(result, name); } /** * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param x Input 1 * @param y Input 2 * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable xor(SDVariable x, SDVariable y) { return xor(null, x, y); } /** * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* Note: supports broadcasting if x and y have different shapes and are broadcastable.
* Returns an array with values 1 where condition is satisfied, or value 0 otherwise. * * @param name Name of the output variable * @param x Input 1 * @param y Input 2 * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable xor(String name, SDVariable x, SDVariable y) { SDVariable result = f().xor(x, y); return updateVariableNameAndReference(result, name); } /** * Elementwise absolute value operation: out = abs(x) * * @param x Input variable * @return Output variable */ public SDVariable abs(SDVariable x) { return abs(null, x); } /** * Elementwise absolute value operation: out = abs(x) * * @param name Name of the output variable * @param x Input variable * @return Output variable */ public SDVariable abs(String name, SDVariable x) { SDVariable result = f().abs(x); return updateVariableNameAndReference(result, name); } /** * Elementwise negative operation: out = -x * * @param x Input variable * @return Output variable */ public SDVariable neg(SDVariable x) { return neg(null, x); } /** * Elementwise negative operation: out = -x * * @param name Name of the output variable * @param x Input variable * @return Output variable */ public SDVariable neg(String name, SDVariable x) { SDVariable result = functionFactory.neg(x); return updateVariableNameAndReference(result, name); } /** * Elementwise cosine operation: out = cos(x) * * @param x Input variable * @return Output variable */ public SDVariable cos(SDVariable x) { return cos(null, x); } /** * Elementwise cosine operation: out = cos(x) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable cos(String name, SDVariable x) { SDVariable result = functionFactory.cos(x); return updateVariableNameAndReference(result, name); } /** * Elementwise sine operation: out = sin(x) * * @param x Input variable * @return Output variable */ public SDVariable sin(SDVariable x) { return sin(null, x); } /** * Elementwise sine operation: out = sin(x) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable sin(String name, SDVariable x) { SDVariable result = functionFactory.sin(x); return updateVariableNameAndReference(result, name); } /** * Elementwise tangent operation: out = tan(x) * * @param x Input variable * @return Output variable */ public SDVariable tan(SDVariable x) { return tan(null, x); } /** * Elementwise tangent operation: out = tan(x) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable tan(String name, SDVariable x) { SDVariable result = functionFactory.tan(x); return updateVariableNameAndReference(result, name); } /** * Elementwise identity operation: out = x * * @param input Input variable * @return Output variable */ public SDVariable identity(SDVariable input) { return identity(null, input); } /** * Elementwise identity operation: out = x * * @param name name of the output variable * @param input Input variable * @return Output variable */ public SDVariable identity(String name, SDVariable input) { SDVariable s = f().identity(input); return updateVariableNameAndReference(s, name); } /** * Compute the inverse permutation indices for a permutation operation
* Example: if input is [2, 0, 1] then output is [1, 2, 0]
* The idea is that x.permute(input).permute(invertPermutation(input)) == x * * @param input 1D indices for permutation * @return 1D inverted permutation */ public SDVariable invertPermutation(SDVariable input) { return invertPermutation(null, input); } /** * Compute the inverse permutation indices for a permutation operation
* Example: if input is [2, 0, 1] then output is [1, 2, 0]
* The idea is that x.permute(input).permute(invertPermutation(input)) == x * * @param name name of the output variable * @param input 1D indices for permutation * @return 1D inverted permutation */ public SDVariable invertPermutation(String name, SDVariable input) { SDVariable ret = f().invertPermutation(input, false); return updateVariableNameAndReference(ret, name); } /** * Elementwise acos (arccosine, inverse cosine) operation: out = arccos(x) * * @param x Input variable * @return Output variable */ public SDVariable acos(SDVariable x) { return acos(null, x); } /** * Elementwise acos (arccosine, inverse cosine) operation: out = arccos(x) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable acos(String name, SDVariable x) { SDVariable result = functionFactory.acos(x); return updateVariableNameAndReference(result, name); } /** * Elementwise asin (arcsin, inverse sine) operation: out = arcsin(x) * * @param x Input variable * @return Output variable */ public SDVariable asin(SDVariable x) { return asin(null, x); } /** * Elementwise asin (arcsin, inverse sine) operation: out = arcsin(x) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable asin(String name, SDVariable x) { SDVariable result = functionFactory.asin(x); return updateVariableNameAndReference(result, name); } /** * Elementwise atan (arctangent, inverse tangent) operation: out = arctangent(x) * * @param x Input variable * @return Output variable */ public SDVariable atan(SDVariable x) { return atan(null, x); } /** * Elementwise atan (arctangent, inverse tangent) operation: out = arctangent(x) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable atan(String name, SDVariable x) { SDVariable result = functionFactory.atan(x); return updateVariableNameAndReference(result, name); } /** * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y). * Similar to atan(y/x) but sigts of x and y are used to determine the the location of the result * * @param y Input Y variable * @param x Input X variable * @return Output variable */ public SDVariable atan2(SDVariable y, SDVariable x) { return atan2(null, y, x); } /** * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y). * Similar to atan(y/x) but sigts of x and y are used to determine the the location of the result * * @param name Name of the output variable * @param y Input Y variable * @param x Input X variable * @return Output variable */ public SDVariable atan2(String name, SDVariable y, SDVariable x) { SDVariable ret = f().atan2(y, x); return updateVariableNameAndReference(ret, name); } /** * Elementwise cosh (hyperbolic cosine) operation: out = cosh(x) * * @param x Input variable * @return Output variable */ public SDVariable cosh(SDVariable x) { return cosh(null, x); } /** * Elementwise cosh (hyperbolic cosine) operation: out = cosh(x) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable cosh(String name, SDVariable x) { SDVariable result = functionFactory.cosh(x); return updateVariableNameAndReference(result, name); } /** * Elementwise sinh (hyperbolic sine) operation: out = sinh(x) * * @param x Input variable * @return Output variable */ public SDVariable sinh(SDVariable x) { return sinh(null, x); } /** * Elementwise sinh (hyperbolic sine) operation: out = sinh(x) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable sinh(String name, SDVariable x) { SDVariable result = functionFactory.sinh(x); return updateVariableNameAndReference(result, name); } /** * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x) * * @param x Input variable * @return Output variable */ public SDVariable tanh(SDVariable x) { return tanh(null, x); } /** * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable tanh(String name, SDVariable x) { SDVariable result = functionFactory.tanh(x); return updateVariableNameAndReference(result, name); } /** * Elementwise step function:
* out(x) = 1 if x >= cutoff
* out(x) = 0 otherwise
* * @param in Input variable * @param cutoff Cutoff value for step function * @return Output variable */ public SDVariable step(SDVariable in, double cutoff) { return step(null, in, cutoff); } /** * Elementwise step function:
* out(x) = 1 if x >= cutoff
* out(x) = 0 otherwise
* * @param name Name of the output variable * @param in Input variable * @param cutoff Cutoff value for step function * @return Output variable */ public SDVariable step(String name, SDVariable in, double cutoff) { SDVariable ret = f().step(in, cutoff); return updateVariableNameAndReference(ret, name); } /** * Elementwise acosh (inverse hyperbolic cosine) function: out = acosh(x) * * @param x Input variable * @return Output variable */ public SDVariable acosh(SDVariable x) { return acosh(null, x); } /** * Elementwise acosh (inverse hyperbolic cosine) function: out = acosh(x) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable acosh(String name, SDVariable x) { SDVariable result = functionFactory.acosh(x); return updateVariableNameAndReference(result, name); } /** * Elementwise asinh (inverse hyperbolic sine) function: out = asinh(x) * * @param x Input variable * @return Output variable */ public SDVariable asinh(SDVariable x) { return asinh(null, x); } /** * Elementwise asinh (inverse hyperbolic sine) function: out = asinh(x) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable asinh(String name, SDVariable x) { SDVariable result = functionFactory.asinh(x); return updateVariableNameAndReference(result, name); } /** * Elementwise atanh (inverse hyperbolic tangent) function: out = atanh(x) * * @param x Input variable * @return Output variable */ public SDVariable atanh(SDVariable x) { return atanh(null, x); } /** * Elementwise atanh (inverse hyperbolic tangent) function: out = atanh(x) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable atanh(String name, SDVariable x) { SDVariable result = functionFactory.atanh(x); return updateVariableNameAndReference(result, name); } /** * Elementwise exponent function: out = exp(x) = 2.71828...^x * * @param x Input variable * @return Output variable */ public SDVariable exp(SDVariable x) { return exp(null, x); } /** * Elementwise exponent function: out = exp(x) = 2.71828...^x * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable exp(String name, SDVariable x) { SDVariable result = functionFactory.exp(x); return updateVariableNameAndReference(result, name); } /** * Element-wise reciprocal (inverse) of square root: out = 1.0 / sqrt(x) * * @param x Input variable * @return Output variable */ public SDVariable rsqrt(SDVariable x) { return rsqrt(null, x); } /** * Element-wise reciprocal (inverse) of square root: out = 1.0 / sqrt(x) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable rsqrt(String name, SDVariable x) { SDVariable result = functionFactory.rsqrt(x); return updateVariableNameAndReference(result, name); } /** * Elementwise 1.0 - exponent function: out = 1.0 - exp(x) = 1.0 - 2.71828...^x * * @param x Input variable * @return Output variable */ public SDVariable expm1(SDVariable x) { return expm1(null, x); } /** * Elementwise 1.0 - exponent function: out = 1.0 - exp(x) = 1.0 - 2.71828...^x * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable expm1(String name, SDVariable x) { SDVariable result = functionFactory.expm1(x); return updateVariableNameAndReference(result, name); } /** * Elementwise natural logarithm function: out = log_e (1 + x) * * @param x Input variable * @return Output variable */ public SDVariable log1p(SDVariable x) { return log1p(null, x); } /** * Elementwise natural logarithm function: out = log_e (1 + x) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable log1p(String name, SDVariable x) { SDVariable result = functionFactory.log1p(x); return updateVariableNameAndReference(result, name); } /** * Elementwise round function: out = round(x). * Rounds (up or down depending on value) to the nearest integer value. * * @param x Input variable * @return Output variable */ public SDVariable round(SDVariable x) { return round(null, x); } /** * Element-wise round function: out = round(x). * Rounds (up or down depending on value) to the nearest integer value. * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable round(String name, SDVariable x) { SDVariable result = functionFactory.round(x); return updateVariableNameAndReference(result, name); } /** * Is infinite operation: elementwise isInfinite(x)
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param x Input array * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable isInfinite(SDVariable x) { return isInfinite(null, x); } /** * Is infinite operation: elementwise isInfinite(x)
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param name Output variable name * @param x Input array * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable isInfinite(String name, SDVariable x) { SDVariable result = functionFactory.isInfinite(x); return updateVariableNameAndReference(result, name); } /** * Is Not a Number operation: elementwise isNaN(x)
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param x Input array * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable isNaN(SDVariable x) { return isNaN(null, x); } /** * Is Not a Number operation: elementwise isNaN(x)
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param name Output variable name * @param x Input array * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable isNaN(String name, SDVariable x) { SDVariable result = functionFactory.isNaN(x); return updateVariableNameAndReference(result, name); } /** * Is finite operation: elementwise isFinite(x)
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param x Input array * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable isFinite(SDVariable x) { return isFinite(null, x); } /** * Is finite operation: elementwise isFinite(x)
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param name Output variable name * @param x Input array * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable isFinite(String name, SDVariable x) { SDVariable result = functionFactory.isFinite(x); return updateVariableNameAndReference(result, name); } /** * Is maximum operation: elementwise x == max(x)
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param x Input array * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable isMax(SDVariable x) { return isMax(null, x); } /** * Is maximum operation: elementwise x == max(x)
* Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or * value 0 otherwise * * @param name Name of the output variable * @param x Input array * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied */ public SDVariable isMax(String name, SDVariable x) { SDVariable ret = f().isMax(x); return updateVariableNameAndReference(ret, name); } /** * Is the array non decreasing?
* An array is non-decreasing if for every valid i, x[i] <= x[i+1]. For Rank 2+ arrays, values are compared * in 'c' (row major) order * * @param x Input variable * @return Scalar variable with value 1 if non-decreasing, or 0 otherwise */ public SDVariable isNonDecreasing(SDVariable x) { return isNonDecreasing(null, x); } /** * Is the array non decreasing?
* An array is non-decreasing if for every valid i, x[i] <= x[i+1]. For Rank 2+ arrays, values are compared * in 'c' (row major) order * * @param name Output name * @param x Input variable * @return Scalar variable with value 1 if non-decreasing, or 0 otherwise */ public SDVariable isNonDecreasing(String name, SDVariable x) { SDVariable result = functionFactory.isNonDecreasing(x); return updateVariableNameAndReference(result, name); } /** * Is the array strictly increasing?
* An array is strictly increasing if for every valid i, x[i] < x[i+1]. For Rank 2+ arrays, values are compared * in 'c' (row major) order * * @param x Input variable * @return Scalar variable with value 1 if strictly increasing, or 0 otherwise */ public SDVariable isStrictlyIncreasing(SDVariable x) { return isStrictlyIncreasing(null, x); } /** * Is the array strictly increasing?
* An array is strictly increasing if for every valid i, x[i] < x[i+1]. For Rank 2+ arrays, values are compared * in 'c' (row major) order * * @param name Output variable name * @param x Input variable * @return Scalar variable with value 1 if strictly increasing, or 0 otherwise */ public SDVariable isStrictlyIncreasing(String name, SDVariable x) { SDVariable result = functionFactory.isStrictlyIncreasing(x); return updateVariableNameAndReference(result, name); } /** * Is the director a numeric tensor? In the current version of ND4J/SameDiff, this always returns true/1 * * @param x Input variable * @return Scalar variable with value 1 */ public SDVariable isNumericTensor(SDVariable x) { return isNumericTensor(null, x); } /** * Is the director a numeric tensor? In the current version of ND4J/SameDiff, this always returns true/1 * * @param name Output variable name * @param x Input variable * @return Scalar variable with value 1 */ public SDVariable isNumericTensor(String name, SDVariable x) { SDVariable result = functionFactory.isNumericTensor(x); return updateVariableNameAndReference(result, name); } /** * Element-wise replace where condition:
* out[i] = from[i] if condition(update[i]) is satisfied, or
* out[i] = update[i] if condition(update[i]) is NOT satisfied * * @param update Source array * @param from Replacement values array (used conditionally). Must be same shape as 'update' array * @param condition Condition to check on update array elements * @return New array with values replaced where condition is satisfied */ public SDVariable replaceWhere(SDVariable update, SDVariable from, Condition condition) { return replaceWhere(null, update, from, condition); } /** * Element-wise replace where condition:
* out[i] = from[i] if condition(update[i]) is satisfied, or
* out[i] = update[i] if condition(update[i]) is NOT satisfied * * @param name Name of the output variable * @param update Source array * @param from Replacement values array (used conditionally). Must be same shape as 'update' array * @param condition Condition to check on update array elements * @return New array with values replaced where condition is satisfied */ public SDVariable replaceWhere(String name, SDVariable update, SDVariable from, Condition condition) { SDVariable ret = f().replaceWhere(update, from, condition); return updateVariableNameAndReference(ret, name); } /** * Element-wise replace where condition:
* out[i] = value if condition(update[i]) is satisfied, or
* out[i] = update[i] if condition(update[i]) is NOT satisfied * * @param update Source array * @param value Value to set at the output, if the condition is satisfied * @param condition Condition to check on update array elements * @return New array with values replaced where condition is satisfied */ public SDVariable replaceWhere(SDVariable update, Number value, Condition condition) { return replaceWhere(null, update, value, condition); } /** * Element-wise replace where condition:
* out[i] = value if condition(update[i]) is satisfied, or
* out[i] = update[i] if condition(update[i]) is NOT satisfied * * @param name Name of the output variable * @param update Source array * @param value Value to set at the output, if the condition is satisfied * @param condition Condition to check on update array elements * @return New array with values replaced where condition is satisfied */ public SDVariable replaceWhere(String name, SDVariable update, Number value, Condition condition) { SDVariable ret = f().replaceWhere(update, value, condition); return updateVariableNameAndReference(ret, name); } /** * Element-wise logarithm function (base e - natural logarithm): out = log(x) * * @param x Input variable * @return Output variable */ public SDVariable log(SDVariable x) { return log(null, x); } /** * Element-wise logarithm function (base e - natural logarithm): out = log(x) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable log(String name, SDVariable x) { SDVariable result = functionFactory.log(x); return updateVariableNameAndReference(result, name); } /** * Element-wise logarithm function (with specified base): out = log_{base}(x) * * @param in Input variable * @param base Logarithm base * @return Output variable */ public SDVariable log(SDVariable in, double base) { return log(null, in, base); } /** * Element-wise logarithm function (with specified base): out = log_{base}(x) * * @param name Name of the output variable * @param in Input variable * @param base Logarithm base * @return Output variable */ public SDVariable log(String name, SDVariable in, double base) { SDVariable ret = f().log(in, base); return updateVariableNameAndReference(ret, name); } /** * Log-sum-exp reduction (optionally along dimension). * Computes log(sum(exp(x)) * * @param input Input variable * @param dimensions Optional dimensions to reduce along * @return Output variable */ public SDVariable logSumExp(SDVariable input, int... dimensions) { return logSumExp(null, input, dimensions); } /** * Log-sum-exp reduction (optionally along dimension). * Computes log(sum(exp(x)) * * @param name Name of the output variable * @param input Input variable * @param dimensions Optional dimensions to reduce along * @return Output variable */ public SDVariable logSumExp(String name, SDVariable input, int... dimensions) { SDVariable ret = f().logSumExp(input, dimensions); return updateVariableNameAndReference(ret, name); } /** * Element-wise cube function: out = x^3 * * @param x Input variable * @return Output variable */ public SDVariable cube(SDVariable x) { return cube(null, x); } /** * Element-wise cube function: out = x^3 * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable cube(String name, SDVariable x) { SDVariable result = functionFactory.cube(x); return updateVariableNameAndReference(result, name); } /** * Element-wise power function: out = x^value * * @param x Input variable * @param value Power to raise each element to * @return Output variable */ public SDVariable pow(SDVariable x, double value) { return pow(null, x, value); } /** * Element-wise power function: out = x^value * * @param name Output variable name * @param x Input variable * @param value Power to raise each element to * @return Output variable */ public SDVariable pow(String name, SDVariable x, double value) { SDVariable result = functionFactory.pow(x, value); return updateVariableNameAndReference(result, name); } /** * Element-wise square root function: out = sqrt(x) * * @param x Input variable * @return Output variable */ public SDVariable sqrt(SDVariable x) { return sqrt(null, x); } /** * Element-wise square root function: out = sqrt(x) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable sqrt(String name, SDVariable x) { SDVariable result = functionFactory.sqrt(x); return updateVariableNameAndReference(result, name); } /** * Element-wise square function: out = x^2 * * @param x Input variable * @return Output variable */ public SDVariable square(SDVariable x) { return square(null, x); } /** * Element-wise square function: out = x^2 * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable square(String name, SDVariable x) { SDVariable result = functionFactory.square(x); return updateVariableNameAndReference(result, name); } /** * Element-wise floor function: out = floor(x). * Rounds each value down to the nearest integer value (if not already an integer) * * @param x Input variable * @return Output variable */ public SDVariable floor(SDVariable x) { return floor(null, x); } /** * Element-wise floor function: out = floor(x). * Rounds each value down to the nearest integer value (if not already an integer) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable floor(String name, SDVariable x) { SDVariable result = functionFactory.floor(x); return updateVariableNameAndReference(result, name); } /** * Element-wise ceiling function: out = ceil(x). * Rounds each value up to the nearest integer value (if not already an integer) * * @param x Input variable * @return Output variable */ public SDVariable ceil(SDVariable x) { return ceil(null, x); } /** * Element-wise ceiling function: out = ceil(x). * Rounds each value up to the nearest integer value (if not already an integer) * * @param name Name of the output variable * @param x Input variable * @return Output variable */ public SDVariable ceil(String name, SDVariable x) { SDVariable ret = f().ceil(x); return updateVariableNameAndReference(ret, name); } /** * Element-wise clipping function:
* out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax
* out[i] = clipValueMin if in[i] < clipValueMin
* out[i] = clipValueMax if in[i] > clipValueMax
* @param x Input variable * @param clipValueMin Minimum value for clipping * @param clipValueMax Maximum value for clipping * @return Output variable */ public SDVariable clipByValue(SDVariable x, double clipValueMin, double clipValueMax) { return clipByValue(null, x, clipValueMin, clipValueMax); } /** * Element-wise clipping function:
* out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax
* out[i] = clipValueMin if in[i] < clipValueMin
* out[i] = clipValueMax if in[i] > clipValueMax
* * @param name Name of the output variable * @param x Input variable * @param clipValueMin Minimum value for clipping * @param clipValueMax Maximum value for clipping * @return Output variable */ public SDVariable clipByValue(String name, SDVariable x, double clipValueMin, double clipValueMax) { SDVariable ret = f().clipByValue(x, clipValueMin, clipValueMax); return updateVariableNameAndReference(ret, name); } /** * Clipping by L2 norm
* if l2Norm(x) < clipValue, then input is returned unmodifed
* Otherwise, out[i] = in[i] * clipValue / l2Norm(in) * * @param x Input variable * @param clipValue Clipping value (maximum l2 norm) * @return Output variable */ public SDVariable clipByNorm(SDVariable x, double clipValue) { return clipByNorm(null, x, clipValue); } /** * Clipping by L2 norm
* if l2Norm(x) < clipValue, then input is returned unmodifed
* Otherwise, out[i] = in[i] * clipValue / l2Norm(in) * * @param name Name of the output variable * @param x Input variable * @param clipValue Clipping value (maximum l2 norm) * @return Output variable */ public SDVariable clipByNorm(String name, SDVariable x, double clipValue) { SDVariable ret = f().clipByNorm(x, clipValue); return updateVariableNameAndReference(ret, name); } /** * Clipping by L2 norm, optionally along dimension(s)
* if l2Norm(x,dimension) < clipValue, then input is returned unmodifed
* Otherwise, out[i] = in[i] * clipValue / l2Norm(in, dimensions) where each value is clipped according * to the corresponding l2Norm along the specified dimensions * * @param x Input variable * @param clipValue Clipping value (maximum l2 norm) * @param dimensions If not specified, all dimensions are used * @return Output variable */ public SDVariable clipByNorm(SDVariable x, double clipValue, int... dimensions) { return clipByNorm(null, x, clipValue, dimensions); } /** * Clipping by L2 norm, optionally along dimension(s)
* if l2Norm(x,dimension) < clipValue, then input is returned unmodifed
* Otherwise, out[i] = in[i] * clipValue / l2Norm(in, dimensions) where each value is clipped according * to the corresponding l2Norm along the specified dimensions * * @param name Output variable name * @param x Input variable * @param clipValue Clipping value (maximum l2 norm) * @param dimensions If not specified, all dimensions are used * @return Output variable */ public SDVariable clipByNorm(String name, SDVariable x, double clipValue, int... dimensions) { SDVariable ret = f().clipByNorm(x, clipValue, dimensions); return updateVariableNameAndReference(ret, name); } /** * Element-wise rectified linear function with specified cutoff:
* out[i] = in[i] if in[i] >= cutoff * out[i] = 0 otherwise * * @param x Input variable * @param cutoff Cutoff value. Usually 0 * @return Output variable */ public SDVariable relu(SDVariable x, double cutoff) { return relu(null, x, cutoff); } /** * Element-wise rectified linear function with specified cutoff:
* out[i] = in[i] if in[i] >= cutoff * out[i] = 0 otherwise * * @param name Output variable name * @param x Input variable * @param cutoff Cutoff value. Usually 0 * @return Output variable */ public SDVariable relu(String name, SDVariable x, double cutoff) { SDVariable result = functionFactory.relu(x, cutoff); return updateVariableNameAndReference(result, name); } /** * Element-wise "rectified linear 6" function with specified cutoff:
* out[i] = min(max(in, cutoff), 6) * * @param x Input variable * @param cutoff Cutoff value. Usually 0 * @return Output variable */ public SDVariable relu6(SDVariable x, double cutoff) { return relu6(null, x, cutoff); } /** * Element-wise "rectified linear 6" function with specified cutoff:
* out[i] = min(max(in, cutoff), 6) * * @param name Output variable name * @param x Input variable * @param cutoff Cutoff value. Usually 0 * @return Output variable */ public SDVariable relu6(String name, SDVariable x, double cutoff) { SDVariable result = functionFactory.relu6(x, cutoff); return updateVariableNameAndReference(result, name); } /** * Softmax activation * * @param x Input variable * @return Output variable */ public SDVariable softmax(SDVariable x) { return softmax(null, x); } /** * Softmax activation * * @param x Input variable * @return Output variable */ public SDVariable softmax(String name, SDVariable x) { SDVariable result = functionFactory.softmax(x); return updateVariableNameAndReference(result, name); } /** * Log softmax activation * * @param x Input variable * @return Output variable */ public SDVariable logSoftmax(SDVariable x) { return logSoftmax(null, x); } /** * Log softmax activation * * @param name Variable name * @param x Input variable * @return Output variable */ public SDVariable logSoftmax(String name, SDVariable x) { SDVariable ret = f().logSoftmax(x); return updateVariableNameAndReference(ret, name); } /** * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks *
* out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
* Uses default lcale and alpha values. * * @param x Input variable * @return Output variable */ public SDVariable selu(SDVariable x) { return selu(null, x); } /** * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks *
* out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
* Uses default lcale and alpha values. * * @param name Name of the output variable * @param x Input variable * @return Output variable */ public SDVariable selu(String name, SDVariable x) { SDVariable ret = f().selu(x); return updateVariableNameAndReference(ret, name); } /** * Merge add function: merges an arbitrary number of equal shaped arrays using elementwise addition: * out = sum_i in[i] * * @param x Input variables * @return Output variable */ public SDVariable mergeAdd(SDVariable... x) { return mergeAdd(null, x); } /** * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition: * out = sum_i in[i] * * @param name Name of the output variable * @param inputs Input variables * @return Output variable */ public SDVariable mergeAdd(String name, SDVariable... inputs) { SDVariable ret = f().mergeAdd(inputs); return updateVariableNameAndReference(ret, name); } /** * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise maximum operation: * out = max_i in[i] * * @param x Input variables * @return Output variable */ public SDVariable mergeMax(SDVariable... x) { return mergeMax(null, x); } /** * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise maximum operation: * out = max_i in[i] * * @param inputs Input variables * @return Output variable */ public SDVariable mergeMax(String name, SDVariable... inputs) { SDVariable ret = f().mergeMax(inputs); return updateVariableNameAndReference(ret, name); } /** * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise mean operation: * out = mean_i in[i] * * @param inputs Input variables * @return Output variable */ public SDVariable mergeAvg(SDVariable... inputs) { return mergeAvg(null, inputs); } /** * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise mean operation: * out = mean_i in[i] * * @param name Name of the output variable * @param inputs Input variables * @return Output variable */ public SDVariable mergeAvg(String name, SDVariable... inputs) { SDVariable ret = f().mergeAvg(inputs); return updateVariableNameAndReference(ret, name); } /** * @see #batchToSpace(String, SDVariable, int[], int[][]) */ public SDVariable batchToSpace(SDVariable x, int[] blocks, int[][] crops) { return batchToSpace(null, x, blocks, crops); } /** * Convolution 2d layer batch to space operation on 4d input. * Reduces input batch dimension by rearranging data into a larger spatial dimensions * * @param name Output variable name * @param x Input variable. 4d input * @param blocks Block size, in the height/width dimension * @param crops Optional 2d int[] array: values [[crop top, crop bottom], [crop left, crop right]] * @return Output variable * @see #spaceToBatch(String, SDVariable, int[], int[][]) */ public SDVariable batchToSpace(String name, SDVariable x, int[] blocks, int[][] crops) { SDVariable ret = f().batchToSpace(x, blocks, crops); return updateVariableNameAndReference(ret, name); } /** * Convolution 2d layer batch to space operation on 4d input.
* Reduces input channels dimension by rearranging data into a larger spatial dimensions
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2] * = [mb, 2, 4, 4] * * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param blockSize Block size, in the height/width dimension * @param dataFormat Data format: "NCHW" or "NHWC" * @return Output variable */ public SDVariable depthToSpace(SDVariable x, int blockSize, String dataFormat) { return depthToSpace(null, x, blockSize, dataFormat); } /** * Convolution 2d layer batch to space operation on 4d input.
* Reduces input channels dimension by rearranging data into a larger spatial dimensions
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2] * = [mb, 2, 4, 4] * * @param name Output variable name * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param blockSize Block size, in the height/width dimension * @param dataFormat Data format: "NCHW" or "NHWC" * @return Output variable * @see #depthToSpace(String, SDVariable, int, String) */ public SDVariable depthToSpace(String name, SDVariable x, int blockSize, String dataFormat) { SDVariable ret = f().depthToSpace(x, blockSize, dataFormat); return updateVariableNameAndReference(ret, name); } /** * @see #spaceToBatch(String, SDVariable, int[], int[][]) */ public SDVariable spaceToBatch(SDVariable x, int[] blocks, int[][] padding) { return spaceToBatch(null, x, blocks, padding); } /** * Convolution 2d layer space to batch operation on 4d input. * Increases input batch dimension by rearranging data from spatial dimensions into batch dimension * * @param name Output variable name * @param x Input variable. 4d input * @param blocks Block size, in the height/width dimension * @param padding Optional 2d int[] array for padding the result: values [[pad top, pad bottom], [pad left, pad right]] * @return Output variable * @see #batchToSpace(String, SDVariable, int[], int[][]) */ public SDVariable spaceToBatch(String name, SDVariable x, int[] blocks, int[][] padding) { SDVariable ret = f().spaceToBatch(x, blocks, padding); return updateVariableNameAndReference(ret, name); } /** * @see #spaceToDepth(String, SDVariable, int, String) */ public SDVariable spaceToDepth(SDVariable x, int blockSize, String dataFormat) { return spaceToDepth(null, x, blockSize, dataFormat); } /** * Convolution 2d layer space to depth operation on 4d input.
* Increases input channels (reduced spatial dimensions) by rearranging data into a larger channels dimension
* Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2] * = [mb, 2, 4, 4] * * @param name Output variable name * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param blockSize Block size, in the height/width dimension * @param dataFormat Data format: "NCHW" or "NHWC" * @return Output variable * @see #depthToSpace(String, SDVariable, int, String) */ public SDVariable spaceToDepth(String name, SDVariable x, int blockSize, String dataFormat) { SDVariable ret = f().spaceToDepth(x, blockSize, dataFormat); return updateVariableNameAndReference(ret, name); } /** * @see #dynamicPartition(String[], SDVariable, SDVariable, int) */ public SDVariable[] dynamicPartition(SDVariable x, SDVariable partitions, int numPartitions) { return dynamicPartition(null, x, partitions, numPartitions); } /** * Dynamically partition the input variable values into the specified number of paritions, using the indices.
* Example:
*
     * {@code input = [1,2,3,4,5]
     * numPartitions = 2
     * partitions = [1,0,0,1,0]
     * out[0] = [2,3,5]
     * out[1] = [1,4] }
     * 
* * @param name Names for the output variables. Length must be equal to numPartitions * @param x Input variable * @param partitions 1D input with values 0 to numPartitions-1 * @param numPartitions Number of partitions, >= 1 * @return Output variables (equal in number to numPartitions) */ public SDVariable[] dynamicPartition(String[] name, SDVariable x, SDVariable partitions, int numPartitions) { SDVariable[] ret = f().dynamicPartition(x, partitions, numPartitions); return updateVariableNamesAndReferences(ret, name); } /** * @see #dynamicStitch(String, SDVariable[], SDVariable[]) */ public SDVariable dynamicStitch(SDVariable[] indices, SDVariable[] x) { return dynamicStitch(null, indices, x); } /** * Dynamically merge the specified input arrays into a single array, using the specified indices * * @param name Name of the output variable * @param indices Indices to use when merging. Must be >= 1, same length as input variables * @param x Input variables. * @return Merged output variable */ public SDVariable dynamicStitch(String name, SDVariable[] indices, SDVariable[] x) { SDVariable ret = f().dynamicStitch(indices, x); return updateVariableNameAndReference(ret, name); } /** * @see #segmentMax(String, SDVariable, SDVariable) */ public SDVariable segmentMax(SDVariable data, SDVariable segmentIds){ return segmentMax(null, data, segmentIds); } /** * Segment max operation.
* If data = [3, 6, 1, 4, 9, 2, 8]
* segmentIds = [0, 0, 1, 1, 1, 2, 2]
* then output = [6, 9, 8] = [max(3,6), max(1,4,9), max(2,8) * * @param name Name of the output variable. May be null * @param data Data to perform segment max on * @param segmentIds Variable for the segment IDs * @return Segment max output */ public SDVariable segmentMax(String name, SDVariable data, SDVariable segmentIds){ SDVariable ret = f().segmentMax(data, segmentIds); return updateVariableNameAndReference(ret, name); } /** * @see #segmentMin(String, SDVariable, SDVariable) */ public SDVariable segmentMin(SDVariable data, SDVariable segmentIds){ return segmentMin(null, data, segmentIds); } /** * Segment min operation.
* If data = [3, 6, 1, 4, 9, 2, 8]
* segmentIds = [0, 0, 1, 1, 1, 2, 2]
* then output = [3, 1, 2] = [min(3,6), min(1,4,9), min(2,8) * * @param name Name of the output variable. May be null * @param data Data to perform segment max on * @param segmentIds Variable for the segment IDs * @return Segment min output */ public SDVariable segmentMin(String name, SDVariable data, SDVariable segmentIds){ SDVariable ret = f().segmentMin(data, segmentIds); return updateVariableNameAndReference(ret, name); } /** * @see #segmentMean(String, SDVariable, SDVariable) */ public SDVariable segmentMean(SDVariable data, SDVariable segmentIds){ return segmentMean(null, data, segmentIds); } /** * Segment mean operation.
* If data = [3, 6, 1, 4, 9, 2, 8]
* segmentIds = [0, 0, 1, 1, 1, 2, 2]
* then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8) * * @param name Name of the output variable. May be null * @param data Data to perform segment max on * @param segmentIds Variable for the segment IDs * @return Segment mean output */ public SDVariable segmentMean(String name, SDVariable data, SDVariable segmentIds){ SDVariable ret = f().segmentMean(data, segmentIds); return updateVariableNameAndReference(ret, name); } /** * @see #segmentProd(String, SDVariable, SDVariable) */ public SDVariable segmentProd(SDVariable data, SDVariable segmentIds){ return segmentProd(null, data, segmentIds); } /** * Segment product operation.
* If data = [3, 6, 1, 4, 9, 2, 8]
* segmentIds = [0, 0, 1, 1, 1, 2, 2]
* then output = [18, 36, 16] = [prod(3,6), prod(1,4,9), prod(2,8) * * @param name Name of the output variable. May be null * @param data Data to perform segment max on * @param segmentIds Variable for the segment IDs * @return Segment product output */ public SDVariable segmentProd(String name, SDVariable data, SDVariable segmentIds){ SDVariable ret = f().segmentProd(data, segmentIds); return updateVariableNameAndReference(ret, name); } /** * @see #segmentSum(String, SDVariable, SDVariable) */ public SDVariable segmentSum(SDVariable data, SDVariable segmentIds){ return segmentSum(null, data, segmentIds); } /** * Segment sum operation.
* If data = [3, 6, 1, 4, 9, 2, 8]
* segmentIds = [0, 0, 1, 1, 1, 2, 2]
* then output = [9, 14, 10] = [sum(3,6), sum(1,4,9), sum(2,8) * * @param name Name of the output variable. May be null * @param data Data to perform segment max on * @param segmentIds Variable for the segment IDs * @return Segment sum output */ public SDVariable segmentSum(String name, SDVariable data, SDVariable segmentIds){ SDVariable ret = f().segmentSum(data, segmentIds); return updateVariableNameAndReference(ret, name); } /** * TODO doc string * * @param df * @param weights * @param strides * @param rates * @param isSameMode * @return */ public SDVariable dilation2D(SDVariable df, SDVariable weights, int[] strides, int[] rates, boolean isSameMode) { return dilation2D(null, df, weights, strides, rates, isSameMode); } /** * TODO doc string * * @param name * @param df * @param weights * @param strides * @param rates * @param isSameMode * @return */ public SDVariable dilation2D(String name, SDVariable df, SDVariable weights, int[] strides, int[] rates, boolean isSameMode) { SDVariable ret = f().dilation2D(df, weights, strides, rates, isSameMode); return updateVariableNameAndReference(ret, name); } /** * Returns the shape of the specified SDVariable as a 1D SDVariable * * @param input Input variable * @return 1D output variable with contents equal to the shape of the input */ public SDVariable shape(SDVariable input) { return shape(null, input); } /** * Returns the shape of the specified SDVariable as a 1D SDVariable * * @param name Name of the output variable * @param input Input variable * @return 1D output variable with contents equal to the shape of the input */ public SDVariable shape(String name, SDVariable input) { SDVariable ret = f().shape(input); return updateVariableNameAndReference(ret, name); } /** * Returns the size (number of elements, i.e., prod(shape)) of the specified SDVariable as a 0D scalar variable * * @param in Input variable * @return 0D (scalar) output variable with value equal to the number of elements in the specified array */ public SDVariable size(SDVariable in){ return size(null, in); } /** * Returns the size (number of elements, i.e., prod(shape)) of the specified SDVariable as a 0D scalar variable * * @param name Name of the output variable * @param in Input variable * @return 0D (scalar) output variable with value equal to the number of elements in the specified array */ public SDVariable size(String name, SDVariable in){ SDVariable ret = f().size(in); return updateVariableNameAndReference(ret, name); } /** * Returns the rank (number of dimensions, i.e., length(shape)) of the specified SDVariable as a 0D scalar variable * * @param in Input variable * @return 0D (scalar) output variable with value equal to the rank of the input variable */ public SDVariable rank(SDVariable in) { return rank(null, in); } /** * Returns the rank (number of dimensions, i.e., length(shape)) of the specified SDVariable as a 0D scalar variable * * @param name Name of the output variable * @param in Input variable * @return 0D (scalar) output variable with value equal to the rank of the input variable */ public SDVariable rank(String name, SDVariable in) { SDVariable ret = f().rank(in); return updateVariableNameAndReference(ret, name); } /** * @see #sizeAt(String, SDVariable, int) */ public SDVariable sizeAt(SDVariable in, int dimension){ return sizeAt(null, in, dimension); } /** * Returns a rank 0 (scalar) variable for the size of the specified dimension. * For example, if X has shape [10,20,30] then sizeAt(X,1)=20. Similarly, sizeAt(X,-1)=30 * * @param name Name of the output variable * @param in Input variable * @param dimension Dimension to get size of * @return Scalar SDVariable for size at specified variable */ public SDVariable sizeAt(String name, SDVariable in, int dimension){ SDVariable ret = f().sizeAt(in, dimension); return updateVariableNameAndReference(ret, name); } /** * @see #cross(String, SDVariable, SDVariable) */ public SDVariable cross(SDVariable a, SDVariable b) { return cross(null, a, b); } /** * Returns the pair-wise cross product of equal size arrays a and b: a x b = ||a||x||b|| sin(theta). * Can take rank 1 or above inputs (of equal shapes), but note that the last dimension must have dimension 3 * * @param a First input * @param b Second input * @return Element-wise cross product */ public SDVariable cross(String name, SDVariable a, SDVariable b) { SDVariable ret = f().cross(a, b); return updateVariableNameAndReference(ret, name); } /** * @see #gather(String, SDVariable, int[], int) */ public SDVariable gather(SDVariable df, int[] indices, int axis) { return gather(null, df, indices, axis); } /** * Gather slices from the input variable where the indices are specified as fixed int[] values.
* Output shape is same as input shape, except for axis dimension, which has size equal to indices.length. * * @param name name of the output variable * @param df Input variable * @param indices Indices to get * @param axis Axis that the indices refer to * @return Output variable with slices pulled from the specified axis */ public SDVariable gather(String name, SDVariable df, int[] indices, int axis) { SDVariable ret = f().gather(df, indices, axis); return updateVariableNameAndReference(ret, name); } /** * @see #gather(String, SDVariable, SDVariable, int) */ public SDVariable gather(SDVariable df, SDVariable indices, int axis) { return gather(null, df, indices, axis); } /** * Gather slices from the input variable where the indices are specified as dynamic SDVariable values.
* Output shape is same as input shape, except for axis dimension, which has size equal to indices.length. * * @param name name of the output variable * @param df Input variable * @param indices Indices to get slices for. Rank 0 or 1 input * @param axis Axis that the indices refer to * @return Output variable with slices pulled from the specified axis */ public SDVariable gather(String name, SDVariable df, SDVariable indices, int axis) { SDVariable ret = f().gather(df, indices, axis); return updateVariableNameAndReference(ret, name); } /** * TODO doc string * * @param df * @param indices * @return */ public SDVariable gatherNd(SDVariable df, SDVariable indices) { return gatherNd(null, df, indices); } /** * TODO doc string * * @param name * @param df * @param indices * @return */ public SDVariable gatherNd(String name, SDVariable df, SDVariable indices) { SDVariable ret = f().gatherNd(df, indices); return updateVariableNameAndReference(ret, name); } /** * @see #repeat(String, SDVariable, int) */ public SDVariable repeat(SDVariable df, int axis) { return repeat(null, df, axis); } /** * @see #repeat(String, SDVariable, int) */ public SDVariable repeat(String name, SDVariable df, int axis) { SDVariable ret = f().repeat(df, axis); return updateVariableNameAndReference(ret, name); } /** * @see #stack(String, int, SDVariable...) */ public SDVariable stack(int axis, SDVariable... values) { return stack(null, axis, values); } /** * Stack a set of N SDVariables of rank X into one rank X+1 variable. * If inputs have shape [a,b,c] then output has shape:
* axis = 0: [N,a,b,c]
* axis = 1: [a,N,b,c]
* axis = 2: [a,b,N,c]
* axis = 3: [a,b,c,N]
* * @param name Name of the output variable * @param axis Axis to stack on * @param values Input variables to stack. Must have the same shape for all inputs * @return Output variable * @see #unstack(String[], SDVariable, int, int) */ public SDVariable stack(String name, int axis, SDVariable... values) { SDVariable ret = f().stack(values, axis); return updateVariableNameAndReference(ret, name); } /** * @see #stack(String, int, SDVariable...) */ public SDVariable parallel_stack(SDVariable[] values) { return parallel_stack(null, values); } /** * @see #stack(String, int, SDVariable...) */ public SDVariable parallel_stack(String name, SDVariable[] values) { SDVariable ret = f().parallel_stack(values); return updateVariableNameAndReference(ret, name); } /** * @see #unstack(String[], SDVariable, int, int) */ public SDVariable[] unstack(SDVariable value, int axis) { return unstack(null, value, axis); } /** * @see #unstack(String[], SDVariable, int, int) */ public SDVariable[] unstack(String[] names, SDVariable value, int axis) { SDVariable[] ret = f().unstack(value, axis); return updateVariableNamesAndReferences(ret, names); } /** * @see #unstack(String[], SDVariable, int, int) */ public SDVariable[] unstack(SDVariable value, int axis, int num) { return unstack(null, value, axis, num); } /** * Unstack a variable of rank X into N rank X-1 variables by taking slices along the specified axis. * If input has shape [a,b,c] then output has shape: * axis = 0: [b,c]
* axis = 1: [a,c]
* axis = 2: [a,b]
* * @param names Output variable names. May be null * @param value Input variable to unstack * @param axis Axis to unstack on * @param num Number of output variables * @return Output variables * @see #stack(String, int, SDVariable...) */ public SDVariable[] unstack(String[] names, SDVariable value, int axis, int num) { SDVariable[] ret = f().unstack(value, axis, num); return updateVariableNamesAndReferences(ret, names); } /** * Element-wise Gaussian error function - out = erf(in) * * @param x Input variable * @return Output variable */ public SDVariable erf(SDVariable x) { return erf(null, x); } /** * Element-wise Gaussian error function - out = erf(in) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable erf(String name, SDVariable x) { SDVariable ret = f().erf(x); return updateVariableNameAndReference(ret, name); } /** * Element-wise complementary Gaussian error function - out = erfc(in) = 1 - erf(in) * * @param x Input variable * @return Output variable */ public SDVariable erfc(SDVariable x) { return erfc(null, x); } /** * Element-wise complementary Gaussian error function - out = erfc(in) = 1 - erf(in) * * @param name Name of the output variable * @param x Input variable * @return Output variable */ public SDVariable erfc(String name, SDVariable x) { SDVariable ret = f().erfc(x); return updateVariableNameAndReference(ret, name); } /** * @see #diag(String, SDVariable) */ public SDVariable diag(SDVariable x) { return diag(null, x); } /** * Returns an output variable with diagonal values equal to the specified values; off-diagonal values will be set to 0
* For example, if input = [1,2,3], then output is given by:
* [ 1, 0, 0]
* [ 0, 2, 0]
* [ 0, 0, 3]
*
* Higher input ranks are also supported: if input has shape [a,...,R-1] then output[i,...,k,i,...,k] = input[i,...,k]. * i.e., for input rank R, output has rank 2R * * @param name Name of the output variable * @param x Input variable * @return Output variable */ public SDVariable diag(String name, SDVariable x) { SDVariable ret = f().diag(x); return updateVariableNameAndReference(ret, name); } /** * @see #diagPart(String, SDVariable) */ public SDVariable diagPart(SDVariable x) { return diagPart(null, x); } /** * Extract the diagonal part from the input array.
* If input is
* [ 1, 0, 0]
* [ 0, 2, 0]
* [ 0, 0, 3]
* then output is [1, 2, 3].
* Supports higher dimensions: in general, out[i,...,k] = in[i,...,k,i,...,k] * * @param x Input variable * @return Diagonal part of the input * @see #diag(String, SDVariable) */ public SDVariable diagPart(String name, SDVariable x) { SDVariable ret = f().diagPart(x); return updateVariableNameAndReference(ret, name); } /** * @see #setDiag(String, SDVariable, SDVariable) */ public SDVariable setDiag(SDVariable in, SDVariable diag) { return setDiag(null, in, diag); } /** * Set the diagonal value to the specified values
* If input is
* [ a, b, c]
* [ d, e, f]
* [ g, h, i]
* and diag = [ 1, 2, 3] then output is
* [ 1, b, c]
* [ d, 2, f]
* [ g, h, 3]
* * @param name Name of the output variable * @param in Input variable * @param diag Diagonal * @return Output variable */ public SDVariable setDiag(String name, SDVariable in, SDVariable diag) { SDVariable ret = f().setDiag(in, diag); return updateVariableNameAndReference(ret, name); } /** * @see #oneHot(String, SDVariable, int) */ public SDVariable oneHot(SDVariable indices, int depth) { return oneHot(null, indices, depth, -1, 1.00, 0.00); } /** * @see #oneHot(String, SDVariable, int, int, double, double) */ public SDVariable oneHot(SDVariable indices, int depth, int axis, double on, double off) { return oneHot(null, indices, depth, axis, on, off); } /** * Convert the array to a one-hot array with walues 0 and 1 for each entry
* If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth], * with out[i, ..., j, in[i,...,j]] = 1 with other values being set to 0 * * @param name Output variable name * @param indices Indices - value 0 to depth-1 * @param depth Number of classes * @return Output variable * @see #oneHot(SDVariable, int, int, double, double) */ public SDVariable oneHot(String name, SDVariable indices, int depth) { return oneHot(name, indices, depth, -1, 1.00, 0.00); } /** * Convert the array to a one-hot array with walues {@code on} and {@code off} for each entry
* If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth], * with {@code out[i, ..., j, in[i,...,j]] = on} with other values being set to {@code off} * * @param name Output variable name * @param indices Indices - value 0 to depth-1 * @param depth Number of classes * @return Output variable */ public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on, double off) { SDVariable ret = f().onehot(indices, depth, axis, on, off); return updateVariableNameAndReference(ret, name); } /** * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i] * * @param a Input variable * @return Output variable */ public SDVariable reciprocal(SDVariable a) { return reciprocal(null, a); } /** * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i] * * @param name Name of the output variable * @param a Input variable * @return Output variable */ public SDVariable reciprocal(String name, SDVariable a) { SDVariable ret = f().reciprocal(a); return updateVariableNameAndReference(ret, name); } /** * Intended for internal/developer use */ public SDVariable gradientBackwardsMarker(SDVariable x) { return gradientBackwardsMarker(generateNewVarName(new GradientBackwardsMarker().opName(), 0), x); } /** * Intended for internal/developer use */ public SDVariable gradientBackwardsMarker(String name, SDVariable x) { SDVariable result = functionFactory.gradientBackwardsMarker(x); return updateVariableNameAndReference(result, name); } /** * Element-wise hard tanh function:
* out[i] = -1 if in[i] <= -1
* out[1] = in[i] if -1 < in[i] < 1
* out[i] = 1 if in[i] >= 1
* * @param in Input variable * @return Output variable */ public SDVariable hardTanh(SDVariable in) { return hardTanh(null, in); } /** * Element-wise hard tanh function:
* out[i] = -1 if in[i] <= -1
* out[1] = in[i] if -1 < in[i] < 1
* out[i] = 1 if in[i] >= 1
* * @param name Output variable name * @param in Input variable * @return Output variable */ public SDVariable hardTanh(String name, SDVariable in) { SDVariable result = functionFactory.hardTanh(in); return updateVariableNameAndReference(result, name); } /** * Element-wise hard sigmoid function:
* out[i] = 0 if in[i] <= -2.5
* out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5
* out[i] = 1 if in[i] >= 2.5
* * @param in Input variable * @return Output variable */ public SDVariable hardSigmoid(SDVariable in) { return hardSigmoid(null, in); } /** * Element-wise hard sigmoid function:
* out[i] = 0 if in[i] <= -2.5
* out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5
* out[i] = 1 if in[i] >= 2.5
* * @param name Name of the output variable * @param in Input variable * @return Output variable */ public SDVariable hardSigmoid(String name, SDVariable in) { SDVariable ret = f().hardSigmoid(in); return updateVariableNameAndReference(ret, name); } /** * Derivative (dOut/dIn) of the element-wise hard Tanh function - {@link #hardTanh(SDVariable)} * * @param x Input * @return Output variable */ public SDVariable hardTanhDerivative(SDVariable x) { return hardTanhDerivative(null, x); } /** * Derivative (dOut/dIn) of the element-wise hard Tanh function - {@link #hardTanh(SDVariable)} * * @param name Output variable name * @param x Input * @return Output variable */ public SDVariable hardTanhDerivative(String name, SDVariable x) { SDVariable result = functionFactory.hardTanhDerivative(x); return updateVariableNameAndReference(result, name); } /** * Element-wise sigmoid function: out[i] = 1.0/(1+exp(-in[i])) * * @param x Input Variable * @return Output variable */ public SDVariable sigmoid(SDVariable x) { return sigmoid(null, x); } /** * Element-wise sigmoid function: out[i] = 1.0/(1+exp(-in[i])) * * @param name Output variable name * @param x Input Variable * @return Output variable */ public SDVariable sigmoid(String name, SDVariable x) { SDVariable result = functionFactory.sigmoid(x); return updateVariableNameAndReference(result, name); } /** * Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut * * @param x Input Variable * @param wrt Gradient at the output - dL/dOut. Must have same shape as the input * @return Output variable */ public SDVariable sigmoidDerivative(SDVariable x, SDVariable wrt) { return sigmoidDerivative(null, x, wrt); } /** * Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut * * @param name Output variable name * @param x Input Variable * @param wrt Gradient at the output - dL/dOut. Must have same shape as the input * @return Output variable */ public SDVariable sigmoidDerivative(String name, SDVariable x, SDVariable wrt) { SDVariable result = functionFactory .sigmoidDerivative(x, wrt); return updateVariableNameAndReference(result, name); } /** * Element-wise sigmoid function: out[i] = log(sigmoid(in[i])) * * @param x Input Variable * @return Output variable */ public SDVariable logSigmoid(SDVariable x) { return logSigmoid(null, x); } /** * Element-wise sigmoid function: out[i] = log(sigmoid(in[i])) * * @param name Name of the output variable * @param x Input Variable * @return Output variable */ public SDVariable logSigmoid(String name, SDVariable x) { SDVariable ret = f().logSigmoid(x); return updateVariableNameAndReference(ret, name); } /** * Element-wise sign (signum) function:
* out = -1 if in < 0
* out = 0 if in = 0
* out = 1 if in > 0 * * @param x Input variable * @return Output variable */ public SDVariable sign(SDVariable x) { return sign(null, x); } /** * Element-wise sign (signum) function:
* out = -1 if in < 0
* out = 0 if in = 0
* out = 1 if in > 0 * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable sign(String name, SDVariable x) { SDVariable result = functionFactory.sign(x); return updateVariableNameAndReference(result, name); } /** * Element-wise softsign function: out = x / (abs(x) + 1) * * @param x Input variable * @return Output variable */ public SDVariable softsign(SDVariable x) { return softsign(null, x); } /** * Element-wise softsign function: out = x / (abs(x) + 1) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable softsign(String name, SDVariable x) { SDVariable result = functionFactory.softsign(x); return updateVariableNameAndReference(result, name); } /** * Element-wise derivative (dOut/dIn) of the softsign function {@link #softsign(SDVariable)} * * @param x Input variable * @return Output varible */ public SDVariable softsignDerivative(SDVariable x) { return softsignDerivative(null, x); } /** * Element-wise derivative (dOut/dIn) of the softsign function {@link #softsign(SDVariable)} * * @param name Output variable name * @param x Input variable * @return Output varible */ public SDVariable softsignDerivative(String name, SDVariable x) { SDVariable result = functionFactory.softsignDerivative(x); return updateVariableNameAndReference(result, name); } /** * Element-wise softplus function: out = log(exp(x) + 1) * * @param x Input variable * @return Output variable */ public SDVariable softplus(SDVariable x) { return softplus(null, x); } /** * Element-wise softplus function: out = log(exp(x) + 1) * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable softplus(String name, SDVariable x) { SDVariable result = functionFactory.softplus(x); return updateVariableNameAndReference(result, name); } /** * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
* See: https://arxiv.org/abs/1710.05941 * * @param x Input variable * @return Output variable */ public SDVariable swish(SDVariable x) { return swish(null, x); } /** * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
* See: https://arxiv.org/abs/1710.05941 * * @param name Name of the output variable * @param x Input variable * @return Output variable */ public SDVariable swish(String name, SDVariable x) { SDVariable ret = f().swish(x); return updateVariableNameAndReference(ret, name); } /** * Element-wise exponential linear unit (ELU) function:
* out = x if x > 0
* out = a * (exp(x) - 1) if x <= 0
* with constant a = 1.0 *

* See: http://arxiv.org/abs/1511.07289 * * @param x Input variable * @return Output variable */ public SDVariable elu(SDVariable x) { return elu(null, x); } /** * Element-wise exponential linear unit (ELU) function:
* out = x if x > 0
* out = a * (exp(x) - 1) if x <= 0
* with constant a = 1.0 *

* See: http://arxiv.org/abs/1511.07289 * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable elu(String name, SDVariable x) { SDVariable result = functionFactory.elu(x); return updateVariableNameAndReference(result, name); } /** * Element-wise derivative exponential linear unit (ELU) function, dOut/dIn given input. * {@link #elu(SDVariable)} * * @param x Input variable * @return Output variable */ public SDVariable eluDerivative(SDVariable x) { return eluDerivative(null, x); } /** * Element-wise derivative exponential linear unit (ELU) function, dOut/dIn given input. * {@link #elu(SDVariable)} * * @param name Output variable name * @param x Input variable * @return Output variable */ public SDVariable eluDerivative(String name, SDVariable x) { SDVariable result = functionFactory.eluDerivative(x); return updateVariableNameAndReference(result, name); } /** * Element-wise leaky ReLU function:
* out = x if x >= 0.0
* out = alpha * x if x < cutoff
* Alpha value is most commonly set to 0.01 * * @param x Input variable * @param alpha Cutoff - usually 0.0 * @return Output variable */ public SDVariable leakyRelu(SDVariable x, double alpha) { return leakyRelu(null, x, alpha); } /** * Element-wise leaky ReLU function:
* out = x if x >= 0.0
* out = alpha * x if x < cutoff
* Alpha value is most commonly set to 0.01 * * @param x Input variable * @param alpha Cutoff - usually 0.0 * @return Output variable */ public SDVariable leakyRelu(String name, SDVariable x, double alpha) { SDVariable result = functionFactory.leakyRelu(x, alpha); return updateVariableNameAndReference(result, name); } /** * Leaky ReLU derivative: dOut/dIn given input.
* See {@link #leakyRelu(String, SDVariable, double)} * * @param x Input variable * @param alpha Alpha value * @return Output variable */ public SDVariable leakyReluDerivative(String name, SDVariable x, double alpha) { SDVariable result = functionFactory.leakyReluDerivative(x, alpha); return updateVariableNameAndReference(result, name); } /** * Full array mean reduction operation * @param x Input variable * @return Output variable - scalar */ public SDVariable mean(SDVariable x) { return mean(null, x); } /** * Mean (average) array reduction operation, optionally along specified dimensions * * @param x Input variable * @param dimension Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable mean(SDVariable x, int... dimension) { return mean(null, x, dimension); } /** * Mean (average) array reduction operation, optionally along specified dimensions * * @param name Output variable name * @param x Input variable * @param dimension Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable mean(String name, SDVariable x, int... dimension) { return mean(name, x, false, dimension); } /** * Mean (average) array reduction operation, optionally along specified dimensions
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Output variable name * @param x Input variable * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions * @param dimension Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable mean(String name, SDVariable x, boolean keepDims, int... dimension) { SDVariable result = functionFactory.mean(x, keepDims, dimension); return updateVariableNameAndReference(result, name); } /** * @see #standardDeviation(String, SDVariable, boolean, int...) */ public SDVariable standardDeviation(SDVariable x, boolean biasCorrected, int... dimensions) { return standardDeviation(null, x, biasCorrected, dimensions); } /** * Stardard deviation array reduction operation, optionally along specified dimensions * * @param name Output variable name * @param x Input variable * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected, int... dimensions) { return standardDeviation(name, x, biasCorrected, false, dimensions); } /** * Stardard deviation array reduction operation, optionally along specified dimensions
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param x Input variable * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected, boolean keepDims, int... dimensions) { SDVariable result = functionFactory.std(x, biasCorrected, keepDims, dimensions); return updateVariableNameAndReference(result, name); } /** * @see #variance(String, SDVariable, boolean, int...) */ public SDVariable variance(SDVariable x, boolean biasCorrected, int... dimensions) { return variance(null, x, biasCorrected, dimensions); } /** * Variance array reduction operation, optionally along specified dimensions * * @param name Output variable name * @param x Input variable * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable variance(String name, SDVariable x, boolean biasCorrected, int... dimensions) { return variance(name, x, biasCorrected, false, dimensions); } /** * Variance array reduction operation, optionally along specified dimensions
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Output variable name * @param x Input variable * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable variance(String name, SDVariable x, boolean biasCorrected, boolean keepDims, int... dimensions) { SDVariable result = functionFactory.variance(x, biasCorrected, keepDims, dimensions); return updateVariableNameAndReference(result, name); } /** * Entropy reduction: -sum(x * log(x)) * * @param in Input variable * @param dimensions Dimensions to reduce on (null/empty for full array) * @return Output variable */ public SDVariable entropy(SDVariable in, int... dimensions) { return entropy(null, in, dimensions); } /** * Entropy reduction: -sum(x * log(x)) * * @param name Name of the output variable * @param in Input variable * @param dimensions Dimensions to reduce on (null/empty for full array) * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable entropy(String name, SDVariable in, int... dimensions) { SDVariable ret = f().entropy(in, dimensions); return updateVariableNameAndReference(ret, name); } /** * Log entropy reduction: log(-sum(x * log(x))) * * @param in Input variable * @param dimensions Dimensions to reduce on (null for full array) * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable logEntropy(SDVariable in, int... dimensions) { return logEntropy(null, in, dimensions); } /** * Log entropy reduction: log(-sum(x * log(x))) * * @param name Name of the output variable * @param in Input variable * @param dimensions Dimensions to reduce on (null for full array) * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable logEntropy(String name, SDVariable in, int... dimensions) { SDVariable ret = f().logEntropy(in, dimensions); return updateVariableNameAndReference(ret, name); } /** * Shannon Entropy reduction: -sum(x * log2(x)) * * @param in Input variable * @param dimensions Dimensions to reduce on (null/empty for full array) * @return Output variable */ public SDVariable shannonEntropy(SDVariable in, int... dimensions) { return shannonEntropy(null, in, dimensions); } /** * Shannon Entropy reduction: -sum(x * log2(x)) * * @param name Name of the output variable * @param in Input variable * @param dimensions Dimensions to reduce on (null/empty for full array) * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable shannonEntropy(String name, SDVariable in, int... dimensions) { SDVariable ret = f().shannonEntropy(in, dimensions); return updateVariableNameAndReference(ret, name); } /** * Sum array reduction operation, optionally along specified dimensions * * @param x Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable sum(SDVariable x, int... dimensions) { return sum(null, x, dimensions); } /** * Sum array reduction operation, optionally along specified dimensions * * @param x Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or * of rank (input rank) if keepdims = true */ public SDVariable sum(String name, SDVariable x, int... dimensions) { return sum(name, x, false, dimensions); } /** * Sum array reduction operation, optionally along specified dimensions.
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Output variable name * @param x Input variable * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or * of rank (input rank) if keepdims = true */ public SDVariable sum(String name, SDVariable x, boolean keepDims, int... dimensions) { SDVariable result = functionFactory.sum(x, keepDims, dimensions); return updateVariableNameAndReference(result, name); } /** * @see #sum(String, SDVariable, boolean, int...) */ public SDVariable sum(SDVariable x, boolean keepDims, int... dimensions) { return sum(null, x, keepDims, dimensions); } /** * Product array reduction operation, optionally along specified dimensions * * @param x Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable prod(SDVariable x, int... dimensions) { return prod(null, x, dimensions); } /** * Product array reduction operation, optionally along specified dimensions * * @param name Output variable name * @param x Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable prod(String name, SDVariable x, int... dimensions) { return prod(name, x, false, dimensions); } /** * Product array reduction operation, optionally along specified dimensions
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Output variable name * @param x Input variable * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Output variable: reduced array of rank (input rank - num dimensions) */ public SDVariable prod(String name, SDVariable x, boolean keepDims, int... dimensions) { SDVariable result = functionFactory.prod(x, keepDims, dimensions); return updateVariableNameAndReference(result, name); } /** * Element-wise scalar maximum operation: out = max(in, value) * * @param in Input variable * @param value Scalar value to compare * @return Output variable */ public SDVariable scalarMax(SDVariable in, Number value) { return scalarMax(null, in, value); } /** * Element-wise scalar maximum operation: out = max(in, value) * * @param name Name of the output variable * @param in Input variable * @param value Scalar value to compare * @return Output variable */ public SDVariable scalarMax(String name, SDVariable in, Number value) { SDVariable ret = f().scalarMax(in, value); return updateVariableNameAndReference(ret, name); } /** * Element-wise scalar minimum operation: out = min(in, value) * * @param in Input variable * @param value Scalar value to compare * @return Output variable */ public SDVariable scalarMin(SDVariable in, Number value) { return scalarMin(null, in, value); } /** * Element-wise scalar minimum operation: out = min(in, value) * * @param name Name of the output variable * @param in Input variable * @param value Scalar value to compare * @return Output variable */ public SDVariable scalarMin(String name, SDVariable in, Number value) { SDVariable ret = f().scalarMin(in, value); return updateVariableNameAndReference(ret, name); } /** * Element-wise scalar floor modulus operation: out = floorMod(in, value). * i.e., returns the remainder after division by 'value' * * @param in Input variable * @param value Scalar value to compare * @return Output variable */ public SDVariable scalarFloorMod(SDVariable in, Number value) { return scalarFloorMod(null, in, value); } /** * Element-wise scalar floor modulus operation: out = floorMod(in, value). * i.e., returns the remainder after division by 'value' * * @param name Name of the output variable * @param in Input variable * @param value Scalar value to compare * @return Output variable */ public SDVariable scalarFloorMod(String name, SDVariable in, Number value) { SDVariable ret = f().scalarFloorMod(in, value); return updateVariableNameAndReference(ret, name); } /** * Return an array with equal shape to the input, but all elements set to value 'set' * * @param in Input variable * @param set Value to set * @return Output variable */ public SDVariable scalarSet(SDVariable in, Number set) { return scalarSet(null, in, set); } /** * Return an arary with equal shape to the input, but all elements set to value 'set' * * @param name Name of the output variable * @param in Input variable * @param set Value to set * @return Output variable */ public SDVariable scalarSet(String name, SDVariable in, Number set) { SDVariable ret = f().scalarSet(in, set); return updateVariableNameAndReference(ret, name); } /** * Max array reduction operation, optionally along specified dimensions * * @param x Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable max(SDVariable x, int... dimensions) { return max(null, x, dimensions); } /** * Max array reduction operation, optionally along specified dimensions * * @param name Output variable name * @param x Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable max(String name, SDVariable x, int... dimensions) { return max(name, x, false, dimensions); } /** * Max array reduction operation, optionally along specified dimensions
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Output variable name * @param x Input variable * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable max(String name, SDVariable x, boolean keepDims, int... dimensions) { SDVariable result = functionFactory.max(x, keepDims, dimensions); return updateVariableNameAndReference(result, name); } /** * Element-wise maximum operation: out[i] = max(first[i], second[i])
* Supports broadcasting * * @param first First input array * @param second Second input array * @return Output variable */ public SDVariable max(SDVariable first, SDVariable second) { return max(null, first, second); } /** * Element-wise maximum operation: out[i] = max(first[i], second[i])
* Supports broadcasting * * @param name Name of the output variable * @param first First input array * @param second Second input array * @return Output variable */ public SDVariable max(String name, SDVariable first, SDVariable second) { SDVariable result = f().max(first, second); return updateVariableNameAndReference(result, name); } /** * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x)) * * @param in Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable amax(SDVariable in, int... dimensions) { return amax(null, in, dimensions); } /** * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x)) * * @param name Name of the output variable * @param in Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable amax(String name, SDVariable in, int... dimensions) { SDVariable ret = f().amax(in, dimensions); return updateVariableNameAndReference(ret, name); } /** * Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x)) * * @param in Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable amin(SDVariable in, int... dimensions) { return amin(null, in, dimensions); } /** * Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x)) * * @param name Name of the output variable * @param in Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable amin(String name, SDVariable in, int... dimensions) { SDVariable ret = f().amin(in, dimensions); return updateVariableNameAndReference(ret, name); } /** * Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x)) * * @param in Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable amean(SDVariable in, int... dimensions) { return amean(null, in, dimensions); } /** * Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x)) * * @param name Name of the output variable * @param in Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable amean(String name, SDVariable in, int... dimensions) { SDVariable ret = f().amean(in, dimensions); return updateVariableNameAndReference(ret, name); } /** * Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x)) * * @param in Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable asum(SDVariable in, int... dimensions) { return asum(null, in, dimensions); } /** * Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x)) * * @param name Name of the output variable * @param in Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable asum(String name, SDVariable in, int... dimensions) { SDVariable ret = f().asum(in, dimensions); return updateVariableNameAndReference(ret, name); } /** * Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0) * * @param input Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable countZero(SDVariable input, int... dimensions) { return countZero(null, input, dimensions); } /** * Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0) * * @param name Name of the output variable * @param input Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable countZero(String name, SDVariable input, int... dimensions) { SDVariable res = f().countZero(input, dimensions); return updateVariableNameAndReference(res, name); } /** * Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x)) * * @param input Input variable * @return Reduced array of rank 0 (scalar) */ public SDVariable zeroFraction(SDVariable input) { return zeroFraction(null, input); } /** * Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x)) * * @param name Name of the output variable * @param input Input variable * @return Reduced array of rank 0 (scalar) */ public SDVariable zeroFraction(String name, SDVariable input) { SDVariable res = f().zeroFraction(input); return updateVariableNameAndReference(res, name); } /** * Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0) * * @param input Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable countNonZero(SDVariable input, int... dimensions) { return countNonZero(null, input, dimensions); } /** * Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0) * * @param name Name of the output variable * @param input Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable countNonZero(String name, SDVariable input, int... dimensions) { SDVariable res = f().countNonZero(input, dimensions); return updateVariableNameAndReference(res, name); } /** * Minimum array reduction operation, optionally along specified dimensions. out = min(in) * * @param x Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable min(SDVariable x, int... dimensions) { return min(null, x, dimensions); } /** * Minimum array reduction operation, optionally along specified dimensions. out = min(in) * * @param name Output variable name * @param x Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable min(String name, SDVariable x, int... dimensions) { return min(name, x, false, dimensions); } /** * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Output variable name * @param x Input variable * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable min(String name, SDVariable x, boolean keepDims, int... dimensions) { SDVariable result = functionFactory.min(x, keepDims, dimensions); return updateVariableNameAndReference(result, name); } /** * Element-wise minimum operation: out[i] = min(first[i], second[i])
* Supports broadcasting * * @param first First input array * @param second Second input array * @return Output variable */ public SDVariable min(SDVariable first, SDVariable second) { return min(null, first, second); } /** * Element-wise minimum operation: out[i] = min(first[i], second[i])
* Supports broadcasting * * @param name Name of the output variable * @param first First input array * @param second Second input array * @return Output variable */ public SDVariable min(String name, SDVariable first, SDVariable second) { SDVariable result = f().min(first, second); return updateVariableNameAndReference(result, name); } /** * Argmax array reduction operation, optionally along specified dimensions.
* Output values are the index of the maximum value of each slice along the specified dimension * * @param in Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable argmax(SDVariable in, int... dimensions) { return argmax(null, in, false, dimensions); } /** * @see #argmax(String, SDVariable, boolean, int...) */ public SDVariable argmax(SDVariable in, boolean keepDims, int... dimensions) { return argmax(null, in, keepDims, dimensions); } /** * Argmax array reduction operation, optionally along specified dimensions.
* Output values are the index of the maximum value of each slice along the specified dimension * * @param in Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable argmax(String name, SDVariable in, int... dimensions) { return argmax(name, in, false, dimensions); } /** * Argmax array reduction operation, optionally along specified dimensions.
* Output values are the index of the maximum value of each slice along the specified dimension.
*
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Name of the output variable * @param in Input variable * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or * of rank (input rank) if keepdims = true */ public SDVariable argmax(String name, SDVariable in, boolean keepDims, int... dimensions) { SDVariable ret = f().argmax(in, keepDims, dimensions); return updateVariableNameAndReference(ret, name); } /** * Argmin array reduction operation, optionally along specified dimensions.
* Output values are the index of the minimum value of each slice along the specified dimension * * @param in Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable argmin(SDVariable in, int... dimensions) { return argmin(null, in, dimensions); } /** * @see #argmin(String, SDVariable, boolean, int...) */ public SDVariable argmin(SDVariable in, boolean keepDims, int... dimensions) { return argmin(null, in, keepDims, dimensions); } /** * Argmin array reduction operation, optionally along specified dimensions.
* Output values are the index of the minimum value of each slice along the specified dimension * * @param in Input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable argmin(String name, SDVariable in, int... dimensions) { return argmin(name, in, false, dimensions); } /** * Argmin array reduction operation, optionally along specified dimensions.
* Output values are the index of the minimum value of each slice along the specified dimension.
*
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Name of the output variable * @param in Input variable * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Output variable: reduced array of rank (input rank - num dimensions) if keepDims = false, or * of rank (input rank) if keepdims = true */ public SDVariable argmin(String name, SDVariable in, boolean keepDims, int... dimensions) { SDVariable ret = f().argmin(in, keepDims, dimensions); return updateVariableNameAndReference(ret, name); } /** * Index of the max absolute value: argmax(abs(in)) * @see #argmax(SDVariable, int...) */ public SDVariable iamax(SDVariable in, int... dimensions) { return iamax(null, in, dimensions); } /** * Index of the max absolute value: argmax(abs(in)) * @see #argmax(String, SDVariable, boolean, int...) */ public SDVariable iamax(SDVariable in, boolean keepDims, int... dimensions) { return iamax(null, in, keepDims, dimensions); } /** * Index of the max absolute value: argmax(abs(in)) * @see #argmax(String, SDVariable, boolean, int...) */ public SDVariable iamax(String name, SDVariable in, int... dimensions) { return iamax(name, in, false, dimensions); } /** * Index of the max absolute value: argmax(abs(in)) * @see #argmax(String, SDVariable, boolean, int...) */ public SDVariable iamax(String name, SDVariable in, boolean keepDims, int... dimensions) { SDVariable ret = f().iamax(in, keepDims, dimensions); return updateVariableNameAndReference(ret, name); } /** * Index of the min absolute value: argmin(abs(in)) * @see #argmin(String, SDVariable, boolean, int...) */ public SDVariable iamin(SDVariable in, int... dimensions) { return iamin(null, in, dimensions); } /** * Index of the min absolute value: argmin(abs(in)) * @see #argmin(String, SDVariable, boolean, int...) */ public SDVariable iamin(SDVariable in, boolean keepDims, int... dimensions) { return iamin(null, in, keepDims, dimensions); } /** * Index of the min absolute value: argmin(abs(in)) * @see #argmin(String, SDVariable, boolean, int...) */ public SDVariable iamin(String name, SDVariable in, int... dimensions) { return iamin(name, in, false, dimensions); } /** * Index of the min absolute value: argmin(abs(in)) * @see #argmin(String, SDVariable, boolean, int...) */ public SDVariable iamin(String name, SDVariable in, boolean keepDims, int... dimensions) { SDVariable ret = f().iamin(in, keepDims, dimensions); return updateVariableNameAndReference(ret, name); } /** * @see #firstIndex(String, SDVariable, Condition, int...) */ public SDVariable firstIndex(SDVariable in, Condition condition, int... dimensions) { return firstIndex(null, in, condition, dimensions); } /** * @see #firstIndex(String, SDVariable, Condition, boolean, int...) */ public SDVariable firstIndex(SDVariable in, Condition condition, boolean keepDims, int... dimensions){ return firstIndex(null, in, condition, keepDims, dimensions); } /** * First index reduction operation.
* Returns a variable that contains the index of the first element that matches the specified condition (for each * slice along the specified dimensions) * * @param name Name of the output variable * @param in Input variable * @param condition Condition to check on input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable firstIndex(String name, SDVariable in, Condition condition, int... dimensions) { return firstIndex(name, in, condition, false, dimensions); } /** * First index reduction operation.
* Returns a variable that contains the index of the first element that matches the specified condition (for each * slice along the specified dimensions)
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Name of the output variable * @param in Input variable * @param condition Condition to check on input variable * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable firstIndex(String name, SDVariable in, Condition condition, boolean keepDims, int... dimensions) { SDVariable ret = f().firstIndex(in, condition, keepDims, dimensions); return updateVariableNameAndReference(ret, name); } /** * @see #lastIndex(String, SDVariable, Condition, int...) */ public SDVariable lastIndex(SDVariable in, Condition condition, int... dimensions) { return lastIndex(null, in, condition, dimensions); } /** * @see #lastIndex(String, SDVariable, Condition, boolean, int...) */ public SDVariable lastIndex(SDVariable in, Condition condition, boolean keepDims, int... dimensions){ return lastIndex(null, in, condition, keepDims, dimensions); } /** * Last index reduction operation.
* Returns a variable that contains the index of the last element that matches the specified condition (for each * slice along the specified dimensions) * * @param name Name of the output variable * @param in Input variable * @param condition Condition to check on input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable lastIndex(String name, SDVariable in, Condition condition, int... dimensions) { return lastIndex(name, in, condition, false, dimensions); } /** * Last index reduction operation.
* Returns a variable that contains the index of the last element that matches the specified condition (for each * slice along the specified dimensions)
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Name of the output variable * @param in Input variable * @param condition Condition to check on input variable * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Reduced array of rank (input rank - num dimensions) */ public SDVariable lastIndex(String name, SDVariable in, Condition condition, boolean keepDims, int... dimensions){ SDVariable ret = f().lastIndex(in, condition, keepDims, dimensions); return updateVariableNameAndReference(ret, name); } /** * Returns a count of the number of elements that satisfy the condition * @param in Input * @param condition Condition * @return Number of elements that the condition is satisfied for */ public SDVariable matchConditionCount(SDVariable in, Condition condition) { return matchConditionCount(null, in, condition); } /** * Returns a count of the number of elements that satisfy the condition * @param name Name of the output variable * @param in Input * @param condition Condition * @return Number of elements that the condition is satisfied for */ public SDVariable matchConditionCount(String name, SDVariable in, Condition condition) { return matchConditionCount(name, in, condition, false); } /** * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Name of the output variable * @param in Input variable * @param condition Condition * @param keepDim If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed * @return Number of elements that the condition is satisfied for */ public SDVariable matchConditionCount(String name, SDVariable in, Condition condition, boolean keepDim, int... dimensions) { SDVariable ret = f().matchConditionCount(in, condition, keepDim, dimensions); return updateVariableNameAndReference(ret, name); } /** * Returns a boolean mask of equal shape to the input, where the condition is satisfied - value 1 where satisfied, 0 otherwise * * @param in Input variable * @param condition Condition * @return Boolean mask mariable */ public SDVariable matchCondition(SDVariable in, Condition condition) { return matchCondition(null, in, condition); } /** * Returns a boolean mask of equal shape to the input, where the condition is satisfied - value 1 where satisfied, 0 otherwise * * @param in Input * @param condition Condition * @return Boolean mask */ public SDVariable matchCondition(String name, SDVariable in, Condition condition){ SDVariable ret = f().matchCondition(in, condition); return updateVariableNameAndReference(ret, name); } /** * @see #cumsum(String, SDVariable, SDVariable, boolean, boolean) */ public SDVariable cumsum(SDVariable in, boolean exclusive, boolean reverse, int... axis) { return cumsum(null, in, exclusive, reverse, axis); } /** * Cumulative sum operation.
* For input: [ a, b, c], output is:
* exclusize=false, reverse=false: [a, a+b, a+b+c]
* exclusive=true, reverse=false, [0, a, a+b]
* exclusive=false, reverse=true: [a+b+c, b+c, c]
* exclusive=true, reverse=true: [b+c, c, 0]

* * @param name Name of the output variable * @param in Input variable * @param axis Scalar axis argument for dimension to perform cumululative sum operations along * @param exclusive If true: exclude the first value * @param reverse If true: reverse the direction of the accumulation * @return Output variable */ public SDVariable cumsum(String name, SDVariable in, boolean exclusive, boolean reverse, int... axis) { SDVariable ret = f().cumsum(in, exclusive, reverse, axis); return updateVariableNameAndReference(ret, name); } /** * @see #cumprod(String, SDVariable, SDVariable, boolean, boolean) */ public SDVariable cumprod(SDVariable in, boolean exclusive, boolean reverse, int... axis) { return cumprod(null, in, exclusive, reverse, axis); } /** * Cumulative product operation.
* For input: [ a, b, c], output is:
* exclusize=false, reverse=false: [a, a*b, a*b*c]
* exclusive=true, reverse=false, [0, a, a*b]
* exclusive=false, reverse=true: [a*b*c, b*c, c]
* exclusive=true, reverse=true: [b*c, c, 0]

* * @param name Name of the output variable * @param in Input variable * @param axis Scalar axis argument for dimension to perform cumululative sum operations along * @param exclusive If true: exclude the first value * @param reverse If true: reverse the direction of the accumulation * @return Output variable */ public SDVariable cumprod(String name, SDVariable in, boolean exclusive, boolean reverse, int... axis) { SDVariable ret = f().cumprod(in, exclusive, reverse, axis); return updateVariableNameAndReference(ret, name); } /** * @see #biasAdd(String, SDVariable, SDVariable) */ public SDVariable biasAdd(SDVariable input, SDVariable bias) { return biasAdd(null, input, bias); } /** * Bias addition operation: a special case of addition, typically used with CNN 4D activations and a 1D bias vector * @param name Name of the output variable * @param input 4d input variable * @param bias 1d bias * @return Output variable */ public SDVariable biasAdd(String name, SDVariable input, SDVariable bias) { SDVariable ret = f().biasAdd(input, bias); return updateVariableNameAndReference(ret, name); } /** * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the * input, but with the specified shape.
* Note that prod(shape) must match length(input) == prod(input.shape) * * @param x Input variable * @param shape New shape for variable * @return Output variable * @see #reshape(SDVariable, SDVariable) */ public SDVariable reshape(SDVariable x, long... shape) { return reshape(null, x, shape); } /** * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the * input, but with the specified shape.
* Note that prod(shape) must match length(input) == prod(input.shape) * * @param name Output variable name * @param x Input variable * @param shape New shape for variable * @return Output variable * @see #reshape(SDVariable, SDVariable) */ public SDVariable reshape(String name, SDVariable x, long... shape) { SDVariable result = functionFactory .reshape(x, shape); return updateVariableNameAndReference(result, name); } /** * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the * input, but with the specified shape.
* Note that prod(shape) must match length(input) == prod(input.shape) * * @param x Input variable * @param shape New shape for variable * @return Output variable * @see #reshape(SDVariable, SDVariable) */ public SDVariable reshape(SDVariable x, int... shape) { return reshape(null, x, shape); } /** * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the * input, but with the specified shape.
* Note that prod(shape) must match length(input) == prod(input.shape) * * @param name Output variable name * @param x Input variable * @param shape New shape for variable * @return Output variable * @see #reshape(SDVariable, SDVariable) */ public SDVariable reshape(String name, SDVariable x, int... shape) { SDVariable result = functionFactory .reshape(x, shape); return updateVariableNameAndReference(result, name); } /** * Reshape the input variable to the specified (dynamic) shape. The output variable will have the same values as the * input, but with the specified shape.
* Note that prod(shape) must match length(input) == prod(input.shape) * * @param x Input variable * @param shape New shape for variable * @return Output variable * @see #reshape(SDVariable, int[]) */ public SDVariable reshape(SDVariable x, SDVariable shape) { return reshape(null, x, shape); } /** * Reshape the input variable to the specified (dynamic) shape. The output variable will have the same values as the * input, but with the specified shape.
* Note that prod(shape) must match length(input) == prod(input.shape) * * @param name Output variable name * @param x Input variable * @param shape New shape for variable * @return Output variable * @see #reshape(SDVariable, int[]) */ public SDVariable reshape(String name, SDVariable x,SDVariable shape) { SDVariable result = functionFactory.reshape(x, shape); return updateVariableNameAndReference(result, name); } /** * @see #reverse(String, SDVariable, int...) */ public SDVariable reverse(SDVariable x, int... dimensions) { return reverse(null, x, dimensions); } /** * Reverse the values of an array for the specified dimensions
* If input is:
* [ 1, 2, 3]
* [ 4, 5, 6]
* then
* reverse(in, 0):
* [3, 2, 1]
* [6, 5, 4]
* reverse(in, 0):
* [4, 5, 6]
* [1, 2 3]
* * @param x Input variable * @param dimensions Dimensions * @return Output variable */ public SDVariable reverse(String name, SDVariable x, int... dimensions) { SDVariable ret = f().reverse(x, dimensions); return updateVariableNameAndReference(ret, name); } /** * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed * * @param name Name of the output variable * @param x Input variable * @param seq_lengths Length of the sequences * @param seqDim Sequence dimension * @param batchDim Batch dimension * @return Reversed sequences */ public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths, int seqDim, int batchDim) { SDVariable ret = f().reverseSequence(x, seq_lengths, seqDim, batchDim); return updateVariableNameAndReference(ret, name); } /** * @see #reverseSequence(String, SDVariable, SDVariable, int, int) */ public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths) { SDVariable ret = f().reverseSequence(x, seq_lengths); return updateVariableNameAndReference(ret, name); } /** * @see #reverseSequence(String, SDVariable, SDVariable, int, int) */ public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths, int seqDim, int batchDim) { return reverseSequence(null, x, seq_lengths, seqDim, batchDim); } /** * @see #reverseSequence(String, SDVariable, SDVariable, int, int) */ public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths) { return reverseSequence(null, x, seq_lengths); } /** * Generate a sequence mask (with values 0 or 1) based on the specified lengths
* Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0) * * @param name Name of the output variable * @param lengths Lengths of the sequences * @param maxLen Maximum sequence length * @return Output variable */ public SDVariable sequenceMask(String name, SDVariable lengths, SDVariable maxLen) { SDVariable ret = f().sequenceMask(lengths, maxLen); return updateVariableNameAndReference(ret, name); } /** * @see #sequenceMask(String, SDVariable, SDVariable) */ public SDVariable sequenceMask(SDVariable lengths, SDVariable maxLen) { return sequenceMask(null, lengths, maxLen); } /** * @see #sequenceMask(String, SDVariable, SDVariable) */ public SDVariable sequenceMask(String name, SDVariable lengths, int maxLen) { SDVariable ret = f().sequenceMask(lengths, maxLen); return updateVariableNameAndReference(ret, name); } /** * @see #sequenceMask(String, SDVariable, SDVariable) */ public SDVariable sequenceMask(SDVariable lengths, int maxLen) { return sequenceMask(null, lengths, maxLen); } /** * @see #sequenceMask(String, SDVariable, SDVariable) */ public SDVariable sequenceMask(String name, SDVariable lengths) { SDVariable ret = f().sequenceMask(lengths); return updateVariableNameAndReference(ret, name); } /** * @see #sequenceMask(String, SDVariable, SDVariable) */ public SDVariable sequenceMask(SDVariable lengths) { SDVariable ret = f().sequenceMask(lengths); return updateVariableNameAndReference(ret, null); } /** * @see #expandDims(String, SDVariable, int) */ public SDVariable expandDims(SDVariable x, int axis) { return expandDims(null, x, axis); } /** * Reshape the input by adding a 1 at the specified location.
* For example, if input has shape [a, b], then output shape is:
* axis = 0: [1, a, b]
* axis = 1: [a, 1, b]
* axis = 2: [a, b, 1]
* * @param name Name of the output variable * @param x Input variable * @param axis Axis to expand * @return Output variable * @see #squeeze(String, SDVariable, int) */ public SDVariable expandDims(String name, SDVariable x, int axis) { SDVariable result = f().expandDims(x, axis); return updateVariableNameAndReference(result, name); } /** * @see #squeeze(String, SDVariable, int) */ public SDVariable squeeze(SDVariable x, int axis) { return squeeze(null, x, axis); } /** * Remove a single dimension of size 1. * For example, if input has shape [a,b,1,c] then squeeze(input, 2) returns an array of shape [a,b,c] * * @param name Name of the output variable * @param x Input variable * @param axis Size 1 dimension to remove * @return Output variable */ public SDVariable squeeze(String name, SDVariable x, int axis) { SDVariable result = f().squeeze(x, axis); return updateVariableNameAndReference(result, name); } /** * Assign/copy op: out = x.assign(y). Supports broadcasting * * @param x Input variable x * @param y Input variable y * @return Output variable */ public SDVariable assign(SDVariable x, SDVariable y) { return assign(null, x, y); } /** * Assign/copy op: out = x.assign(y). Supports broadcasting * * @param name Name of the output variable * @param x Input variable x * @param y Input variable y * @return Output variable */ public SDVariable assign(String name, SDVariable x, SDVariable y) { SDVariable ret = f().assign(x, y); return updateVariableNameAndReference(ret, name); } /** * Return an array with equal shape to the input, but all elements set to 'value' * * @param in Input variable * @param value Value to set * @return Output variable */ public SDVariable assign(SDVariable in, Number value) { return assign(null, in, value); } /** * Return an array with equal shape to the input, but all elements set to 'value' * * @param name Name of the output variable * @param in Input variable * @param value Value to set * @return Output variable */ public SDVariable assign(String name, SDVariable in, Number value) { SDVariable ret = f().assign(in, value); return updateVariableNameAndReference(ret, name); } /** * Matrix transpose operation: If input has shape [a,b] output has shape [b,a] * * @param x Input variable * @return Output variable (transposed input) */ public SDVariable transpose(SDVariable x) { return transpose(null, x); } /** * Matrix transpose operation: If input has shape [a,b] output has shape [b,a] * * @param name Output variable name * @param x Input variable * @return Output variable (transposed input) */ public SDVariable transpose(String name, SDVariable x) { SDVariable result = functionFactory.transpose(x); return updateVariableNameAndReference(result, name); } /** * Array permutation operation: permute the dimensions according to the specified permutation indices.
* Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b] * * @param x Input variable * @return Output variable (permuted input) */ public SDVariable permute(SDVariable x, int... dimensions) { return permute(null, x, dimensions); } /** * Array permutation operation: permute the dimensions according to the specified permutation indices.
* Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b] * * @param name Output variable name * @param x Input variable * @return Output variable (permuted input) */ public SDVariable permute(String name, SDVariable x, int... dimensions) { SDVariable result = functionFactory.permute(x, dimensions); return updateVariableNameAndReference(result, name); } /** * @param x * @param axis * @return */ public SDVariable rollAxis(SDVariable x, int axis) { return rollAxis(null, x, axis); } /** * @param x * @param axis * @return */ public SDVariable rollAxis(String name, SDVariable x, int axis) { SDVariable result = functionFactory.rollAxis(x, axis); return updateVariableNameAndReference(result, name); } /** * @see #concat(String, int, SDVariable...) */ public SDVariable concat(int dimension, SDVariable... inputs) { return concat(null, dimension, inputs); } /** * Concatenate a set of inputs along the specified dimension.
* Note that inputs must have identical rank and identical dimensions, other than the dimension to stack on.
* For example, if 2 inputs have shape [a, x, c] and [a, y, c] and dimension = 1, then the output has shape [a, x+y, c] * * @param name Name of the output variable * @param dimension Dimension to concatenate on * @param inputs Input variables * @return Output variable * @see #stack(String, int, SDVariable...) */ public SDVariable concat(String name, int dimension, SDVariable... inputs) { SDVariable result = functionFactory.concat(dimension, inputs); return updateVariableNameAndReference(result, name); } /** * @see #moments(String[], SDVariable, int...) */ public SDVariable[] moments(SDVariable input, int... axes) { return moments(null, input, axes); } /** * Calculate the mean and (population) variance for the input variable, for the specified axis * * @param name Name of the output variables. Can be null; if non-null, must be length 2 * @param input Input to calculate moments for * @param axes Dimensions to perform calculation over * @return Mean and variance variables */ public SDVariable[] moments(String[] name, SDVariable input, int... axes) { SDVariable[] res = f().moments(input, axes); return updateVariableNamesAndReferences(res, name); } /** * @see #normalizeMoments(String[], SDVariable, SDVariable, SDVariable, double) */ public SDVariable[] normalizeMoments(SDVariable counts, SDVariable means, SDVariable variances, double shift) { return normalizeMoments(null, counts, means, variances, shift); } /** * Calculate the mean and variance from the sufficient statistics * * @param name Name of the output variables. Can be null; if non-null, must be length 2 * @param counts Rank 0 (scalar) value with the total number of values used to calculate the sufficient statistics * @param means Mean-value sufficient statistics: this is the SUM of all data values * @param variances Variaance sufficient statistics: this is the squared sum of all data values * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for numerical stability) * @return Output variables: mean and population variance */ public SDVariable[] normalizeMoments(String[] name, SDVariable counts, SDVariable means, SDVariable variances, double shift) { SDVariable[] res = f().normalizeMoments(counts, means, variances, shift); return updateVariableNamesAndReferences(res, name); } /** * @see #matrixDeterminant(String, SDVariable) */ public SDVariable matrixDeterminant(SDVariable in){ return matrixDeterminant(null, in); } /** * Matrix determinant op. For 2D input, this returns the standard matrix determinant. * For higher dimensional input with shape [..., m, m] the matrix determinant is returned for each * shape [m,m] sub-matrix. * @param name Name of the output variable * @param in Input * @return Matrix determinant variable */ public SDVariable matrixDeterminant(String name, SDVariable in){ SDVariable ret = f().matrixDeterminant(in); return updateVariableNameAndReference(ret, name); } /** * @see #matrixInverse(String, SDVariable) */ public SDVariable matrixInverse(SDVariable in){ return matrixInverse(null, in); } /** * Matrix inverse op. For 2D input, this returns the standard matrix inverse. * For higher dimensional input with shape [..., m, m] the matrix inverse is returned for each * shape [m,m] sub-matrix. * @param name Name of the output variable * @param in Input * @return Matrix inverse variable */ public SDVariable matrixInverse(String name, SDVariable in){ SDVariable ret = f().matrixInverse(in); return updateVariableNameAndReference(ret, name); } /** * @see #confusionMatrix(String, SDVariable, SDVariable) */ public SDVariable confusionMatrix(SDVariable labels, SDVariable predictions) { return confusionMatrix((String) null, labels, predictions); } /** * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
* For example, if labels = [0, 1, 1] and predicted = [0, 2, 1] then output is:
* [1, 0, 0]
* [0, 1, 1]
* [0, 0, 0]
* * @param name Name of the output variable * @param labels Labels - 1D array of integer values representing label values * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels * @return Output variable (2D, shape [numClasses, numClasses}) */ public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred) { SDVariable result = f().confusionMatrix(labels, pred); return updateVariableNameAndReference(result, name); } /** * @see #confusionMatrix(String, SDVariable, SDVariable, Integer) */ public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses) { return confusionMatrix(null, labels, pred, numClasses); } /** * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of * which are represented as integer values.
* For example, if labels = [0, 1, 1], predicted = [0, 2, 1], and numClasses=4 then output is:
* [1, 0, 0, 0]
* [0, 1, 1, 0]
* [0, 0, 0, 0]
* [0, 0, 0, 0]
* * @param name Name of the output variable * @param labels Labels - 1D array of integer values representing label values * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels * @param numClasses Number of classes * @return Output variable (2D, shape [numClasses, numClasses}) */ public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, Integer numClasses) { SDVariable result = f().confusionMatrix(labels, pred, numClasses); return updateVariableNameAndReference(result, name); } /** * @see #confusionMatrix(String, SDVariable, SDVariable, SDVariable) */ public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable weights) { return confusionMatrix(null, labels, pred, weights); } /** * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
* For example, if labels = [0, 1, 1], predicted = [0, 2, 1] and weights = [1, 2, 3] * [1, 0, 0]
* [0, 3, 2]
* [0, 0, 0]
* * @param name Name of the output variable * @param labels Labels - 1D array of integer values representing label values * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of * each prediction. Must be same length as both labels and predictions arrays * @return Output variable (2D, shape [numClasses, numClasses}) */ public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, SDVariable weights) { SDVariable result = f().confusionMatrix(labels, pred, weights); return updateVariableNameAndReference(result, name); } /** * @see #confusionMatrix(String, SDVariable, SDVariable, Integer, SDVariable) */ public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights) { return confusionMatrix(null, labels, pred, numClasses, weights); } /** * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of * which are represented as integer values.
* For example, if labels = [0, 1, 1], predicted = [0, 2, 1], numClasses = 4, and weights = [1, 2, 3] * [1, 0, 0, 0]
* [0, 3, 2, 0]
* [0, 0, 0, 0]
* [0, 0, 0, 0]
* * @param name Name of the output variable * @param labels Labels - 1D array of integer values representing label values * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of * each prediction. Must be same length as both labels and predictions arrays * @return Output variable (2D, shape [numClasses, numClasses}) */ public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights) { SDVariable result = f().confusionMatrix(labels, pred, numClasses, weights); return updateVariableNameAndReference(result, name); } /** * @see #tile(String, SDVariable, int[]) */ public SDVariable tile(SDVariable x, int[] repeat) { return tile(null, x, repeat); } /** * Repeat (tile) the input tensor the specified number of times.
* For example, if input is
* [1, 2]
* [3, 4]
* and repeat is [2, 3]
* then output is
* [1, 2, 1, 2, 1, 2]
* [3, 4, 3, 4, 3, 4]
* [1, 2, 1, 2, 1, 2]
* [3, 4, 3, 4, 3, 4]
*
* * @param name Output variable name * @param x Input variable * @param repeat Number of times to repeat in each axis. Must have length equal to the rank of the input array * @return Output variable */ public SDVariable tile(String name, SDVariable x, int[] repeat) { SDVariable result = functionFactory.tile(x, repeat); return updateVariableNameAndReference(result, name); } /** * Generate an output variable with the specified (dynamic) shape with all elements set to the specified value * * @param shape Shape: must be a 1D array/variable * @param value Value to set all elements to * @return Output variable */ public SDVariable fill(SDVariable shape, double value) { return fill(null, shape, value); } /** * Generate an output variable with the specified (dynamic) shape with all elements set to the specified value * * @param name Name of the output variable * @param shape Shape: must be a 1D array/variable * @param value Value to set all elements to * @return Output variable */ public SDVariable fill(String name, SDVariable shape, double value) { SDVariable result = functionFactory.fill(shape, value); return updateVariableNameAndReference(result, name); } /** * * @param input Input * @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p) * @return */ public SDVariable dropout(SDVariable input, double inputRetainProbability) { return dropout(null, input, inputRetainProbability); } /** * * @param input Input * @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p) * @return */ public SDVariable dropout(String name, SDVariable input, double inputRetainProbability) { SDVariable res = f().dropout(input, inputRetainProbability); return updateVariableNameAndReference(res, name); } /** * @see #linear(String, SDVariable, SDVariable, SDVariable) */ public SDVariable linear(SDVariable input, SDVariable weights, SDVariable bias) { return linear(null, input, weights, bias); } /** * Linear layer operation: out = mmul(in,w) + bias
* Note that bias array is optional * * @param name Name of the output variable * @param input Input data * @param weights Weights variable * @param bias Optional bias variable (may be null) * @return Output variable */ public SDVariable linear(String name, SDVariable input, SDVariable weights, SDVariable bias) { SDVariable res = f().xwPlusB(input, weights, bias); return updateVariableNameAndReference(res, name); } /** * @see #reluLayer(String, SDVariable, SDVariable, SDVariable) */ public SDVariable reluLayer(SDVariable input, SDVariable weights, SDVariable bias) { return reluLayer(null, input, weights, bias); } /** * ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)
* Note that bias array is optional * * @param name Name of the output variable * @param input Input data * @param weights Weights variable * @param bias Optional bias variable (may be null) * @return Output variable */ public SDVariable reluLayer(String name, SDVariable input, SDVariable weights, SDVariable bias) { SDVariable res = f().reluLayer(input, weights, bias); return updateVariableNameAndReference(res, name); } /** * Matrix multiplication: out = mmul(x,y)
* Supports specifying a {@link MMulTranspose} argument to perform operation such as mmul(a^T, b), etc. * * @param x First input variable * @param y Second input variable * @param transpose Transpose arguments * @return Output variable */ public SDVariable mmul(SDVariable x, SDVariable y, MMulTranspose transpose) { return mmul(null, x, y, transpose); } /** * Matrix multiplication: out = mmul(x,y) * * @param x First input variable * @param y Second input variable * @return Output variable */ public SDVariable mmul(SDVariable x, SDVariable y) { return mmul(null, x, y); } /** * Matrix multiplication: out = mmul(x,y)
* Supports specifying a {@link MMulTranspose} argument to perform operation such as mmul(a^T, b), etc. * * @param name Output variable name * @param x First input variable * @param y Second input variable * @param transpose Transpose arguments * @return Output variable */ public SDVariable mmul(String name, SDVariable x, SDVariable y, MMulTranspose transpose) { SDVariable result = functionFactory.mmul(x, y, transpose); return updateVariableNameAndReference(result, name); } /** * Matrix multiplication: out = mmul(x,y) * * @param name Output variable name * @param x First input variable * @param y Second input variable * @return Output variable */ public SDVariable mmul(String name, SDVariable x, SDVariable y) { return mmul(name, x, y, MMulTranspose.allFalse()); } /** * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same * length and each pair taken from these sets has to have dimensions (M, N) and (N, K), * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead. * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N). * * * The result of this operation will be a batch of multiplied matrices. The * result has the same length as both input batches and each output matrix is of shape (M, K). * * @param matricesA First array of input matrices, all of shape (M, N) or (N, M) * @param matricesB Second array of input matrices, all of shape (N, K) or (K, N) * @param transposeA whether first batch of matrices is transposed. * @param transposeB whether second batch of matrices is transposed. * @param names names for all provided SDVariables * * @return Array of multiplied SDVariables of shape (M, K) */ public SDVariable[] batchMmul(String[] names, SDVariable[] matricesA, SDVariable[] matricesB, boolean transposeA, boolean transposeB) { SDVariable[] result = functionFactory.batchMmul(matricesA, matricesB, transposeA, transposeB); return updateVariableNamesAndReferences(result, names); } /** * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same * length and each pair taken from these sets has to have dimensions (M, N) and (N, K), * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead. * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N). * * * The result of this operation will be a batch of multiplied matrices. The * result has the same length as both input batches and each output matrix is of shape (M, K). * * @param matricesA First array of input matrices, all of shape (M, N) or (N, M) * @param matricesB Second array of input matrices, all of shape (N, K) or (K, N) * @param transposeA whether first batch of matrices is transposed. * @param transposeB whether second batch of matrices is transposed. * * @return Array of multiplied SDVariables of shape (M, K) */ public SDVariable[] batchMmul(SDVariable[] matricesA, SDVariable[] matricesB, boolean transposeA, boolean transposeB) { return batchMmul(null, matricesA, matricesB, transposeA, transposeB); } /** * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same * length and each pair taken from these sets has to have dimensions (M, N) and (N, K), * respectively. The result of this operation will be a batch of multiplied matrices. The * result has the same length as both input batches and each output matrix is of shape (M, K). * * @param matricesA First array of input matrices, all of shape (M, N) * @param matricesB Second array of input matrices, all of shape (N, K) * @return Array of multiplied SDVariables of shape (M, K) */ public SDVariable[] batchMmul(SDVariable[] matricesA, SDVariable[] matricesB) { return batchMmul(null, matricesA, matricesB, false, false); } /** * @param x * @param y * @param dimensions * @return */ public SDVariable tensorMmul(SDVariable x, SDVariable y, int[][] dimensions) { return tensorMmul(null, x, y, dimensions); } /** * TODO doc string * * @param x * @param y * @param dimensions * @return */ public SDVariable dot(SDVariable x, SDVariable y, int... dimensions) { return dot(null, x, y, dimensions); } /** * TODO doc string * * @param name * @param x * @param y * @param dimensions * @return */ public SDVariable dot(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable ret = f().dot(x, y, dimensions); return updateVariableNameAndReference(ret, name); } /** * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
* out = sum_i abs(x[i]) * * @param name Output variable name * @param x Input variable * @param dimensions dimensions to reduce over * @return Output variable */ public SDVariable norm1(String name, SDVariable x, int... dimensions) { return norm1(name, x, false, dimensions); } /** * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
* out = sum_i abs(x[i])
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Output variable name * @param x Input variable * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions * @param dimensions dimensions to reduce over * @return Output variable */ public SDVariable norm1(String name, SDVariable x, boolean keepDims, int... dimensions) { SDVariable result = f().norm1(x, keepDims, dimensions); return updateVariableNameAndReference(result, name); } /** * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
* out = sqrt(sum_i x[i]^2) * * @param name Output variable name * @param x Input variable * @param dimensions dimensions to reduce over * @return Output variable */ public SDVariable norm2(String name, SDVariable x, int... dimensions) { return norm2(name, x, false, dimensions); } /** * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
* out = sqrt(sum_i x[i]^2)
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Output variable name * @param x Input variable * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions * @param dimensions dimensions to reduce over * @return Output variable */ public SDVariable norm2(String name, SDVariable x, boolean keepDims, int... dimensions) { SDVariable result = f().norm2(x, keepDims, dimensions); return updateVariableNameAndReference(result, name); } /** * Squared L2 norm: see {@link #norm2(String, SDVariable, int...)} */ public SDVariable squaredNorm(SDVariable x, int... dimensions) { return squaredNorm(null, x, false, dimensions); } /** * Squared L2 norm: see {@link #norm2(String, SDVariable, int...)} */ public SDVariable squaredNorm(String name, SDVariable x, int... dimensions) { return squaredNorm(name, x, false, dimensions); } /** * Squared L2 norm: see {@link #norm2(String, SDVariable, boolean, int...)} */ public SDVariable squaredNorm(SDVariable x, boolean keepDims, int... dimensions) { return squaredNorm(null, x, keepDims, dimensions); } /** * Squared L2 norm: see {@link #norm2(String, SDVariable, boolean, int...)} */ public SDVariable squaredNorm(String name, SDVariable x, boolean keepDims, int... dimensions) { SDVariable result = f().squaredNorm(x, keepDims, dimensions); return updateVariableNameAndReference(result, name); } /** * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the * specified dimensions * * @param name Output variable name * @param x Input variable * @param dimensions dimensions to reduce over * @return Output variable */ public SDVariable normmax(String name, SDVariable x, int... dimensions) { return normmax(name, x, false, dimensions); } /** * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the * specified dimensions:
* out = max(abs(x[i]))
* Note that if keepDims = true, the output variable has the same rank as the input variable, * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting * the mean along a dimension).
* Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: * keepDims = true: [a,1,c]
* keepDims = false: [a,c] * * @param name Output variable name * @param x Input variable * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions * @param dimensions dimensions to reduce over * @return Output variable */ public SDVariable normmax(String name, SDVariable x, boolean keepDims, int... dimensions) { SDVariable result = f().normmax(x, keepDims, dimensions); return updateVariableNameAndReference(result, name); } /** * @see #cosineSimilarity(String, SDVariable, SDVariable, int...) */ public SDVariable cosineSimilarity(SDVariable x, SDVariable y, int... dimensions) { return cosineSimilarity(generateNewVarName(CosineSimilarity.OP_NAME, 0), x, y, dimensions); } /** * Cosine similarity pairwise reduction operation. The output contains the cosine similarity for each tensor/subset * along the specified dimensions:
* out = (sum_i x[i] * y[i]) / ( sqrt(sum_i x[i]^2) * sqrt(sum_i y[i]^2) * * @param x Input variable x * @param y Input variable y * @param dimensions Dimensions to calculate cosine similarity over * @return Output variable */ public SDVariable cosineSimilarity(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable cosim = functionFactory.cosineSimilarity(x, y, dimensions); return updateVariableNameAndReference(cosim, name); } /** * @see #euclideanDistance(String, SDVariable, SDVariable, int...) */ public SDVariable euclideanDistance(SDVariable x, SDVariable y, int... dimensions) { return euclideanDistance(generateNewVarName(EuclideanDistance.OP_NAME, 0), x, y, dimensions); } /** * Euclidean distance (l2 norm, l2 distance) reduction operation. The output contains the Euclidean distance for each * tensor/subset along the specified dimensions:
* out = sqrt( sum_i (x[i] - y[i])^2 ) * * @param x Input variable x * @param y Input variable y * @param dimensions Dimensions to calculate cosine similarity over * @return Output variable */ public SDVariable euclideanDistance(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable result = functionFactory.euclideanDistance(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * @see #manhattanDistance(String, SDVariable, SDVariable, int...) */ public SDVariable manhattanDistance(SDVariable x, SDVariable y, int... dimensions) { return manhattanDistance(generateNewVarName(ManhattanDistance.OP_NAME, 0), x, y, dimensions); } /** * Manhattan distance (l1 norm, l1 distance) reduction operation. The output contains the Manhattan distance for each * tensor/subset along the specified dimensions:
* out = sum_i abs(x[i]-y[i]) * * @param name Name of the output variable * @param x Input variable x * @param y Input variable y * @param dimensions Dimensions to calculate cosine similarity over * @return Output variable */ public SDVariable manhattanDistance(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable result = functionFactory.manhattanDistance(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * @see #cosineDistance(String, SDVariable, SDVariable, int...) */ public SDVariable cosineDistance(SDVariable x, SDVariable y, int... dimensions) { return cosineDistance(null, x, y, dimensions); } /** * Cosine distance reduction operation. The output contains the cosine distance for each * tensor/subset along the specified dimensions:
* out = 1.0 - cosineSimilarity(x,y)
* See {@link #cosineSimilarity(String, SDVariable, SDVariable, int...)} * * @param name Name of the output variable * @param x Input variable x * @param y Input variable y * @param dimensions Dimensions to calculate cosine similarity over * @return Output variable */ public SDVariable cosineDistance(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable result = functionFactory.cosineDistance(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * @see #hammingDistance(String, SDVariable, SDVariable, int...) */ public SDVariable hammingDistance(SDVariable x, SDVariable y, int... dimensions) { return hammingDistance(null, x, y, dimensions); } /** * Hamming distance reduction operation. The output contains the cosine distance for each * tensor/subset along the specified dimensions:
* out = count( x[i] != y[i] ) * * @param name Name of the output variable * @param x Input variable x * @param y Input variable y * @param dimensions Dimensions to calculate cosine similarity over * @return Output variable */ public SDVariable hammingDistance(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable result = functionFactory.hammingDistance(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * Jaccard similarity reduction operation. The output contains the Jaccard distance for each * tensor along the the specified dimensions. * * @param x Input variable x * @param y Input variable y * @param dimensions Dimensions to calculate Jaccard similarity over * @return Output variable */ public SDVariable jaccardDistance(SDVariable x, SDVariable y, int... dimensions) { return jaccardDistance(null, x, y, dimensions); } /** * Jaccard similarity reduction operation. The output contains the Jaccard distance for each * tensor along the the specified dimensions. * * @param name Name of the output variable * @param x Input variable x * @param y Input variable y * @param dimensions Dimensions to calculate Jaccard similarity over * @return Output variable */ public SDVariable jaccardDistance(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable result = functionFactory.jaccardDistance(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * Binary cross entropy loss. * * @param x Input variable x * @param y Input variable y * @param dimensions Reduction dimensions * @return Output variable */ public SDVariable lossBinaryXENT(SDVariable x, SDVariable y, int... dimensions) { return lossBinaryXENT(generateNewVarName(new LossBinaryXENT().opName(), 0), x, y, dimensions); } /** * TODO doc string * * @param x Input variable x * @param y Input variable y * @param dimensions Reduction dimensions * @return Output variable */ public SDVariable lossCosineSimilarity(SDVariable x, SDVariable y, int... dimensions) { return lossCosineSimilarity(generateNewVarName(new LossCosineProximity().opName(), 0), x, y, dimensions); } // TODO: document all losses /** * Hinge loss * * @param x Input variable x * @param y Input variable y * @param dimensions Reduction dimensions * @return Output variable */ public SDVariable lossHinge(SDVariable x, SDVariable y, int... dimensions) { return lossHinge(generateNewVarName(new LossHinge().opName(), 0), x, y, dimensions); } /** * Kullback-Leibler divergence loss * * @param x Input variable x * @param y Input variable y * @param dimensions Reduction dimensions * @return Output variable */ public SDVariable lossKLD(SDVariable x, SDVariable y, int... dimensions) { return lossKLD(generateNewVarName(new LossKLD().opName(), 0), x, y, dimensions); } /** * L1 loss * * @param x Input variable x * @param y Input variable y * @param dimensions Reduction dimensions * @return Output variable */ public SDVariable lossL1(SDVariable x, SDVariable y, int... dimensions) { return lossL1(generateNewVarName(new LossL1().opName(), 0), x, y, dimensions); } /** * L2 loss * * @param x Input variable x * @param y Input variable y * @param dimensions Reduction dimensions * @return Output variable */ public SDVariable lossL2(SDVariable x, SDVariable y, int... dimensions) { return lossL2(generateNewVarName(new LossL2().opName(), 0), x, y, dimensions); } /** * Mean absolute error loss * * @param x Input variable x * @param y Input variable y * @param dimensions Reduction dimensions * @return Output variable */ public SDVariable lossMAE(SDVariable x, SDVariable y, int... dimensions) { return lossMAE(generateNewVarName(new LossMAE().opName(), 0), x, y, dimensions); } /** * Mean squared error loss * * @param x Input variable x * @param y Input variable y * @param dimensions Reduction dimensions * @return Output variable */ public SDVariable lossMSE(SDVariable x, SDVariable y, int... dimensions) { return lossMSE(generateNewVarName(new LossMSE().opName(), 0), x, y, dimensions); } /** * @param x Input variable x * @param y Input variable y * @param dimensions Reduction dimensions * @return Output variable */ public SDVariable lossMCXENT(SDVariable x, SDVariable y, int... dimensions) { return lossMCXENT(generateNewVarName(new LossMCXENT().opName(), 0), x, y, dimensions); } /** * @param x Input variable x * @param y Input variable y * @param dimensions Reduction dimensions * @return Output variable */ public SDVariable lossMSLE(SDVariable x, SDVariable y, int... dimensions) { return lossMSLE(generateNewVarName(new LossMSLE().opName(), 0), x, y, dimensions); } /** * @param x Input variable x * @param y Input variable y * @param dimensions Reduction dimensions * @return Output variable */ public SDVariable lossNegativeLogLikelihood(SDVariable x, SDVariable y, int... dimensions) { return lossNegativeLogLikelihood(generateNewVarName(new LossNegativeLogLikelihood().opName(), 0), x, y, dimensions); } /** * @param x Input variable x * @param y Input variable y * @param dimensions Reduction dimensions * @return Output variable */ public SDVariable lossPoisson(SDVariable x, SDVariable y, int... dimensions) { return lossPoisson(generateNewVarName(new LossPoisson().opName(), 0), x, y, dimensions); } /** * @param x Input variable x * @param y Input variable y * @param dimensions Reduction dimensions * @return Output variable */ public SDVariable lossSquaredHinge(SDVariable x, SDVariable y, int... dimensions) { return lossSquaredHinge(generateNewVarName(new LossSquaredHinge().opName(), 0), x, y, dimensions); } /** * @param x * @return */ public SDVariable softmaxDerivative(String name, SDVariable x, SDVariable wrt) { return softmaxDerivative(name, x, wrt, null); } public SDVariable softmaxDerivative(String name, SDVariable x, SDVariable wrt, Integer dimension) { SDVariable result = functionFactory.softmaxDerivative(x, wrt, dimension); return updateVariableNameAndReference(result, name); } /** * @param x Input variable x * @param y Input variable y * @param dimensions dimensions * @return Output variable */ public SDVariable tensorMmul(String name, SDVariable x, SDVariable y, int[][] dimensions) { SDVariable result = functionFactory.tensorMmul(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * TODO * * @param logits * @param weights * @param labels * @param reductionMode * @param labelSmoothing * @return */ public SDVariable sigmoidCrossEntropyWithLogits(SDVariable logits, SDVariable weights, SDVariable labels, int reductionMode, double labelSmoothing) { return sigmoidCrossEntropyWithLogits(null, logits, weights, labels, reductionMode, labelSmoothing); } /** * TODO * * @param name * @param logits * @param weights * @param labels * @param reductionMode * @param labelSmoothing * @return */ public SDVariable sigmoidCrossEntropyWithLogits(String name, SDVariable logits, SDVariable weights, SDVariable labels, int reductionMode, double labelSmoothing) { SDVariable res = f().sigmoidCrossEntropyWithLogits(logits, weights, labels, reductionMode, labelSmoothing); return updateVariableNameAndReference(res, name); } /** * TODO * * @param logits * @param weights * @param labels * @param reductionMode * @param labelSmoothing * @return */ public SDVariable softmaxCrossEntropyWithLogits(SDVariable logits, SDVariable weights, SDVariable labels, int reductionMode, double labelSmoothing) { return softmaxCrossEntropyWithLogits(null, logits, weights, labels, reductionMode, labelSmoothing); } /** * TODO * * @param name * @param logits * @param weights * @param labels * @param reductionMode * @param labelSmoothing * @return */ public SDVariable softmaxCrossEntropyWithLogits(String name, SDVariable logits, SDVariable weights, SDVariable labels, int reductionMode, double labelSmoothing) { SDVariable res = f().softmaxCrossEntropyWithLogits(logits, weights, labels, reductionMode, labelSmoothing); return updateVariableNameAndReference(res, name); } /** * TODO * * @param targets * @param inputs * @param weights * @return */ public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable inputs, SDVariable weights) { return weightedCrossEntropyWithLogits(null, targets, inputs, weights); } /** * TODO * * @param name * @param targets * @param inputs * @param weights * @return */ public SDVariable weightedCrossEntropyWithLogits(String name, SDVariable targets, SDVariable inputs, SDVariable weights) { SDVariable res = f().weightedCrossEntropyWithLogits(targets, inputs, weights); return updateVariableNameAndReference(res, name); } /** * @param x * @param y * @param dimensions * @return */ public SDVariable lossBinaryXENT(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable result = functionFactory.lossBinaryXENT(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * @param x * @param y * @param dimensions * @return */ public SDVariable lossCosineSimilarity(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable result = functionFactory.lossCosineSimilarity(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * @param x * @param y * @param dimensions * @return */ public SDVariable lossHinge(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable result = functionFactory.lossHinge(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * @param x * @param y * @param dimensions * @return */ public SDVariable lossKLD(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable result = functionFactory.lossKLD(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * @param x * @param y * @param dimensions * @return */ public SDVariable lossL1(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable result = functionFactory.lossL1(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * @param x * @param y * @param dimensions * @return */ public SDVariable lossL2(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable result = functionFactory.lossL2(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * @param x * @param y * @param dimensions * @return */ public SDVariable lossMAE(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable result = functionFactory.lossMAE(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * @param x * @param y * @param dimensions * @return */ public SDVariable lossMSE(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable result = functionFactory.lossMSE(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * @param x * @param y * @param dimensions * @return */ public SDVariable lossMCXENT(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable result = functionFactory.lossMCXENT(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * @param x * @param y * @param dimensions * @return */ public SDVariable lossMSLE(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable result = functionFactory.lossMSLE(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * @param x * @param y * @param dimensions * @return */ public SDVariable lossNegativeLogLikelihood(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable result = functionFactory.lossNegativeLogLikelihood(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * @param x * @param y * @param dimensions * @return */ public SDVariable lossPoisson(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable result = functionFactory.lossPoisson(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * @param x * @param y * @param dimensions * @return */ public SDVariable lossSquaredHinge(String name, SDVariable x, SDVariable y, int... dimensions) { SDVariable result = functionFactory.lossSquaredHinge(x, y, dimensions); return updateVariableNameAndReference(result, name); } /** * Add the specified variable to this SameDiff instance * @param variable Variable to add */ public void addVariable(SDVariable variable) { if (variableMap == null) variableMap = new HashMap<>(); Preconditions.checkState(variable.getSameDiff() == this, "Samediff instance must be the same."); /** * Of note here: * We don't validate based on vertex id because more than one input can have the same * vertex id as a result. * * We validate based on variable opName instead which takes in to account function names as well * as input ids */ if (variableMap.containsKey(variable.getVarName()) && !variableMap.get(variable.getVarName()).equals(variable)) { throw new IllegalArgumentException("Variable already found with variable opName " + variable.getVarName()); } Preconditions.checkState(variable.getSameDiff() == this, "Same diff instance for variable must be the same!"); variableMap.put(variable.getVarName(), variable); } /** * Generate a new variable name based on the uniqueness of the base name and arg index
* For example, if baseName = "X" will return:
* "X" if "X" does not already exist, or "X:argIndex" if argIndex > 0
* "X_1" if "X" already exists, or "X_1:argIndex" if argIndex > 0
* "X_2" if "X" and "X_1" already exists, or "X_2:argIndex" if argIndex > 0
* And so on, until an unused name is found * * @param baseName the base name to use (use function.opName() where function is a {@link DifferentialFunction} * @param argIndex the arg index * @return the new generated name */ public String generateNewVarName(String baseName, int argIndex) { if (getVariable(baseName) == null && argIndex == 0) { return baseName; } //need to find a new name int count = 0; String name = baseName + (count == 0 ? "" : "_" + count) + (argIndex > 0 ? ":" + argIndex : ""); while (getVariable(name) != null) { name = baseName + "_" + (++count) + (argIndex > 0 ? ":" + argIndex : ""); } if (getVariable(name) != null) { throw new ND4JIllegalStateException("Converged on already generated variable!"); } return name; } /** * LSTM unit * * @param baseName the base name for outputs * @param configuration the configuration to use * @return */ public SDVariable lstm(String baseName, LSTMCellConfiguration configuration) { return new LSTMCell(this, configuration).outputVariables(baseName)[0]; } /** * An sru cell * * @param configuration the configuration for the sru cell * @return */ public SDVariable sruCell(SRUCellConfiguration configuration) { return new SRUCell(this, configuration).outputVariables()[0]; } /** * Simple recurrent unit * * @param configuration the configuration for the sru * @return */ public SDVariable sru(SRUConfiguration configuration) { return new SRU(this, configuration).outputVariables()[0]; } /** * The gru cell * * @param configuration teh configuration to use * @return */ public SDVariable gru(GRUCellConfiguration configuration) { return new GRUCell(this, configuration).outputVariables()[0]; } /** * An sru cell * * @param baseName the base name to use for the output variables * @param configuration the configuration for the sru cell * @return */ public SDVariable sruCell(String baseName, SRUCellConfiguration configuration) { return new SRUCell(this, configuration).outputVariables(baseName)[0]; } /** * Simiple recurrent unit * * @param baseName the base name to use for output variables * @param configuration the configuration for the sru * @return */ public SDVariable sru(String baseName, SRUConfiguration configuration) { return new SRU(this, configuration).outputVariables(baseName)[0]; } /** * The gru cell * * @param baseName the base name for the gru cell * @param configuration teh configuration to use * @return */ public SDVariable gru(String baseName, GRUCellConfiguration configuration) { return new GRUCell(this, configuration).outputVariables(baseName)[0]; } /** * @see #slice(String, SDVariable, int[], int[]) */ public SDVariable slice(SDVariable input, int[] begin, int[] size) { return slice(null, input, begin, size); } /** * Get a subset of the specified input, by specifying the first element and the size of the array.
* For example, if input is:
* [a, b, c]
* [d, e, f]
* then slice(input, begin=[0,1], size=[2,1] will return:
* [b]
* [e]
*
* Note that for each dimension i, begin[i] + size[i] <= input.size(i) * * @param name Output variable name * @param input Variable to get subset of * @param begin Beginning index. Must be same length as rank of input array * @param size Size of the output array. Must be same length as rank of input array * @return Subset of the input */ public SDVariable slice(String name, SDVariable input, int[] begin, int[] size) { SDVariable ret = f().slice(input, begin, size); return updateVariableNameAndReference(ret, name); } /** * @see #stridedSlice(String, SDVariable, long[], long[], long[]) */ public SDVariable stridedSlice(SDVariable input, int[] begin, int[] end, int[] strides) { return stridedSlice(null, input, begin, end, strides); } /** * @see #stridedSlice(String, SDVariable, long[], long[], long[]) */ public SDVariable stridedSlice(String name, SDVariable input, int[] begin, int[] end, int[] strides) { return stridedSlice(name, input, begin, end, strides, 0, 0, 0, 0, 0); } /** * @see #stridedSlice(String, SDVariable, long[], long[], long[]) */ public SDVariable stridedSlice(SDVariable input, long[] begin, long[] end, long[] strides) { return stridedSlice(null, input, begin, end, strides); } /** * Get a subset of the specified input, by specifying the first element, last element, and the strides.
* For example, if input is:
* [a, b, c]
* [d, e, f]
* [g, h, i]
* then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1]) will return:
* [b, c]
* [h, i]
*
* * @param name Output variable name * @param input Variable to get subset of * @param begin Beginning index. Must be same length as rank of input array * @param end End index. Must be same length as the rank of the array * @param strides Stride ("step size") for each dimension. Must be same length as the rank of the array. For example, * stride of 2 means take every second element. * @return Subset of the input */ public SDVariable stridedSlice(String name, SDVariable input, long[] begin, long[] end, long[] strides) { return stridedSlice(name, input, begin, end, strides, 0, 0, 0, 0, 0); } /** * Get a subset of the specified input, by specifying the first element, last element, and the strides.
* Operates as described in {@link #stridedSlice(SDVariable, long[], long[], long[])} with some extra mask arrays * as described below. * * @param name Output variable name * @param in Variable to get subset of * @param begin Beginning index * @param end End index * @param strides Stride ("step size") for each dimension. For example, * stride of 2 means take every second element. * @param beginMask Bit mask: If the ith bit is set to 1, then the value in the begin long[] is ignored, * and a value of 0 is used instead for the beginning index for that dimension * @param endMask Bit mask: If the ith bit is set to 1, then the value in the end long[] is ignored, * and a value of size(i)-1 is used instead for the end index for that dimension * @param ellipsisMask Bit mask: only one non-zero value is allowed here. If a non-zero value is set, then other * dimensions are inserted as required at the specified position * @param newAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and * a size 1 dimension is inserted at this point * @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and * a size 1 dimension is removed at this point. Note that begin/end/stride values must * result in a size 1 output for these dimensions * @return A subset of the input array */ public SDVariable stridedSlice(String name, SDVariable in, long[] begin, long[] end, long[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { SDVariable ret = f().stridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); return updateVariableNameAndReference(ret, name); } /** * @see #stridedSlice(String, SDVariable, long[], long[], long[], int, int, int, int, int) */ public SDVariable stridedSlice(SDVariable in, int[] begin, int[] end, int[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { return stridedSlice(null, in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); } /** * @see #stridedSlice(String, SDVariable, long[], long[], long[], int, int, int, int, int) */ public SDVariable stridedSlice(String name, SDVariable in, int[] begin, int[] end, int[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { SDVariable ret = f().stridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); return updateVariableNameAndReference(ret, name); } /** * @see #stridedSlice(String, SDVariable, long[], long[], long[], int, int, int, int, int) */ public SDVariable stridedSlice(SDVariable in, long[] begin, long[] end, long[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { return stridedSlice(null, in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); } /** * @see #scatterAdd(String, SDVariable, SDVariable, SDVariable) */ public SDVariable scatterAdd(SDVariable ref, SDVariable indices, SDVariable updates) { return scatterAdd(null, ref, indices, updates); } /** * Scatter addition operation.
* If indices is rank 0 (a scalar), then out[index, ...] += updates[...]
* If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] += updates[i, ...]
* If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] += updates[i, ..., k, ...]
* Note that if multiple indices refer to the same location, the contributions from each is handled correctly. * * @param name Name of the output variable * @param ref Initial/source variable * @param indices Indices array * @param updates Updates to add to the initial/source array * @return The updated variable */ public SDVariable scatterAdd(String name, SDVariable ref, SDVariable indices, SDVariable updates) { SDVariable ret = f().scatterAdd(ref, indices, updates); return updateVariableNameAndReference(ret, name); } /** * @see #scatterMul(String, SDVariable, SDVariable, SDVariable) */ public SDVariable scatterMul(SDVariable ref, SDVariable indices, SDVariable updates) { return scatterMul(null, ref, indices, updates); } /** * Scatter multiplication operation.
* If indices is rank 0 (a scalar), then out[index, ...] *= updates[...]
* If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] *= updates[i, ...]
* If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] *= updates[i, ..., k, ...]
* Note that if multiple indices refer to the same location, the contributions from each is handled correctly. * * @param name Name of the output variable * @param ref Initial/source variable * @param indices Indices array * @param updates Updates to add to the initial/source array * @return The updated variable */ public SDVariable scatterMul(String name, SDVariable ref, SDVariable indices, SDVariable updates) { SDVariable ret = f().scatterMul(ref, indices, updates); return updateVariableNameAndReference(ret, name); } /** * @see #scatterSub(String, SDVariable, SDVariable, SDVariable) */ public SDVariable scatterSub(SDVariable ref, SDVariable indices, SDVariable updates) { return scatterSub(null, ref, indices, updates); } /** * Scatter subtraction operation.
* If indices is rank 0 (a scalar), then out[index, ...] -= updates[...]
* If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] -= updates[i, ...]
* If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] -= updates[i, ..., k, ...]
* Note that if multiple indices refer to the same location, the contributions from each is handled correctly. * * @param name Name of the output variable * @param ref Initial/source variable * @param indices Indices array * @param updates Updates to add to the initial/source array * @return The updated variable */ public SDVariable scatterSub(String name, SDVariable ref, SDVariable indices, SDVariable updates) { SDVariable ret = f().scatterSub(ref, indices, updates); return updateVariableNameAndReference(ret, name); } /** * @see #scatterDiv(String, SDVariable, SDVariable, SDVariable) */ public SDVariable scatterDiv(SDVariable ref, SDVariable indices, SDVariable updates) { return scatterDiv(null, ref, indices, updates); } /** * Scatter division operation.
* If indices is rank 0 (a scalar), then out[index, ...] /= updates[...]
* If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] /= updates[i, ...]
* If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] /= updates[i, ..., k, ...]
* Note that if multiple indices refer to the same location, the contributions from each is handled correctly. * * @param name Name of the output variable * @param ref Initial/source variable * @param indices Indices array * @param updates Updates to add to the initial/source array * @return The updated variable */ public SDVariable scatterDiv(String name, SDVariable ref, SDVariable indices, SDVariable updates) { SDVariable ret = f().scatterDiv(ref, indices, updates); return updateVariableNameAndReference(ret, name); } /** * @see #scatterMax(String, SDVariable, SDVariable, SDVariable) */ public SDVariable scatterMax(SDVariable ref, SDVariable indices, SDVariable updates) { return scatterMax(null, ref, indices, updates); } /** * Scatter max operation.
* If indices is rank 0 (a scalar), then out[index, ...] = max(updates[...], in[index,...])
* If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = max(updates[i,...], in[indices[i],...])
* If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = max(updates[i, ..., k, ...], in[indices[i], ..., indices[k], ...]
* Note that if multiple indices refer to the same location, the contributions from each is handled correctly. * * @param name Name of the output variable * @param ref Initial/source variable * @param indices Indices array * @param updates Updates to add to the initial/source array * @return The updated variable */ public SDVariable scatterMax(String name, SDVariable ref, SDVariable indices, SDVariable updates) { SDVariable ret = f().scatterMax(ref, indices, updates); return updateVariableNameAndReference(ret, name); } /** * @see #scatterMin(String, SDVariable, SDVariable, SDVariable) */ public SDVariable scatterMin(SDVariable ref, SDVariable indices, SDVariable updates) { return scatterMin(null, ref, indices, updates); } /** * Scatter min operation.
* If indices is rank 0 (a scalar), then out[index, ...] = min(updates[...], in[index,...])
* If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = min(updates[i,...], in[indices[i],...])
* If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = min(updates[i, ..., k, ...], in[indices[i], ..., indices[k], ...]
* Note that if multiple indices refer to the same location, the contributions from each is handled correctly. * * @param name Name of the output variable * @param ref Initial/source variable * @param indices Indices array * @param updates Updates to add to the initial/source array * @return The updated variable */ public SDVariable scatterMin(String name, SDVariable ref, SDVariable indices, SDVariable updates) { SDVariable ret = f().scatterMin(ref, indices, updates); return updateVariableNameAndReference(ret, name); } /** * @see #scatterUpdate(String, SDVariable, SDVariable, SDVariable) */ public SDVariable scatterUpdate(SDVariable ref, SDVariable indices, SDVariable updates) { return scatterUpdate(null, ref, indices, updates); } /** * Scatter update operation.
* If indices is rank 0 (a scalar), then out[index, ...] = updates[...]
* If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = updates[i, ...]
* If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = updates[i, ..., k, ...]
* Note that if multiple indices refer to the same location, the output at those locations is undefined - different * updates may occur in different orders * * @param name Name of the output variable * @param ref Initial/source variable * @param indices Indices array * @param updates Updates to add to the initial/source array * @return The updated variable */ public SDVariable scatterUpdate(String name, SDVariable ref, SDVariable indices, SDVariable updates) { SDVariable ret = f().scatterUpdate(ref, indices, updates); return updateVariableNameAndReference(ret, name); } /** * @see #trace(String, SDVariable) */ public SDVariable trace(SDVariable in){ return trace(null, in); } /** * Matrix trace operation * For rank 2 matrices, the output is a scalar vith the trace - i.e., sum of the main diagonal.
* For higher rank inputs, output[a,b,c] = trace(in[a,b,c,:,:]) * * @param name Name of the output variable. May be null. * @param in Input variable * @return Trace */ public SDVariable trace(String name, SDVariable in){ SDVariable ret = f().trace(in); return updateVariableNameAndReference(ret, name); } /** * Generate the variables based on the given input op and return the output variable names. * * @param function the function to generate the output * variable names for * @return the set of names generated for each output of the function. */ public SDVariable[] generateOutputVariableForOp(DifferentialFunction function, String baseName) { //xyz ops only have 1 output //if there is already a base name defined, use that if (baseName == null || baseName.isEmpty() && getBaseNameForFunction(function) != null) baseName = getBaseNameForFunction(function); if (baseName == null) baseName = function.opName(); val outputShape = function.calculateOutputShape(); if (outputShape == null || outputShape.isEmpty()) { if (function instanceof CustomOp) { CustomOp customOp = (CustomOp) function; //can't guess number of outputs, variable int num_outputs = function.getNumOutputs(); //Use this in preference - if set. Descriptor might specify 2, but it can sometimes be 2+ if (num_outputs <= 0) { val descriptor = customOp.getDescriptor(); if (descriptor != null) { num_outputs = descriptor.getNumOutputs(); } if (num_outputs <= 0) { throw new ND4UnresolvedOutputVariables("Could not determine number of output variables for op " + function.getOwnName() + " - " + function.getClass().getSimpleName() + ". Ops can override" + " getNumOutputs() to specify number of outputs if required"); } } char ordering = 'c'; SDVariable[] args = function.args(); if (args != null && args.length > 0 && args[0].getArr() != null) { //Args may be null or length 0 for some ops, like eye ordering = function.args()[0].getArr().ordering(); } SDVariable[] ret = new SDVariable[num_outputs]; //dynamic shapes //When importing from TF: convention seem to be names like "unstack", "unstack:1", "unstack:2", ... //TODO validate this! for (int i = 0; i < ret.length; i++) { SDVariable var = (i == 0 ? getVariable(baseName) : getVariable(baseName + ":" + i)); if (var == null) { //Generate new variable name if one with the specified name doesn't exist var = var(generateNewVarName(baseName, i), null, new ZeroInitScheme(ordering)); } var.setOutputIndex(i); var.setCreator(function); ret[i] = var; } //Update the internal state: outgoing variables for function if (getOutputsForFunction(function) == null) addOutgoingFor(ret, function); return ret; } //this is for unresolved shapes, we know xyz is always 1 output else if (function instanceof BaseOp && outputShape.isEmpty()) { SDVariable[] ret = new SDVariable[1]; SDVariable checkGet = getVariable(baseName); char ordering = 'c'; SDVariable[] args = function.args(); if (args != null && args.length > 0 && function.args()[0].getArr() != null) { //Args may be null or length 0 for some ops, like eye ordering = function.args()[0].getArr().ordering(); } if (checkGet == null) { checkGet = var(baseName, null, new ZeroInitScheme(ordering)); } else if (!importedVarName.contains(baseName)) { //need to find a new name String newName = generateNewVarName(baseName, 0); checkGet = var(newName, null, new ZeroInitScheme(ordering)); } if (checkGet == null) { checkGet = var(baseName, null, new ZeroInitScheme(ordering)); } checkGet.setOutputIndex(0); checkGet.setCreator(function); ret[0] = checkGet; //Update the internal state: outgoing variables for function if (getOutputsForFunction(function) == null) addOutgoingFor(ret, function); return ret; } } char ordering = 'c'; if (function.args() != null && function.args().length > 0 && function.args()[0].getArr() != null) { ordering = function.args()[0].getArr().ordering(); } SDVariable[] ret = new SDVariable[outputShape.size()]; // ownName/baseName will be used to get variables names val ownName = function.getOwnName(); val rootName = baseName; for (int i = 0; i < ret.length; i++) { val shape = outputShape.get(i); // it should be: rootName:index. i.e.: split:1, split:2, split:3, split:4 etc baseName = rootName + (i > 0 ? ":" + i : ""); SDVariable checkGet = getVariable(baseName); if (checkGet == null) { // obviously - there's no such var, just add it checkGet = var(baseName, shape, new ZeroInitScheme(ordering)); } else if (shape != null && !shapeAlreadyExistsForVarName(checkGet.getVarName())) { // var exists, let's update its shape putShapeForVarName(checkGet.getVarName(), shape); } else if (shape != null && shapeAlreadyExistsForVarName(checkGet.getVarName())) { // no-op. // TODO: maybe we should check shapes equality here? // it's either var that already exist, or something bad happening } else if (!importedVarName.contains(baseName)) { // FIXME: dead end. it's impossible to get here with null as shape //need to find a new name int count = 1; String name = baseName + "_" + count + (i > 0 ? ":" + i : ""); while (getVariable(name) != null) { count++; name = baseName + "_" + count + (i > 0 ? ":" + i : ""); } if (getVariable(name) != null) { throw new ND4JIllegalStateException("Converged on already generated variable!"); } checkGet = var(name, shape, new ZeroInitScheme(ordering)); } if (checkGet == null) { checkGet = var(baseName + (i > 0 ? ":" + i : ""), shape, new ZeroInitScheme(ordering)); } checkGet.setOutputIndex(i); checkGet.setCreator(function); ret[i] = checkGet; } return ret; } /** * Generate the variables based on the given input op * and return the output variable names. * * @param function the function to generate the output * variable names for * @return the set of names generated for each output of the function. */ public SDVariable[] generateOutputVariableForOp(DifferentialFunction function) { return generateOutputVariableForOp(function, function.opName()); } /** * Get a SameDiff function instance given the name of the function * * @param functionName the name of the function * @return the same diff function instance defined for the given name */ public SameDiff getFunction(String functionName) { return sameDiffFunctionInstances.get(functionName); } /** * Execute the specified ops and return the output of the last one * * @param ops Ops to execute * @return Output (or first output) of the final op in the list, after execution */ public INDArray execAndEndResult(List ops) { List exec = exec(ops); Op op = (Op) exec.get(exec.size() - 1); return op.z(); } /** * Execute the graph using the current arrays/state and return the array for the final variable in the graph.
* After execution, the arrays for other variables can be obtained using {@link #getArrForVarName(String)} * or {@link SDVariable#getArr()}
*

* Note: If the final operation has multiple output variables, use {@link #execAndEndResults()} instead * * @return The output of the final operation in the graph after execution */ public INDArray execAndEndResult() { List exec = exec().getRight(); val finalOp = exec.get(exec.size() - 1); val output = finalOp.outputVariables(); if (output.length > 1) { throw new ND4JIllegalStateException(finalOp.opName() + " has multiple outputs. Use execAndEndResults instead."); } return output[0].getArr(); } /** * Execute the graph using the current arrays/state and return the array(s) for the final variable in the graph.
* After execution, the arrays for other variables can be obtained using {@link #getArrForVarName(String)} * or {@link SDVariable#getArr()}
* * @return The outputs of the final operation in the graph, after execution */ public INDArray[] execAndEndResults() { List exec = exec().getRight(); val finalOp = exec.get(exec.size() - 1); val output = finalOp.outputVariables(); INDArray outArrays[] = new INDArray[output.length]; for (int i = 0; i < outArrays.length; i++) { outArrays[i] = output[i].getArr(); } return outArrays; } /** * Execute the graph using the current arrays/state and return the (first, and possibly only) array for the specified * variable in the graph.
* After execution, the arrays for other variables can be obtained using {@link #getArrForVarName(String)} * or {@link SDVariable#getArr()} * * @return The output of the final operation in the graph */ public INDArray execAndEndResult(int outputIndex) { List exec = exec().getRight(); val output = exec.get(exec.size() - 1).outputVariables()[outputIndex]; return output.getArr(); } public INDArray yetAnotherExecMethod(@NonNull Map inputs) { if (!wasRegistered.get()) { synchronized (this) { if (!wasRegistered.get()) { val bb = asFlatBuffers(); val ptr = new BytePointer(bb); Nd4j.getExecutioner().registerGraph(this.hashCode(), ptr); wasRegistered.set(true); } } } val newMap = new LinkedHashMap(); val keySet = inputs.keySet(); for (val key : keySet) { val vx = variableMap.get(key); newMap.put(vx.getVarName(), inputs.get(key)); } val result = Nd4j.getExecutioner().executeGraph(this.hashCode(), newMap, this.reverseMap); if (result.size() == 0) throw new ND4JIllegalStateException("Execution failed"); val list = new ArrayList(result.values()); return list.get(list.size() - 1); } /** * Executes the list of operations. * This exec method is for only invoking operations rather than creating them * * @param ops the list of already created ops * @return the passes in list */ public List exec(List ops) { for (int i = 0; i < ops.size(); i++) { Op op = (Op) ops.get(i); Nd4j.getExecutioner().exec(op); } return ops; } public TensorList getListByName(@NonNull String name) { return lists.get(name); } public void putListByName(@NonNull String name, TensorList list) { lists.put(name, list); } /** * Creates a while statement * * @param sameDiffConditional * @param loopBody * @return */ public While whileStatement(SameDiffConditional sameDiffConditional, SameDiffFunctionDefinition conditionBody, SameDiffFunctionDefinition loopBody , SDVariable[] inputVars) { return While.builder() .inputVars(inputVars) .condition(conditionBody) .predicate(sameDiffConditional) .trueBody(loopBody) .parent(this) .blockName("while-" + UUID.randomUUID().toString()) .build(); } /** * @param conditional * @param trueBody * @param falseBody * @return */ public If ifStatement(SameDiffConditional conditional, SameDiffFunctionDefinition conditionBody, SameDiffFunctionDefinition trueBody, SameDiffFunctionDefinition falseBody , SDVariable[] inputVars) { return If.builder() .conditionBody(conditionBody) .falseBody(falseBody) .trueBody(trueBody) .predicate(conditional) .inputVars(inputVars) .parent(this) .blockName("if-" + UUID.randomUUID().toString()) .build(); } public TensorArrayV3 tensorArray() { return new TensorArrayV3(this); } /** * @param functionName * @param with */ public SDVariable invokeFunctionOn(String functionName, SameDiff with) { SameDiff instance = sameDiffFunctionInstances.get(functionName); SDVariable ret = instance.invokeGraphOn(with); return ret; } /** * @param function */ public SameDiff defineFunction(String function, SameDiffFunctionDefinition functionDefinition, SDVariable[] variables) { if (!sameDiffFunctionInstances.containsKey(function)) { SameDiff sub = SameDiff.create(); sub.workspace = (workspace); this.child = sub; sub.parent = this; //setup subgraph //re execute to populate subgraph SDVariable[] ret = new SDVariable[variables.length]; for (int i = 0; i < ret.length; i++) { ret[i] = sub.var(variables[i]); } sub.inputs = ret; sub.outputs = functionDefinition.define(sub, null, ret); sameDiffFunctionInstances.put(function, sub); } this.child = null; return sameDiffFunctionInstances.get(function); } /** * @param function */ public void defineFunction(String function, SameDiffFunctionDefinition functionDefinition) { defineFunction(function, functionDefinition, new LinkedHashMap()); } /** * @param function * @param functionDefinition * @param inputs */ public void defineFunction(String function, SameDiffFunctionDefinition functionDefinition, Map inputs) { if (!sameDiffFunctionInstances.containsKey(function)) { SameDiff sub = SameDiff.create(); sub.workspace = (workspace); //setup subgraph //re execute to populate subgraph functionDefinition.define(sub, inputs, null); sameDiffFunctionInstances.put(function, sub); } } /** * Exec a given SameDiff function instance * * @param functionName the name of the SameDiff function instance to invoke * @return Output of the final variable after execution */ public INDArray execAndEndResult(String functionName) { return sameDiffFunctionInstances.get(functionName).execAndEndResult(); } /** * Execute the specified SameDiff function instance * * @param functionName the name of the SameDiff function instance to invoke * @return */ public Pair, List> exec(String functionName) { Pair, List> ret; if (debugMode) { ret = sameDiffFunctionInstances.get(functionName).enableDebugMode().exec(); } else ret = sameDiffFunctionInstances.get(functionName).exec(); //Ensure all variables are associated with this SameDiff instance after possible execBackwards() etc associateSameDiffWithOpsAndVariables(); return ret; } /** * Exec the given function given the ops * * @param functionName the opName of the function to * exec * @param cachedOps the cached operations * @return */ public List exec(String functionName, List cachedOps) { return sameDiffFunctionInstances.get(functionName).exec(cachedOps); } /** * Execute the gradient (backward pass) function on this graph.
* Constructs a backwards graph (differentiating the defined graph) if it does not already exist, and the executes * the operations on that graph, calculating gradients for all variables.
* Note that after execBackwards() has completed, the gradient arrays for a each variable can be accessed using * {@link SDVariable#getGradient()} followed by {@link SDVariable#getArr()} or by using {@link #getGradForVariable(String)} * * @return Result of execution */ public Pair, List> execBackwards() { if (getFunction("grad") == null) { createGradFunction(); } if (log.isTraceEnabled()) { log.trace("About to execute backward function"); } Pair, List> forward = exec("grad"); SameDiff grad = getFunction("grad"); if (grad.isDebugMode()) { //ensure all gradients are present for all variables for (SDVariable sdVariable : grad.variables()) { sdVariable.gradient(); } } return forward; } /** * Create the gradient function (for calculating gradients via {@link #execBackwards()}) if it is not already defined. * Users do not usually need to call this function manually, as it is called as required in the aforementioned method. *

* If the gradient function already exists, this method is a no-op.
* After this method returns, the SameDiff function instance for the gradient can be accessed using {@link #getFunction(String)} * with name "grad" as the argument. */ public void createGradFunction() { if (log.isTraceEnabled()) { log.trace("Defining function \"grad\""); } final SameDiff outer = this; defineFunction("grad", new SameDiffFunctionDefinition() { @Override public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { //propagate graph to this samediff instance //which will also contain the backward if (SameDiff.this.debugMode) { sameDiff.enableDebugMode(); } outer.invokeGraphOn(sameDiff); if (debugMode) { //Expect incoming args and outgoing args to be the same Preconditions.checkState(sameDiff.incomingArgsReverse.keySet().equals(incomingArgsReverse.keySet()), "incomingArgsReverse keysets not equal"); Preconditions.checkState(sameDiff.outgoingArgsReverse.keySet().equals(outgoingArgsReverse.keySet()), "outgoingArgsReverse keysets not equal"); } List allFunctions = new ArrayList<>(sameDiff.functionInstancesById.values()); if (allFunctions.isEmpty()) { throw new ND4JIllegalStateException("No ops found!"); } for (val func : allFunctions) { if (func instanceof SDVariable) { continue; } val args = func.args(); for (val arg : args) arg.setSameDiff(sameDiff); val outputs = func.outputVariables(); for (val output : outputs) output.setSameDiff(sameDiff); func.setSameDiff(sameDiff); } val initialOuts = allFunctions.get(allFunctions.size() - 1).outputVariables(); val firstBackward = initialOuts[0]; if (log.isTraceEnabled()) { String[] initialOutputsStr = allFunctions.get(allFunctions.size() - 1).outputVariablesNames(); String s = initialOutputsStr == null ? "null" : Arrays.toString(initialOutputsStr); log.trace("Defining backward function: initial outputs {}", s); } //start with scalar backprop SDVariable initialGrad = sameDiff.var("one-var", Nd4j.trueScalar(1.0)); sameDiff.forwardVarForGrad.put(firstBackward.getVarName(), initialGrad); sameDiff.gradients.put(firstBackward.getVarName(), initialGrad); SDVariable gradientBackwardsMarker = sameDiff.gradientBackwardsMarker(firstBackward); //reinitialize list with all declared variables allFunctions = new ArrayList<>(sameDiff.functionInstancesById.values()); Collections.reverse(allFunctions); for (int i = 0; i < allFunctions.size(); i++) { DifferentialFunction action = allFunctions.get(i); if (log.isTraceEnabled()) { log.trace("Defining backward function step {} of {}: {} ({}) - {}", (i + 1), allFunctions.size(), action.opName(), action.getOwnName(), action.getClass().getName()); } if (action instanceof GradientBackwardsMarker) { continue; } DifferentialFunction currFunction = action; Preconditions.checkState(currFunction.getSameDiff() == sameDiff, "Wrong samediff instance found!"); //Preconditions.checkNotNull("Gradient for " + currFunction.opName() + " was null ! " + sameDiff.getVariableForVertexId(currFunction.getVertexId()).getGradient()); val args = currFunction.outputVariables(); for (val arg : args) { if (arg.getSameDiff() != sameDiff) { arg.setSameDiff(sameDiff); } } List grads = new ArrayList<>(); for (val varToGrad : args) { val grad = varToGrad.gradient(); if (grad == null) throw new ND4JIllegalStateException("No gradient found for " + varToGrad.getVarName()); grads.add(grad); } List currFnGrads = currFunction.diff(grads); if (log.isTraceEnabled()) { log.trace("Finished Defining backward function step {} of {}: {} ({}) - {}", (i + 1), allFunctions.size(), action.opName(), action.getOwnName(), action.getClass().getName()); } if (debugMode) { //Expect incoming args and outgoing args to be the same Preconditions.checkState(sameDiff.incomingArgsReverse.keySet().equals(sameDiff.outgoingArgsReverse.keySet()), "incomingArgsReverse and outgoingArgsReverse keysets not equal after backprop of function %s of %s: %s (%s)", (i + 1), allFunctions.size(), action.getOwnName(), action.getClass().getName()); } } if (sameDiff.isDebugMode()) { //ensure all gradients are present for all variables for (SDVariable sdVariable : variables()) { sdVariable.gradient(); } } if (log.isTraceEnabled()) { log.trace("Defining backward function complete"); } return new SDVariable[]{sameDiff.var("grad", new int[]{1, 1})}; } }); } /** * Exec a backwards operation and return the end result * * @return */ public INDArray execBackwardAndEndResult() { List backwards = execBackwards().getRight(); DifferentialFunction df = backwards.get(backwards.size() - 1); if (df instanceof Op) { return ((Op) df).z(); } else if (df instanceof DynamicCustomOp) { return ((DynamicCustomOp) df).getOutputArgument(0); } else { return null; } } /** * Creates and executes a list of operations * * @return */ public INDArray execWithPlaceHolderAndEndResult(Map inputs) { resolveVariablesWith(inputs); return execAndEndResult(); } /** * Set the original shape for a given place holder.
* This is used to track original shapes of place holder variables.
* The reason we track original shapes is to validate possible candidate arrays coming in (especially with -1 * as the expected shapes). *

* Note that if {@link #isPlaceHolder(String)} * returns false for the passed in vertex id, * a {@link ND4JIllegalStateException} is thrown. *

* A vertex id must be added first. You can * do this with {@link #addAsPlaceHolder(String)} * * @param variableName the vertex id for the original shape * @param shape the shape of the place holder */ public void setOriginalPlaceHolderShape(String variableName, long[] shape) { if (!isPlaceHolder(variableName)) { throw new ND4JIllegalStateException("Vertex id " + variableName + " does not appear to be a place holder. Did you forget to call addPlaceHolder?"); } if (shape == null) { throw new ND4JIllegalStateException("Null and 0 length shape arrays not allowed"); } if (placeHolderOriginalShapes.containsKey(variableName) && !Arrays.equals(placeHolderOriginalShapes.get(variableName), shape)) { throw new ND4JIllegalStateException("Unable to add a new shape for vertex id " + variableName); } //after validation now only set once placeHolderOriginalShapes.put(variableName, shape); } /** * Get the original shape for the vertex id if one was set (other wise returns null).
* This is mainly for use in validating passed in arrays as arguments to {@link #resolveVariablesWith(Map)} * usually when executing using {@link #execWithPlaceHolder(Map)} * * @param varName the vertex id to get the original shape for. * @return the set vertex */ public long[] getOriginalShapeForPlaceHolder(String varName) { return placeHolderOriginalShapes.get(varName); } /** * Returns true if this vertex id is a place holder variable or not
* A place holder variable is one where the array shape(s) are currently known and can't yet be calculated * * @param varName the vertex id to test * @return True if the variable is a placeholder, false otherwise */ public boolean isPlaceHolder(String varName) { return placeHolderVarNames.contains(varName); } /** * Add this vertex id as a place holder * * @param varName the vertex id to add */ public void addAsPlaceHolder(String varName) { placeHolderVarNames.add(varName); if (getVariable(varName) != null && getVariable(varName).getShape() != null) { placeHolderOriginalShapes.put(varName, getVariable(varName).getShape()); } } /** * Resolve all ndarrays by updating the variables for each array specified in the given map. * An {@link IllegalStateException} will be thrown if not all arrays are specified for resolution. * * @param arrays the arrays to resolve. */ public void resolveVariablesWith(Map arrays) { for (val arrayEntry : arrays.entrySet()) { val varForName = getVariable(arrayEntry.getKey()); if (varForName == null) { throw new ND4JIllegalStateException("No variable name found for " + arrayEntry.getKey()); } if (placeHolderOriginalShapes.containsKey(arrayEntry.getKey())) { val originalShape = placeHolderOriginalShapes.get(arrayEntry.getKey()); if (originalShape.length == arrayEntry.getValue().rank()) { for (int i = 0; i < originalShape.length; i++) { if (originalShape[i] != arrayEntry.getValue().shape()[i] && originalShape[i] >= 1) { throw new ND4JIllegalStateException("Incompatible shape passed for variable. " + Arrays.toString(arrayEntry.getValue().shape())); } } } } } for (val entry : arrays.entrySet()) { if (!placeHolderVarNames.contains(entry.getKey())) { throw new ND4JIllegalStateException("Illegal variable " + entry.getKey() + " passed in. Variable found not to be a place holder variable"); } val specifiedShape = getOriginalShapeForPlaceHolder(entry.getKey()); //whole shape was specified: validate whether the input array shape is equal if (!Shape.isPlaceholderShape(specifiedShape)) { if (!Shape.shapeEquals(specifiedShape, entry.getValue().shape())) { throw new ND4JIllegalStateException("Place holder shape specified was " + Arrays.toString(specifiedShape) + " but array shape was " + Arrays.toString(entry.getValue().shape())); } } updateShapeForVarName(entry.getKey(), entry.getValue().shape()); associateArrayWithVariable(entry.getValue(), getVariable(entry.getKey())); updateArrayForVarName(entry.getKey(), entry.getValue()); } for (val funcName : propertiesToResolve.keySet()) { val func = functionInstancesById.get(funcName); if (!functionInstancesById.containsKey(funcName)) { throw new ND4JIllegalStateException("Unable to resolve function name " + funcName); } if (func instanceof CustomOp) { CustomOp customOp = (CustomOp) func; customOp.populateInputsAndOutputsFromSameDiff(); } } //declare resolved resolvedVariables = true; } /** * Returns true if all place holder variables are resolved.
* A place holder variable is resolved when {@link #getVariable(String)} getArr() does not return null and * the shape is properly resolved. * * @return true if all place holder variables are resolved. */ public boolean allPlaceHolderVariablesResolved() { for (val vertexId : placeHolderVarNames) { val var = getVariable(vertexId); if (var.getArr() == null) { return false; } } return true; } /** * Add one or or more place holder variables for the given vertex id. *

* Note that if a vertex id in placeHolderVariables isn't present in this samediff instance anyways, * an {@link ND4JIllegalStateException} is thrown * * @param varName the vertex id to add place holders for * @param placeHolderVariables the place holder variables */ public void putPlaceHolderForVariable(String varName, String... placeHolderVariables) { for (val placeHolderVariable : placeHolderVariables) { if (!variableMap.containsKey(placeHolderVariable)) { throw new ND4JIllegalStateException("No variable found for " + placeHolderVariable); } } List placeHolders = placeHolderMap.get(varName); if (placeHolders == null) { placeHolders = new ArrayList<>(); placeHolderMap.put(varName, placeHolders); } placeHolders.add(placeHolderVariables); } /** * Returns true if the given vertex id has any placeholder variables * * @param vertexId the vertex id to check for * @return true if this vertex has any place holder variables or not */ public boolean hasPlaceHolderVariables(String vertexId) { return placeHolderMap.containsKey(vertexId); } /** * Get the place holders for a given vertex id. May return null. *

* Consider using {@link #hasPlaceHolderVariables(String)} * * @param varName the vertex id to get the place holders for * @return the place holder variables for the given vertex id or null */ public List getPlaceHoldersFor(String varName) { return placeHolderMap.get(varName); } /** * Creates and executes a list of operations based on the given variables passed in.
* {@link #resolveVariablesWith(Map)} is called * * @return */ public Pair, List> execWithPlaceHolder(Map inputs) { resolveVariablesWith(inputs); return exec(); } /** * Get the {@link SDVariable} associated with each function based on the * {@link DifferentialFunction#outputVariables()} ()} * * @param functions the functions to get the variables for * @return the list of variables associated with the given {@link DifferentialFunction} */ public List getVariablesAssociatedWithFunctions(List functions) { List ret = new ArrayList<>(functions.size()); for (DifferentialFunction function : functions) { ret.addAll(Arrays.asList(function.outputVariables())); } return ret; } /** * Updates the variable name property on the passed in variable, the reference in samediff, and returns the variable. *

* Note that if null for the new variable is passed in, it will just return the original input variable. * * @param varToUpdate the variable to update * @param newVarName the new variable name * @return the passed in variable */ public SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) { if (varToUpdate == null) { throw new NullPointerException("Null input: No variable found for updating!"); } if(newVarName != null && variableMap.containsKey(newVarName) && varToUpdate != variableMap.get(newVarName)){ throw new IllegalStateException("Variable name \"" + newVarName + "\" already exists for a different SDVariable"); } if (newVarName == null && variableMap.containsKey(varToUpdate.getVarName())) { //Edge case: suppose we do m1=sd.mean(in), m2=sd.mean(m1) -> both initially have the name // "mean" and consequently a new variable name needs to be generated newVarName = generateNewVarName(varToUpdate.getVarName(), 0); } if (newVarName == null || varToUpdate.getVarName().equals(newVarName)) { return varToUpdate; } val oldVarName = varToUpdate.getVarName(); varToUpdate.setVarName(newVarName); updateVariableName(oldVarName, newVarName); return varToUpdate; } /** * Updates the variable name property on the passed in variables, its reference in samediff, and returns the variable. * * @param variablesToUpdate the variable to update * @param newVariableNames the new variable name * @return the updated, passed in variables */ public SDVariable[] updateVariableNamesAndReferences(SDVariable[] variablesToUpdate, String[] newVariableNames) { int numVariables = variablesToUpdate.length; SDVariable[] updatedVariables = new SDVariable[numVariables]; for (int i = 0; i < numVariables; i++) { SDVariable varToUpdate = variablesToUpdate[i]; String name = newVariableNames == null ? null : newVariableNames[i]; updatedVariables[i] = updateVariableNameAndReference(varToUpdate, name); } return updatedVariables; } /** * Associate the current SameDiff instance with all ops and variables. * This is necessary to ensure that when dealing with shared state (usually with a SameDiff function such * as "grad" - the backward function) we have the correct SameDiff instance set for all ops/SDVariables.
* If this is not done, arrays and shapes could be fetched from the incorrect SameDiff instance for some methods */ protected void associateSameDiffWithOpsAndVariables(){ for(DifferentialFunction df : functionInstancesById.values()){ df.setSameDiff(this); //TODO: This is ugly but seemingly necessary //Finally, also set the SDVariable for each op //Otherwise: could have an op pointing to this SameDiff instance, but op's SDVariable's sameDiff field pointing // to another SameDiff instance. At which point, they could fetch shapes and arrays from some other instance // (i.e., not from this one that is currently executing) SDVariable[] args = df.args(); if(args != null){ for(SDVariable arg : args){ arg.setSameDiff(this); } } SDVariable[] outputs = df.outputVariables(); if(outputs != null){ for(SDVariable out : outputs){ out.setSameDiff(this); } } } for(SDVariable var : variableMap.values()){ var.setSameDiff(this); } } // required for loops private SDVariable[] outputs; private SDVariable[] inputs; private Pair, List> exec_cache; /** * Clear the execution cache, if it is present */ public void clearExecutionCache(){ exec_cache = null; } /** * Execute the SameDiff instance using the current state
* After execution, the arrays for variables can be obtained using {@link #getArrForVarName(String)} or * {@link SDVariable#getArr()}
* @return Execution results */ public Pair, List> exec() { /* if (exec_cache != null){ return exec_cache; } */ if (log.isTraceEnabled()) { log.trace("Starting execution: {} functions", functionInstancesById.size()); } if (!resolvedVariables) resolveVariablesWith(new LinkedHashMap()); List ops = new ArrayList<>(); // we don't care if this thread had any other FlowPath objects attached. we'll just create new one localFlowPath.set(new FlowPath()); val flowPath = localFlowPath.get(); Map opMap = new HashMap<>(); val funcs = new ArrayList(functionInstancesById.values()); List funcNames = new ArrayList<>(functionInstancesById.keySet()); //LinkedHashMap, so order for both these vars should be identical boolean onBackward = false; // dequeue for Frames (nested, probably) val frames = new ArrayDeque(); // simple flag, set true if within frame boolean inFrame = false; // yet another flag, to remove LastFrame once we really left last frame boolean frameLeft = false; //If true: this execution includes gradient functions... boolean isExecBackwards = functionInstancesById.containsKey(GradientBackwardsMarker.OP_NAME); //Before execution: set the SameDiff instance //This is necessary, because the one op could be shared by both forward and backward samediff instances //If the SameDiff instance isn't set, they might use wrong shapes or arrays as part of their ops //And, set the SameDiff instance on all variables, for exactly the same reason associateSameDiffWithOpsAndVariables(); int i = 0; int exec_counter = 0; for (; i < funcs.size(); i++) { ++exec_counter; if (log.isTraceEnabled()) { val f = funcs.get(i); String[] argNames = f.argNames(); String[] outNames = f.outputVariablesNames(); log.trace("Starting execution of step {} of {}: Function {} (ownName={}) - {}", exec_counter, funcs.size(), f.opName(), f.getOwnName(), f.getClass().getName()); log.trace("Function inputs: {} - Function outputs: {}", (argNames == null ? "(none)" : Arrays.toString(argNames)), (outNames == null ? "(none)" : Arrays.toString(outNames))); SDVariable[] args = f.args(); for (int arg = 0; arg < args.length; arg++) { if (args[arg] == null) { log.trace("--> arg {} - {}: argument is null!", arg, argNames[arg]); } else { INDArray arr = args[arg].getArr(); String arrShape = (arr == null ? "" : Arrays.toString(arr.shape())); log.trace("--> arg {} - {}: array shape: {}", arg, argNames[arg], arrShape); } } } val opName = funcs.get(i).opName(); if (!onBackward && GradientBackwardsMarker.OP_NAME.equals(opName)) { onBackward = true; } if (GradientBackwardsMarker.OP_NAME.equals(opName)) continue; DifferentialFunction differentialFunction = funcs.get(i); if((differentialFunction instanceof ExternalErrorsFunction)) { if(isExecBackwards) ((ExternalErrorsFunction) differentialFunction).updateBeforeExecution(); continue; } val ownName = differentialFunction.getOwnName(); // just registering function for this pass flowPath.ensureNodeStateExists(differentialFunction.getOwnName()); if (differentialFunction instanceof SDVariable) { if (log.isTraceEnabled()) { log.trace("Skipping differentialFunction that is instanceof SDVariable: {}", opName); } continue; } val args = getInputsForFunction(differentialFunction); log.debug("Step: {}; Executing op [{}] for node [{}]", exec_counter, opName, ownName); // check if inputs are active nodes. skip step otherwise // please note: Exit node can't be skipped, because it's either rewind point or exit loop point boolean shouldSkip = false; if (differentialFunction instanceof Merge) { val arg0 = args[0]; val arg1 = args[1]; if (!flowPath.isActive(arg0) && !flowPath.isActive(arg1)) shouldSkip = true; } else { if (!(differentialFunction instanceof Exit)) { // if we've left Exit nodes, we can finally delete last frame name if (frameLeft) { frameLeft = false; val frame_name = frames.removeLast(); flowPath.activateFrame(frame_name, false); flowPath.forgetFrame(frame_name); } // we must check, if there's inactive nodes used as inputs for this node for (val input : args) { if (!flowPath.isActive(input)) { // propagate inactivity flowPath.markActive(differentialFunction.getOwnName(), false); shouldSkip = true; break; } } } } if (shouldSkip) { if (log.isTraceEnabled()) { log.trace("Skipping function {}: shouldSkip = true", opName); } continue; } differentialFunction.resolvePropertiesFromSameDiffBeforeExecution(); flowPath.markActive(differentialFunction.getOwnName(), true); /** * This set of operations (Enter/Exit/NextIteration/Exit/Switch) are special snowflakes: they modify graph execution order, and basically used here to replicate TF logic. * Since SameDiff itself has own logic for loops and conditionals using Scopes */ if (differentialFunction instanceof LoopCond) { if (log.isTraceEnabled()) log.trace("Starting execution of LoopCond op"); // this node just passes single input forward, for future evaluation val inputs = getInputVariablesForFunction(differentialFunction); val array = inputs[0].getArr(); variableNameToArr.put(differentialFunction.getOwnName(), array.dup(array.ordering())); flowPath.markExecuted(differentialFunction.getOwnName(), true); if ((int) array.getDouble(0) == 1) { val frameName = frames.getLast(); // incrementing number of cycles for THIS frame, only if LoopCond is true flowPath.incrementNumberOfCycles(frameName); } } else if (differentialFunction instanceof Enter) { if (log.isTraceEnabled()) log.trace("Starting execution of Enter op"); // if (flowPath.wasExecuted(differentialFunction.getOwnName())) // continue; val inputs = getInputVariablesForFunction(differentialFunction); val array = inputs[0].getArr(); val name = inputs[0].getVarName(); if (array != null) variableNameToArr.put(differentialFunction.getOwnName(), array.dup(array.ordering())); else { val cleansed = name.replaceAll(":.*",""); val list = lists.get(cleansed); if (list != null) lists.put(ownName, list); } flowPath.markExecuted(differentialFunction.getOwnName(), true); // frame_name MUST be non-null here val frame_name = ((Enter) differentialFunction).getFrameName(); if (!flowPath.isRegisteredFrame(frame_name)) { flowPath.registerFrame(frame_name); frames.addLast(frame_name); inFrame = true; } } else if (differentialFunction instanceof Exit) { if (log.isTraceEnabled()) log.trace("Starting execution of Exit op"); // this is just exit point of graph: it maps own input to own output or rewinds graph to specific position planned at first NextIteration node val frame_name = frames.getLast(); // saving frame_name for backward pass ((Exit) differentialFunction).setFrameName(frame_name); if (!flowPath.isFrameActive(frame_name)) { flowPath.markActive(differentialFunction.getOwnName(), false); // if frame is inactive, lets remove it from queue as well frameLeft = true; continue; } // Exit node is called in any way, doesn't matters if body was executed or not // so, we're checking if rewind was planned (so, NextIteration was executed before Exit) // and if it's TRUE - we're setting applying rewind by setting loop idx and calling continue if (flowPath.isRewindPlanned(frame_name)) { // just reset loop flowPath.planRewind(frame_name, false); val currentPosition = i; i = flowPath.getRewindPosition(frame_name); val startPosition = i + 1; flowPath.setRewindPosition(frame_name, -1); continue; } val inputs = getInputVariablesForFunction(differentialFunction); val array = inputs[0].getArr(); val name = inputs[0].getVarName(); if (array != null) variableNameToArr.put(differentialFunction.getOwnName(), array.dup(array.ordering())); else { val cleansed = name.replaceAll(":.*",""); val list = lists.get(cleansed); if (list != null) lists.put(ownName, list); } flowPath.markExecuted(differentialFunction.getOwnName(), true); // now it's safe to remove LastFrame frameLeft = true; } else if (differentialFunction instanceof NextIteration) { if (log.isTraceEnabled()) log.trace("Starting execution of NextIteration op"); // this operations merges own input, and schedules rewind to specific Merge node val inputs = getInputVariablesForFunction(differentialFunction); val frame_name = frames.getLast(); val array = inputs[0].getArr(); val name = inputs[0].getVarName(); if (array != null) variableNameToArr.put(differentialFunction.getOwnName(), array.dup(array.ordering())); else { val cleansed = name.replaceAll(":.*",""); val list = lists.get(cleansed); if (list != null) lists.put(ownName, list); } flowPath.markExecuted(differentialFunction.getOwnName(), true); // if NextIteration wasn't skipped with inactive branch, we'll plan rewind for this frame. obviously, only once if (!flowPath.isRewindPlanned(frame_name)) { flowPath.planRewind(frame_name, true); continue; } } else if (differentialFunction instanceof Merge) { if (log.isTraceEnabled()) log.trace("Starting execution of Merge op"); // merge operation takes two inputs, and saves one of them as own output. // if SDVariable exists for second input - we use it. First input used otherwise val inputs = getInputVariablesForFunction(differentialFunction); val frame_name = frames.size() > 0 ? frames.getLast() : null; if (frame_name != null) flowPath.activateFrame(frame_name, true); // frame_name can be null if this merge node is used for something that's not loop. i.e. switch/merge pair if (frame_name != null) flowPath.setRewindPositionOnce(frame_name, i - 1); // NextIteration can have NO frame_name defined. so let's propagate it if (inputs.length == 2) { val secondArg = functionInstancesById.get(inputs[1].getVarName()); if (secondArg != null && secondArg instanceof NextIteration) { ((NextIteration) secondArg).setFrameName(frame_name); } } // we must check second input first here if (flowPath.wasExecuted(inputs[1].getVarName())) { // propagate second input val array = inputs[1].getArr(); val name = inputs[1].getVarName(); if (array != null) variableNameToArr.put(differentialFunction.getOwnName(), array.dup(array.ordering())); else { val cleansed = name.replaceAll(":.*",""); val list = lists.get(cleansed); if (list != null) lists.put(ownName, list); } // nullify executed mark flowPath.markExecuted(inputs[1].getVarName(), false); } else { // propagate first input val array = inputs[0].getArr(); val name = inputs[0].getVarName(); if (array != null) variableNameToArr.put(differentialFunction.getOwnName(), array.dup(array.ordering())); else { val cleansed = name.replaceAll(":.*",""); val list = lists.get(cleansed); if (list != null) lists.put(ownName, list); } } flowPath.markExecuted(differentialFunction.getOwnName(), true); } else if (differentialFunction instanceof Switch) { if (log.isTraceEnabled()) log.trace("Starting execution of Switch op"); // switch takes 2 inputs: actual input and boolean scalar. If scalar is false, input is saved as output:0, if scalar is true, input is saved as output:1 ((CustomOp) differentialFunction).populateInputsAndOutputsFromSameDiff(); val inputs = getInputVariablesForFunction(differentialFunction); val input = inputs[0].getArr(); val bool = inputs[1].getArr(); val name = inputs[0].getVarName(); // basically we're setting one of the graph branches inactive. branch 0 for false, branch 1 for true if ((int) bool.getDouble(0) == 0) { // false step, we'll propagate output:0 here flowPath.setActiveBranch(differentialFunction.getOwnName(), 0); flowPath.markActive(differentialFunction.getOwnName(), true); flowPath.markActive(differentialFunction.getOwnName() + ":1", false); if (input != null) variableNameToArr.put(differentialFunction.getOwnName(), input.dup(input.ordering())); else { val cleansed = name.replaceAll(":.*",""); val list = lists.get(cleansed); if (list != null) lists.put(ownName, list); } } else { // true step, we'll propagate output:1 here flowPath.setActiveBranch(differentialFunction.getOwnName(), 1); if (input != null) variableNameToArr.put(differentialFunction.getOwnName() + ":1", input.dup(input.ordering())); else { val cleansed = name.replaceAll(":.*",""); val list = lists.get(cleansed); if (list != null) lists.put(ownName, list); } flowPath.markActive(differentialFunction.getOwnName(), false); flowPath.markActive(differentialFunction.getOwnName() + ":1", true); } flowPath.markExecuted(differentialFunction.getOwnName(), true); } else if (differentialFunction instanceof BaseTensorOp) { //if(log.isTraceEnabled()) log.info("Starting execution of Tensor op [{}]", opName); // we just pull actual code out of val list = ((BaseTensorOp) differentialFunction).execute(this); if (!lists.containsKey(list.getName())) lists.put(list.getName(), list); ops.add(differentialFunction); } else if (differentialFunction instanceof If) { if (log.isTraceEnabled()) log.trace("Starting execution of If op"); If ifOp = (If) differentialFunction; if (!onBackward) { ifOp.getPredicateExecution().exec(); //depending on the block add the proper graph body to this for persistence //and possible later processing. if (ifOp.getTargetBoolean().getArr().sumNumber().doubleValue() > 0) { ifOp.getLoopBodyExecution().exec(); ifOp.exectedTrueOrFalse(true); } else { ifOp.getFalseBodyExecution().exec(); ifOp.exectedTrueOrFalse(false); } } else { if (ifOp.getTrueBodyExecuted() != null) { Pair, List> execBackwards = null; List variablesForFunctions = null; if (ifOp.getTrueBodyExecuted()) { execBackwards = ifOp.getLoopBodyExecution().execBackwards(); variablesForFunctions = ifOp.getLoopBodyExecution().getVariablesAssociatedWithFunctions(execBackwards.getRight()); } else { execBackwards = ifOp.getFalseBodyExecution().execBackwards(); variablesForFunctions = ifOp.getFalseBodyExecution().getVariablesAssociatedWithFunctions(execBackwards.getRight()); } /** * Maps the variables from the child namespace body to * the parent. This allows access to the underlying ndarray * and returning a valid variable reference for autodiff. */ for (SDVariable variable : variablesForFunctions) { SDVariable proxyVar = var(variable); } } else throw new ND4JIllegalStateException("No body was run."); } flowPath.markExecuted(differentialFunction.getOwnName(), true); ops.add(differentialFunction); } else if (differentialFunction instanceof While) { if (log.isTraceEnabled()) log.trace("Starting execution of While op"); While whileOp = (While) differentialFunction; if (!onBackward) { SameDiff execBody = whileOp.getLoopBodyExecution(); //depending on the block add the proper graph body to this for persistence //and possible later processing. //note that we need to update the graph predicate by running the execution whileOp.getPredicateExecution().exec(); if (execBody.outputs == null) { // No explicit inputs/outputs provided. //Op was probably created by tensorflow import. // Non-inplace ops not supported. while (whileOp.getTargetBoolean().getArr().sumNumber().doubleValue() > 0) { //run the body execBody.exec(); whileOp.getPredicateExecution().exec(); whileOp.incrementLoopCounter(); } } else { if (whileOp.getTargetBoolean().getSameDiff().inputs == null) { whileOp.getTargetBoolean().getSameDiff().inputs = new SDVariable[whileOp.getInputVars().length]; for (int e = 0; e < whileOp.getInputVars().length; e++) { whileOp.getTargetBoolean().getSameDiff().inputs[i] = whileOp.getTargetBoolean().getSameDiff().variables().get(i); } } while (whileOp.getTargetBoolean().getArr().sumNumber().doubleValue() > 0) { //run the body execBody.exec(); val outputs = execBody.outputs; int cnt = 0; for (val out : execBody.outputs) { execBody.associateArrayWithVariable(out.getArr(), execBody.inputs[cnt]); whileOp.getTargetBoolean().getSameDiff().associateArrayWithVariable(out.getArr(), whileOp.getTargetBoolean().getSameDiff().inputs[cnt++]); } //update the predicate whileOp.getPredicateExecution().exec(); whileOp.incrementLoopCounter(); } } List outputs = new ArrayList<>(); val outputFuncArgs = new ArrayList<>(execBody.functionInstancesById.values()).get(execBody.functionInstancesById.values().size() - 1).outputVariables(); outputs.addAll(Arrays.asList(outputFuncArgs)); whileOp.setOutputVars(outputs.toArray(new SDVariable[outputs.size()])); ops.add(differentialFunction); } else { /** * Note: Need to accumulate gradients. * Multiply each value by the number of times looped. * This approximates accumulating the gradient * across a number of loop cycles. * We only compute the gradient for the internal loop once * and from that we multiply the gradient by 5. * */ Pair, List> mapListPair = whileOp.getLoopBodyExecution().execBackwards(); for (SDVariable variable : mapListPair.getFirst().keySet()) { variable.getArr().muli(whileOp.getNumLooped()); } } flowPath.markExecuted(differentialFunction.getOwnName(), true); } else if (differentialFunction instanceof CustomOp) { if (log.isTraceEnabled()) log.trace("Starting execution of CustomOp op"); DynamicCustomOp customOp = (DynamicCustomOp) differentialFunction; if (customOp.opName().equalsIgnoreCase("identity")) { val cleansed = args[0].replaceAll(":.*",""); val list = lists.get(cleansed); if (list != null) { lists.put(ownName, list); flowPath.markExecuted(differentialFunction.getOwnName(), true); ops.add(customOp); continue; } } try { customOp.populateInputsAndOutputsFromSameDiff(); } catch (Throwable t) { throw new RuntimeException("Error populating inputs and outputs for function \"" + differentialFunction.getOwnName() + "\" of type " + differentialFunction.getClass().getName(), t); } customOp.assertValidForExecution(); Nd4j.getExecutioner().exec(customOp); /* if (customOp instanceof LessThanOrEqual) { log.info("Step: {}; InnerCondition: {} <= {} = {}", exec_counter, customOp.getInputArgument(0), customOp.getInputArgument(1), customOp.getOutputArgument(0)); } else if (customOp instanceof LessThan) { log.info("Step: {}; OuterCondition: {} <= {} = {}", exec_counter, customOp.getInputArgument(0), customOp.getInputArgument(1), customOp.getOutputArgument(0)); } */ flowPath.markExecuted(differentialFunction.getOwnName(), true); ops.add(customOp); } else if (differentialFunction instanceof Op) { if (log.isTraceEnabled()) log.trace("Starting execution of Op op"); val inputs = getInputVariablesForFunction(differentialFunction); Op op = (Op) differentialFunction; String outVarName = ((BaseOp) op).outputVariable().getVarName(); // ops in differential function might have stale NDArrays used. we should renew them if(inputs != null && inputs.length > 0) { op.setX(inputs[0].getArr()); if (inputs.length == 2) op.setY(inputs[1].getArr()); } //Check output shape; allocate a new Z if required //For example, if minibatch size has changed since last op execution List outputShape = ((BaseOp)op).calculateOutputShape(); Preconditions.checkState(outputShape != null && outputShape.size() == 1, "Could not calculate output shape for op: %s", op.getClass()); //Update shape. DynamicCustomOp does this in populateInputsAndOutputsFromSameDiff(); for legacy ops, we'll do it here putOrUpdateShapeForVarName(outVarName, outputShape.get(0), true); INDArray z = op.z(); Preconditions.checkNotNull(z, "Could not get output array for op: %s", op.getClass()); if(!Arrays.equals(outputShape.get(0), z.shape())){ if(log.isTraceEnabled()){ log.trace("Existing op result (z) array shape for op {} was {}, allocating new array of shape {}", op.getClass().getSimpleName(), Arrays.toString(z.shape()), Arrays.toString(outputShape.get(0))); } //Get output variable: String fnName = funcNames.get(i); String outputName = outgoingArgsReverse.get(fnName)[0]; SDVariable outputVar = getVariable(outputName); putOrUpdateShapeForVarName(outputName, outputShape.get(0), true); z = outputVar.storeAndAllocateNewArray(); op.setZ(z); } if(getArrForVarName(outVarName) != z){ //Also handles null case putOrUpdateArrayForVarName(outVarName, z); } if (differentialFunction.getDimensions() == null) Nd4j.getExecutioner().exec(op); else if (op.isExecSpecial()) { op.exec(); } else { int[] axes = differentialFunction.getDimensions(); if (differentialFunction instanceof Accumulation) { Accumulation accumulation = (Accumulation) differentialFunction; Nd4j.getExecutioner().exec(accumulation, axes); if (differentialFunction.outputVariable().getArr() == null) { val var = differentialFunction.outputVariables()[0]; updateVariable(var.getVarName(), accumulation.z()); updateShapeForVarName(var.getVarName(), accumulation.z().shape()); } } else if (differentialFunction instanceof BroadcastOp) { BroadcastOp broadcastOp = (BroadcastOp) differentialFunction; Nd4j.getExecutioner().exec(broadcastOp, axes); } else if (differentialFunction instanceof GradientOp) { Nd4j.getExecutioner().exec(op); } else if (differentialFunction instanceof IndexAccumulation) { IndexAccumulation indexAccumulation = (IndexAccumulation) differentialFunction; Nd4j.getExecutioner().exec(indexAccumulation, axes); } else if (differentialFunction instanceof TransformOp) { TransformOp t = (TransformOp) differentialFunction; Nd4j.getExecutioner().exec(t, axes); } } flowPath.markExecuted(differentialFunction.getOwnName(), true); ops.add(differentialFunction); } else { throw new IllegalStateException("Unknown function type: " + differentialFunction.getClass().getName()); } //debug // printFunction(differentialFunction); if (log.isTraceEnabled()) { log.trace("Execution completed for DifferentialFunction {} - {}", opName, differentialFunction.getOwnName()); SDVariable[] outputVars = differentialFunction.outputVariables(); for (int x = 0; x < outputVars.length; x++) { INDArray arr = outputVars[x].getArr(); String arrShape = (arr == null ? "" : Arrays.toString(arr.shape())); log.trace("--> output {} - {}: array shape {}", x, outputVars[x].getVarName(), arrShape); } } } if (log.isTraceEnabled()) { log.trace("Execution complete"); } val ret = new Pair<>(opMap, ops); exec_cache = ret; if (parent != null) { parent.exec_cache = exec_cache; } return ret; } /** * Print the given function for debugging (will not print functions) * * @param differentialFunction the function to print */ public void printFunction(DifferentialFunction differentialFunction) { if (!logExecution) return; if (differentialFunction instanceof SDVariable) return; StringBuilder argShapes = new StringBuilder(); for (val arg : differentialFunction.args()) { argShapes.append(" Variable " + arg.getVarName() + " Shape for " + Arrays.toString(arg.getShape())); } for (val func : differentialFunction.outputVariables()) { argShapes.append(" Output variable " + func.getVarName() + " is " + Arrays.toString(func.getShape())); } StringBuilder realShapes = new StringBuilder(); for (val arg : differentialFunction.args()) { realShapes.append(" Input shape for " + arg.getVarName() + " is " + Arrays. toString(getShapeForVarName(arg.getVarName()))); } for (val arg : differentialFunction.outputVariables()) { realShapes.append(" Output shape for " + arg.getVarName() + " is " + Arrays. toString(getShapeForVarName(arg.getVarName()))); } // log.info(realShapes.toString()); } /** * Permute indices for the samediff/dl4j format. * Due to the dl4j format being NCHW, this is a * simple routine for returning permute indices. * This is typically used for model import. * * @param dataFormat the data format to permute * @return the permuted indices */ public static int[] permuteDataFormatForSameDiff(String dataFormat, boolean weights) { val dl4jFormat = "NCHW"; dataFormat = dataFormat.toUpperCase(); //TF: filter_height, filter_width, in_channels, out_channels /** * N: filter_height * H: filter_width * W: in_channels * C: out_channels */ /** * * */ //DL4J: filter_height,out_channels,filter_width,in_channels // Weights should be: out channels, in channels, height,width int[] ret = new int[4]; if (weights) { ret[0] = dataFormat.indexOf('W'); ret[1] = dataFormat.indexOf('C'); ret[2] = dataFormat.indexOf('N'); ret[3] = dataFormat.indexOf('H'); return ret; } //NHWC //DL4J: NCHW for (int i = 0; i < dataFormat.length(); i++) { if (dl4jFormat.indexOf(dataFormat.charAt(i)) < 0) { throw new ND4JIllegalStateException("Illegal convolution data format string passed in " + dataFormat + " must be some variant of NCHW"); } } for (int i = 0; i < dl4jFormat.length(); i++) { ret[i] = dl4jFormat.indexOf(dataFormat.charAt(i)); } return ret; } /** * Update the {@link INDArray} * ndarray for the given variable name * * @param variableName the variable to update * @param arr the array to update with */ public void updateVariable(String variableName, INDArray arr) { if (!variableNameToArr.containsKey(variableName)) putArrayForVarName(variableName, arr); else updateArrayForVarName(variableName, arr); } protected int asFlatNode(String name, @NonNull SameDiff scope, @NonNull FlatBufferBuilder bufferBuilder) { int scopeName = bufferBuilder.createString(name); int flatNode = FlatNode.createFlatNode(bufferBuilder, scopeName, scopeName, OpType.LOGIC, 10, // hardcoded value 0, 0, 0, (byte) 0, 0, 0, 0, 0, -1, 0.0f, 0, 0); return flatNode; } /** * Note: INTENDED FOR DEVELOPER USE
* This method extract base variable name and output index (if exists) from raw variable name. * I.e: * - if variable name is "Unstack_2", result will be Pair("Unstack_2", 0) * - if variable name is "Unstack_2:12", result will be Pair("Unstack_2", 12) * * @param varName * @return */ public static Pair parseVariable(@NonNull String varName) { if (!varName.contains(":")) { return Pair.pairOf(varName, 0); } else { val split = varName.split(":"); val index = Integer.valueOf(split[split.length - 1]); if (split.length == 2) return Pair.pairOf(split[0], index); else { val builder = new StringBuilder(); for (int e = 0; e < split.length - 1; e++) { builder.append(split[e]); if (e < split.length - 2) builder.append(":"); } return Pair.pairOf(builder.toString(), index); } } } protected int asFlatNode(@NonNull DifferentialFunction node, @NonNull FlatBufferBuilder bufferBuilder, List variables, Map reverseMap, Map forwardMap, Map framesMap, AtomicInteger idCounter) { val opName = node.opName(); val hash = getOpNum(node.opName(), node.opType()); //log.info("Exporting node: [{}:<{}> ; OpType: {}; Hash/opNum: {}]", node.opName(), node.tensorflowName(), node.opType(), hash); double[] extras = node.getExtraArgs() != null ? new double[node.getExtraArgs().length] : new double[0]; for (int e = 0; e < extras.length; e++) { extras[e] = ((Number) node.getExtraArgs()[e]).doubleValue(); } long[] extraBits = null; if (node.opType() == Op.Type.CUSTOM) { DynamicCustomOp dynamicCustomOp = (DynamicCustomOp) node; extraBits = dynamicCustomOp.iArgs(); } else if (node instanceof Enter) { // in case of Enter node we'll be storing unique frame reference val frameName = ((Enter) node).getFrameName(); if (!framesMap.containsKey(frameName)) framesMap.put(frameName, idCounter.incrementAndGet()); extraBits = new long[]{framesMap.get(frameName).intValue()}; } else extraBits = new long[]{}; val inPaired = new ArrayList(); int[] outputIds = null; SDVariable[] outputVertexId = null; try { outputVertexId = node.outputVariables(); outputIds = new int[outputVertexId.length]; for (int i = 0; i < outputIds.length; i++) { outputIds[i] = variables.indexOf(outputVertexId[i]); } } catch (ND4UnresolvedOutputVariables e) { outputIds = new int[0]; outputVertexId = null; } catch (Exception e) { throw new ND4JIllegalStateException(e); } val inputs = node.args(); log.trace(""); for (val input : inputs) { //for (int i = 0; i < outputVertexId.length; i++) { val pair = parseVariable(input.getVarName()); if (!reverseMap.containsKey(pair.getFirst())) { if (pair.getFirst().contains("NextIteration")) { // forward declaration: Merge node in case of loop will be referring to NextIteration node, which wasn't announced yet int fwdNodeId = idCounter.incrementAndGet(); forwardMap.put(pair.getFirst(), fwdNodeId); reverseMap.put(pair.getFirst(), fwdNodeId); } else { throw new ND4JIllegalStateException("Unknown variable used in input: [" + pair.getFirst() + "]"); } } int nodeId = reverseMap.get(pair.getFirst()); int outputIndex = pair.getSecond(); inPaired.add(IntPair.createIntPair(bufferBuilder, nodeId, outputIndex)); //} } log.debug("Own Name: {}", node.getOwnName()); int ownId = forwardMap.containsKey(node.getOwnName()) ? forwardMap.get(node.getOwnName()) : idCounter.incrementAndGet(); reverseMap.put(node.getOwnName(), ownId); val dims = node.opType() == Op.Type.REDUCE && node.getDimensions() != null ? node.getDimensions() : new int[]{}; // TODO: Adam, just put your props here, instead of empty list, and they will be saved List props = new ArrayList<>(); int properties = FunctionProperties.asFlatProperties(bufferBuilder, props); int nodesIn = FlatNode.createInputVector(bufferBuilder, new int[]{}); int nodesInPaired = FlatNode.createInputPairedVector(bufferBuilder, Ints.toArray(inPaired)); int nodesOut = FlatNode.createOutputVector(bufferBuilder, outputIds); int extraz = FlatNode.createExtraParamsVector(bufferBuilder, extras); int integerArgs = FlatNode.createExtraIntegerVector(bufferBuilder, extraBits); int dimensions = FlatNode.createDimensionsVector(bufferBuilder, dims); int fname = bufferBuilder.createString( outputVertexId == null || outputVertexId.length < 1 || outputVertexId[0] == null ? "" : outputVertexId[0].getVarName()); int scopeName = bufferBuilder.createString(""); if (node.opType() == null) log.warn("Null-op node: {}", node); int flatNode = FlatNode.createFlatNode( bufferBuilder, ownId, fname, getFlatOpType(node.opType()), hash, properties, nodesIn, nodesInPaired, (byte) 0, nodesOut, extraz, integerArgs, dimensions, -1, node.opType() == Op.Type.SCALAR && node.getScalarValue() != null ? node.getScalarValue().floatValue() : 0.0f, 0, scopeName); return flatNode; } /** * This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and * all arrays as a ByteBuffer containing the FlatBuffers format data * * @param configuration - ExecutorConfiguration to be embedded into serialized graph * @return a ByteBuffer holding the exported FlatBuffers representation of the graph */ public ByteBuffer asFlatBuffers(@NonNull ExecutorConfiguration configuration) { Nd4j.getExecutioner().commit(); FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(1024); val idCounter = new AtomicInteger(0); val flatVariables = new ArrayList(); val flatOffsets = new ArrayList(); val flatNodes = new ArrayList(); // first of all we build VariableSpace dump List variableList = new ArrayList<>(variables()); val reverseMap = new LinkedHashMap(); val forwardMap = new LinkedHashMap(); val framesMap = new LinkedHashMap(); int idx = 0; for (val variable : variables()) { log.debug("Exporting variable: [{}]", variable.getVarName()); if (variable.getArr() == null || variable.getShape() == null) { //putArrayForVarName(variable.getVarName(), Nd4j.scalar(1.0)); //addAsPlaceHolder(variable.getVarName()); continue; } val pair = parseVariable(variable.getVarName()); reverseMap.put(pair.getFirst(), idCounter.incrementAndGet()); log.debug("Adding [{}] as [{}]", pair.getFirst(), idCounter.get()); val arr = variable.getArr(); int name = bufferBuilder.createString(variable.getVarName()); int array = arr.toFlatArray(bufferBuilder); int id = IntPair.createIntPair(bufferBuilder, idCounter.get(), 0); int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, 0, array, -1); flatVariables.add(flatVariable); } //add functions for (val func : functionInstancesById.values()) { flatNodes.add(asFlatNode(func, bufferBuilder, variableList, reverseMap, forwardMap, framesMap, idCounter)); } // we're dumping scopes now for (val scope : sameDiffFunctionInstances.entrySet()) { flatNodes.add(asFlatNode(scope.getKey(), scope.getValue(), bufferBuilder)); val currVarList = new ArrayList(scope.getValue().variables()); // converting all ops from node for (val node : scope.getValue().variables()) { INDArray arr = node.getArr(); if (arr == null) { //val otherArr = Nd4j.scalar(1.0); //scope.getValue().putArrayForVarName(node.getVarName(), otherArr); //log.warn("Adding placeholder for export for var name {}", node.getVarName()); //arr = otherArr; continue; } int name = bufferBuilder.createString(node.getVarName()); int array = arr.toFlatArray(bufferBuilder); int id = IntPair.createIntPair(bufferBuilder, ++idx, 0); val pair = parseVariable(node.getVarName()); reverseMap.put(pair.getFirst(), idx); log.debug("Adding [{}] as [{}]", pair.getFirst(), idx); int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, 0, array, -1); flatVariables.add(flatVariable); } //add functions for (val func : scope.getValue().functionInstancesById.values()) { flatNodes.add(asFlatNode(func, bufferBuilder, currVarList, reverseMap, forwardMap, framesMap, idCounter)); } } int outputsOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatOffsets)); int variablesOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatVariables)); int nodesOffset = FlatGraph.createNodesVector(bufferBuilder, Ints.toArray(flatNodes)); int fg = FlatGraph.createFlatGraph(bufferBuilder, 119, variablesOffset, nodesOffset, outputsOffset, configuration.getFlatConfiguration(bufferBuilder)); bufferBuilder.finish(fg); synchronized (this) { if (this.reverseMap == null) this.reverseMap = reverseMap; } return bufferBuilder.dataBuffer(); } /** * This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and * all arrays as a ByteBuffer containing the FlatBuffers format data * * @return a ByteBuffer holding the exported FlatBuffers representation of the graph */ public ByteBuffer asFlatBuffers() { val configuration = ExecutorConfiguration.builder() .outputMode(OutputMode.VARIABLE_SPACE) .executionMode(org.nd4j.autodiff.execution.conf.ExecutionMode.SEQUENTIAL) .profilingMode(OpExecutioner.ProfilingMode.DISABLED) .gatherTimings(true) .build(); return asFlatBuffers(configuration); } /** * This method just converts enums * * @param val * @return */ public static ByteOrder getOrderFromByte(byte val) { if (val == org.nd4j.graph.ByteOrder.LE) return ByteOrder.LITTLE_ENDIAN; else return ByteOrder.BIG_ENDIAN; } /** * This method returns current byte order for this JVM as libnd4j enum * * @return */ public static byte getOrderAsByte() { if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) return org.nd4j.graph.ByteOrder.BE; else return org.nd4j.graph.ByteOrder.LE; } /** * This method converts SameDiff instance to FlatBuffers and saves it to file which can be restored later * * @param file File to save the FlatBuffers serialized graph (including arrays) to */ public void asFlatFile(@NonNull File file) throws IOException { val fb = asFlatBuffers(); val offset = fb.position(); val array = fb.array(); try (val fos = new FileOutputStream(file); val bos = new BufferedOutputStream(fos); val dos = new DataOutputStream(bos)) { dos.write(array, offset, array.length - offset); } } /** * This method converts SameDiff instance to FlatBuffers and saves it to file which can be restored later * * @param file File to save the FlatBuffers serialized graph (including arrays) to */ public void asFlatFile(@NonNull File file, @NonNull ExecutorConfiguration configuration) throws IOException { val fb = asFlatBuffers(configuration); val offset = fb.position(); val array = fb.array(); try (val fos = new FileOutputStream(file); val bos = new BufferedOutputStream(fos); val dos = new DataOutputStream(bos)) { dos.write(array, offset, array.length - offset); } } /** * This method returns a text representation of the "flattened" graph. * * @return String representation of the graph * @see #summary() */ public String asFlatPrint() { val sb = new StringBuilder(); val fb = asFlatBuffers(); val graph = FlatGraph.getRootAsFlatGraph(fb); sb.append("\nExternal variables:\n\n"); for (int e = 0; e < graph.variablesLength(); e++) { val var = graph.variables(e); val ndarray = Nd4j.createFromFlatArray(var.ndarray()); sb.append(var.id().first()) .append(":<").append(var.name()).append("> ") .append(Arrays.toString(ndarray.shapeInfoDataBuffer().asInt())).append("; Values: ").append(Arrays.toString(ndarray.data().asFloat())).append(";\n"); } val map = Nd4j.getExecutioner().getCustomOperations(); sb.append("\nOps sequence:\n\n"); for (int e = 0; e < graph.nodesLength(); e++) { val node = graph.nodes(e); log.info("{}:<{}>", node.id(), node.name()); sb.append(node.id()) .append(":<").append(node.name()).append("> ").append(SameDiff.getTypeFromByte(node.opType())); if (SameDiff.getTypeFromByte(node.opType()) != Op.Type.CUSTOM) sb.append(": ").append(node.opNum()); else { val keys = map.keySet(); String opName = null; for (val k : keys) { val d = map.get(k); if (d.getHash() == node.opNum()) opName = k; } if (opName == null) opName = "unknown"; sb.append(": ").append(opName); } sb.append("; Inputs: {"); for (int i = 0; i < node.inputPairedLength(); i++) { val pair = node.inputPaired(i); sb.append("[").append(pair.first()).append(":").append(pair.second()).append("]"); if (i < node.inputPairedLength() - 1) sb.append(", "); } sb.append("};"); sb.append(" OpNum: {").append(node.opNum()).append("};"); sb.append("\n"); } return sb.toString(); } /** * This method converts enums for DataType * * @param val * @return */ public static DataBuffer.Type getDataTypeFromByte(byte val) { if (val == DataType.FLOAT) return DataBuffer.Type.FLOAT; else if (val == DataType.DOUBLE) return DataBuffer.Type.DOUBLE; else if (val == DataType.HALF) return DataBuffer.Type.HALF; throw new UnsupportedOperationException("Unsupported DataType: [" + val + "]"); } /** * This method converts enums for DataType * * @param type * @return */ public static byte getDataTypeAsByte(DataBuffer.Type type) { switch (type) { case FLOAT: return DataType.FLOAT; case DOUBLE: return DataType.DOUBLE; case HALF: return DataType.HALF; case INT: return DataType.INT32; case LONG: return DataType.INT64; default: throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]"); } } /** * This method return operation ID for given op name/type pair. * * @param name * @param type * @return */ public static long getOpNum(String name, Op.Type type) { if (type == Op.Type.LOOP) { return 0; } else if (type == Op.Type.RETURN) { return 40; } else if (type == Op.Type.IF) { return 30; } else if (type == Op.Type.CONDITIONAL) { return 10; } else if (type == Op.Type.MERGE) { return 60L; } else if (type == Op.Type.LOOP_COND) { return 70L; } else if (type == Op.Type.NEXT_ITERATION) { return 80L; } else if (type == Op.Type.EXIT) { return 90L; } else if (type == Op.Type.ENTER) { return 100L; } else if (type == Op.Type.CUSTOM) { val name2 = Nd4j.getExecutioner().getCustomOperations().get(name.toLowerCase()); if (name2 == null) { val name3 = Nd4j.getExecutioner().getCustomOperations().get(name); if (name3 == null) return 0; else return name3.getHash(); } else return name2.getHash(); //return Nd4j.getExecutioner().getCustomOperations().get(name.toLowerCase()).getHash(); } else return (long) Nd4j.getOpFactory().getOpNumByName(name); } /** * This method converts enums for Op.Type * * @param type Byte representing the op type * @return Op type */ public static Op.Type getTypeFromByte(byte type) { switch (type) { case OpType.SCALAR: return Op.Type.SCALAR; case OpType.BROADCAST: return Op.Type.BROADCAST; case OpType.TRANSFORM: return Op.Type.TRANSFORM; case OpType.ACCUMULATION: return Op.Type.REDUCE; case OpType.ACCUMULATION3: return Op.Type.REDUCE3; case OpType.INDEX_ACCUMULATION: return Op.Type.INDEXREDUCE; case OpType.RANDOM: return Op.Type.RANDOM; case OpType.LOGIC: return Op.Type.META; case OpType.CUSTOM: return Op.Type.CUSTOM; case OpType.SHAPE: return Op.Type.SHAPE; case OpType.PAIRWISE: return Op.Type.PAIRWISE; case OpType.SUMMARYSTATS: return Op.Type.SUMMARYSTATS; default: throw new UnsupportedOperationException("Unknown op type passed in: " + type); } } /** * This method converts an Op.Type to it's corresponding byte value * * @param type type to convert * @return Byte representing the op type */ public static byte getFlatOpType(Op.Type type) { switch (type) { case SCALAR: return OpType.SCALAR; case BROADCAST: return OpType.BROADCAST; case TRANSFORM: case SPECIAL: return OpType.TRANSFORM; case REDUCE: return OpType.ACCUMULATION; case REDUCE3: return OpType.ACCUMULATION3; case INDEXREDUCE: return OpType.INDEX_ACCUMULATION; case RANDOM: return OpType.RANDOM; case MERGE: case CONDITIONAL: case LOOP: case RETURN: case ENTER: case EXIT: case NEXT_ITERATION: case LOOP_COND: case IF: return OpType.LOGIC; case CUSTOM: return OpType.CUSTOM; case SHAPE: return OpType.SHAPE; case PAIRWISE: return OpType.PAIRWISE; case SUMMARYSTATS: return OpType.SUMMARYSTATS; default: throw new UnsupportedOperationException("Unknown op type passed in: " + type); } } /** * Generate and return a String representation of the current SameDiff instance
* Reports variables, ops, SameDiff function instances, and (where possible) array shapes.
* For ops, the input and output variables are reported.
* For variables, the ops that they are inputs to - or outputs of - are also reported * * @return A String representation of the SameDiff instance */ public String summary() { Map varMap = variableMap(); DifferentialFunction[] functions = functions(); int countVarsWithArrays = 0; for (String s : varMap.keySet()) { if (getArrForVarName(s) != null) { countVarsWithArrays++; } } StringBuilder sb = new StringBuilder(); String format = "%-25s%-20s"; sb.append("--- Summary ---\n"); sb.append(String.format(format, "Variables:", varMap.size())).append(" (").append(countVarsWithArrays).append(" with arrays)").append("\n") .append(String.format(format, "Functions:", functions.length)).append("\n") .append(String.format(format, "SameDiff Function Defs:", sameDiffFunctionInstances.size())) .append("\n\n"); sb.append("--- Variables ---\n"); //Work out which function - if any - this arg is an output of... Map outputOfFn = new HashMap<>(); int maxLengthOutputOf = 22; //Length of "- Output Of Function -" int maxLengthOfName = 8; //Length of "- Name -" for (String s : varMap.keySet()) { String outputOf = null; for (Map.Entry dfToArgs : outgoingArgsReverse.entrySet()) { if (dfToArgs.getValue() != null && ArrayUtils.contains(dfToArgs.getValue(), s)) { outputOf = dfToArgs.getKey(); break; } } if (outputOf == null) { outputOf = ""; } else { DifferentialFunction d = getFunctionById(outputOf); outputOf = d.getOwnName() + "(" + d.opName() + ")"; } outputOfFn.put(s, outputOf); maxLengthOutputOf = Math.max(maxLengthOutputOf, outputOf.length()); maxLengthOfName = Math.max(maxLengthOfName, s.length()); } maxLengthOutputOf += 2; maxLengthOfName += 2; //Create the output for values: format = "%-" + maxLengthOfName + "s%-20s%-" + maxLengthOutputOf + "s%-20s"; sb.append(String.format(format, "- Name -", "- Array Shape -", "- Output Of Function -", "- Inputs To Functions -")).append("\n"); for (String s : varMap.keySet()) { INDArray arr = getArrForVarName(s); String arrayShape = "-"; if (arr != null) { arrayShape = Arrays.toString(arr.shape()); } List dfs = functionsArgsFor.get(s); String dfArrStr = ""; if (dfs != null) { String[] dfArr = new String[dfs.size()]; for (int i = 0; i < dfs.size(); i++) { dfArr[i] = dfs.get(i).getOwnName(); } dfArrStr = Arrays.toString(dfArr); } String outputOfStr = outputOfFn.get(s); sb.append(String.format(format, s, arrayShape, outputOfStr, dfArrStr)).append("\n"); } sb.append("\n\n--- Functions ---\n"); //First: work out the amount of space we need for inputs and outputs... List dfInputStr = new ArrayList<>(); List dfOutputStr = new ArrayList<>(); int maxInLength = 10; //Length of "- Inputs -" int maxOutLength = 11; //Length of "- Outputs -" int maxOpNameLength = 17; //Default to min of 17 - length of "- Function Name -" int maxDfClassNameLength = 10; //Default to min of 10 for (DifferentialFunction df : functions) { String[] argNames = df.argNames(); String[] outNames = df.outputVariablesNames(); String argStr = Arrays.toString(argNames); String outStr = Arrays.toString(outNames); maxInLength = Math.max(maxInLength, argStr.length()); maxOutLength = Math.max(maxOutLength, outStr.length()); dfInputStr.add(argStr); dfOutputStr.add(outStr); String name = df.getOwnName() == null ? df.opName() : df.getOwnName(); maxOpNameLength = Math.max(maxOpNameLength, name.length()); maxDfClassNameLength = Math.max(maxDfClassNameLength, df.getClass().getSimpleName().length()); } //Extra padding space maxInLength += 2; maxOutLength += 2; maxOpNameLength += 2; maxDfClassNameLength += 2; format = "%-5s%-" + maxOpNameLength + "s%-" + maxDfClassNameLength + "s%-" + maxInLength + "s%-" + maxOutLength + "s"; sb.append(String.format(format, "", "- Function Name -", "- Op -", "- Inputs -", "- Outputs -")).append("\n"); for (int i = 0; i < functions.length; i++) { DifferentialFunction df = functions[i]; String fnName = df.getOwnName() == null ? df.opName() : df.getOwnName(); sb.append(String.format(format, String.valueOf(i), fnName, df.getClass().getSimpleName(), dfInputStr.get(i), dfOutputStr.get(i))).append("\n"); } if (sameDiffFunctionInstances.size() > 0) { sb.append("\n\n--- SameDiff Defined Functions ---\n"); format = "%-20s%-15s%-15s%-15s"; sb.append(String.format(format, "- Name -", "- Variables -", "- Functions -", "- Fn Defs -")).append("\n"); for (Map.Entry e : sameDiffFunctionInstances.entrySet()) { SameDiff sd = e.getValue(); int vars = sd.variableMap().size(); int fns = (sd.functions() == null ? 0 : sd.functions().length); int defFns = sd.definedFunctionNames().size(); sb.append(String.format(format, e.getKey(), String.valueOf(vars), String.valueOf(fns), String.valueOf(defFns))).append("\n"); } } return sb.toString(); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy