org.nd4j.autodiff.functions.AbstractBinaryReduceFunction Maven / Gradle / Ivy
package org.nd4j.autodiff.functions;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.ArrayField;
import org.nd4j.autodiff.Field;
import org.nd4j.autodiff.graph.Graph;
import org.nd4j.autodiff.opstate.NDArrayInformation;
import org.nd4j.autodiff.opstate.OpState;
import org.nd4j.autodiff.samediff.SDGraph;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.shape.Shape;
import java.util.List;
/**
* Created by agibsonccc on 4/12/17.
*/
@NoArgsConstructor
public abstract class AbstractBinaryReduceFunction> extends AbstractBinaryFunction {
protected int[] dimensions;
public AbstractBinaryReduceFunction(SameDiff sameDiff,
DifferentialFunction i_v1,
DifferentialFunction i_v2,
int...dimensions) {
super(sameDiff, i_v1, i_v2);
this.dimensions = dimensions;
//note that the below won't trigger if dimensions are null
//please don't remove this
addEdges(sameDiff,i_v1,
i_v2,functionName());
}
public AbstractBinaryReduceFunction(SameDiff sameDiff) {
super(sameDiff);
}
@Override
protected void addEdges(SameDiff sameDiff,
DifferentialFunction i_v1,
DifferentialFunction i_v2,
String opName) {
if(i_v1.getValue(true) instanceof ArrayField) {
ArrayField arrayField = i_v1.getValue(true);
//skip empty dimensions
if(dimensions == null)
return;
addEdges(sameDiff,i_v1,i_v2,opName,
OpState.OpType.ACCUMULATION,
Shape.getReducedShape(arrayField.getInput().getShape(),
dimensions));
}
else
throw new UnsupportedOperationException("Only supporting array fields");
}
@Override
public String doGetFormula(List> variables) {
return toString();
}
@Override
public double getReal() {
throw new UnsupportedOperationException();
}
@Override
public String toString() {
return functionName() + "(" + larg() + "," + rarg() + ")";
}
@Override
public DifferentialFunction dup() {
try {
return getClass().getConstructor(sameDiff.getClass(),larg()
.getClass(),rarg().getClass()).newInstance(sameDiff,larg(),rarg());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy