All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.jcublas.context.CudaContext Maven / Gradle / Ivy

The newest version!
package org.nd4j.linalg.jcublas.context;

import jcuda.driver.CUevent;
import jcuda.driver.CUstream;
import jcuda.driver.CUstream_flags;
import jcuda.driver.JCudaDriver;
import jcuda.jcublas.JCublas2;
import jcuda.jcublas.cublasHandle;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaEvent_t;
import jcuda.runtime.cudaStream_t;
import lombok.Data;
import org.nd4j.linalg.jcublas.CublasPointer;

import java.util.concurrent.atomic.AtomicBoolean;

/**
 * A higher level class for handling
 * the different primitives around the cuda apis
 * This being:
 * streams (both old and new) as well as
 * the cublas handles.
 *
 *
 */
@Data
public class CudaContext implements AutoCloseable {
    private CUstream stream;
    private CUevent cUevent;
    private cudaStream_t oldStream;
    private cudaEvent_t oldEvent;
    private cublasHandle handle;
    private CublasPointer resultPointer;
    private AtomicBoolean oldStreamReturned = new AtomicBoolean(false);
    private AtomicBoolean handleReturned = new AtomicBoolean(false);
    private AtomicBoolean streamReturned = new AtomicBoolean(false);
    private boolean streamFromPool = true;
    private boolean handleFromPool = true;
    private boolean oldStreamFromPool = true;
    private boolean free = true;
    private boolean oldEventDestroyed = true;
    private boolean eventDestroyed = true;


    public CudaContext(boolean free) {
        this();
        this.free = free;
    }


    public CudaContext() {
        ContextHolder.getInstance().setContext();
    }

    /**
     * Synchronizes on the new
     * stream
     */
    public void syncStream() {
        if(eventDestroyed) {
            return;
        }

        JCudaDriver.cuEventSynchronize(cUevent);
        JCudaDriver.cuEventDestroy(cUevent);
        eventDestroyed = true;
    }

    /**
     * Synchronizes
     * on the old stream
     */
    public void syncOldStream() {
        if(!oldEventDestroyed) {
            JCuda.cudaStreamWaitEvent(oldStream,oldEvent,0);
            JCuda.cudaEventDestroy(oldEvent);
            oldEventDestroyed = true;
        }
    }

    /**
     * Synchronizes on
     * the old stream
     * since the given handle
     * will be associated with the
     * stream for this context
     */
    public void syncHandle() {
        syncOldStream();
    }

    /**
     * Get the result pointer for the context
     * @return
     */
    public CublasPointer getResultPointer() {
        return resultPointer;
    }

    /**
     * Associates
     * the handle on this context
     * to the given stream
     */
    public void associateHandle() {
        ContextHolder.getInstance().setContext();
        JCublas2.cublasSetStream(handle, oldStream);
    }

    /**
     * Record an event.
     * This is for marking when an operation
     * starts.
     */
    public void startOldEvent() {
        JCuda.cudaEventRecord(oldEvent, oldStream);
    }

    /**
     * Record an  event (new).
     * This is for marking when an operation
     * starts.
     */
    public void startNewEvent() {
        JCudaDriver.cuEventRecord(cUevent,stream);
    }


    /**
     * Initializes the stream
     */
    public void initStream() {
        if(stream == null) {
            try {
                stream = ContextHolder.getInstance().getStreamPool().borrowObject();
            } catch (Exception e) {
                stream = new CUstream();
                JCudaDriver.cuStreamCreate(stream, CUstream_flags.CU_STREAM_NON_BLOCKING);
                streamFromPool = false;
            }

            cUevent = new CUevent();
            JCudaDriver.cuEventCreate(cUevent,0);
            eventDestroyed = false;
        }
    }

    /**
     * Initializes the old stream
     */
    public void initOldStream() {
        if(oldStream == null)  {
            try {
                oldStream = ContextHolder.getInstance().getOldStreamPool().borrowObject();
            } catch (Exception e) {
                oldStreamFromPool = false;
                oldStream = new cudaStream_t();
                JCuda.cudaStreamCreate(oldStream);

            }

            oldEvent = new cudaEvent_t();
            JCuda.cudaEventCreate(oldEvent);
            oldEventDestroyed = false;
        }

    }




    /**
     * Initializes a handle and
     * associates with the given stream.
     * initOldStream() should be called first
     *
     */
    public void initHandle() {
        if(handle == null) {
            try {
                handle = ContextHolder.getInstance().getHandlePool().borrowObject();
            } catch (Exception e) {
                handle = new cublasHandle();
                JCublas2.cublasCreate(handle);
                handleFromPool = false;
            }
            associateHandle();
        }

    }

    /**
     * Destroys the context
     * and associated resources
     */
    public void destroy(CublasPointer resultPointer,boolean freeIfNotEqual) {
        if(handle != null && !handleReturned.get()) {
            try {
                if(handleFromPool)
                    ContextHolder.getInstance().getHandlePool().returnObject(handle);
                else {
                    JCublas2.cublasDestroy(handle);

                }
                handleReturned.set(true);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        if(stream != null && !streamReturned.get()) {
            try {
                if(streamFromPool)
                    ContextHolder.getInstance().getStreamPool().returnObject(stream);
                else {
                    JCudaDriver.cuStreamDestroy(stream);

                }
                streamReturned.set(true);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        if(oldStream != null && !oldStreamReturned.get()) {
            try {
                if(oldStreamFromPool)
                    ContextHolder.getInstance().getOldStreamPool().returnObject(oldStream);
                else {
                    JCuda.cudaStreamDestroy(oldStream);

                }
                oldStreamReturned.set(true);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }

        if(resultPointer != null && freeIfNotEqual && freeIfNotEqual) {
            resultPointer.copyToHost();
            try {
                resultPointer.close();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }

        if(!oldEventDestroyed) {
            JCuda.cudaEventDestroy(oldEvent);
            oldEventDestroyed = true;
        }

        if(!eventDestroyed) {
            JCudaDriver.cuEventDestroy(cUevent);
            eventDestroyed = true;
        }
    }


    /**
     * Destroys the context
     * and associated resources
     */
    public void destroy() {
        if(handle != null && !handleReturned.get()) {
            try {
                if(handleFromPool)
                    ContextHolder.getInstance().getHandlePool().returnObject(handle);
                else {
                    JCublas2.cublasDestroy(handle);

                }
                handleReturned.set(true);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        if(stream != null && !streamReturned.get()) {
            try {
                if(streamFromPool)
                    ContextHolder.getInstance().getStreamPool().returnObject(stream);
                else {
                    JCudaDriver.cuStreamDestroy(stream);

                }
                streamReturned.set(true);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        if(oldStream != null && !oldStreamReturned.get()) {
            try {
                if(oldStreamFromPool)
                    ContextHolder.getInstance().getOldStreamPool().returnObject(oldStream);
                else {
                    JCuda.cudaStreamDestroy(oldStream);

                }
                oldStreamReturned.set(true);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }

        if(resultPointer != null) {
            resultPointer.copyToHost();
            try {
                resultPointer.close();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }


    /**
     * Finishes a blas operation
     * and destroys this context
     */
    public void finishBlasOperation() {

        destroy();
    }

    /**
     * Sets up a context with an old stream
     * and a blas handle
     * @return the cuda context
     * as setup for cublas usage
     */
    public static CudaContext getBlasContext() {
        CudaContext ctx = new CudaContext();
        ctx.initOldStream();
        ctx.initHandle();
        ctx.startOldEvent();
        return ctx;
    }

    /**
     * Calls cuda device synchronize
     */
    public void syncDevice() {
        JCuda.cudaDeviceSynchronize();
    }

    @Override
    public void close() throws Exception {
        destroy();
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy