All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.functions.mmul.Mmul Maven / Gradle / Ivy

package org.nd4j.autodiff.functions.mmul;

import org.nd4j.autodiff.ArrayField;
import org.nd4j.autodiff.Field;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
import org.nd4j.autodiff.graph.Graph;
import org.nd4j.autodiff.opstate.NDArrayInformation;
import org.nd4j.autodiff.opstate.OpState;
import org.nd4j.autodiff.samediff.SDGraph;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.shape.Shape;

import java.lang.reflect.Array;

/**
 *  Specialized matrix multiply operations.
 *  Many people know this as "gemm"
 *
 *
 */

public class Mmul extends TensorMmul {


    public Mmul(SameDiff sameDiff,
                DifferentialFunction i_v1,
                DifferentialFunction i_v2,
                int argNum) {
        super(sameDiff,
                i_v1,
                i_v2, new int[][] {
                {1},{0}
        },argNum);
    }



    @Override
    protected void addEdges(SameDiff sameDiff,
                            DifferentialFunction i_v1,
                            DifferentialFunction i_v2,
                            String opName) {
        if(i_v1.getValue(true) instanceof ArrayField) {
            ArrayField arrayField = i_v1.getValue(true);
            ArrayField secondVal = i_v2.getValue(true);
            //skip empty dimensions
            addEdges(sameDiff,i_v1,i_v2,opName,
                    OpState.OpType.ACCUMULATION,
                    Shape.getMatrixMultiplyShape(arrayField.getInput().getShape(),secondVal.getInput().getShape()));

        }

        else
            throw new UnsupportedOperationException("Only supporting array fields");
    }



    /**
     * Get the value of this function
     *
     * @return
     */
    @Override
    public ArrayField doGetValue() {
        return sameDiff.getArrayFactory().mmul(larg(),rarg());
    }


    @Override
    public String functionName() {
        return "mmul";
    }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy