org.nd4j.autodiff.functions.Zero Maven / Gradle / Ivy
package org.nd4j.autodiff.functions;
import org.nd4j.autodiff.AbstractIdentityFactory;
import org.nd4j.autodiff.ArrayField;
import org.nd4j.autodiff.Field;
import org.nd4j.autodiff.opstate.OpState;
import org.nd4j.autodiff.samediff.SDGraph;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.MulOp;
public class Zero> extends Constant {
public Zero(SameDiff sameDiff, int[] shape) {
super(sameDiff, (X) sameDiff.getArrayFactory().zero(shape),shape);
ArrayField arrayField = (ArrayField) m_x;
arrayField.getInput().setScalarValue(0.0);
}
@Override
public DifferentialFunction add(DifferentialFunction i_v) {
addEdge(new AddOp().name(),i_v);
return i_v;
}
@Override
public Zero mul(DifferentialFunction i_v) {
addEdge(new MulOp().name(),i_v);
return this;
}
@Override
public Constant inverse() {
// TODO
throw new UnsupportedOperationException();
}
@Override
public Zero negate() {
addEdge(new org.nd4j.linalg.api.ops.impl.transforms.Negative().name(),this);
return this;
}
private void addEdge(String opName,DifferentialFunction i_v) {
if(i_v.getValue(true) instanceof ArrayField) {
ArrayField x = (ArrayField) i_v.getValue(true);
addEdges(sameDiff,
this,
i_v,
opName,
OpState.OpType.TRANSFORM,
x.getInput().getShape(),
null);
}
}
@Override
public DifferentialFunction dup() {
return new Zero<>(sameDiff,shape);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy