All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.jita.memory.impl.CudaDirectProvider Maven / Gradle / Ivy

The newest version!
package org.nd4j.jita.memory.impl;

import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.memory.MemoryProvider;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;

/**
 * @author [email protected]
 */
public class CudaDirectProvider implements MemoryProvider {

    protected static final long DEVICE_RESERVED_SPACE = 1024 * 1024 * 50L;
    private static Logger log = LoggerFactory.getLogger(CudaDirectProvider.class);
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();

    protected volatile ConcurrentHashMap validator = new ConcurrentHashMap<>();


    private AtomicLong emergencyCounter = new AtomicLong(0);

    /**
     * This method provides PointersPair to memory chunk specified by AllocationShape
     *
     * @param shape shape of desired memory chunk
     * @param point target AllocationPoint structure
     * @param location either HOST or DEVICE
     * @return
     */
    @Override
    public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) {

        //log.info("shape onCreate: {}, target: {}", shape, location);

        switch (location) {
            case HOST: {
                Pointer devicePointer = new Pointer();
                long reqMem = AllocationUtils.getRequiredMemory(shape);

                // FIXME: this is WRONG, and directly leads to memleak
                if (reqMem < 1)
                    reqMem = 1;

                Pointer pointer = nativeOps.mallocHost(reqMem, 0);
                if (pointer == null)
                    throw new RuntimeException("Can't allocate [HOST] memory: " + reqMem + "; threadId: "
                                    + Thread.currentThread().getId());

                //                log.info("Host allocation, Thread id: {}, ReqMem: {}, Pointer: {}", Thread.currentThread().getId(), reqMem, pointer != null ? pointer.address() : null);

                Pointer hostPointer = new CudaPointer(pointer);

                PointersPair devicePointerInfo = new PointersPair();
                devicePointerInfo.setDevicePointer(new CudaPointer(hostPointer, reqMem));
                devicePointerInfo.setHostPointer(new CudaPointer(hostPointer, reqMem));

                point.setPointers(devicePointerInfo);

                point.setAllocationStatus(AllocationStatus.HOST);
                return devicePointerInfo;
            }
            case DEVICE: {
                // cudaMalloc call
                int deviceId = AtomicAllocator.getInstance().getDeviceId();
                long reqMem = AllocationUtils.getRequiredMemory(shape);

                // FIXME: this is WRONG, and directly leads to memleak
                if (reqMem < 1)
                    reqMem = 1;

                //                if (CudaEnvironment.getInstance().getConfiguration().getDebugTriggered() == 119)
                //                    throw new RuntimeException("Device allocation happened");


                Pointer pointer = nativeOps.mallocDevice(reqMem, null, 0);
                //log.info("Device [{}] allocation, Thread id: {}, ReqMem: {}, Pointer: {}", AtomicAllocator.getInstance().getDeviceId(), Thread.currentThread().getId(), reqMem, pointer != null ? pointer.address() : null);


                if (pointer == null)
                    return null;
                //throw new RuntimeException("Can't allocate [DEVICE] memory!");

                Pointer devicePointer = new CudaPointer(pointer);

                PointersPair devicePointerInfo = point.getPointers();
                if (devicePointerInfo == null)
                    devicePointerInfo = new PointersPair();
                devicePointerInfo.setDevicePointer(new CudaPointer(devicePointer, reqMem));

                point.setAllocationStatus(AllocationStatus.DEVICE);
                point.setDeviceId(deviceId);

                return devicePointerInfo;
            }
            default:
                throw new IllegalStateException("Unsupported location for malloc: [" + location + "]");
        }
    }

    /**
     * This method frees specific chunk of memory, described by AllocationPoint passed in
     *
     * @param point
     */
    @Override
    public void free(AllocationPoint point) {
        switch (point.getAllocationStatus()) {
            case HOST: {
                // cudaFreeHost call here
                // FIXME: it would be nice to get rid of typecasting here
                long reqMem = AllocationUtils.getRequiredMemory(point.getShape());

                //  log.info("Deallocating {} bytes on [HOST]", reqMem);

                NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();

                long result = nativeOps.freeHost(point.getPointers().getHostPointer());
                //JCuda.cudaFreeHost(new Pointer(point.getPointers().getHostPointer()));
                if (result == 0)
                    throw new RuntimeException("Can't deallocate [HOST] memory...");
            }
                break;
            case DEVICE: {
                // cudaFree call
                //JCuda.cudaFree(new Pointer(point.getPointers().getDevicePointer().address()));
                if (point.isConstant())
                    return;

                long reqMem = AllocationUtils.getRequiredMemory(point.getShape());

                //       log.info("Deallocating {} bytes on [DEVICE]", reqMem);

                NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();

                long result = nativeOps.freeDevice(point.getPointers().getDevicePointer(), new CudaPointer(0));
                if (result == 0)
                    throw new RuntimeException("Can't deallocate [DEVICE] memory...");
            }
                break;
            default:
                throw new IllegalStateException("Can't free memory on target [" + point.getAllocationStatus() + "]");
        }
    }

    /**
     * This method checks specified device for specified amount of memory
     *
     * @param deviceId
     * @param requiredMemory
     * @return
     */
    public boolean pingDeviceForFreeMemory(Integer deviceId, long requiredMemory) {
        /*
        long[] totalMem = new long[1];
        long[] freeMem = new long[1];
        
        
        JCuda.cudaMemGetInfo(freeMem, totalMem);
        
        long free = freeMem[0];
        long total = totalMem[0];
        long used = total - free;
        
        /*
            We don't want to allocate memory if it's too close to the end of available ram.
         */
        //if (configuration != null && used > total * configuration.getMaxDeviceMemoryUsed()) return false;

        /*
        if (free + requiredMemory < total * 0.85)
            return true;
        else return false;
        */
        long freeMem = nativeOps.getDeviceFreeMemory(new CudaPointer(-1));
        if (freeMem - requiredMemory < DEVICE_RESERVED_SPACE)
            return false;
        else
            return true;
    }

    protected void freeHost(Pointer pointer) {
        NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
        nativeOps.freeHost(pointer);
    }

    protected void freeDevice(Pointer pointer, int deviceId) {
        NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
        nativeOps.freeDevice(pointer, new CudaPointer(0));
    }

    @Override
    public void purgeCache() {
        // no-op
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy