All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.jcublas.blas.JcublasLapack Maven / Gradle / Ivy

The newest version!
package org.nd4j.linalg.jcublas.blas;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t;
import org.nd4j.linalg.api.blas.BlasException;
import org.nd4j.linalg.api.blas.impl.BaseLapack;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;

import static org.bytedeco.javacpp.cuda.CUstream_st;
import static org.bytedeco.javacpp.cusolver.*;
import static org.bytedeco.javacpp.cublas.* ;


/**
 * JCublas lapack
 *
 * @author Adam Gibson
 * @author Richard Corbishley
 */
@Slf4j
public class JcublasLapack extends BaseLapack {

    private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    private Allocator allocator = AtomicAllocator.getInstance();

    @Override
    public void sgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO) {
        INDArray a = A;
        if (Nd4j.dataType() != DataBuffer.Type.FLOAT)
            log.warn("FLOAT getrf called in DOUBLE environment");

        if (A.ordering() == 'c')
            a = A.dup('f');


        if (Nd4j.getExecutioner() instanceof GridExecutioner)
            ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();

        // Get context for current thread
        CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext();

        // setup the solver handles for cuSolver calls
        cusolverDnHandle_t handle = ctx.getSolverHandle();
        cusolverDnContext solverDn = new cusolverDnContext(handle);

        // synchronized on the solver
        synchronized (handle) {
            int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream()));
            if (result != 0)
                throw new BlasException("solverSetStream failed");

            // transfer the INDArray into GPU memory
            CublasPointer xAPointer = new CublasPointer(a, ctx);

            // this output - indicates how much memory we'll need for the real operation
            DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1);

            int stat = cusolverDnSgetrf_bufferSize(solverDn, M, N, (FloatPointer) xAPointer.getDevicePointer(), M,
                            (IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here
            );

            if (stat != CUSOLVER_STATUS_SUCCESS) {
                throw new BlasException("cusolverDnSgetrf_bufferSize failed", stat);
            }

            int worksize = worksizeBuffer.getInt(0);
            // Now allocate memory for the workspace, the permutation matrix and a return code
            Pointer workspace = new Workspace(worksize * Nd4j.sizeOfDataType());

            // Do the actual LU decomp
            stat = cusolverDnSgetrf(solverDn, M, N, (FloatPointer) xAPointer.getDevicePointer(), M,
                            new CudaPointer(workspace).asFloatPointer(),
                            new CudaPointer(allocator.getPointer(IPIV, ctx)).asIntPointer(),
                            new CudaPointer(allocator.getPointer(INFO, ctx)).asIntPointer());

            // we do sync to make sure getrf is finished
            //ctx.syncOldStream();

            if (stat != CUSOLVER_STATUS_SUCCESS) {
                throw new BlasException("cusolverDnSgetrf failed", stat);
            }
        }
        allocator.registerAction(ctx, a);
        allocator.registerAction(ctx, INFO);
        allocator.registerAction(ctx, IPIV);

        if (a != A)
            A.assign(a);
    }



    @Override
    public void dgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO) {
        INDArray a = A;

        if (Nd4j.dataType() != DataBuffer.Type.DOUBLE)
            log.warn("FLOAT getrf called in FLOAT environment");

        if (A.ordering() == 'c')
            a = A.dup('f');

        if (Nd4j.getExecutioner() instanceof GridExecutioner)
            ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();

        // Get context for current thread
        CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext();

        // setup the solver handles for cuSolver calls
        cusolverDnHandle_t handle = ctx.getSolverHandle();
        cusolverDnContext solverDn = new cusolverDnContext(handle);

        // synchronized on the solver
        synchronized (handle) {
            int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream()));
            if (result != 0)
                throw new BlasException("solverSetStream failed");

            // transfer the INDArray into GPU memory
            CublasPointer xAPointer = new CublasPointer(a, ctx);

            // this output - indicates how much memory we'll need for the real operation
            DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1);

            int stat = cusolverDnDgetrf_bufferSize(solverDn, M, N, (DoublePointer) xAPointer.getDevicePointer(), M,
                            (IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here
            );

            if (stat != CUSOLVER_STATUS_SUCCESS) {
                throw new BlasException("cusolverDnDgetrf_bufferSize failed", stat);
            }
            int worksize = worksizeBuffer.getInt(0);

            // Now allocate memory for the workspace, the permutation matrix and a return code
            Pointer workspace = new Workspace(worksize * Nd4j.sizeOfDataType());

            // Do the actual LU decomp
            stat = cusolverDnDgetrf(solverDn, M, N, (DoublePointer) xAPointer.getDevicePointer(), M,
                            new CudaPointer(workspace).asDoublePointer(),
                            new CudaPointer(allocator.getPointer(IPIV, ctx)).asIntPointer(),
                            new CudaPointer(allocator.getPointer(INFO, ctx)).asIntPointer());

            if (stat != CUSOLVER_STATUS_SUCCESS) {
                throw new BlasException("cusolverDnSgetrf failed", stat);
            }
        }
        allocator.registerAction(ctx, a);
        allocator.registerAction(ctx, INFO);
        allocator.registerAction(ctx, IPIV);

        if (a != A)
            A.assign(a);
    }


//=========================    
// Q R DECOMP
    @Override
    public void sgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO) {
        INDArray a = A;
        INDArray r = R;

        if (Nd4j.dataType() != DataBuffer.Type.FLOAT)
            log.warn("FLOAT getrf called in DOUBLE environment");

        if (A.ordering() == 'c') 
            a = A.dup('f');
        if ( R!=null && R.ordering() == 'c')
            r = R.dup('f');

        INDArray tau = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createFloat(N),
                Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, N}).getFirst());

        if (Nd4j.getExecutioner() instanceof GridExecutioner)
            ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();

        // Get context for current thread
        CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext();

        // setup the solver handles for cuSolver calls
        cusolverDnHandle_t handle = ctx.getSolverHandle();
        cusolverDnContext solverDn = new cusolverDnContext(handle);

        // synchronized on the solver
        synchronized (handle) {
            int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream()));
            if (result != 0)
                throw new IllegalStateException("solverSetStream failed");

            // transfer the INDArray into GPU memory
            CublasPointer xAPointer = new CublasPointer(a, ctx);
            CublasPointer xTauPointer = new CublasPointer(tau, ctx);

            // this output - indicates how much memory we'll need for the real operation
            DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1);

            int stat = cusolverDnSgeqrf_bufferSize(solverDn, M, N, 
                        (FloatPointer) xAPointer.getDevicePointer(), M,
                        (IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here
            );


            if (stat != CUSOLVER_STATUS_SUCCESS) {
                throw new BlasException("cusolverDnSgeqrf_bufferSize failed", stat);
            }
            int worksize = worksizeBuffer.getInt(0);
            // Now allocate memory for the workspace, the permutation matrix and a return code
            Pointer workspace = new Workspace(worksize * Nd4j.sizeOfDataType());

            // Do the actual QR decomp
            stat = cusolverDnSgeqrf(solverDn, M, N, 
                            (FloatPointer) xAPointer.getDevicePointer(), M,
                            (FloatPointer) xTauPointer.getDevicePointer(),
                            new CudaPointer(workspace).asFloatPointer(),
                            worksize,
                            new CudaPointer(allocator.getPointer(INFO, ctx)).asIntPointer()
                            );
            if (stat != CUSOLVER_STATUS_SUCCESS) {
                throw new BlasException("cusolverDnSgeqrf failed", stat);
            }
            
            allocator.registerAction(ctx, a);
            //allocator.registerAction(ctx, tau);
            allocator.registerAction(ctx, INFO);
            if (INFO.getInt(0) != 0 ) {
                throw new BlasException("cusolverDnSgeqrf failed on INFO", INFO.getInt(0));
            }

            // Copy R ( upper part of Q ) into result
            if( r != null ) {
                r.assign( a.get( NDArrayIndex.interval( 0, a.columns() ), NDArrayIndex.all() ) ) ; 
	            
                INDArrayIndex ix[] = new INDArrayIndex[ 2 ] ;
                for( int i=1 ; i




© 2015 - 2025 Weber Informatics LLC | Privacy Policy