org.nd4j.jita.handler.impl.CudaZeroHandler Maven / Gradle / Ivy
The newest version!
package org.nd4j.jita.handler.impl;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import lombok.Getter;
import lombok.NonNull;
import lombok.val;
import org.apache.commons.lang3.RandomUtils;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.concurrency.DeviceAllocationsTracker;
import org.nd4j.jita.allocator.context.ContextPool;
import org.nd4j.jita.allocator.context.ExternalContext;
import org.nd4j.jita.allocator.context.impl.LimitedContextPool;
import org.nd4j.jita.allocator.context.impl.PackedContextPool;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.enums.CudaConstants;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.jita.flow.impl.AsynchronousFlowController;
import org.nd4j.jita.flow.impl.GridFlowController;
import org.nd4j.jita.handler.MemoryHandler;
import org.nd4j.jita.memory.MemoryProvider;
import org.nd4j.jita.memory.impl.CudaCachingZeroProvider;
import org.nd4j.jita.memory.impl.CudaDirectProvider;
import org.nd4j.jita.memory.impl.CudaFullCachingProvider;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.memory.MemcpyDirection;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
/**
* This Mover implementation uses following techs:
* 1. Unified Memory Architecture
* 2. Zero-Copy Pinned Memory (if available)
* 3. Pageable memory (if zero-copy pinned memory isn't supported by device)
*
*
* @author [email protected]
*/
public class CudaZeroHandler implements MemoryHandler {
private static Configuration configuration = CudaEnvironment.getInstance().getConfiguration();
private static Logger log = LoggerFactory.getLogger(CudaZeroHandler.class);
// simple counter to track allocated host-memory
protected final AtomicLong zeroUseCounter = new AtomicLong(0);
// another simple counter, to track allocated device memory on per-thread per-device basis
protected volatile DeviceAllocationsTracker deviceMemoryTracker;
// tracker for thread->device affinity
protected Map devicesAffinity = new ConcurrentHashMap<>();
private ReentrantReadWriteLock deviceLock = new ReentrantReadWriteLock();
private AtomicInteger devPtr = new AtomicInteger(0);
private final AtomicBoolean wasInitialised = new AtomicBoolean(false);
private final ContextPool contextPool;
@Getter
private final MemoryProvider memoryProvider;
private final FlowController flowController;
private final AllocationStatus INITIAL_LOCATION;
private final AffinityManager affinityManager = Nd4j.getAffinityManager();
/*
table for Thread, Device, Object allocations of device memory. Objects should be used to grab Allocation point from allocationsMap
*/
// TODO: proper thread-safe implementation would be nice to have here :(
// FIXME: CopyOnWriteArrayList is BAD here. Really BAD. B A D.
// Table thread safety is guaranteed by reentrant read/write locks :(
//private Table> deviceAllocations = HashBasedTable.create();
//private final Map> deviceAllocations = new ConcurrentHashMap<>();
private final List> deviceAllocations = new ArrayList<>();
/*
map for Thread, Object allocations in zero memory.
*/
// CopyOnWriteArrayList performance to be investigated in this use case
// Map thread safety is guaranteed by exclusive writeLock in getDeviceId() method, because we can't use putIfAbsent on j7
// FIXME: at j7 -> j8 transition, this one could be changed to ConcurrentHashMap
private final Map> zeroAllocations = new ConcurrentHashMap<>();
private AtomicLong zeroCounter = new AtomicLong(0);
protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
public CudaZeroHandler() {
configuration.setInitialized();
this.INITIAL_LOCATION = configuration.getFirstMemory();
switch (configuration.getExecutionModel()) {
case OPTIMIZED:
case ASYNCHRONOUS: {
this.flowController = new AsynchronousFlowController();
this.contextPool = new PackedContextPool();
}
break;
case SEQUENTIAL: {
this.flowController = new GridFlowController();
this.contextPool = new LimitedContextPool();
}
break;
default:
throw new RuntimeException("Unknown ExecutionModel: [" + configuration.getExecutionModel() + "]");
}
switch (configuration.getAllocationModel()) {
case CACHE_ALL:
this.memoryProvider = new CudaFullCachingProvider();
break;
case CACHE_HOST:
this.memoryProvider = new CudaCachingZeroProvider();
break;
case DIRECT:
this.memoryProvider = new CudaDirectProvider();
break;
default:
throw new RuntimeException("Unknown AllocationModel: [" + configuration.getAllocationModel() + "]");
}
int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices();
for (int i = 0; i < numDevices; i++) {
deviceAllocations.add(new ConcurrentHashMap());
}
if (NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceMajor(new CudaPointer(0)) < 3) {
throw new ND4JIllegalStateException("CUDA backend requires compute capatibility of 3.0 and above to run.");
}
}
/**
* This method gets called from Allocator, during Allocator/MemoryHandler initialization
*
* @param configuration
* @param allocator
*/
@Override
public void init(@NonNull Configuration configuration, @NonNull Allocator allocator) {
this.configuration = configuration;
this.deviceMemoryTracker = new DeviceAllocationsTracker(this.configuration);
this.flowController.init(allocator);
}
private void pickupHostAllocation(AllocationPoint point) {
int numBuckets = configuration.getNumberOfGcThreads();
long bucketId = RandomUtils.nextInt(0, numBuckets);
long reqMemory = AllocationUtils.getRequiredMemory(point.getShape());
zeroUseCounter.addAndGet(reqMemory);
point.setBucketId(bucketId);
if (!zeroAllocations.containsKey(bucketId)) {
log.debug("Creating bucketID: " + bucketId);
synchronized (this) {
if (!zeroAllocations.containsKey(bucketId)) {
zeroAllocations.put(bucketId, new ConcurrentHashMap());
}
}
}
zeroAllocations.get(bucketId).put(point.getObjectId(), point.getObjectId());
}
/**
* Allocate specified memory chunk on specified device/host
*
* @param targetMode valid arguments are DEVICE, ZERO
* @param shape
* @return
*/
@Override
public PointersPair alloc(AllocationStatus targetMode, AllocationPoint point, AllocationShape shape,
boolean initialize) {
long reqMemory = AllocationUtils.getRequiredMemory(shape);
CudaContext context = getCudaContext();
switch (targetMode) {
case HOST: {
if (zeroUseCounter.get() + reqMemory >= configuration.getMaximumZeroAllocation()) {
if (reqMemory > configuration.getMaximumZeroAllocation()) {
throw new IllegalStateException(
"You can't allocate more memory, then allowed with configured value: ["
+ configuration.getMaximumZeroAllocation() + "]");
}
while (zeroUseCounter.get() + reqMemory >= configuration.getMaximumZeroAllocation()) {
try {
log.warn("No available [HOST] memory, sleeping for a while...");
log.debug("Currently used: [" + zeroUseCounter.get() + "], allocated objects: ["
+ zeroAllocations.get(0) + "]");
Nd4j.getMemoryManager().invokeGc();
Thread.sleep(1000);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
PointersPair pair = memoryProvider.malloc(shape, point, targetMode);
if (initialize) {
org.bytedeco.javacpp.Pointer.memset(pair.getHostPointer(), 0, reqMemory);
point.tickHostWrite();
}
pickupHostAllocation(point);
return pair;
}
case DEVICE: {
int deviceId = getDeviceId();
PointersPair returnPair = new PointersPair();
PointersPair tmpPair = new PointersPair();
// if the initial memory location is device, there's a chance we don't have zero memory allocated
if (point.getPointers() == null || point.getPointers().getHostPointer() == null) {
tmpPair = alloc(AllocationStatus.HOST, point, point.getShape(), initialize);
returnPair.setDevicePointer(tmpPair.getHostPointer());
returnPair.setHostPointer(tmpPair.getHostPointer());
point.setAllocationStatus(AllocationStatus.HOST);
point.setPointers(tmpPair);
}
/*
if (reqMemory < configuration.getMaximumSingleHostAllocation()
&& deviceMemoryTracker.getAllocatedSize(deviceId) + reqMemory < configuration
.getMaximumDeviceAllocation()) {
*/
if (deviceMemoryTracker.reserveAllocationIfPossible(Thread.currentThread().getId(), deviceId,
reqMemory)) {
point.setDeviceId(deviceId);
PointersPair pair = memoryProvider.malloc(shape, point, targetMode);
if (pair != null) {
// log.info("PEWPEW");
returnPair.setDevicePointer(pair.getDevicePointer());
point.setAllocationStatus(AllocationStatus.DEVICE);
if (point.getPointers() == null)
throw new RuntimeException("WTF?");
point.getPointers().setDevicePointer(pair.getDevicePointer());
deviceAllocations.get(deviceId).put(point.getObjectId(), point.getObjectId());
val p = point.getBucketId();
if (p != null) {
val m = zeroAllocations.get(point.getBucketId());
// m can be null, if that's point from workspace - just no bucketId for it
if (m != null)
m.remove(point.getObjectId());
}
deviceMemoryTracker.addToAllocation(Thread.currentThread().getId(), deviceId, reqMemory);
// point.tickDeviceWrite();
point.tickHostWrite();
if (!initialize) {
point.tickDeviceWrite();
point.tickHostRead();
} else {
//CudaContext ctx = AtomicAllocator.getInstance().getFlowController().prepareAction(point);
nativeOps.memsetAsync(pair.getDevicePointer(), 0, reqMemory, 0,
context.getSpecialStream());
context.getSpecialStream().synchronize();
point.tickDeviceWrite();
point.tickHostRead();
//AtomicAllocator.getInstance().getFlowController().registerAction(ctx, point);
}
} else {
log.warn("Out of [DEVICE] memory, host memory will be used instead: deviceId: [{}], requested bytes: [{}]",
deviceId, reqMemory);
// if device memory allocation failed (aka returned NULL), keep using host memory instead
returnPair.setDevicePointer(tmpPair.getHostPointer());
point.setAllocationStatus(AllocationStatus.HOST);
Nd4j.getMemoryManager().invokeGc();
try {
Thread.sleep(100);
} catch (Exception e) {
}
}
} else {
log.warn("Hard limit on [DEVICE] memory hit, please consider tuning memory parameters, deviceId [{}]",
deviceId);
Nd4j.getMemoryManager().invokeGc();
try {
Thread.sleep(100);
} catch (Exception e) {
}
}
/* } else {
log.warn("Soft limit on [DEVICE] memory hit, please consider tuning memory parameters, deviceId [{}]",
deviceId);
Nd4j.getMemoryManager().invokeGc();
try {
Thread.sleep(100);
} catch (Exception e) {
}
}*/
return returnPair;
}
default:
throw new IllegalStateException("Can't allocate memory on target [" + targetMode + "]");
}
}
/**
* This method checks if specified device has free memory
*
* @param deviceId
* @param requiredMemory
* @return
*/
@Override
public boolean pingDeviceForFreeMemory(Integer deviceId, long requiredMemory) {
return memoryProvider.pingDeviceForFreeMemory(deviceId, requiredMemory);
}
/**
* Copies specific chunk of memory from one storage to another
*
* Possible directions: HOST -> DEVICE, DEVICE -> HOST
*
* @param currentStatus
* @param targetStatus
* @param point
*/
@Override
public void relocate(AllocationStatus currentStatus, AllocationStatus targetStatus, AllocationPoint point,
AllocationShape shape, CudaContext context) {
//log.info("RELOCATE CALLED: [" +currentStatus+ "] -> ["+targetStatus+"]");
if (currentStatus == AllocationStatus.DEVICE && targetStatus == AllocationStatus.HOST) {
// DEVICE -> HOST
DataBuffer targetBuffer = point.getBuffer();
if (targetBuffer == null)
throw new IllegalStateException("Target buffer is NULL!");
Pointer devicePointer = new CudaPointer(point.getPointers().getDevicePointer().address());
} else if (currentStatus == AllocationStatus.HOST && targetStatus == AllocationStatus.DEVICE) {
// HOST -> DEVICE
// TODO: this probably should be removed
if (point.isConstant()) {
//log.info("Skipping relocation for constant");
return;
}
if (point.getPointers().getDevicePointer() == null) {
throw new IllegalStateException("devicePointer is NULL!");
}
val profD = PerformanceTracker.getInstance().helperStartTransaction();
if (nativeOps.memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(),
AllocationUtils.getRequiredMemory(shape), CudaConstants.cudaMemcpyHostToDevice,
context.getSpecialStream()) == 0)
throw new IllegalStateException("MemcpyAsync relocate H2D failed: [" + point.getHostPointer().address()
+ "] -> [" + point.getDevicePointer().address() + "]");
flowController.commitTransfer(context.getSpecialStream());
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
//context.syncOldStream();
} else
throw new UnsupportedOperationException("Can't relocate data in requested direction: [" + currentStatus
+ "] -> [" + targetStatus + "]");
}
/**
* Copies memory from device to host, if needed.
* Device copy is preserved as is.
*
* @param point
*/
@Override
@Deprecated
public void copyback(AllocationPoint point, AllocationShape shape) {
/*
Technically that's just a case for relocate, with source as point.getAllocationStatus() and target HOST
*/
// log.info("copyback() called on shape: " + point.getShape());
// relocate(point.getAllocationStatus(), AllocationStatus.HOST, point, shape);
throw new UnsupportedOperationException("Deprecated call");
}
/**
* Copies memory from host buffer to device.
* Host copy is preserved as is.
*
* @param point
*/
@Override
@Deprecated
public void copyforward(AllocationPoint point, AllocationShape shape) {
/*
Technically that's just a case for relocate, with source as HOST and target point.getAllocationStatus()
*/
log.info("copyforward() called on tp[" + point.getObjectId() + "], shape: " + point.getShape());
//relocate(AllocationStatus.HOST, point.getAllocationStatus(), point, shape);
throw new UnsupportedOperationException("Deprecated call");
}
/**
* Copies memory from device to zero-copy memory
*
* @param point
* @param shape
*/
@Override
@Deprecated
public void fallback(AllocationPoint point, AllocationShape shape) {
throw new IllegalStateException("Can't fallback from [" + point.getAllocationStatus() + "]");
}
/**
* This method frees memory chunk specified by pointer and location
*
* @param point Pointer
*/
@Override
public void free(AllocationPoint point, AllocationStatus target) {
//if (point.getAllocationStatus() == AllocationStatus.DEVICE)
//deviceAllocations.get(point.getDeviceId()).remove(point.getObjectId());
//zeroAllocations.get(point.getBucketId()).remove(point.getObjectId());
if (point.getAllocationStatus() == AllocationStatus.DEVICE)
deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), point.getDeviceId(),
AllocationUtils.getRequiredMemory(point.getShape()));
memoryProvider.free(point);
}
/**
* This method returns initial allocation location. So, it can be HOST, or DEVICE if environment allows that.
*
* @return
*/
@Override
public AllocationStatus getInitialLocation() {
return INITIAL_LOCATION;
}
/**
* This method initializes specific device for current thread
*
* @param threadId
* @param deviceId
*/
@Override
public void initializeDevice(Long threadId, Integer deviceId) {
/*
JCuda.cudaSetDevice(deviceId);
CudaContext context = new CudaContext();
context.initHandle();
context.initOldStream();
// context.initStream();
context.associateHandle();
contextPool.put(threadId, context);
*/
}
/**
* Asynchronous version of memcpy
*
* PLEASE NOTE: This is device-dependent method, if it's not supported in your environment, blocking call will be used instead.
*
* @param dstBuffer
* @param srcPointer
* @param length
* @param dstOffset
*/
@Override
public void memcpyAsync(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
// we update host memory regardless.
//Pointer dP = new Pointer((point.getAllocationStatus() == AllocationStatus.DEVICE ? point.getPointers().getDevicePointer().address() : point.getPointers().getHostPointer().address()) + dstOffset);
Pointer dP = new CudaPointer((point.getPointers().getHostPointer().address()) + dstOffset);
// Pointer sP = new Pointer(srcPointer.getNativePointer());
//log.info("Location: " + point.getAllocationStatus());
// if (length > 4)
//log.info("memcpyAsync: ["+ srcPointer.getNativePointer()+"] -> ["+ dP.getNativePointer()+"], length: [" + length+ "], offset: ["+ dstOffset+"], dstBufferOffset: ["+(dstBuffer.getElementSize() * dstBuffer.offset()) + "/" + dstBuffer.offset() +"]");
CudaContext tContext = null;
if (dstBuffer.isConstant()) {
org.bytedeco.javacpp.Pointer dstPointer =
new CudaPointer(point.getPointers().getHostPointer().address() + dstOffset, 0L);
org.bytedeco.javacpp.Pointer srcPointerJ = new CudaPointer(srcPointer, length);
// log.info("JCPP Memcpy: [{}] -> [{}], length: [{}]", srcPointerJ.address(), dstPointer.address(), length);
val profD = PerformanceTracker.getInstance().helperStartTransaction();
org.bytedeco.javacpp.Pointer.memcpy(dstPointer, srcPointerJ, length);
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST);
point.tickHostRead();
} else {
//log.info("Memcpy pointers: [{}] -> [{}]", srcPointer.address(), dP.address());
CudaContext context = flowController.prepareAction(point);
tContext = context;
val prof = PerformanceTracker.getInstance().helperStartTransaction();
if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyHostToHost,
context.getSpecialStream()) == 0)
throw new IllegalStateException(
"MemcpyAsync H2H failed: [" + srcPointer.address() + "] -> [" + dP.address() + "]");
flowController.commitTransfer(tContext.getSpecialStream());
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), prof, point.getNumberOfBytes(),MemcpyDirection.HOST_TO_HOST);
if (point.getAllocationStatus() == AllocationStatus.HOST)
flowController.registerAction(context, point);
}
// if we're copying something into host memory, but we're on device - we need to provide exact copy to device as well
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
// TODO: this sounds wrong, and probably memcpy whould check initial direction, like relocate did before
Pointer rDP = new CudaPointer(point.getPointers().getDevicePointer().address() + dstOffset);
if (tContext == null)
tContext = flowController.prepareAction(point);
//log.info("MemcpyAsync to device... [{}] -> [{}]", dP.getNativePointer(), rDP.getNativePointer());
val prof = PerformanceTracker.getInstance().helperStartTransaction();
if (nativeOps.memcpyAsync(rDP, dP, length, CudaConstants.cudaMemcpyHostToDevice,
tContext.getSpecialStream()) == 0)
throw new IllegalStateException(
"MemcpyAsync H2D failed: [" + dP.address() + "] -> [" + rDP.address() + "]");
flowController.commitTransfer(tContext.getSpecialStream());
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), prof, point.getNumberOfBytes(),MemcpyDirection.HOST_TO_DEVICE);
flowController.registerAction(tContext, point);
}
point.tickDeviceWrite();
}
@Override
public void memcpyDevice(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset,
CudaContext context) {
AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
Pointer dP = new CudaPointer((point.getPointers().getDevicePointer().address()) + dstOffset);
if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0)
throw new ND4JIllegalStateException("memcpyAsync failed");
point.tickDeviceWrite();
}
/**
* Special memcpy version, addressing shapeInfoDataBuffer copies
*
* PLEASE NOTE: Blocking H->H, Async H->D
*
* @param dstBuffer
* @param srcPointer
* @param length
* @param dstOffset
*/
@Override
public void memcpySpecial(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
CudaContext context = getCudaContext();
AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
Pointer dP = new CudaPointer((point.getPointers().getHostPointer().address()) + dstOffset);
val profH = PerformanceTracker.getInstance().helperStartTransaction();
if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyHostToHost, context.getOldStream()) == 0)
throw new ND4JIllegalStateException("memcpyAsync failed");
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profH, point.getNumberOfBytes(),MemcpyDirection.HOST_TO_HOST);
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
Pointer rDP = new CudaPointer(point.getPointers().getDevicePointer().address() + dstOffset);
val profD = PerformanceTracker.getInstance().helperStartTransaction();
if (nativeOps.memcpyAsync(rDP, dP, length, CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0)
throw new ND4JIllegalStateException("memcpyAsync failed");
context.syncOldStream();
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profD, point.getNumberOfBytes(),MemcpyDirection.HOST_TO_DEVICE);
}
context.syncOldStream();
point.tickDeviceWrite();
}
/**
* Synchronous version of memcpy.
*
*
* @param dstBuffer
* @param srcPointer
* @param length
* @param dstOffset
*/
@Override
public void memcpyBlocking(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
// internally it's just memcpyAsync + sync
CudaContext context = getCudaContext();
memcpyAsync(dstBuffer, srcPointer, length, dstOffset);
context.syncOldStream();
}
/**
* Synchronous version of memcpy.
*
*
* @param dstBuffer
* @param srcBuffer
*/
@Override
public void memcpy(DataBuffer dstBuffer, DataBuffer srcBuffer) {
//log.info("Buffer MemCpy called");
//log.info("Memcpy buffer: {} bytes ", dstBuffer.length() * dstBuffer.getElementSize());
CudaContext context = getCudaContext();
AllocationPoint dstPoint = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
AllocationPoint srcPoint = ((BaseCudaDataBuffer) srcBuffer).getAllocationPoint();
Pointer dP = null; //new CudaPointer(dstPoint.getPointers().getHostPointer().address());
Pointer sP = null;
MemcpyDirection direction = null;
val profDH = PerformanceTracker.getInstance().helperStartTransaction();
Nd4j.getExecutioner().push();
if (srcPoint.isActualOnDeviceSide()) {
sP = AtomicAllocator.getInstance().getPointer(srcBuffer, context);
dP = AtomicAllocator.getInstance().getPointer(dstBuffer, context);
if (nativeOps.memcpyAsync(dP, sP, srcBuffer.length() * srcBuffer.getElementSize(),
CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0) {
throw new ND4JIllegalStateException("memcpyAsync failed");
}
dstPoint.tickDeviceWrite();
direction = MemcpyDirection.DEVICE_TO_DEVICE;
} else {
sP = AtomicAllocator.getInstance().getHostPointer(srcBuffer);
dP = AtomicAllocator.getInstance().getPointer(dstBuffer, context);
if (nativeOps.memcpyAsync(dP, sP, srcBuffer.length() * srcBuffer.getElementSize(),
CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0) {
throw new ND4JIllegalStateException("memcpyAsync failed");
}
direction = MemcpyDirection.HOST_TO_DEVICE;
}
dstPoint.tickDeviceWrite();
// it has to be blocking call
context.syncOldStream();
PerformanceTracker.getInstance().helperRegisterTransaction(srcPoint.getDeviceId(), profDH / 2, dstPoint.getNumberOfBytes(), direction);
// PerformanceTracker.getInstance().helperRegisterTransaction(dstPoint.getDeviceId(), profDH / 2, dstPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
}
/**
* PLEASE NOTE: Specific implementation, on systems without special devices can return HostPointer here
*
* @param buffer
* @return
*/
@Override
public org.bytedeco.javacpp.Pointer getDevicePointer(DataBuffer buffer, CudaContext context) {
// TODO: It would be awesome to get rid of typecasting here
//getCudaContext().syncOldStream();
AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint();
//log.info("getDevicePointer called");
/*
if (configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED && dstPoint.getAllocationStatus() == AllocationStatus.HOST) {
// if we have constant buffer (aka shapeInfo or other constant stuff)
if (buffer.isConstant()) {
Nd4j.getConstantHandler().moveToConstantSpace(buffer);
} else {
PointersPair pair = memoryProvider.malloc(dstPoint.getShape(), dstPoint, AllocationStatus.DEVICE);
if (pair != null) {
Integer deviceId = getDeviceId();
dstPoint.getPointers().setDevicePointer(pair.getDevicePointer());
dstPoint.setAllocationStatus(AllocationStatus.DEVICE);
deviceAllocations.get(deviceId).put(dstPoint.getObjectId(), dstPoint.getObjectId());
zeroAllocations.get(dstPoint.getBucketId()).remove(dstPoint.getObjectId());
deviceMemoryTracker.addToAllocation(Thread.currentThread().getId(), deviceId, AllocationUtils.getRequiredMemory(dstPoint.getShape()));
dstPoint.tickHostWrite();
}
}
}
*/
// here's the place, where we do care about promotion. but we only care about promotion of original buffers
if (dstPoint.getAllocationStatus() == AllocationStatus.HOST && buffer.offset() == 0 && 1 < 0) {
if (dstPoint.getDeviceTicks() > configuration.getMinimumRelocationThreshold()) {
// at this point we know, that this request is done withing some existent context
long requiredMemory = AllocationUtils.getRequiredMemory(dstPoint.getShape());
if (deviceMemoryTracker.reserveAllocationIfPossible(Thread.currentThread().getId(), getDeviceId(),
requiredMemory) && pingDeviceForFreeMemory(getDeviceId(), requiredMemory)) {
// so, memory is reserved
promoteObject(buffer);
}
}
}
// if that's device state, we probably might want to update device memory state
if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
if (!dstPoint.isActualOnDeviceSide()) {
// log.info("Relocating to GPU");
relocate(AllocationStatus.HOST, AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), context);
} else {
// log.info("Buffer is actual on device side: " + dstPoint.getShape());
}
} //else log.info("Not on [DEVICE]");
// we update memory use counter, to announce that it's somehow used on device
dstPoint.tickDeviceRead();
// return pointer with offset if needed. length is specified for constructor compatibility purposes
CudaPointer p = new CudaPointer(dstPoint.getPointers().getDevicePointer(), buffer.length(),
(buffer.offset() * buffer.getElementSize()));
switch (buffer.dataType()) {
case DOUBLE:
return p.asDoublePointer();
case FLOAT:
return p.asFloatPointer();
case INT:
return p.asIntPointer();
case HALF:
return p.asShortPointer();
default:
return p;
}
}
/**
* PLEASE NOTE: This method always returns pointer within OS memory space
*
* @param buffer
* @return
*/
@Override
public org.bytedeco.javacpp.Pointer getHostPointer(DataBuffer buffer) {
AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint();
// return pointer with offset if needed. length is specified for constructor compatibility purposes
if (dstPoint.getPointers().getHostPointer() == null) {
log.info("DevicePointer: " + dstPoint.getPointers().getDevicePointer());
log.info("HostPointer: " + dstPoint.getPointers().getHostPointer());
log.info("AllocStatus: " + dstPoint.getAllocationStatus());
throw new RuntimeException("pointer is null");
}
//dstPoint.tickHostWrite();
//dstPoint.tickHostRead();
//log.info("Requesting host pointer for {}", buffer);
//getCudaContext().syncOldStream();
synchronizeThreadDevice(Thread.currentThread().getId(), dstPoint.getDeviceId(), dstPoint);
CudaPointer p = new CudaPointer(dstPoint.getPointers().getHostPointer(), buffer.length(),
(buffer.offset() * buffer.getElementSize()));
switch (buffer.dataType()) {
case DOUBLE:
return p.asDoublePointer();
case FLOAT:
return p.asFloatPointer();
case INT:
return p.asIntPointer();
case HALF:
return p.asShortPointer();
default:
return p;
}
}
@Override
public synchronized void relocateObject(DataBuffer buffer) {
AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer);
// we don't relocate non-DEVICE buffers (i.e HOST or CONSTANT)
if (dstPoint.getAllocationStatus() != AllocationStatus.DEVICE)
return;
int deviceId = getDeviceId();
if (dstPoint.getDeviceId() >= 0 && dstPoint.getDeviceId() == deviceId ) {
return;
}
// FIXME: cross-thread access, might cause problems
if (!dstPoint.isActualOnHostSide())
AtomicAllocator.getInstance().synchronizeHostData(buffer);
if (!dstPoint.isActualOnHostSide())
throw new RuntimeException("Buffer synchronization failed");
if (buffer.isAttached() || dstPoint.isAttached()) {
// if this buffer is Attached, we just relocate to new workspace
MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
if (workspace == null) {
// if we're out of workspace, we should mark our buffer as detached, so gc will pick it up eventually
alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false);
CudaContext context = getCudaContext();
val profD = PerformanceTracker.getInstance().helperStartTransaction();
if (nativeOps.memcpyAsync(dstPoint.getDevicePointer(), dstPoint.getHostPointer(),
buffer.length() * buffer.getElementSize(), 1, context.getSpecialStream()) == 0)
throw new ND4JIllegalStateException("memcpyAsync failed");
context.syncSpecialStream();
PerformanceTracker.getInstance().helperRegisterTransaction(dstPoint.getDeviceId(), profD / 2, dstPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
// updating host pointer now
alloc(AllocationStatus.HOST, dstPoint, dstPoint.getShape(), false);
// marking it as detached
dstPoint.setAttached(false);
// marking it as proper on device
dstPoint.tickHostRead();
dstPoint.tickDeviceWrite();
} else {
// this call will automagically take care of workspaces, so it'll be either
//log.info("Relocating to deviceId [{}], workspace [{}]...", deviceId, workspace.getId());
BaseCudaDataBuffer nBuffer = (BaseCudaDataBuffer) Nd4j.createBuffer(buffer.length());
Nd4j.getMemoryManager().memcpy(nBuffer, buffer);
dstPoint.getPointers().setDevicePointer(nBuffer.getAllocationPoint().getDevicePointer());
dstPoint.getPointers().setHostPointer(nBuffer.getAllocationPoint().getHostPointer());
dstPoint.setDeviceId(deviceId);
dstPoint.tickDeviceRead();
dstPoint.tickHostRead();
}
return;
}
if (buffer.isConstant()) {
// we can't relocate or modify buffers
throw new RuntimeException("Can't relocateObject() for constant buffer");
} else {
// log.info("Free relocateObject: deviceId: {}, pointer: {}", deviceId, dstPoint.getPointers().getDevicePointer().address());
memoryProvider.free(dstPoint);
deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), dstPoint.getDeviceId(), AllocationUtils.getRequiredMemory(dstPoint.getShape()));
// we replace original device pointer with new one
alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false);
val profD = PerformanceTracker.getInstance().helperStartTransaction();
CudaContext context = getCudaContext();
if (nativeOps.memcpyAsync(dstPoint.getDevicePointer(), dstPoint.getHostPointer(),
buffer.length() * buffer.getElementSize(), 1, context.getSpecialStream()) == 0)
throw new ND4JIllegalStateException("memcpyAsync failed");
context.syncSpecialStream();
PerformanceTracker.getInstance().helperRegisterTransaction(dstPoint.getDeviceId(), profD, dstPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
dstPoint.tickDeviceRead();
dstPoint.tickHostRead();
}
}
/**
* This method moves specific object from zero-copy memory to device memory
*
* PLEASE NOTE: DO NOT EVER USE THIS METHOD MANUALLY, UNLESS YOU 100% HAVE TO
*
* @return
*/
@Override
public boolean promoteObject(DataBuffer buffer) {
AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer);
if (dstPoint.getAllocationStatus() != AllocationStatus.HOST)
return false;
if (configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED
&& dstPoint.getAllocationStatus() == AllocationStatus.HOST) {
// if we have constant buffer (aka shapeInfo or other constant stuff)
if (buffer.isConstant()) {
Nd4j.getConstantHandler().moveToConstantSpace(buffer);
} else {
PointersPair pair = memoryProvider.malloc(dstPoint.getShape(), dstPoint, AllocationStatus.DEVICE);
if (pair != null) {
Integer deviceId = getDeviceId();
// log.info("Promoting object to device: [{}]", deviceId);
dstPoint.getPointers().setDevicePointer(pair.getDevicePointer());
dstPoint.setAllocationStatus(AllocationStatus.DEVICE);
deviceAllocations.get(deviceId).put(dstPoint.getObjectId(), dstPoint.getObjectId());
zeroAllocations.get(dstPoint.getBucketId()).remove(dstPoint.getObjectId());
deviceMemoryTracker.addToAllocation(Thread.currentThread().getId(), deviceId,
AllocationUtils.getRequiredMemory(dstPoint.getShape()));
dstPoint.tickHostWrite();
} else
throw new RuntimeException("PewPew");
}
}
return true;
}
/**
* This method returns total amount of memory allocated within system
*
* @return
*/
@Override
public Table getAllocationStatistics() {
Table table = HashBasedTable.create();
table.put(AllocationStatus.HOST, 0, zeroUseCounter.get());
for (Integer deviceId : configuration.getAvailableDevices()) {
table.put(AllocationStatus.DEVICE, deviceId, getAllocatedDeviceMemory(deviceId));
}
return table;
}
/**
* This method returns total amount of memory allocated at specified device
*
* @param device
* @return
*/
@Override
public long getAllocatedDeviceMemory(Integer device) {
return deviceMemoryTracker.getAllocatedSize(device);
}
/**
* This method returns total amount of host memory allocated within this MemoryHandler
*
* @return
*/
@Override
public long getAllocatedHostMemory() {
return zeroUseCounter.get();
}
/**
* This method returns total number of object allocated on specified device
*
* @param deviceId
* @return
*/
@Override
public long getAllocatedDeviceObjects(Integer deviceId) {
return deviceAllocations.get(deviceId).size();
}
/**
* This method returns number of allocated objects within specific bucket
*
* @param bucketId
* @return
*/
@Override
public long getAllocatedHostObjects(Long bucketId) {
if (zeroAllocations.containsKey(bucketId))
return zeroAllocations.get(bucketId).size();
else
return 0L;
}
/**
* This method returns total number of allocated objects in host memory
* @return
*/
@Override
public long getAllocatedHostObjects() {
AtomicLong counter = new AtomicLong(0);
for (Long threadId : zeroAllocations.keySet()) {
counter.addAndGet(zeroAllocations.get(threadId).size());
}
return counter.get();
}
/**
* This method returns set of allocation tracking IDs for specific device
*
* @param deviceId
* @return
*/
@Override
public Set getDeviceTrackingPoints(Integer deviceId) {
return deviceAllocations.get(deviceId).keySet();
}
/**
* This method returns sets of allocation tracking IDs for specific bucket
*
* @param bucketId
* @return
*/
@Override
public Set getHostTrackingPoints(Long bucketId) {
if (!zeroAllocations.containsKey(bucketId)) {
return new HashSet<>();
}
return zeroAllocations.get(bucketId).keySet();
}
/**
* This method explicitly removes object from device memory.
*
* @param threadId
* @param objectId
* @param copyback if TRUE, corresponding memory block on JVM side will be updated, if FALSE - memory will be just discarded
*/
@Override
public void purgeDeviceObject(Long threadId, Integer deviceId, Long objectId, AllocationPoint point,
boolean copyback) {
if (point.getAllocationStatus() != AllocationStatus.DEVICE)
return;
flowController.waitTillReleased(point);
free(point, AllocationStatus.DEVICE);
if (!deviceAllocations.get(deviceId).containsKey(objectId))
throw new IllegalStateException("Can't happen ever");
forget(point, AllocationStatus.DEVICE);
if (deviceAllocations.get(deviceId).containsKey(objectId))
throw new IllegalStateException("Can't happen ever");
deviceMemoryTracker.subFromAllocation(threadId, deviceId, AllocationUtils.getRequiredMemory(point.getShape()));
point.setAllocationStatus(AllocationStatus.HOST);
//environment.trackAllocatedMemory(deviceId, AllocationUtils.getRequiredMemory(point.getShape()));
}
/**
* This method explicitly removes object from zero-copy memory.
*
* @param bucketId
* @param objectId
* @param copyback if TRUE, corresponding memory block on JVM side will be updated, if FALSE - memory will be just discarded
*/
@Override
public void purgeZeroObject(Long bucketId, Long objectId, AllocationPoint point, boolean copyback) {
forget(point, AllocationStatus.HOST);
flowController.waitTillReleased(point);
// we call for caseless deallocation here
//JCudaDriver.cuCtxSetCurrent(contextPool.getCuContextForDevice(0));
free(point, AllocationStatus.HOST);
point.setAllocationStatus(AllocationStatus.DEALLOCATED);
long reqMem = AllocationUtils.getRequiredMemory(point.getShape()) * -1;
zeroUseCounter.addAndGet(reqMem);
}
@Override
public void forget(AllocationPoint point, AllocationStatus location) {
if (location == AllocationStatus.DEVICE) {
deviceAllocations.get(point.getDeviceId()).remove(point.getObjectId());
} else if (location == AllocationStatus.HOST) {
zeroAllocations.get(point.getBucketId()).remove(point.getObjectId());
}
}
/**
* This method returns CUDA deviceId for current thread
*
* @return
*/
public Integer getDeviceId() {
int deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
return deviceId;
}
/** Returns {@link #getDeviceId()} wrapped as a {@link Pointer}. */
@Override
public Pointer getDeviceIdPointer() {
return new CudaPointer(getDeviceId());
}
/**
* This method returns set of available devices
* @return
*/
@Override
public Set getAvailableDevices() {
return new HashSet<>(configuration.getAvailableDevices());
}
/**
* This method returns ExternalContext wrapper (if applicable)
* @return
*/
@Override
public ExternalContext getDeviceContext() {
return new ExternalContext(getCudaContext());
}
/**
* This method returns CudaContext for current thread. If context doesn't exist - it gets created first.
* @return
*/
public CudaContext getCudaContext() {
// FIXME: remove this before release
Integer deviceId = getDeviceId();
return contextPool.acquireContextForDevice(deviceId);
}
/**
* This method does initialization for thread.
*
*
* @param threadId
*/
protected void initCudaContextForThread(Long threadId) {
// we set device to be used prior to stream creation
nativeOps.setDevice(getDeviceIdPointer());
CudaContext context = new CudaContext();
context.initHandle();
context.initOldStream();
context.initStream();
context.associateHandle();
//contextPool.put(threadId, context);
}
/**
* This method returns if this MemoryHandler instance is device-dependant (i.e. CUDA)
*
* @return TRUE if dependant, FALSE otherwise
*/
@Override
public boolean isDeviceDependant() {
// this is always TRUE for current implementation
return true;
}
/**
* This method causes memory synchronization on host side.
* Viable only for Device-dependant MemoryHandlers
*
* @param threadId
* @param deviceId
* @param point
*/
@Override
public void synchronizeThreadDevice(Long threadId, Integer deviceId, AllocationPoint point) {
// we synchronize only if this AllocationPoint was used within device context, so for multiple consequent syncs only first one will be issued
flowController.synchronizeToHost(point);
}
@Override
public void registerAction(CudaContext context, INDArray result, INDArray... operands) {
flowController.registerAction(context, result, operands);
}
@Override
public FlowController getFlowController() {
return flowController;
}
@Override
public ContextPool getContextPool() {
return contextPool;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy