All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.jcublas.context.ContextHolder Maven / Gradle / Ivy

There is a newer version: 0.4-rc3.7
Show newest version
package org.nd4j.linalg.jcublas.context;

import java.util.HashMap;
import java.util.Map;

import com.google.common.collect.*;

import jcuda.CudaException;
import jcuda.driver.CUcontext;
import jcuda.driver.CUdevice;
import jcuda.driver.CUresult;
import jcuda.driver.CUstream;
import jcuda.driver.CUstream_flags;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.JCuda;

import org.nd4j.linalg.jcublas.SimpleJCublas;

import static jcuda.driver.JCudaDriver.*;

/**
 * A multithreaded version derived from the cuda launcher util
 * by the authors of jcuda.
 *
 * This class handles managing cuda contexts
 * across multiple devices and threads.
 *
 *
 * @author Adam Gibson
 */
public class ContextHolder {
    private Map devices = new HashMap<>();
    private Map deviceIDContexts = new HashMap<>();
    private Table contextStreams = HashBasedTable.create();
    private int numDevices = 0;
    private static ContextHolder INSTANCE;
    
    private ContextHolder(){
        getNumDevices();
        Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {
            @Override
            public void run() {
//               for(Table.Cell cell : deviceToThreadAndContext.cellSet()) {
//                   JCudaDriver.cuCtxDestroy(cell.getValue());
//               }
            }
        }));
    }

    public static ContextHolder getInstance() {
        if(INSTANCE == null)
            INSTANCE = new ContextHolder();
        return INSTANCE;
    }


    private void getNumDevices() {
        int count[] = new int[1];
        cuDeviceGetCount(count);
        numDevices = count[0];
        if(numDevices < 1)
           numDevices = 1;
    }

    /**
     * Retrieve a context for use with the current thread
     * and the given device
     * @return the t
     */
    public  synchronized CUcontext getContext() {
        return getContext(0);
    }
    
    public synchronized CUstream getStream() {
    	Thread currentThread = Thread.currentThread();
    	CUcontext ctx = getContext(0);
    	CUstream stream = contextStreams.get(ctx, currentThread.getName());
    	
    	if(stream == null) {
    		stream = new CUstream();
    		int result = JCudaDriver.cuStreamCreate(stream, CUstream_flags.CU_STREAM_DEFAULT);
    		if (result != CUresult.CUDA_SUCCESS) {
                throw new CudaException("Failed to create a stream: "+ CUresult.stringFor(result));
            }
    		contextStreams.put(ctx, currentThread.getName(), stream);
    	}
    	
    	return stream;
    }

    /**
     * Retrieve a context for use with the current thread
     * and the given device
     * @param deviceToUse the device to use
     * @return the t
     */
    public  synchronized CUcontext getContext(int deviceToUse) {
        
        CUcontext ctx = deviceIDContexts.get(0);
        if(ctx == null) {
            ctx = new CUcontext();
            for(int device = 0; device < numDevices; device++) {
                initialize(ctx,device);
                CUdevice currDevice = createDevice(ctx, device);
                devices.put(device,currDevice);
                deviceIDContexts.put(device,ctx);
                //deviceToThreadAndContext.put(device,currentThread.getName(),ctx);


            }

        }

        return ctx;
    }


    /**
     * Initializes this KernelLauncher. This method will try to
     * initialize the JCuda driver API. Then it will try to
     * attach to the current CUDA context. If no active CUDA
     * context exists, then it will try to create one, for
     * the device which is specified by the current
     * deviceNumber.
     *
     * @throws CudaException If it is neither possible to
     * attach to an existing context, nor to create a new
     * context.
     */
    private void initialize(CUcontext context,int deviceNumber) {
        int result = cuInit(0);
        if (result != CUresult.CUDA_SUCCESS)
        {
            throw new CudaException(
                    "Failed to initialize the driver: "+
                            CUresult.stringFor(result));
        }

        // Try to obtain the current context
        result = cuCtxGetCurrent(context);
        if (result != CUresult.CUDA_SUCCESS)
        {
            throw new CudaException(
                    "Failed to obtain the current context: "+
                            CUresult.stringFor(result));
        }

        // If the context is 'null', then a new context
        // has to be created.
        CUcontext nullContext = new CUcontext();
        if (context.equals(nullContext))
        {
            createContext(context,deviceNumber);
        }
    }

    /**
     * Tries to create a context for device 'deviceNumber'.
     *
     * @throws CudaException If the device can not be
     * accessed or the context can not be created
     */
    private void createContext(CUcontext context,int deviceNumber) {
        CUdevice device = new CUdevice();
        int result = cuDeviceGet(device, deviceNumber);
        if (result != CUresult.CUDA_SUCCESS) {
            throw new CudaException(
                    "Failed to obtain a device: "+
                            CUresult.stringFor(result));
        }

        result = cuCtxCreate(context, 0, device);
        if (result != CUresult.CUDA_SUCCESS) {
            throw new CudaException(
                    "Failed to create a context: "+
                            CUresult.stringFor(result));
        }

    }


    public static CUdevice createDevice(CUcontext context,int deviceNumber) {
        CUdevice device = new CUdevice();
        int result = cuDeviceGet(device, deviceNumber);
        if (result != CUresult.CUDA_SUCCESS) {
            throw new CudaException(
                    "Failed to obtain a device: "+
                            CUresult.stringFor(result));
        }

        result = cuCtxCreate(context, 0, device);
        if (result != CUresult.CUDA_SUCCESS) {
            throw new CudaException(
                    "Failed to create a context: "+
                            CUresult.stringFor(result));
        }

        return device;
    }

    /**
     * Returns the available devices
     * delimited by device,thread
     * @return the available devices
     */
    public Map getDevices() {
        return devices;
    }

    /**
     * Returns the available contexts
     * based on device and thread name
     * @return the context
     */
    public Map getDeviceIDContexts() {
        return deviceIDContexts;
    }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy