org.nd4j.autodiff.functions.AbstractUnaryFunction Maven / Gradle / Ivy
package org.nd4j.autodiff.functions;
import com.rits.cloning.Cloner;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.ArrayField;
import org.nd4j.autodiff.Field;
import org.nd4j.autodiff.opstate.NDArrayInformation;
import org.nd4j.autodiff.opstate.NDArrayVertex;
import org.nd4j.autodiff.opstate.OpState;
import org.nd4j.autodiff.samediff.SDGraph;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.List;
import java.util.UUID;
@Data
@NoArgsConstructor
public abstract class AbstractUnaryFunction> extends DifferentialFunction {
protected DifferentialFunction m_x;
protected int[] shape;
protected OpState.OpType opType;
public AbstractUnaryFunction(SameDiff sameDiff,
DifferentialFunction i_v,
int[] shape,
OpState.OpType opType,
Object[] extraArgs) {
super(sameDiff,extraArgs);
this.opType = opType;
if (i_v != null) {
m_x = i_v;
validateDifferentialFunctionsameDiff(i_v);
addEdges(sameDiff,m_x,functionName(),shape);
} else {
throw new IllegalArgumentException("Input not null variable.");
}
}
public AbstractUnaryFunction(SameDiff sameDiff,
DifferentialFunction i_v,
int[] shape,
Object[] extraArgs) {
this(sameDiff,i_v,shape, OpState.OpType.TRANSFORM,extraArgs);
}
public AbstractUnaryFunction(SameDiff sameDiff,
DifferentialFunction i_v,
Object[] extraArgs) {
super(sameDiff,extraArgs);
if (i_v != null) {
m_x = i_v;
validateDifferentialFunctionsameDiff(i_v);
addEdges(sameDiff,m_x,functionName());
} else {
throw new IllegalArgumentException("Input not null variable.");
}
}
@Override
public String doGetFormula(List> variables) {
return functionName() + "(" + arg().doGetFormula(variables) + ")";
}
@Override
public String toString() {
return functionName() + "(" + arg().toString() + ")";
}
/**
* Add nodes to the graph
* @param sameDiff
* @param i_v1
* @param opName
*/
protected void addEdges(SameDiff sameDiff,
DifferentialFunction i_v1,
String opName,
int...shape) {
if(i_v1.getValue(true) instanceof ArrayField) {
ArrayField v1 = (ArrayField) i_v1.getValue(true);
NDArrayInformation information = NDArrayInformation.builder()
.arrId(UUID.randomUUID().toString())
.id(opName + "(" + v1.getInput().getId() + " -> " +
v1.getInput().getId() + ")")
.shape(shape).build();
//result
NDArrayVertex newVertex = new NDArrayVertex(sameDiff.getGraph().nextVertexId(), information);
this.vertexId = newVertex.vertexID();
sameDiff.getGraph().addVertex(newVertex);
OpState owner = OpState.builder()
.opType(opType)
.opName(opName).extraArgs(extraArgs)
.id(opName + "(" + v1.getInput().getId() + " -> " + newVertex.getValue().getId() + ")")
.vertexIds(new String[]{String.valueOf(v1.getVertex().vertexID()),String.valueOf(newVertex.vertexID())})
.n(ArrayUtil.prod(shape)).result(information)
.build();
sameDiff.getGraph().addEdge(v1.getVertex().vertexID(),newVertex.vertexID(),owner,true);
newVertex.setOpState(owner);
information.setOwner(owner);
owner.setResult(information);
if(owner.isInPlace()) {
information.setArrId(v1.getInput().getArrId());
}
this.opState = owner;
}
}
/**
* Add nodes to the graph
* @param sameDiff
* @param i_v1
* @param opName
*/
protected void addEdges(SameDiff sameDiff,
DifferentialFunction i_v1,
String opName) {
if(i_v1.getValue(true) instanceof ArrayField) {
this.opType = OpState.OpType.TRANSFORM;
ArrayField arrayField = (ArrayField) i_v1.getValue(true);
addEdges(sameDiff,
i_v1,
opName,
arrayField.getInput().getShape());
}
}
@Override
public DifferentialFunction[] args() {
return new DifferentialFunction[] {arg()};
}
@Override
public DifferentialFunction arg() {
return m_x;
}
@Override
public DifferentialFunction dup() {
Cloner cloner = new Cloner();
return cloner.deepClone(this);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy