org.nd4j.jita.memory.impl.CudaFullCachingProvider Maven / Gradle / Ivy
package org.nd4j.jita.memory.impl;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
/**
* This MemoryProvider implementation does caching for both host and device memory within predefined limits.
*
* @author [email protected]
*/
public class CudaFullCachingProvider extends CudaCachingZeroProvider {
protected final long MAX_GPU_ALLOCATION = configuration.getMaximumSingleDeviceAllocation();
protected final long MAX_GPU_CACHE = configuration.getMaximumDeviceCache();
protected volatile ConcurrentHashMap> deviceCache = new ConcurrentHashMap<>();
private static Logger log = LoggerFactory.getLogger(CudaFullCachingProvider.class);
@Override
public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) {
long reqMemory = AllocationUtils.getRequiredMemory(shape);
if (location == AllocationStatus.DEVICE && reqMemory < MAX_GPU_ALLOCATION) {
ensureDeviceCacheHolder(point.getDeviceId(), shape);
CacheHolder cache = deviceCache.get(point.getDeviceId()).get(shape);
if (cache != null) {
Pointer pointer = cache.poll();
if (pointer != null) {
cacheDeviceHit.incrementAndGet();
deviceCachedAmount.addAndGet(-1 * reqMemory);
// log.info("Serving from cache {} bytes", reqMemory);
PointersPair pair = new PointersPair();
pair.setDevicePointer(pointer);
point.setAllocationStatus(AllocationStatus.DEVICE);
return pair;
}
}
cacheDeviceMiss.incrementAndGet();
return super.malloc(shape, point, location);
}
return super.malloc(shape, point, location);
}
@Override
public void free(AllocationPoint point) {
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
AllocationShape shape = point.getShape();
long reqMemory = AllocationUtils.getRequiredMemory(shape);
// we don't cache too big objects
if (reqMemory > MAX_GPU_ALLOCATION || deviceCachedAmount.get() >= MAX_GPU_CACHE) {
super.free(point);
return;
}
ensureDeviceCacheHolder(point.getDeviceId(), shape);
CacheHolder cache = deviceCache.get(point.getDeviceId()).get(shape);
// memory chunks < threshold will be cached no matter what
if (reqMemory <= FORCED_CACHE_THRESHOLD) {
cache.put(new CudaPointer(point.getDevicePointer().address()));
return;
} else {
long cacheEntries = cache.size();
long cacheHeight = deviceCache.get(point.getDeviceId()).size();
// total memory allocated within this bucket
long cacheDepth = cacheEntries * reqMemory;
//if (cacheDepth < MAX_CACHED_MEMORY / cacheHeight) {
cache.put(new CudaPointer(point.getDevicePointer().address()));
return;
//} else {
// super.free(point);
// }
}
}
super.free(point);
}
protected void ensureDeviceCacheHolder(Integer deviceId, AllocationShape shape) {
if (!deviceCache.containsKey(deviceId)) {
try {
singleLock.acquire();
if (!deviceCache.containsKey(deviceId)) {
deviceCache.put(deviceId, new ConcurrentHashMap());
}
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
singleLock.release();
}
}
if (!deviceCache.get(deviceId).containsKey(shape)) {
try {
singleLock.acquire();
if (!deviceCache.get(deviceId).containsKey(shape)) {
deviceCache.get(deviceId).put(shape, new CacheHolder(shape, deviceCachedAmount));
}
} catch (Exception e) {
} finally {
singleLock.release();
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy