All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.jcublas.context.CudaContext Maven / Gradle / Ivy

The newest version!
package org.nd4j.linalg.jcublas.context;

import lombok.Data;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.garbage.GarbageResourceReference;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;

import java.util.concurrent.atomic.AtomicBoolean;

/**
 * A higher level class for handling
 * the different primitives around the cuda apis
 * This being:
 * streams (both old and new) as well as
 * the cublas handles.
 *
 *
 */
@Data
public class CudaContext {
    //private CUcontext context;
    //private CUstream stream;
    //private CUevent cUevent;
    private cudaStream_t oldStream;

    private cudaStream_t cublasStream;
    private cudaStream_t solverStream;

    private cudaStream_t specialStream;

    //private cudaEvent_t oldEvent;
    private cublasHandle_t handle;
    private cusolverDnHandle_t solverHandle;
    private CublasPointer resultPointer;
    private AtomicBoolean oldStreamReturned = new AtomicBoolean(false);
    private AtomicBoolean handleReturned = new AtomicBoolean(false);
    private AtomicBoolean streamReturned = new AtomicBoolean(false);
    private boolean streamFromPool = true;
    private boolean handleFromPool = true;
    private boolean oldStreamFromPool = true;
    private boolean free = true;
    private boolean oldEventDestroyed = true;
    private boolean eventDestroyed = true;

    private Pointer bufferReduction;
    private Pointer bufferAllocation;
    private Pointer bufferScalar;
    private Pointer bufferSpecial;
    private GarbageResourceReference reference;

    private int deviceId = -1;

    private int laneId = 0;

    private static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();


    public CudaContext(boolean free) {
        this();
        this.free = free;
    }

    public void attachReference(GarbageResourceReference ref) {
        reference = ref;
    }


    public CudaContext() {
        //   ContextHolder.getInstance().setContext();
    }

    /**
     * Synchronizes on the new
     * stream
     */
    public void syncStream() {
        //JCudaDriver.cuStreamSynchronize(stream);
    }

    /**
     * Synchronizes
     * on the old stream
     */
    public void syncOldStream() {
        //        ContextHolder.getInstance().setContext();
        //        System.out.println("old sync called");
        syncOldStream(false);
    }

    public void syncSpecialStream() {
        if (nativeOps.streamSynchronize(specialStream) == 0)
            throw new ND4JIllegalStateException("CUDA special stream synchronization failed");
    }

    public void syncOldStream(boolean syncCuBlas) {
        //        ContextHolder.getInstance().setContext();
        if (nativeOps.streamSynchronize(oldStream) == 0)
            throw new ND4JIllegalStateException("CUDA stream synchronization failed");

        if (syncCuBlas)
            syncCublasStream();
    }

    public void syncCublasStream() {
        if (cublasStream != null) {
            if (nativeOps.streamSynchronize(cublasStream) == 0)
                throw new ND4JIllegalStateException("CUDA stream synchronization failed");
        } else
            throw new IllegalStateException("cuBLAS stream isnt set");
    }


    public void syncSolverStream() {
        if (solverStream != null) {
            if (nativeOps.streamSynchronize(solverStream) == 0)
                throw new ND4JIllegalStateException("CUDA stream synchronization failed");
        } else
            throw new IllegalStateException("cuBLAS stream isnt set");
    }

    /**
     * Associates
     * the handle on this context
     * to the given stream
     */
    public synchronized void associateHandle() {
        //JCublas2.cublasSetStream(handle,oldStream);
    }



    /**
     * Initializes the stream
     */
    public void initStream() {
        //        ContextHolder.getInstance().setContext();
        /*
        if(stream == null) {
            stream = new CUstream();
            JCudaDriver.cuStreamCreate(stream, CUstream_flags.CU_STREAM_DEFAULT);
            streamFromPool = false;
            eventDestroyed = false;
        }
        */
    }

    /**
     * Initializes the old stream
     */
    public void initOldStream() {
        //        ContextHolder.getInstance().setContext();
        if (oldStream == null) {
            oldStreamFromPool = false;
            oldStream = new cudaStream_t(nativeOps.createStream());
            //JCuda.cudaStreamCreate(oldStream);

            specialStream = new cudaStream_t(nativeOps.createStream());
            //JCuda.cudaStreamCreate(specialStream);
        }

    }



    /**
     * Initializes a handle and
     * associates with the given stream.
     * initOldStream() should be called first
     *
     */
    public void initHandle() {
        /*
        
        We don't create handles here anymore
        
        if(handle == null) {
            handle = new cublasHandle();
            JCublas2.cublasCreate(handle);
            handleFromPool = false;
        }
        */
    }

    /**
     * Destroys the context
     * and associated resources
     */
    @Deprecated
    public void destroy(CublasPointer resultPointer, boolean freeIfNotEqual) {}


    /**
     * Destroys the context
     * and associated resources
     */
    @Deprecated
    public void destroy() {

    }


    /**
     * Finishes a blas operation
     * and destroys this context
     */
    public void finishBlasOperation() {
        //destroy();
    }

    /**
     * Sets up a context with an old stream
     * and a blas handle
     * @return the cuda context
     * as setup for cublas usage
     */
    public static CudaContext getBlasContext() {
        CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
        //context.syncOldStream(false);
        return context;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy