All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.linalg.cpu.nativecpu.ops;


import lombok.Getter;
import org.apache.commons.math3.util.Pair;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.*;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.cache.ConstantHandler;
import org.nd4j.linalg.cpu.nativecpu.CpuTADManager;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.NativeOps;

import java.util.Arrays;


/**
 *
 * Native operation
 * executioner in c++
 *
 * @author Adam Gibson
 */

public class NativeOpExecutioner extends DefaultOpExecutioner {
    private NativeOps loop = new NativeOps();
    private ConstantHandler constantHandler = Nd4j.getConstantHandler();
    @Getter private CpuTADManager tadManager = new CpuTADManager();

    public NativeOpExecutioner() {
        tadManager.init(loop, constantHandler);
    }

    @Override
    public Op exec(Op op) {
        if(op instanceof ScalarOp) {
            ScalarOp s = (ScalarOp) op;
            exec(s);
        }
        else if(op instanceof TransformOp) {
            TransformOp t = (TransformOp) op;
            exec(t);
        }
        else if(op instanceof Accumulation) {
            Accumulation ac = (Accumulation) op;
            exec(ac);
        }
        else if(op instanceof IndexAccumulation) {
            IndexAccumulation iac = (IndexAccumulation) op;
            exec(iac);  //Currently using DefaultOpExecutioner
        }
        else if(op instanceof BroadcastOp) {
            BroadcastOp broadcastOp = (BroadcastOp) op;
            exec(broadcastOp,broadcastOp.getDimension());
        }

        return op;
    }


    @Override
    public INDArray exec(IndexAccumulation op, int... dimension) {
        Arrays.sort(dimension);
        for(int i = 0; i < dimension.length; i++) {
            if(dimension[i] < 0)
                dimension[i] += op.x().rank();
        }
        //do op along all dimensions
        if (dimension.length == op.x().rank())
            dimension = new int[]{Integer.MAX_VALUE};



        int[] retShape = Shape.wholeArrayDimension(dimension) ? new int[] {1,1} : ArrayUtil.removeIndex(op.x().shape(), dimension);
        if(op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape))
            return op.x();


        //ensure vector is proper shape
        if (retShape.length == 1) {
            if (dimension[0] == 0)
                retShape = new int[]{1, retShape[0]};
            else
                retShape = new int[]{retShape[0], 1};
        } else if (retShape.length == 0) {
            retShape = new int[]{1, 1};
        }

        INDArray ret = Nd4j.valueArrayOf(retShape,op.zeroDouble());
        op.setZ(ret);
        //do op along all dimensions
        if (dimension.length == op.x().rank())
            dimension = new int[]{Integer.MAX_VALUE};


        Pointer dimensionAddress = constantHandler.getConstantBuffer(dimension).addressPointer();

        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);

        Pointer hostTadShapeInfo = tadBuffers.getFirst().addressPointer();

        DataBuffer offsets = tadBuffers.getSecond();
        Pointer hostTadOffsets = offsets == null ? null : offsets.addressPointer();

        PointerPointer dummy = new PointerPointer(
                hostTadShapeInfo,
                hostTadOffsets
        );

        Pointer x = op.x().data().addressPointer();
        Pointer z = op.z().data().addressPointer();

        if(op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            loop.execIndexReduceDouble(
                    dummy,
                    op.opNum(),
                    x,
                    op.x().shapeInfoDataBuffer().addressPointer(),
                    getPointerForExtraArgs(op),
                    z,
                    op.z().shapeInfoDataBuffer().addressPointer(),
                    dimensionAddress, dimension.length);

        }
        else {
            loop.execIndexReduceFloat(
                    dummy,
                    op.opNum(),
                    op.x().data().addressPointer(),
                    op.x().shapeInfoDataBuffer().addressPointer(),
                    getPointerForExtraArgs(op),
                    op.z().data().addressPointer(),
                    op.z().shapeInfoDataBuffer().addressPointer(),
                    dimensionAddress, dimension.length);

        }
        return op.z();
    }



    @Override
    public INDArray exec(Accumulation op, int... dimension) {
        Arrays.sort(dimension);

        for(int i = 0; i < dimension.length; i++) {
            if(dimension[i] < 0)
                dimension[i] += op.x().rank();
        }
        //do op along all dimensions
        if (dimension.length == op.x().rank())
            dimension = new int[]{Integer.MAX_VALUE};


        int[] retShape = Shape.wholeArrayDimension(dimension) ? new int[] {1,1} : ArrayUtil.removeIndex(op.x().shape(), dimension);
        //ensure vector is proper shape
        if (retShape.length == 1) {
            if (dimension[0] == 0)
                retShape = new int[]{1, retShape[0]};
            else
                retShape = new int[]{retShape[0], 1};
        } else if (retShape.length == 0) {
            retShape = new int[]{1, 1};
        }

        if(op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape))
            return op.noOp();

        INDArray ret = Nd4j.valueArrayOf(retShape,op.zeroDouble());
        op.setZ(ret);


        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);

        Pointer hostTadShapeInfo = tadBuffers.getFirst().addressPointer();

        DataBuffer offsets = tadBuffers.getSecond();
        Pointer hostTadOffsets = offsets == null ? null : offsets.addressPointer();

        PointerPointer dummy = new PointerPointer(
                hostTadShapeInfo,
                hostTadOffsets
        );

        Pointer dimensionAddress = constantHandler.getConstantBuffer(dimension).addressPointer();

        if(op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if(op instanceof Variance) {
                if(ret.isScalar()) {
                    ret.putScalar(0,loop.execSummaryStatsScalarDouble(
                            dummy,
                            op.opNum()
                            , op.x().data().addressPointer(),
                            op.x().shapeInfoDataBuffer().addressPointer(),
                            getPointerForExtraArgs(op), true));
                }
                else {
                    Variance var = (Variance) op;
                    loop.execSummaryStatsDouble(
                            dummy,
                            op.opNum(),
                            op.x().data().addressPointer(),
                            op.x().shapeInfoDataBuffer().addressPointer(),
                            getPointerForExtraArgs(op),
                            op.z().data().addressPointer(),
                            op.z().shapeInfoDataBuffer().addressPointer(),dimensionAddress,dimension.length,
                            var.isBiasCorrected());
                }

            }

            else if(op.y() != null) {
                if(ret.isScalar()) {
                    ret.putScalar(0,loop.execReduce3ScalarDouble(
                            dummy,
                            op.opNum(),
                            op.x().data().addressPointer(),
                            op.x().shapeInfoDataBuffer().addressPointer(),
                            getPointerForExtraArgs(op),
                            op.y().data().addressPointer(),
                            op.y().shapeInfoDataBuffer().addressPointer()));
                }
                else {
                    loop.execReduce3Double(
                            dummy,
                            op.opNum(),
                            op.x().data().addressPointer(),
                            op.x().shapeInfoDataBuffer().addressPointer(),
                            getPointerForExtraArgs(op),
                            op.y().data().addressPointer(),
                            op.y().shapeInfoDataBuffer().addressPointer(),
                            op.z().data().addressPointer(),
                            op.z().shapeInfoDataBuffer().addressPointer(),
                            dimensionAddress, dimension.length);
                }

            }
            else {
                if(ret.isScalar()) {
                    ret.putScalar(0,loop.execReduceScalarDouble(
                            dummy,
                            op.opNum(),
                            op.x().data().addressPointer(),
                            op.x().shapeInfoDataBuffer().addressPointer(),
                            getPointerForExtraArgs(op)));
                }
                else {
                    loop.execReduceDouble(
                            dummy,
                            op.opNum(),
                            op.x().data().addressPointer(),
                            op.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(op),
                            op.z().data().addressPointer(),
                            op.z().shapeInfoDataBuffer().addressPointer(),
                            dimensionAddress, dimension.length);
                }

            }
        }
        else {
            if(op instanceof Variance) {
                Variance variance = (Variance) op;
                if(ret.isScalar()) {
                    ret.putScalar(0,loop.execSummaryStatsScalarFloat(
                            dummy,
                            op.opNum()
                            , op.x().data().addressPointer(),
                            op.x().shapeInfoDataBuffer().addressPointer(),
                            getPointerForExtraArgs(op),variance.isBiasCorrected()));
                }
                else {
                    loop.execSummaryStatsFloat(
                            dummy,
                            op.opNum(),
                            op.x().data().addressPointer(),
                            op.x().shapeInfoDataBuffer().addressPointer(),getPointerForExtraArgs(op),
                            op.z().data().addressPointer(),
                            op.z().shapeInfoDataBuffer().addressPointer(),
                            dimensionAddress, dimension.length,variance.isBiasCorrected());
                }

            }

            else if(op.y() != null) {
                if(ret.isScalar()) {
                    ret.putScalar(0,loop.execReduce3ScalarFloat(
                            dummy,
                            op.opNum(),
                            op.x().data().addressPointer(),
                            op.x().shapeInfoDataBuffer().addressPointer(),
                            getPointerForExtraArgs(op),
                            op.y().data().addressPointer(),
                            op.y().shapeInfoDataBuffer().addressPointer()));
                }
                else {
                    loop.execReduce3Float(
                            dummy,
                            op.opNum(),
                            op.x().data().addressPointer(),
                            op.x().shapeInfoDataBuffer().addressPointer(),getPointerForExtraArgs(op),
                            op.y().data().addressPointer(),
                            op.y().shapeInfoDataBuffer().addressPointer(),
                            op.z().data().addressPointer(),
                            op.z().shapeInfoDataBuffer().addressPointer(),
                            dimensionAddress, dimension.length);
                }

            }
            else {
                if(ret.isScalar()) {
                    ret.putScalar(0,loop.execReduceScalarFloat(
                            dummy,
                            op.opNum(),
                            op.x().data().addressPointer(),
                            op.x().shapeInfoDataBuffer().addressPointer(),
                            getPointerForExtraArgs(op)));
                }
                else {
                    loop.execReduceFloat(
                            dummy,
                            op.opNum(),
                            op.x().data().addressPointer(),
                            op.x().shapeInfoDataBuffer().addressPointer(),
                            getPointerForExtraArgs(op),
                            op.z().data().addressPointer(),
                            op.z().shapeInfoDataBuffer().addressPointer(),
                            dimensionAddress, dimension.length);
                }

            }
        }

        return ret;
    }

    private void exec(ScalarOp op) {
        if(op.x() instanceof IComplexNDArray || executionMode() == ExecutionMode.JAVA) {
            super.exec(op);
        }
        else {
            PointerPointer dummy = new PointerPointer(new Pointer[] {null});
            if(op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if(op.x(). elementWiseStride() >= 1 && !op.isExecSpecial() && op.z(). elementWiseStride() >= 1 && !op.isExecSpecial()) {
                    loop.execScalarDouble(
                            dummy,
                            op.opNum(),
                            op.x().data().addressPointer(),
                            op.x().elementWiseStride(),
                            op.z().data().addressPointer(),
                            op.z().elementWiseStride(),
                            op.scalar().doubleValue(),
                            getPointerForExtraArgs(op),
                            op.n());
                }
                else
                    loop.execScalarDouble(
                            dummy,
                            op.opNum()
                            , op.x().data().addressPointer(),
                            op.x().shapeInfoDataBuffer().addressPointer(),
                            op.z().data().addressPointer(),
                            op.z().shapeInfoDataBuffer().addressPointer(),
                            op.scalar().doubleValue(),
                            getPointerForExtraArgs(op));
            }
            else {
                if(op.x(). elementWiseStride() >= 1 && !op.isExecSpecial() && op.z(). elementWiseStride() >= 1 && !op.isExecSpecial()) {
                    loop.execScalarFloat(
                            dummy,
                            op.opNum(),
                            op.x().data().addressPointer(),
                            op.x().elementWiseStride(),
                            op.z().data().addressPointer(),
                            op.z().elementWiseStride(),
                            op.scalar().floatValue(),
                            getPointerForExtraArgs(op),
                            op.n());
                }
                else
                    loop.execScalarFloat(
                            dummy,
                            op.opNum()
                            , op.x().data().addressPointer(),
                            op.x().shapeInfoDataBuffer().addressPointer(),
                            op.z().data().addressPointer(),
                            op.z().shapeInfoDataBuffer().addressPointer(),
                            op.scalar().floatValue(),
                            getPointerForExtraArgs(op));

            }
        }
    }

    private Pointer getPointerForExtraArgs(Op op) {
        if(op.extraArgs() != null)
            return op.extraArgsDataBuff().addressPointer();
        return null;
    }

    private void exec(TransformOp op) {
            PointerPointer dummy = new PointerPointer(new Pointer[] {null});
            if(op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if(op.y() != null) {
                    if(op.x().elementWiseStride() >=1 && op.y(). elementWiseStride() >= 1 && op.x().elementWiseStride() == op.y(). elementWiseStride()  && !op.isExecSpecial() && op.x().ordering() == op.y().ordering() && op.x().ordering() == op.z().ordering()) {
                        loop.execPairwiseTransformDouble(
                                dummy,
                                op.opNum(),
                                op.x().data().addressPointer(),
                                op.x().elementWiseStride(),
                                op.y().data().addressPointer(),
                                op.y().elementWiseStride(),
                                op.z().data().addressPointer(),
                                op.z().elementWiseStride(),
                                getPointerForExtraArgs(op),
                                op.n());

                    }
                    else {
                        loop.execPairwiseTransformDouble(
                                dummy,
                                op.opNum(),
                                op.x().data().addressPointer(),
                                op.x().shapeInfoDataBuffer().addressPointer(),
                                op.y().data().addressPointer(),
                                op.y().shapeInfoDataBuffer().addressPointer(),
                                op.z().data().addressPointer(),
                                op.z().shapeInfoDataBuffer().addressPointer(),
                                getPointerForExtraArgs(op));
                    }

                }
                else {
                    if(op.x(). elementWiseStride() >= 1 && !op.isExecSpecial() && !op.isExecSpecial() && op.x().ordering() == op.z().ordering()) {
                        loop.execTransformDouble(
                                dummy,
                                op.opNum(),
                                op.x().data().addressPointer(),
                                op.x().elementWiseStride(),
                                op.z().data().addressPointer(),
                                op.z().elementWiseStride(),
                                getPointerForExtraArgs(op), op.n());
                    }
                    else {
                        loop.execTransformDouble(
                                dummy,
                                op.opNum(),
                                op.x().data().addressPointer(),
                                op.x().shapeInfoDataBuffer().addressPointer(),
                                op.z().data().addressPointer(),
                                op.z().shapeInfoDataBuffer().addressPointer(),
                                getPointerForExtraArgs(op));
                    }

                }
            }
            else {
                if(op.y() != null) {
                    if(op.x().elementWiseStride() >=1 && op.y(). elementWiseStride() >= 1 && op.x().elementWiseStride() == op.y(). elementWiseStride() && !op.isExecSpecial() && op.x().ordering() == op.y().ordering()) {
                        loop.execPairwiseTransformFloat
                                (dummy,op.opNum(),
                                        op.x().data().addressPointer(),
                                        op.x().elementWiseStride(),
                                        op.y().data().addressPointer(),
                                        op.y().elementWiseStride(),
                                        op.z().data().addressPointer(),
                                        op.z().elementWiseStride(),
                                        getPointerForExtraArgs(op),
                                        op.n());

                    }
                    else {
                        loop.execPairwiseTransformFloat(
                                dummy,
                                op.opNum(),
                                op.x().data().addressPointer(),
                                op.x().shapeInfoDataBuffer().addressPointer(),
                                op.y().data().addressPointer(),
                                op.y().shapeInfoDataBuffer().addressPointer(),
                                op.z().data().addressPointer(),
                                op.z().shapeInfoDataBuffer().addressPointer(),
                                getPointerForExtraArgs(op));
                    }

                }
                else {
                    if(op.x(). elementWiseStride() >= 1 && !op.isExecSpecial() && op.x().ordering() == op.z().ordering()) {
                        loop.execTransformFloat(dummy,op.opNum(),
                                op.x().data().addressPointer(),
                                op.x().elementWiseStride(),
                                op.z().data().addressPointer(),
                                op.z().elementWiseStride(),
                                getPointerForExtraArgs(op), op.n());
                    }
                    else {
                        loop.execTransformFloat(
                                dummy,
                                op.opNum(),
                                op.x().data().addressPointer(),
                                op.x().shapeInfoDataBuffer().addressPointer(),
                                op.z().data().addressPointer(),
                                op.z().shapeInfoDataBuffer().addressPointer(),
                                getPointerForExtraArgs(op));
                    }

                }
            }

    }

    @Override
    public INDArray exec(BroadcastOp op,int...dimension) {
        Arrays.sort(dimension);

        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);

        Pointer hostTadShapeInfo = tadBuffers.getFirst().addressPointer();
        Pointer hostTadOffsets = tadBuffers.getSecond().addressPointer();

        PointerPointer dummy = new PointerPointer(
                hostTadShapeInfo,
                hostTadOffsets
        );

        Pointer dimensionAddress = constantHandler.getConstantBuffer(dimension).addressPointer();

        if(op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            loop.execBroadcastDouble(dummy,op.opNum(),
                    op.x().data().addressPointer()
                    ,op.x().shapeInfoDataBuffer().addressPointer(),
                    op.y().data().addressPointer(), op.y().shapeInfoDataBuffer().addressPointer()
                    , op.z().data().addressPointer(), op.z().shapeInfoDataBuffer().addressPointer(),
                    dimensionAddress, dimension.length);
        }
        else {
            loop.execBroadcastFloat(dummy,op.opNum(),
                    op.x().data().addressPointer()
                    ,op.x().shapeInfoDataBuffer().addressPointer(),
                    op.y().data().addressPointer(),
                    op.y().shapeInfoDataBuffer().addressPointer()
                    , op.z().data().addressPointer(),
                    op.z().shapeInfoDataBuffer().addressPointer(),
                    dimensionAddress, dimension.length);
        }

        return op.z();
    }

    private void exec(IndexAccumulation op) {
        if(op.x() instanceof IComplexNDArray || executionMode() == ExecutionMode.JAVA) {
            super.exec(op);

        }
        else {
            PointerPointer dummy = new PointerPointer(new Pointer[] {null});
            if(op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                op.setFinalResult((int) loop.execIndexReduceScalarDouble(
                        dummy,
                        op.opNum(),
                        op.x().data().addressPointer()
                        ,op.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(op)));

            }
            else {
                op.setFinalResult((int) loop.execIndexReduceScalarFloat(
                        dummy,
                        op.opNum(),
                        op.x().data().addressPointer()
                        ,op.x().shapeInfoDataBuffer().addressPointer(),
                        getPointerForExtraArgs(op)));
            }

        }
    }

    private void exec(Accumulation op) {
        if(op.x() instanceof IComplexNDArray || executionMode() == ExecutionMode.JAVA) {
            super.exec(op);

        }
        else {
            PointerPointer dummy = new PointerPointer(new Pointer[] {null});
            if(op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if(op instanceof Variance) {
                    op.setFinalResult(loop.execSummaryStatsScalarDouble(
                            dummy,
                            op.opNum(),
                            op.x().data().addressPointer()
                            ,op.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(op), true));
                }
                else if(op.y() != null) {
                    op.setFinalResult(loop.execReduce3ScalarDouble(
                            dummy,
                            op.opNum(),
                            op.x().data().addressPointer()
                            ,op.x().shapeInfoDataBuffer().addressPointer(),getPointerForExtraArgs(op),
                            op.y().data().addressPointer(), op.y().shapeInfoDataBuffer().addressPointer()));
                }
                else {
                    op.setFinalResult(loop.execReduceScalarDouble(
                            dummy,
                            op.opNum(),
                            op.x().data().addressPointer()
                            ,op.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(op)));
                }
            }
            else {
                if(op instanceof Variance) {
                    Variance variance = (Variance) op;
                    op.setFinalResult(loop.execSummaryStatsScalarFloat(
                            dummy,
                            op.opNum(),
                            op.x().data().addressPointer()
                            ,op.x().shapeInfoDataBuffer().addressPointer(),  getPointerForExtraArgs(op),variance.isBiasCorrected()));
                }
                else if(op.y() != null) {
                    op.setFinalResult(loop.execReduce3ScalarFloat(
                            dummy,
                            op.opNum(),
                            op.x().data().addressPointer()
                            ,op.x().shapeInfoDataBuffer().addressPointer(),
                            getPointerForExtraArgs(op),
                            op.y().data().addressPointer(),
                            op.y().shapeInfoDataBuffer().addressPointer()));
                }
                else {
                    op.setFinalResult(loop.execReduceScalarFloat(
                            dummy,
                            op.opNum(),
                            op.x().data().addressPointer()
                            ,op.x().shapeInfoDataBuffer().addressPointer(),
                            getPointerForExtraArgs(op)));
                }
            }
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy