org.nd4j.autodiff.functions.DifferentialFunction Maven / Gradle / Ivy
package org.nd4j.autodiff.functions;
import java.util.List;
import java.util.UUID;
import com.google.common.base.Preconditions;
import lombok.*;
import org.nd4j.autodiff.AbstractIdentityFactory;
import org.nd4j.autodiff.ArrayFactory;
import org.nd4j.autodiff.ArrayField;
import org.nd4j.autodiff.Field;
import org.nd4j.autodiff.graph.Graph;
import org.nd4j.autodiff.opstate.NDArrayInformation;
import org.nd4j.autodiff.opstate.NDArrayVertex;
import org.nd4j.autodiff.opstate.OpState;
import org.nd4j.autodiff.samediff.SDGraph;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.util.ArrayUtil;
@AllArgsConstructor
@Data
@NoArgsConstructor
public abstract class DifferentialFunction>
implements Field>,
Differential> {
@Getter
@Setter
protected SameDiff sameDiff;
@Getter
protected OpState opState;
@Getter
@Setter
protected int vertexId;
protected Object[] extraArgs;
/**
*
* @param sameDiff
* @param extraArgs
*/
public DifferentialFunction(SameDiff sameDiff, Object[] extraArgs) {
this.sameDiff = sameDiff;
this.extraArgs = extraArgs;
}
/**
* Get the result shape for this function
* @return
*/
public int[] getResultShape() {
if(opState == null)
throw new IllegalStateException("Unable to get result shape with null op state");
return opState.getResult().getShape();
}
/**
* Get the value of this function
* @return
*/
public abstract X doGetValue();
/**
* Get the value specifying
* whether to freeze the graph or not
* @param freeze whether to freeze the graph or not,
* this means whether to add nodes to the internal
* computation graph or not
* @return the value of this function
*/
public X getValue(boolean freeze) {
boolean graphAlreadyFrozen = this.sameDiff.getGraph().isFrozen();
//if graph is already frozen leave it frozen
if(freeze && !graphAlreadyFrozen) {
this.sameDiff.getGraph().freeze();
}
X val = doGetValue();
if(val instanceof ArrayField && !freeze) {
ArrayField arrayField = (ArrayField) val;
NDArrayVertex vertex = (NDArrayVertex) getSameDiff().getGraph().getVertex(getVertexId());
arrayField.setVertex(vertex);
arrayField.setOps(this.sameDiff);
Preconditions.checkState(vertex != null,"Vertex " + getVertexId() + " was null.");
Preconditions.checkState(vertex.getValue() != null,"Vertex did not have a value set.");
arrayField.getInput().setScalarValue(vertex.getValue().getScalarValue());
arrayField.setInput(vertex.getValue());
Preconditions.checkState(sameDiff == arrayField.getOps(),"Passed in array factory != the passed in graph. Unable to instantiate.");
}
if(freeze && !graphAlreadyFrozen) {
this.sameDiff.getGraph().unfreeze();
}
return val;
}
@Override
public abstract double getReal();
@Override
public String getFormula(List> variables) {
sameDiff.getGraph().freeze();
String ret = doGetFormula(variables);
sameDiff.getGraph().unfreeze();
return ret;
}
public abstract String doGetFormula(List> variables);
public abstract String functionName();
@Override
public abstract String toString();
public boolean isConstant() {
return false;
}
public boolean isVariable() {
return false;
}
@Override
public abstract DifferentialFunction diff(DifferentialFunction i_v1);
private void validateDifferentialFunctionGraph(DifferentialFunction function) {
Preconditions.checkState(function.getSameDiff() == this.getSameDiff(),"Function applications must be contained in same graph. The left " + function +" must match this function " + this);
}
@Override
public DifferentialFunction rdivi(DifferentialFunction i_v) {
validateDifferentialFunctionGraph(i_v);
X ret = i_v.getValue(true).rdivi(getValue(true));
return new Constant<>(sameDiff, ret, i_v.getResultShape());
}
@Override
public DifferentialFunction rsubi(DifferentialFunction i_v) {
validateDifferentialFunctionGraph(i_v);
X ret = i_v.getValue(true).rsubi(getValue(true));
return new Constant<>(sameDiff, ret, i_v.getResultShape());
}
@Override
public DifferentialFunction addi(DifferentialFunction i_v) {
validateDifferentialFunctionGraph(i_v);
X ret = i_v.getValue(true).addi(getValue(true));
return new Constant<>(sameDiff, ret, i_v.getResultShape());
}
@Override
public DifferentialFunction muli(DifferentialFunction i_v) {
validateDifferentialFunctionGraph(i_v);
X ret = i_v.getValue(true).muli(getValue(true));
return new Constant<>(sameDiff, ret, i_v.getResultShape());
}
@Override
public DifferentialFunction subi(DifferentialFunction i_v) {
validateDifferentialFunctionGraph(i_v);
X ret = i_v.getValue(true).subi(getValue(true));
return new Constant<>(sameDiff, ret, i_v.getResultShape());
}
@Override
public DifferentialFunction divi(DifferentialFunction i_v) {
validateDifferentialFunctionGraph(i_v);
X ret = i_v.getValue(true).divi(getValue(true));
return new Constant<>(sameDiff, ret, i_v.getResultShape());
}
@Override
public DifferentialFunction inversei() {
DifferentialFunction ret = new Inverse<>(sameDiff,this,true);
return ret;
}
@Override
public DifferentialFunction negatei() {
DifferentialFunction ret = new Negative<>(sameDiff,this,true);
return ret;
}
@Override
public DifferentialFunction muli(double i_n) {
PolynomialTerm ret = new PolynomialTerm<>(sameDiff,i_n, this, 1,true);
return ret;
}
@Override
public DifferentialFunction powi(int i_n) {
PolynomialTerm ret = new PolynomialTerm<>(sameDiff,1L,
this, i_n,true);
return ret;
}
@Override
public DifferentialFunction addi(double i_v) {
Scalar constant = new Scalar<>(sameDiff, i_v,true);
return constant.addi(this);
}
@Override
public DifferentialFunction subi(double i_v) {
Scalar constant = new Scalar<>(sameDiff, i_v,true);
return constant.subi(this);
}
@Override
public DifferentialFunction divi(double v) {
Scalar constant = new Scalar<>(sameDiff,
v,true);
return this.divi(constant);
}
@Override
public DifferentialFunction rsubi(double v) {
Scalar constant = new Scalar<>(sameDiff, v,true);
return this.rsubi(constant);
}
@Override
public DifferentialFunction rdivi(double v) {
Scalar constant = new Scalar<>(sameDiff, v,true);
return this.rdivi(constant);
}
@Override
public DifferentialFunction set(DifferentialFunction i_v) {
validateDifferentialFunctionGraph(i_v);
X ret = getValue(true).set(i_v.getValue(true));
return new Constant<>(sameDiff, ret, i_v.getResultShape());
}
@Override
public DifferentialFunction rdiv(DifferentialFunction i_v) {
validateDifferentialFunctionGraph(i_v);
X ret = i_v.getValue(true).rdiv(getValue(true));
return new Constant<>(sameDiff, ret, i_v.getResultShape());
}
@Override
public DifferentialFunction rsub(DifferentialFunction i_v) {
validateDifferentialFunctionGraph(i_v);
X ret = i_v.getValue(true).rsub(getValue(true));
return new Constant<>(sameDiff, ret, i_v.getResultShape());
}
@Override
public DifferentialFunction add(DifferentialFunction i_v) {
validateDifferentialFunctionGraph(i_v);
X ret = i_v.getValue(true).add(getValue(true));
return new Constant<>(sameDiff, ret, i_v.getResultShape());
}
@Override
public DifferentialFunction mul(DifferentialFunction i_v) {
validateDifferentialFunctionGraph(i_v);
X ret = i_v.getValue(true).mul(getValue(true));
return new Constant<>(sameDiff, ret, i_v.getResultShape());
}
@Override
public DifferentialFunction sub(DifferentialFunction i_v) {
validateDifferentialFunctionGraph(i_v);
X ret = i_v.getValue(true).sub(getValue(true));
return new Constant<>(sameDiff, ret, i_v.getResultShape());
}
@Override
public DifferentialFunction div(DifferentialFunction i_v) {
validateDifferentialFunctionGraph(i_v);
X ret = i_v.getValue(true).div(getValue(true));
return new Constant<>(sameDiff, ret, i_v.getResultShape());
}
@Override
public DifferentialFunction inverse() {
DifferentialFunction ret = new Inverse<>(sameDiff,this.mul(1.0));
return ret;
}
@Override
public DifferentialFunction negate() {
DifferentialFunction ret = new Negative<>(sameDiff,this.mul(1.0));
return ret;
}
@Override
public DifferentialFunction mul(double i_n) {
Scalar constant = new Scalar<>(sameDiff, i_n);
return this.mul(constant);
}
@Override
public DifferentialFunction pow(int i_n) {
PolynomialTerm ret = new PolynomialTerm<>(sameDiff,1L, this, i_n);
return ret;
}
@Override
public DifferentialFunction add(double i_v) {
Scalar constant = new Scalar<>(sameDiff, i_v);
return constant.add(this);
}
@Override
public DifferentialFunction sub(double i_v) {
Scalar constant = new Scalar<>(sameDiff, i_v);
return constant.sub(this);
}
@Override
public DifferentialFunction rsub(double v) {
Scalar constant = new Scalar<>(sameDiff, v);
return this.rsub(constant);
}
@Override
public DifferentialFunction rdiv(double v) {
Scalar constant = new Scalar<>(sameDiff, v);
return this.rdiv(constant);
}
protected void addEdges(SameDiff sameDiff,
DifferentialFunction i_v1,
DifferentialFunction i_v2,
String opName,
OpState.OpType opType,
int[] shape) {
addEdges(sameDiff,
i_v1,
i_v2,
opName,
opType,
shape,
null);
}
@Override
public DifferentialFunction[] args() {
return new DifferentialFunction[0];
}
@Override
public DifferentialFunction pow(DifferentialFunction a) {
return null;
}
@Override
public DifferentialFunction floor() {
return null;
}
@Override
public DifferentialFunction ceil() {
return null;
}
@Override
public DifferentialFunction round() {
return null;
}
@Override
public DifferentialFunction abs() {
return null;
}
@Override
public DifferentialFunction sqrt() {
return null;
}
@Override
public DifferentialFunction minus(double v) {
return null;
}
@Override
public DifferentialFunction prod(double v) {
return null;
}
@Override
public DifferentialFunction div(double v) {
return null;
}
@Override
public DifferentialFunction pow(double v) {
return null;
}
@Override
public DifferentialFunction cos() {
return null;
}
@Override
public DifferentialFunction acos() {
return null;
}
@Override
public DifferentialFunction cosh() {
return null;
}
@Override
public DifferentialFunction acosh() {
return null;
}
@Override
public DifferentialFunction sin() {
return null;
}
@Override
public DifferentialFunction asin() {
return null;
}
@Override
public DifferentialFunction sinh() {
return null;
}
@Override
public DifferentialFunction asinh() {
return null;
}
@Override
public DifferentialFunction tan() {
return null;
}
@Override
public DifferentialFunction atan() {
return null;
}
@Override
public DifferentialFunction tanh() {
return null;
}
@Override
public DifferentialFunction atanh() {
return null;
}
@Override
public DifferentialFunction exp() {
return null;
}
@Override
public DifferentialFunction log() {
return null;
}
@Override
public DifferentialFunction log10() {
return null;
}
@Override
public DifferentialFunction sgn() {
return null;
}
@Override
public DifferentialFunction pwr(DifferentialFunction y) {
return null;
}
@Override
public DifferentialFunction pwrs(DifferentialFunction y) {
return null;
}
@Override
public DifferentialFunction square() {
return null;
}
@Override
public DifferentialFunction relu() {
return null;
}
@Override
public DifferentialFunction hardTanh() {
return null;
}
@Override
public DifferentialFunction hardTanhDerivative() {
return null;
}
@Override
public DifferentialFunction leakyRelu() {
return null;
}
@Override
public DifferentialFunction elu() {
return null;
}
@Override
public DifferentialFunction eluDerivative() {
return null;
}
@Override
public DifferentialFunction leakyRelu(double cutoff) {
return null;
}
@Override
public DifferentialFunction leakyReluDerivative() {
return null;
}
@Override
public DifferentialFunction leakyReluDerivative(double cutoff) {
return null;
}
@Override
public DifferentialFunction sigmoid() {
return null;
}
@Override
public DifferentialFunction sigmoidDerivative() {
return null;
}
@Override
public DifferentialFunction step() {
return null;
}
@Override
public DifferentialFunction softsign() {
return null;
}
@Override
public DifferentialFunction softsignDerivative() {
return null;
}
@Override
public DifferentialFunction softmax() {
return null;
}
@Override
public DifferentialFunction softplus() {
return null;
}
@Override
public DifferentialFunction reshape(int[] shape) {
return null;
}
@Override
public DifferentialFunction transpose() {
return null;
}
@Override
public DifferentialFunction permute(int[] dimensions) {
return null;
}
@Override
public DifferentialFunction expandDims(int dim) {
return null;
}
@Override
public DifferentialFunction sum(int[] dimensions) {
return null;
}
@Override
public DifferentialFunction prod(int[] dimensions) {
return null;
}
@Override
public DifferentialFunction mean(int[] dimensions) {
return null;
}
@Override
public DifferentialFunction std(int[] dimensions, boolean biasCorrected) {
return null;
}
@Override
public DifferentialFunction variance(int[] dimensions, boolean biasCorrected) {
return null;
}
@Override
public DifferentialFunction std(int[] dimensions) {
return null;
}
@Override
public DifferentialFunction variance(int[] dimensions) {
return null;
}
@Override
public DifferentialFunction max(int[] dimensions) {
return null;
}
@Override
public DifferentialFunction min(int[] dimensions) {
return null;
}
@Override
public DifferentialFunction norm1(int[] dimensions) {
return null;
}
@Override
public DifferentialFunction norm2(int[] dimensions) {
return null;
}
@Override
public DifferentialFunction normmax(int[] dimensions) {
return null;
}
@Override
public DifferentialFunction valueArrayOf(int[] shape) {
return null;
}
@Override
public DifferentialFunction tile(int[] repeat) {
return null;
}
@Override
public DifferentialFunction repeat(int axis) {
return null;
}
@Override
public DifferentialFunction broadcast(int[] shape) {
return null;
}
@Override
public DifferentialFunction eq(DifferentialFunction i_y) {
return null;
}
@Override
public DifferentialFunction neq(DifferentialFunction i_y) {
return null;
}
@Override
public DifferentialFunction or(DifferentialFunction i_y) {
return null;
}
@Override
public DifferentialFunction rollAxis(int axis) {
return null;
}
@Override
public DifferentialFunction cosineSimilarity(DifferentialFunction i_y, int... dimensions) {
return null;
}
@Override
public DifferentialFunction euclideanDistance(DifferentialFunction i_y, int... dimensions) {
return null;
}
@Override
public DifferentialFunction manhattanDistance(DifferentialFunction i_y, int... dimensions) {
return null;
}
@Override
public DifferentialFunction lossBinaryXENT(DifferentialFunction i_y, int... dimensions) {
return null;
}
@Override
public DifferentialFunction lossCosineSimilarity(DifferentialFunction i_y, int... dimensions) {
return null;
}
@Override
public DifferentialFunction lossHinge(DifferentialFunction i_y, int... dimensions) {
return null;
}
@Override
public DifferentialFunction lossKLD(DifferentialFunction i_y, int... dimensions) {
return null;
}
@Override
public DifferentialFunction lossL1(DifferentialFunction i_y, int... dimensions) {
return null;
}
@Override
public DifferentialFunction lossL2(DifferentialFunction i_y, int... dimensions) {
return null;
}
@Override
public DifferentialFunction lossMAE(DifferentialFunction i_y, int... dimensions) {
return null;
}
@Override
public DifferentialFunction lossMAPE(DifferentialFunction i_y, int... dimensions) {
return null;
}
@Override
public DifferentialFunction lossMSE(DifferentialFunction i_y, int... dimensions) {
return null;
}
@Override
public DifferentialFunction lossMCXENT(DifferentialFunction i_y, int... dimensions) {
return null;
}
@Override
public DifferentialFunction lossMSLE(DifferentialFunction i_y, int... dimensions) {
return null;
}
@Override
public DifferentialFunction lossNegativeLogLikelihood(DifferentialFunction i_y, int... dimensions) {
return null;
}
@Override
public DifferentialFunction lossPoisson(DifferentialFunction i_y, int... dimensions) {
return null;
}
@Override
public DifferentialFunction lossSquaredHinge(DifferentialFunction i_y, int... dimensions) {
return null;
}
protected void addEdges(SameDiff sameDiff,
DifferentialFunction i_v1,
DifferentialFunction i_v2,
String opName,
OpState.OpType opType,
int[] shape, Object[] extraArgs) {
if(i_v1.getValue(true) instanceof ArrayField) {
validateDifferentialFunctionGraph(i_v1);
validateDifferentialFunctionGraph(i_v2);
/**
* getValue() generates invalid vertex ids
* need to look at a way of getting the proper vertex
* metadata
*
* Should be looking at a way to derive the vertex id
* for each of these equations.
*
*
*
*/
ArrayField v1 = (ArrayField) i_v1.getValue(true);
int v1VertexId = i_v1.resultVertexId();
ArrayField v2 = (ArrayField) i_v2.getValue(true);
int v2VertexId = i_v2.resultVertexId();
NDArrayInformation arrInfo = NDArrayInformation.builder()
.arrId(UUID.randomUUID().toString())
.id(opName +"(" + v1.getInput().getId() + "," + v2.getInput().getId() + ")")
.shape(shape).build();
//result
NDArrayVertex newVertex = new NDArrayVertex(sameDiff.getGraph().nextVertexId(), arrInfo);
if(newVertex.vertexID() == v2VertexId || newVertex.vertexID() == v1VertexId)
throw new ND4JIllegalStateException("Illegal vertex id specified in new vertex." +
" Perhaps a mismatched graph call? Another likely cause is applyGraph");
this.vertexId = newVertex.vertexID();
//add the result vertex
sameDiff.getGraph().addVertex(newVertex);
OpState opState,opState2;
//ensure there's 2 vertices for when the 2 inputs are the same
if(v1.equals(v2)) {
NDArrayVertex dupVertex = new NDArrayVertex(sameDiff.getGraph().nextVertexId(),
NDArrayInformation.builder()
.shape(v1.getInput().getShape())
.id(v1.getInput().getId()).build());
//update vertex id
v2VertexId = dupVertex.vertexID();
sameDiff.getGraph().addVertex(dupVertex);
opState = OpState.builder()
.opType(opType)
.opName(opName)
.id(opName + "(" + dupVertex.getValue().getId() + " -> " + newVertex.getValue().getId() + ")")
.vertexIds(new String[]{String.valueOf(v2VertexId),String.valueOf(newVertex.vertexID())})
.n(ArrayUtil.prod(shape))
.extraArgs(extraArgs)
.result(arrInfo)
.build();
}
else {
opState = OpState.builder()
.opType(opType)
.opName(opName)
.id(opName + "(" + v1.getVertex().getValue().getId() + " -> " + newVertex.getValue().getId() + ")")
.vertexIds(new String[]{String.valueOf(v2VertexId),String.valueOf(newVertex.vertexID())})
.n(ArrayUtil.prod(shape))
.extraArgs(extraArgs)
.result(arrInfo)
.build();
}
opState2 = OpState.builder()
.opType(opType)
.opName(opName).result(arrInfo)
.id(opName + "(" + v1.getVertex().getValue().getId() + " -> " + newVertex.getValue().getId() + ")")
.vertexIds(new String[]{String.valueOf(v1VertexId),String.valueOf(newVertex.vertexID())})
.n(ArrayUtil.prod(shape))
.extraArgs(extraArgs)
.result(arrInfo)
.build();
//add the first vertex no matter what as normal
sameDiff.getGraph().addEdge(v1VertexId,
newVertex.vertexID(),
opState2,true);
sameDiff.getGraph().addEdge(v2VertexId,
newVertex.vertexID(),
opState
,true);
newVertex.setOpState(opState2);
arrInfo.setOwner(opState2);
this.opState = opState;
}
}
protected void addEdges(SameDiff sameDiff,
DifferentialFunction i_v1,
DifferentialFunction i_v2,
String opName) {
validateDifferentialFunctionGraph(i_v1);
validateDifferentialFunctionGraph(i_v2);
if(i_v1.getValue(true) instanceof ArrayField) {
ArrayField arrayField = (ArrayField) i_v1.getValue(true);
addEdges(sameDiff,
i_v1,
i_v2,
opName,
OpState.OpType.TRANSFORM,
arrayField.getInput().getShape());
}
else
throw new UnsupportedOperationException("Only supporting array fields");
}
protected boolean getInPlace(Object[] extraArgs) {
if(extraArgs == null) {
return false;
}
else {
for(int i = 0; i < extraArgs.length; i++) {
if(extraArgs[i] instanceof Boolean)
return (Boolean) extraArgs[i];
}
}
return false;
}
public abstract DifferentialFunction dup();
public int resultVertexId() {
return vertexId;
}
protected void validateDifferentialFunctionsameDiff(
DifferentialFunction function) {
Preconditions.checkState(function != null,"Passed in function was null.");
Preconditions.checkState(function.getSameDiff() ==
this.getSameDiff(),
"Function applications must be contained " +
"in same sameDiff. The left " + function +"" +
" must match this function " + this);
Preconditions.checkState(sameDiff ==
this.getSameDiff(),"Function applications m" +
"ust be " +
"contained in same sameDiff. The left " + function +" " +
"must " +
"match this function " + this);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy