
org.nd4j.linalg.jcublas.blas.JcublasLapack Maven / Gradle / Ivy
The newest version!
package org.nd4j.linalg.jcublas.blas;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t;
import org.nd4j.linalg.api.blas.BlasException;
import org.nd4j.linalg.api.blas.impl.BaseLapack;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import static org.bytedeco.javacpp.cuda.CUstream_st;
import static org.bytedeco.javacpp.cusolver.*;
import static org.bytedeco.javacpp.cublas.* ;
/**
* JCublas lapack
*
* @author Adam Gibson
* @author Richard Corbishley
*/
@Slf4j
public class JcublasLapack extends BaseLapack {
private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
private Allocator allocator = AtomicAllocator.getInstance();
@Override
public void sgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO) {
INDArray a = A;
if (Nd4j.dataType() != DataBuffer.Type.FLOAT)
log.warn("FLOAT getrf called in DOUBLE environment");
if (A.ordering() == 'c')
a = A.dup('f');
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
// Get context for current thread
CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext();
// setup the solver handles for cuSolver calls
cusolverDnHandle_t handle = ctx.getSolverHandle();
cusolverDnContext solverDn = new cusolverDnContext(handle);
// synchronized on the solver
synchronized (handle) {
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream()));
if (result != 0)
throw new BlasException("solverSetStream failed");
// transfer the INDArray into GPU memory
CublasPointer xAPointer = new CublasPointer(a, ctx);
// this output - indicates how much memory we'll need for the real operation
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1);
int stat = cusolverDnSgetrf_bufferSize(solverDn, M, N, (FloatPointer) xAPointer.getDevicePointer(), M,
(IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here
);
if (stat != CUSOLVER_STATUS_SUCCESS) {
throw new BlasException("cusolverDnSgetrf_bufferSize failed", stat);
}
int worksize = worksizeBuffer.getInt(0);
// Now allocate memory for the workspace, the permutation matrix and a return code
Pointer workspace = new Workspace(worksize * Nd4j.sizeOfDataType());
// Do the actual LU decomp
stat = cusolverDnSgetrf(solverDn, M, N, (FloatPointer) xAPointer.getDevicePointer(), M,
new CudaPointer(workspace).asFloatPointer(),
new CudaPointer(allocator.getPointer(IPIV, ctx)).asIntPointer(),
new CudaPointer(allocator.getPointer(INFO, ctx)).asIntPointer());
// we do sync to make sure getrf is finished
//ctx.syncOldStream();
if (stat != CUSOLVER_STATUS_SUCCESS) {
throw new BlasException("cusolverDnSgetrf failed", stat);
}
}
allocator.registerAction(ctx, a);
allocator.registerAction(ctx, INFO);
allocator.registerAction(ctx, IPIV);
if (a != A)
A.assign(a);
}
@Override
public void dgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO) {
INDArray a = A;
if (Nd4j.dataType() != DataBuffer.Type.DOUBLE)
log.warn("FLOAT getrf called in FLOAT environment");
if (A.ordering() == 'c')
a = A.dup('f');
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
// Get context for current thread
CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext();
// setup the solver handles for cuSolver calls
cusolverDnHandle_t handle = ctx.getSolverHandle();
cusolverDnContext solverDn = new cusolverDnContext(handle);
// synchronized on the solver
synchronized (handle) {
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream()));
if (result != 0)
throw new BlasException("solverSetStream failed");
// transfer the INDArray into GPU memory
CublasPointer xAPointer = new CublasPointer(a, ctx);
// this output - indicates how much memory we'll need for the real operation
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1);
int stat = cusolverDnDgetrf_bufferSize(solverDn, M, N, (DoublePointer) xAPointer.getDevicePointer(), M,
(IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here
);
if (stat != CUSOLVER_STATUS_SUCCESS) {
throw new BlasException("cusolverDnDgetrf_bufferSize failed", stat);
}
int worksize = worksizeBuffer.getInt(0);
// Now allocate memory for the workspace, the permutation matrix and a return code
Pointer workspace = new Workspace(worksize * Nd4j.sizeOfDataType());
// Do the actual LU decomp
stat = cusolverDnDgetrf(solverDn, M, N, (DoublePointer) xAPointer.getDevicePointer(), M,
new CudaPointer(workspace).asDoublePointer(),
new CudaPointer(allocator.getPointer(IPIV, ctx)).asIntPointer(),
new CudaPointer(allocator.getPointer(INFO, ctx)).asIntPointer());
if (stat != CUSOLVER_STATUS_SUCCESS) {
throw new BlasException("cusolverDnSgetrf failed", stat);
}
}
allocator.registerAction(ctx, a);
allocator.registerAction(ctx, INFO);
allocator.registerAction(ctx, IPIV);
if (a != A)
A.assign(a);
}
//=========================
// Q R DECOMP
@Override
public void sgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO) {
INDArray a = A;
INDArray r = R;
if (Nd4j.dataType() != DataBuffer.Type.FLOAT)
log.warn("FLOAT getrf called in DOUBLE environment");
if (A.ordering() == 'c')
a = A.dup('f');
if ( R!=null && R.ordering() == 'c')
r = R.dup('f');
INDArray tau = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createFloat(N),
Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, N}).getFirst());
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
// Get context for current thread
CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext();
// setup the solver handles for cuSolver calls
cusolverDnHandle_t handle = ctx.getSolverHandle();
cusolverDnContext solverDn = new cusolverDnContext(handle);
// synchronized on the solver
synchronized (handle) {
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream()));
if (result != 0)
throw new IllegalStateException("solverSetStream failed");
// transfer the INDArray into GPU memory
CublasPointer xAPointer = new CublasPointer(a, ctx);
CublasPointer xTauPointer = new CublasPointer(tau, ctx);
// this output - indicates how much memory we'll need for the real operation
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1);
int stat = cusolverDnSgeqrf_bufferSize(solverDn, M, N,
(FloatPointer) xAPointer.getDevicePointer(), M,
(IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here
);
if (stat != CUSOLVER_STATUS_SUCCESS) {
throw new BlasException("cusolverDnSgeqrf_bufferSize failed", stat);
}
int worksize = worksizeBuffer.getInt(0);
// Now allocate memory for the workspace, the permutation matrix and a return code
Pointer workspace = new Workspace(worksize * Nd4j.sizeOfDataType());
// Do the actual QR decomp
stat = cusolverDnSgeqrf(solverDn, M, N,
(FloatPointer) xAPointer.getDevicePointer(), M,
(FloatPointer) xTauPointer.getDevicePointer(),
new CudaPointer(workspace).asFloatPointer(),
worksize,
new CudaPointer(allocator.getPointer(INFO, ctx)).asIntPointer()
);
if (stat != CUSOLVER_STATUS_SUCCESS) {
throw new BlasException("cusolverDnSgeqrf failed", stat);
}
allocator.registerAction(ctx, a);
//allocator.registerAction(ctx, tau);
allocator.registerAction(ctx, INFO);
if (INFO.getInt(0) != 0 ) {
throw new BlasException("cusolverDnSgeqrf failed on INFO", INFO.getInt(0));
}
// Copy R ( upper part of Q ) into result
if( r != null ) {
r.assign( a.get( NDArrayIndex.interval( 0, a.columns() ), NDArrayIndex.all() ) ) ;
INDArrayIndex ix[] = new INDArrayIndex[ 2 ] ;
for( int i=1 ; i
© 2015 - 2025 Weber Informatics LLC | Privacy Policy