Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.nd4j.autodiff.samediff.SDVariable Maven / Gradle / Ivy
package org.nd4j.autodiff.samediff;
import lombok.*;
import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.*;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
/**
*
* A variable representing a component within a
* {@@link SameDiff} graph.
*
* SDVariable is used for symbolic declaration
* of equations.
*
* @author Adam Gibson
*
*/
@Data
@NoArgsConstructor
public class SDVariable extends DifferentialFunction implements Serializable {
@Getter
@Setter
private String varName;
@Getter
@Setter
protected WeightInitScheme weightInitScheme;
@Builder
private SDVariable(String varName,
SameDiff sameDiff,
int[] shape,
WeightInitScheme weightInitScheme) {
super(sameDiff,new Object[]{});
this.varName = varName;
this.weightInitScheme = weightInitScheme;
if(weightInitScheme == null) {
this.weightInitScheme = new ZeroInitScheme('f');
}
if(shape == null) {
sameDiff.addAsPlaceHolder(varName);
}
else {
boolean foundPlaceHolder = false;
for(int i = 0; i < shape.length; i++) {
if(shape[i] < 0) {
sameDiff.addAsPlaceHolder(varName);
sameDiff.setOriginalPlaceHolderShape(varName,shape);
foundPlaceHolder = true;
break;
}
}
if(!foundPlaceHolder && shape != null)
sameDiff.putShapeForVarName(varName,shape);
}
this.sameDiff = sameDiff;
}
/**
* Returns true if this variable is a place holder
* @return
*/
public boolean isPlaceHolder() {
return sameDiff.isPlaceHolder(varName);
}
@Override
public String opName() {
return "variable";
}
@Override
public SDVariable[] outputVariables() {
return new SDVariable[] {this};
}
@Override
public SDVariable arg() {
return this;
}
@Override
public SDVariable[] args() {
return new SDVariable[] {this};
}
@Override
public SDVariable[] outputVariables(String baseName) {
return new SDVariable[] {this};
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) {
}
/**
* Allocate and return a new array
* based on the vertex id and weight initialization.
* @return the allocated array
*/
public INDArray storeAndAllocateNewArray() {
val shape = sameDiff.getShapeForVarName(getVarName());
if(getArr() != null && Arrays.equals(getArr().shape(),shape))
return getArr();
if(varName == null)
throw new ND4JIllegalStateException("Unable to store array for null variable name!");
if(shape == null) {
throw new ND4JIllegalStateException("Unable to allocate new array. No shape found for variable " + varName);
}
val arr = getWeightInitScheme().create(shape);
sameDiff.putArrayForVarName(getVarName(),arr);
return arr;
}
/**
* A getter for the allocated ndarray
* with this {@link SDVariable}.
*
* This getter will lazy initialize an array if one is not found
* based on the associated shape and {@link WeightInitScheme}
* if neither are found, an {@link ND4JIllegalStateException}
* is thrown.
*
* If a {@link DifferentialFunction} is defined, note that
* its getArr() method is called instead.
* @return the {@link INDArray} associated with this variable.
*/
public INDArray getArr() {
if(sameDiff.arrayAlreadyExistsForVarName(getVarName()))
return sameDiff.getArrForVarName(getVarName());
//initialize value if it's actually a scalar constant (zero or 1 typically...)
if(getScalarValue() != null && ArrayUtil.prod(getShape()) == 1) {
INDArray arr = Nd4j.valueArrayOf(getShape(),
getScalarValue().doubleValue());
sameDiff.associateArrayWithVariable(arr,this);
}
else if(sameDiff.getShapeForVarName(getVarName()) == null)
return null;
else {
INDArray newAlloc = getWeightInitScheme().create(sameDiff.getShapeForVarName(getVarName()));
sameDiff.associateArrayWithVariable(newAlloc,this);
}
return sameDiff.getArrForVarName(getVarName());
}
/**
* Nicer looking alias
* for the gradient variable.
* The gradient variable is meant to be an
* a variable representation
* of the gradient represented
* in the underlying {@link DifferentialFunction}
* @return
*/
public SDVariable gradient() {
return getGradient();
}
/**
* A getter for the variable gradient.
* Note here that a lazy initialization of the
* gradient variable will happen if the gradient
* isn't present at this variable's initialization
* but is set later.
* @return
*/
public SDVariable getGradient() {
return sameDiff.getGradForVariable(getVarName());
}
@Override
public List doDiff(List f1) {
throw new ND4JIllegalStateException("Unable to differentiate a variable! Must be a function.");
}
/**
* Returns the shape of this variable
* @return
*/
public int[] getShape() {
int[] initialShape = sameDiff.getShapeForVarName(getVarName());
if(initialShape == null) {
val arr = getArr();
if(arr != null)
return arr.shape();
}
return initialShape;
}
/**
*
* @return
*/
public SDVariable dup() {
return sameDiff.var(this);
}
//scalars
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsub(double sameDiffVariable) {
return rsub(sameDiff.generateNewVarName(new RSubOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdiv(double sameDiffVariable) {
return rdiv(sameDiff.generateNewVarName(new RDivOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable add(double sameDiffVariable) {
return add(sameDiff.generateNewVarName(new AddOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable sub(double sameDiffVariable) {
return sub(sameDiff.generateNewVarName(new SubOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable squaredDifference(SDVariable sameDiffVariable) {
return squaredDifference(sameDiff.generateNewVarName(new SquaredDifferenceOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable div(double sameDiffVariable) {
return div(sameDiff.generateNewVarName(new DivOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable mul(double sameDiffVariable) {
return mul(sameDiff.generateNewVarName(new MulOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsubi(double sameDiffVariable) {
return rsubi(sameDiff.generateNewVarName(new RSubOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdivi(double sameDiffVariable) {
return rdivi(sameDiff.generateNewVarName(new RDivOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable addi(double sameDiffVariable) {
return addi(sameDiff.generateNewVarName(new AddOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable subi(double sameDiffVariable) {
return subi(sameDiff.generateNewVarName(new SubOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable divi(double sameDiffVariable) {
return divi(sameDiff.generateNewVarName(new DivOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable muli(double sameDiffVariable) {
return muli(sameDiff.generateNewVarName(new MulOp().opName(),0),sameDiffVariable);
}
//end scalars
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsub(SDVariable sameDiffVariable) {
return rsub(sameDiff.generateNewVarName(new RSubOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdiv(SDVariable sameDiffVariable) {
return rdiv(sameDiff.generateNewVarName(new RDivOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable truncatedDiv(SDVariable sameDiffVariable) {
return truncatedDiv(sameDiff.generateNewVarName(new TruncateDivOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable add(SDVariable sameDiffVariable) {
return add(sameDiff.generateNewVarName(new AddOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable sub(SDVariable sameDiffVariable) {
return sub(sameDiff.generateNewVarName(new SubOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable div(SDVariable sameDiffVariable) {
return div(sameDiff.generateNewVarName(new DivOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable mul(SDVariable sameDiffVariable) {
return mul(sameDiff.generateNewVarName(new MulOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsubi(SDVariable sameDiffVariable) {
return rsubi(sameDiff.generateNewVarName(new RSubOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdivi(SDVariable sameDiffVariable) {
return rdivi(sameDiff.generateNewVarName(new RDivOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable addi(SDVariable sameDiffVariable) {
return addi(sameDiff.generateNewVarName(new AddOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable subi(SDVariable sameDiffVariable) {
return subi(sameDiff.generateNewVarName(new SubOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable divi(SDVariable sameDiffVariable) {
return divi(sameDiff.generateNewVarName(new DivOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable muli(SDVariable sameDiffVariable) {
return muli(sameDiff.generateNewVarName(new MulOp().opName(),0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsub(String varName, double sameDiffVariable) {
val function = sameDiff.f().rsub(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdiv(String varName, double sameDiffVariable) {
val function = sameDiff.f().rdiv(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable truncatedDiv(String varName, SDVariable sameDiffVariable) {
val function = sameDiff.f().truncatedDiv(this, sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable add(String varName, double sameDiffVariable) {
val function = sameDiff.f().add(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable sub(String varName, double sameDiffVariable) {
SDVariable right = this;
val result = sameDiff.f().sub(right,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable div(String varName, double sameDiffVariable) {
val function = sameDiff.f().div(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable mul(String varName, double sameDiffVariable) {
val function = sameDiff.f().mul(this
, sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsubi(String varName, double sameDiffVariable) {
val function = sameDiff.f().rsubi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdivi(String varName, double sameDiffVariable) {
SDVariable function = sameDiff.f().rdivi(this
,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable addi(String varName, double sameDiffVariable) {
val function = sameDiff.f().addi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable subi(String varName, double sameDiffVariable) {
val function = sameDiff.f().subi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable divi(String varName, double sameDiffVariable) {
val function = sameDiff.f().divi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable muli(String varName, double sameDiffVariable) {
val function = sameDiff.f().muli(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
//end scalars
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsub(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
val result = sameDiff.f().rsub(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdiv(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
val result = sameDiff.f().rdiv(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable add(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
val result = sameDiff.f().add(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable sub(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
SDVariable left = this;
SDVariable right = sameDiffVariable;
val result = sameDiff.f().sub(left,right);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return squared difference between variables
*/
public SDVariable squaredDifference(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
SDVariable left = this;
SDVariable right = sameDiffVariable;
val result = sameDiff.f().squaredDifference(left, right);
return sameDiff.updateVariableNameAndReference(result, varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable div(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
val result = sameDiff.f().div(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable mul(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
SDVariable left = this;
SDVariable right = sameDiffVariable;
Preconditions.checkNotNull(left,"Left input is null!");
Preconditions.checkNotNull(right,"Right input is null!");
val result = sameDiff.f().mul(left,right);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsubi(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
val result = sameDiff.f().rsubi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdivi(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
val result = sameDiff.f().rdivi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable addi(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
val result = sameDiff.f().addi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
@Override
public Op.Type opType() {
return Op.Type.RETURN;
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable subi(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
SDVariable left = this;
SDVariable right = sameDiffVariable;
val result = sameDiff.f().subi(left,right);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable divi(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
val result = sameDiff.f().divi(this,sameDiffVariable);
result.setVarName(varName);
return result;
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable muli(String varName, SDVariable sameDiffVariable) {
assertShapeEquals(sameDiffVariable);
SDVariable left = this;
SDVariable right = sameDiffVariable;
SDVariable result = sameDiff.f().muli(left,right);
result.setVarName(varName);
return result;
}
/**
* Evaluate the result of this variable
* @return
*/
public INDArray eval() {
SameDiff exec = sameDiff.dup();
exec.defineFunction("output", new SameDiff.SameDiffFunctionDefinition() {
@Override
public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) {
return new SDVariable[] { SDVariable.this};
}
});
SDVariable output = exec.invokeFunctionOn("output",exec);
return output.getSameDiff().execAndEndResult();
}
private void assertShapeEquals(SDVariable variable) {
/* val shape = sameDiff.getShapeForVarName(getVarName());
if(shape == null && !variable.isPlaceHolder())
throw new ND4JIllegalStateException("Shape not found for variable " + getVarName());
if(!Arrays.equals(shape,variable.getShape()) && ArrayUtil.prod(variable.getShape()) != 1 && Shape.broadcastOutputShape(shape,variable.getShape()) == null) {
throw new IllegalArgumentException("Input shape must be the same as this shape " + Arrays.toString(shape) + " and shape was " + Arrays.toString(variable.getShape()));
}*/
}
@Override
public String toString() {
return varName;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (!super.equals(o)) return false;
SDVariable that = (SDVariable) o;
if (varName != null ? !varName.equals(that.varName) : that.varName != null) return false;
return weightInitScheme != null ? weightInitScheme.equals(that.weightInitScheme) : that.weightInitScheme == null;
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + (varName != null ? varName.hashCode() : 0);
result = 31 * result + (weightInitScheme != null ? weightInitScheme.hashCode() : 0);
return result;
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
}