org.nd4j.autodiff.functions.Constant Maven / Gradle / Ivy
package org.nd4j.autodiff.functions;
import java.util.List;
import lombok.Data;
import org.nd4j.autodiff.ArrayField;
import org.nd4j.autodiff.Field;
import org.nd4j.autodiff.samediff.SameDiff;
@Data
public class Constant> extends DifferentialFunction {
protected X m_x;
protected int[] shape;
protected Constant(SameDiff sameDiff,
X i_v,
int[] shape,
boolean inPlace) {
super(sameDiff,new Object[]{i_v,inPlace});
this.shape = shape;
if (i_v != null) {
m_x = i_v;
} else {
throw new IllegalArgumentException("Input not null value.");
}
if(i_v instanceof ArrayField) {
ArrayField arrayField = (ArrayField) i_v;
this.vertexId = arrayField.getVertex().vertexID();
if(sameDiff.getGraph().getVertex(this.vertexId) == null)
sameDiff.getGraph().addVertex(arrayField.getVertex());
}
}
protected Constant(SameDiff sameDiff,
X i_v,
int[] shape) {
this(sameDiff,i_v,shape,false);
}
public Constant(SameDiff sameDiff, ArrayField one) {
}
/**
* Get the result shape for this function
*
* @return
*/
@Override
public int[] getResultShape() {
return shape;
}
@Override
public boolean isConstant() {
return true;
}
@Override
public X doGetValue() {
return m_x;
}
@Override
public double getReal() {
return m_x.getReal();
}
@Override
public DifferentialFunction[] args() {
return new DifferentialFunction[] {this};
}
@Override
public DifferentialFunction arg() {
return this;
}
@Override
public DifferentialFunction diff(DifferentialFunction i_v) {
return new Zero<>(sameDiff,shape);
}
@Override
public String toString() {
return getValue(true).toString();
}
@Override
public String doGetFormula(List> variables) {
return getValue(true).toString();
}
@Override
public String functionName() {
return "constant";
}
@Override
public Constant inverse() {
Constant ret = new Constant<>(sameDiff, m_x.inverse(),shape);
return ret;
}
@Override
public Constant negate() {
Constant ret = new Constant<>(sameDiff, m_x.negate(),shape);
return ret;
}
@Override
public DifferentialFunction dup() {
return new Constant<>(sameDiff,m_x,shape);
}
// This class must be immutable.
// set and assign must not be implemented.
@SuppressWarnings("unused")
private final void set(X i_x) {
}
@SuppressWarnings("unused")
private final void assign(X i_x) {
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy