Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.nd4j.autodiff.samediff.SDVariable Maven / Gradle / Ivy
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.autodiff.samediff;
import com.google.common.annotations.VisibleForTesting;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.builder.Diff;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.*;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
/**
*
* A variable representing a component within a
* {@@link SameDiff} graph.
*
* SDVariable is used for symbolic declaration
* of equations.
*
* @author Adam Gibson
*
*/
@Data
@NoArgsConstructor
@Slf4j
public class SDVariable extends DifferentialFunction implements Serializable {
@Getter
@Setter
private String varName;
@Getter
@Setter
protected WeightInitScheme weightInitScheme;
// autogen_tag::sdvars::start
@Builder
private SDVariable(String varName,
SameDiff sameDiff,
long[] shape,
WeightInitScheme weightInitScheme) {
super(sameDiff,new Object[]{});
this.varName = varName;
this.weightInitScheme = weightInitScheme;
if(weightInitScheme == null) {
// we want C order as default in ALL cases
this.weightInitScheme = new ZeroInitScheme('c');
}
if(shape == null) {
sameDiff.addAsPlaceHolder(varName);
}
else {
boolean foundPlaceHolder = false;
for(int i = 0; i < shape.length; i++) {
if(shape[i] < 0) {
sameDiff.addAsPlaceHolder(varName);
sameDiff.setOriginalPlaceHolderShape(varName, shape);
foundPlaceHolder = true;
break;
}
}
if(!foundPlaceHolder && shape != null)
sameDiff.putShapeForVarName(varName,shape);
}
this.sameDiff = sameDiff;
}
/**
* Returns true if this variable is a place holder
* @return
*/
public boolean isPlaceHolder() {
return sameDiff.isPlaceHolder(varName);
}
@Override
public String opName() {
return "variable";
}
@Override
public SDVariable[] outputVariables() {
return new SDVariable[] {this};
}
@Override
public SDVariable arg() {
return this;
}
@Override
public SDVariable[] args() {
return new SDVariable[] {this};
}
@Override
public SDVariable[] outputVariables(String baseName) {
return new SDVariable[] {this};
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) {
}
/**
* Allocate and return a new array
* based on the vertex id and weight initialization.
* @return the allocated array
*/
public INDArray storeAndAllocateNewArray() {
val shape = sameDiff.getShapeForVarName(getVarName());
INDArray currArr = getArr();
if(currArr != null && Arrays.equals(currArr.shape(),shape))
return getArr();
if(varName == null)
throw new ND4JIllegalStateException("Unable to store array for null variable name!");
if(shape == null) {
throw new ND4JIllegalStateException("Unable to allocate new array. No shape found for variable " + varName);
}
val arr = getWeightInitScheme().create(shape);
sameDiff.associateArrayWithVariable(arr, this);
if(log.isTraceEnabled()){
log.trace("Generated and stored new array for variable \"{}\": old shape: {}, new shape {}", getVarName(),
(currArr == null ? "null" : Arrays.toString(currArr.shape())), Arrays.toString(arr.shape()));
}
return arr;
}
/**
* A getter for the allocated ndarray with this {@link SDVariable}.
*
* This getter will lazy initialize an array if one is not found based on the associated shape and
* {@link WeightInitScheme} - if this is possible. If this is not possible (due to shapes being unknown, etc)
* null is returned
*
* @return the {@link INDArray} associated with this variable.
*/
public INDArray getArr() {
return getArr(false);
}
// autogen_tag::sdvars::end
/**
* A getter for the allocated ndarray with this {@link SDVariable}.
*
* This getter will lazy initialize an array if one is not found based on the associated shape and
* {@link WeightInitScheme} - if this is possible.
* If this is not possible (due to shapes being unknown, etc) either:
* (a) null is returned - if enforceExistence == false, or
* (b) an IllegalStateException is thrown, if enforceExistence == true
*
* @return the {@link INDArray} associated with this variable.
*/
public INDArray getArr(boolean enforceExistence){
if(sameDiff.arrayAlreadyExistsForVarName(getVarName()))
return sameDiff.getArrForVarName(getVarName());
//initialize value if it's actually a scalar constant (zero or 1 typically...)
if(getScalarValue() != null && ArrayUtil.prod(getShape()) == 1) {
INDArray arr = Nd4j.valueArrayOf(getShape(),getScalarValue().doubleValue());
sameDiff.associateArrayWithVariable(arr,this);
if(log.isTraceEnabled()){
log.trace("getArr() for variable \"{}\" allocated new scalar array: shape {}", getVarName(), Arrays.toString(getShape()));
}
}
else if(sameDiff.getShapeForVarName(getVarName()) == null) {
if (enforceExistence) {
throw new IllegalStateException("Cannot get array for SDVariable \"" + getVarName() + "\": no array has" +
" been defined, and array shape cannot be calculated");
}
if(log.isTraceEnabled()){
log.trace("SDVariable.getArr(): could not get array for variable {}: shape is null", getVarName());
}
return null;
} else {
long[] shape = sameDiff.getShapeForVarName(getVarName());
INDArray newAlloc = getWeightInitScheme().create(shape);
sameDiff.associateArrayWithVariable(newAlloc,this);
if(log.isTraceEnabled()){
log.trace("getArr() for variable \"{}\" allocated new array with shape {}", getVarName(), Arrays.toString(getShape()));
}
}
return sameDiff.getArrForVarName(getVarName());
}
/**
* Nicer looking alias
* for the gradient variable.
* The gradient variable is meant to be an
* a variable representation
* of the gradient represented
* in the underlying {@link DifferentialFunction}
* @return
*/
public SDVariable gradient() {
return getGradient();
}
/**
* A getter for the variable gradient.
* Note here that a lazy initialization of the
* gradient variable will happen if the gradient
* isn't present at this variable's initialization
* but is set later.
* @return
*/
public SDVariable getGradient() {
return sameDiff.getGradForVariable(getVarName());
}
@Override
public List doDiff(List f1) {
throw new ND4JIllegalStateException("Unable to differentiate a variable! Must be a function.");
}
/**
* Returns the shape of this variable
* @return
*/
public long[] getShape() {
long[] initialShape = sameDiff.getShapeForVarName(getVarName());
if(initialShape == null) {
val arr = getArr();
if(arr != null)
return arr.shape();
}
return initialShape;
}
/**
*
* @return
*/
public SDVariable dup() {
return sameDiff.var(this);
}
public SDVariable assign(Number value){
return sameDiff.scalarSet(this, value);
}
/**
* Negate op
* @return Negated variable
*/
public SDVariable neg(){
return f().neg(this);
}
/**
* Negate op
* @return Negated variable
*/
public SDVariable neg(String name){
return sameDiff.neg(name, this);
}
public SDVariable lt(double value){
return lt(null, value);
}
public SDVariable lt(String name, double value){
return sameDiff.lt(name, this, value);
}
public SDVariable lte(double value){
return lte(null, value);
}
public SDVariable lte(String name, double value){
return sameDiff.lte(name, this, value);
}
public SDVariable gt(double value){
return gt(null, value);
}
public SDVariable gt(String name, double value){
return sameDiff.gt(name, this, value);
}
public SDVariable gte(double value){
return gte(null, value);
}
public SDVariable gte(String name, double value){
return sameDiff.gte(name, this, value);
}
public SDVariable eq(double value){
return eq(null, value);
}
public SDVariable eq(String name, double value){
return sameDiff.eq(name, this, value);
}
public SDVariable neq(double value){
return neq(null, value);
}
public SDVariable neq(String name, double value){
return sameDiff.neq(name, this, value);
}
public SDVariable lt(SDVariable other){
return lt(null, other);
}
public SDVariable lt(String name, SDVariable other){
return sameDiff.lt(name, this, other);
}
public SDVariable lte(SDVariable other){
return lte(null, other);
}
public SDVariable lte(String name, SDVariable other){
return sameDiff.lte(name, this, other);
}
public SDVariable gt(SDVariable other){
return gt(null, other);
}
public SDVariable gt(String name, SDVariable other){
return sameDiff.gt(name, this, other);
}
public SDVariable gte(SDVariable other){
return gte(null, other);
}
public SDVariable gte(String name, SDVariable other){
return sameDiff.gte(name, this, other);
}
public SDVariable eq(SDVariable other){
return eq(null, other);
}
public SDVariable eq(String name, SDVariable other){
return sameDiff.eq(name, this, other);
}
public SDVariable neq(SDVariable other){
return neq(null, other);
}
public SDVariable neq(String name, SDVariable other){
return sameDiff.neq(name, this, other);
}
public SDVariable mmul(SDVariable other){
return sameDiff.mmul(this, other);
}
//scalars
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsub(double sameDiffVariable) {
return rsub(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdiv(double sameDiffVariable) {
return rdiv(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable add(double sameDiffVariable) {
return add(sameDiff.generateNewVarName(AddOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable sub(double sameDiffVariable) {
return sub(sameDiff.generateNewVarName(SubOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable squaredDifference(SDVariable sameDiffVariable) {
return squaredDifference(sameDiff.generateNewVarName(SquaredDifferenceOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable div(double sameDiffVariable) {
return div(sameDiff.generateNewVarName(DivOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable mul(double sameDiffVariable) {
return mul(sameDiff.generateNewVarName(MulOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsubi(double sameDiffVariable) {
return rsubi(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdivi(double sameDiffVariable) {
return rdivi(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable addi(double sameDiffVariable) {
return addi(sameDiff.generateNewVarName(AddOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable subi(double sameDiffVariable) {
return subi(sameDiff.generateNewVarName(SubOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable divi(double sameDiffVariable) {
return divi(sameDiff.generateNewVarName(DivOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable muli(double sameDiffVariable) {
return muli(sameDiff.generateNewVarName(MulOp.OP_NAME,0),sameDiffVariable);
}
//end scalars
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsub(SDVariable sameDiffVariable) {
return rsub(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdiv(SDVariable sameDiffVariable) {
return rdiv(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable truncatedDiv(SDVariable sameDiffVariable) {
return truncatedDiv(sameDiff.generateNewVarName(TruncateDivOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable add(SDVariable sameDiffVariable) {
return add(sameDiff.generateNewVarName(AddOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable sub(SDVariable sameDiffVariable) {
return sub(sameDiff.generateNewVarName(SubOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable div(SDVariable sameDiffVariable) {
return div(sameDiff.generateNewVarName(DivOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable mul(SDVariable sameDiffVariable) {
return mul(sameDiff.generateNewVarName(MulOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsubi(SDVariable sameDiffVariable) {
return rsubi(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdivi(SDVariable sameDiffVariable) {
return rdivi(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable addi(SDVariable sameDiffVariable) {
return addi(sameDiff.generateNewVarName(AddOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable subi(SDVariable sameDiffVariable) {
return subi(sameDiff.generateNewVarName(SubOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable divi(SDVariable sameDiffVariable) {
return divi(sameDiff.generateNewVarName(DivOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable muli(SDVariable sameDiffVariable) {
return muli(sameDiff.generateNewVarName(MulOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsub(String varName, double sameDiffVariable) {
val function = sameDiff.f().rsub(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdiv(String varName, double sameDiffVariable) {
val function = sameDiff.f().rdiv(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable truncatedDiv(String varName, SDVariable sameDiffVariable) {
val function = sameDiff.f().truncatedDiv(this, sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable add(String varName, double sameDiffVariable) {
val function = sameDiff.f().add(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable sub(String varName, double sameDiffVariable) {
SDVariable right = this;
val result = sameDiff.f().sub(right,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable div(String varName, double sameDiffVariable) {
val function = sameDiff.f().div(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param value
* @return
*/
public SDVariable mul(String varName, double value) {
val function = sameDiff.f().mul(this
, value);
return sameDiff.updateVariableNameAndReference(function,varName);
}
public SDVariable pow(double value){
return pow(null, value);
}
public SDVariable pow(String varName, double value){
SDVariable ret = f().pow(this, value);
return sameDiff.updateVariableNameAndReference(ret, varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsubi(String varName, double sameDiffVariable) {
val function = sameDiff.f().rsubi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdivi(String varName, double sameDiffVariable) {
SDVariable function = sameDiff.f().rdivi(this
,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable addi(String varName, double sameDiffVariable) {
val function = sameDiff.f().addi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable subi(String varName, double sameDiffVariable) {
val function = sameDiff.f().subi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable divi(String varName, double sameDiffVariable) {
val function = sameDiff.f().divi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable muli(String varName, double sameDiffVariable) {
val function = sameDiff.f().muli(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
//end scalars
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsub(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
val result = sameDiff.f().rsub(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdiv(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
val result = sameDiff.f().rdiv(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable add(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
val result = sameDiff.f().add(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable sub(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
SDVariable left = this;
SDVariable right = sameDiffVariable;
val result = sameDiff.f().sub(left,right);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return squared difference between variables
*/
public SDVariable squaredDifference(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
SDVariable left = this;
SDVariable right = sameDiffVariable;
val result = sameDiff.f().squaredDifference(left, right);
return sameDiff.updateVariableNameAndReference(result, varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable div(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
val result = sameDiff.f().div(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable mul(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
SDVariable left = this;
SDVariable right = sameDiffVariable;
Preconditions.checkNotNull(left,"Left input is null!");
Preconditions.checkNotNull(right,"Right input is null!");
val result = sameDiff.f().mul(left,right);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsubi(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
val result = sameDiff.f().rsubi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdivi(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
val result = sameDiff.f().rdivi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable addi(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
val result = sameDiff.f().addi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
@Override
public Op.Type opType() {
return Op.Type.RETURN;
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable subi(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
SDVariable left = this;
SDVariable right = sameDiffVariable;
val result = sameDiff.f().subi(left,right);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable divi(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
val result = sameDiff.f().divi(this,sameDiffVariable);
result.setVarName(varName);
return result;
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable muli(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
SDVariable left = this;
SDVariable right = sameDiffVariable;
SDVariable result = sameDiff.f().muli(left,right);
result.setVarName(varName);
return result;
}
public SDVariable sum(int... dimensions){
return sum(null, dimensions);
}
public SDVariable sum(boolean keepDims, int... dimensions){
return sum(null, keepDims, dimensions);
}
public SDVariable sum(String name, int... dimensions){
return sum(name, false, dimensions);
}
public SDVariable sum(String name, boolean keepDims, int... dimensions){
return sameDiff.sum(name, this, keepDims, dimensions);
}
public SDVariable mean(boolean keepDims, int... dimensions){
return mean(null, keepDims, dimensions);
}
public SDVariable mean(String name, int... dimensions){
return mean(name, false, dimensions);
}
public SDVariable mean(int... dimensions){
return mean(null, false, dimensions);
}
public SDVariable mean(String name, boolean keepDims, int... dimensions){
return sameDiff.mean(name, this, keepDims, dimensions);
}
public SDVariable std(boolean biasCorrected, int... dimensions){
return std(null, biasCorrected, dimensions);
}
public SDVariable std(String name, boolean biasCorrected, int... dimensions){
return sameDiff.standardDeviation(name, this, biasCorrected, dimensions);
}
public SDVariable std(String name, boolean biasCorrected, boolean keepDims, int... dimensions){
return sameDiff.standardDeviation(name, this, biasCorrected, keepDims, dimensions);
}
public SDVariable prod(int... dimensions){
return prod(null, dimensions);
}
public SDVariable prod(boolean keepDims, int... dimensions){
return prod(null, keepDims, dimensions);
}
public SDVariable prod(String name, int... dimensions){
return sameDiff.prod(name, this, dimensions);
}
public SDVariable prod(String name, boolean keepDims, int... dimensions){
return sameDiff.prod(name, this, keepDims, dimensions);
}
public SDVariable min(int... dimensions){
return min(null, dimensions);
}
public SDVariable min(boolean keepDims, int... dimensions){
return min(null, keepDims, dimensions);
}
public SDVariable min(String name, int... dimensions){
return min(name, false, dimensions);
}
public SDVariable min(String name, boolean keepDims, int... dimensions){
return sameDiff.min(name, this, keepDims, dimensions);
}
public SDVariable max(int... dimensions){
return max(null, dimensions);
}
public SDVariable max(boolean keepDims, int... dimensions){
return max(null, keepDims, dimensions);
}
public SDVariable max(String name, int... dimensions){
return max(name, false, dimensions);
}
public SDVariable max(String name, boolean keepDims, int... dimensions){
return sameDiff.max(name, this, keepDims, dimensions);
}
public SDVariable norm1(int... dimensions){
return norm1(null, dimensions);
}
public SDVariable norm1(boolean keepDims, int... dimensions){
return norm1(null, keepDims, dimensions);
}
public SDVariable norm1(String name, int... dimensions){
return norm1(name, false, dimensions);
}
public SDVariable norm1(String name, boolean keepDims, int... dimensions){
return sameDiff.norm1(name, this, keepDims, dimensions);
}
public SDVariable norm2(int... dimensions){
return norm2(null, dimensions);
}
public SDVariable norm2(boolean keepDims, int... dimensions){
return norm2(null, keepDims, dimensions);
}
public SDVariable norm2(String name, int... dimensions){
return norm2(name, false, dimensions);
}
public SDVariable norm2(String name, boolean keepDims, int... dimensions){
return sameDiff.norm2(name, this, keepDims, dimensions);
}
public SDVariable normmax(int... dimensions){
return normmax(null, dimensions);
}
public SDVariable normmax(boolean keepDims, int... dimensions){
return normmax(null, keepDims, dimensions);
}
public SDVariable normmax(String name, int... dimensions){
return normmax(name, false, dimensions);
}
public SDVariable normmax(String name, boolean keepDims, int... dimensions){
return sameDiff.normmax(name, this, keepDims, dimensions);
}
public SDVariable argmax(int... dimensions){
return argmax(null, dimensions);
}
public SDVariable argmax(String name, int... dimensions){
return sameDiff.argmax(name, this, dimensions);
}
public SDVariable argmin(int... dimensions){
return argmin(null, dimensions);
}
public SDVariable argmin(String name, int... dimensions){
return sameDiff.argmin(name, this, dimensions);
}
public SDVariable setArray(INDArray array){
sameDiff.associateArrayWithVariable(array, this);
return this;
}
/**
* Evaluate the result of this variable
* @return
*/
public INDArray eval() {
sameDiff.exec();
return getArr();
}
private int outputIndex = 0;
private DifferentialFunction creator;
private void assertShapeEquals(SDVariable variable) {
/* val shape = sameDiff.getShapeForVarName(getVarName());
if(shape == null && !variable.isPlaceHolder())
throw new ND4JIllegalStateException("Shape not found for variable " + getVarName());
if(!Arrays.equals(shape,variable.getShape()) && ArrayUtil.prod(variable.getShape()) != 1 && Shape.broadcastOutputShape(shape,variable.getShape()) == null) {
throw new IllegalArgumentException("Input shape must be the same as this shape " + Arrays.toString(shape) + " and shape was " + Arrays.toString(variable.getShape()));
}*/
}
@Override
public String toString() {
return varName;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (!super.equals(o)) return false;
SDVariable that = (SDVariable) o;
if (varName != null ? !varName.equals(that.varName) : that.varName != null) return false;
return weightInitScheme != null ? weightInitScheme.equals(that.weightInitScheme) : that.weightInitScheme == null;
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + (varName != null ? varName.hashCode() : 0);
result = 31 * result + (weightInitScheme != null ? weightInitScheme.hashCode() : 0);
return result;
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
public SDVariable get(SDIndex... indices){
int ndims = indices.length;
long[] begin = new long[ndims];
long[] end = new long[ndims];
long[] strides = new long[ndims];
int[] begin_mask_arr = new int[ndims];
int[] end_mask_arr = new int[ndims];
int[] shrink_axis_mask_arr = new int[ndims];
for(int i=0; i