
org.nd4j.autodiff.functions.mmul.Mmul Maven / Gradle / Ivy
package org.nd4j.autodiff.functions.mmul;
import org.nd4j.autodiff.ArrayField;
import org.nd4j.autodiff.Field;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
import org.nd4j.autodiff.graph.Graph;
import org.nd4j.autodiff.opstate.NDArrayInformation;
import org.nd4j.autodiff.opstate.OpState;
import org.nd4j.autodiff.samediff.SDGraph;
import org.nd4j.linalg.api.shape.Shape;
/**
* Specialized matrix multiply operations.
* Many people know this as "gemm"
*
*
*/
public class Mmul> extends TensorMmul {
public Mmul(SDGraph graph,
DifferentialFunction i_v1,
DifferentialFunction i_v2,
DifferentialFunctionFactory differentialFunctionFactory,
int argNum) {
super(graph,
i_v1,
i_v2,
differentialFunctionFactory,new int[][] {
{1},{0}
},argNum);
}
@Override
protected void addEdges(Graph graph,
DifferentialFunction i_v1,
DifferentialFunction i_v2,
String opName) {
if(i_v1.getValue(true) instanceof ArrayField) {
ArrayField arrayField = (ArrayField) i_v1.getValue(true);
ArrayField secondVal = (ArrayField) i_v2.getValue(true);
//skip empty dimensions
addEdges(graph,i_v1,i_v2,opName,
OpState.OpType.ACCUMULATION,
Shape.getMatrixMultiplyShape(arrayField.getInput().getShape(),secondVal.getInput().getShape()));
}
else
throw new UnsupportedOperationException("Only supporting array fields");
}
/**
* Get the value of this function
*
* @return
*/
@Override
protected X doGetValue() {
return differentialFunctionFactory.getMFactory().mmul(larg(),rarg());
}
@Override
public String functionName() {
return "mmul";
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy