Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
package org.nd4j.autodiff.samediff.internal.memory;
import lombok.*;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.*;
/**
* ArrayCacheMemoryMgr reuses arrays to reduce the number of memory allocations and deallocations.
* Memory allocations and deallocations can be quite expensive, especially on GPUs.
* Note that when arrays are reused, they are reused for the same datatype only.
* If caching a released array would result in the the maximum cache size being is exceeded, the oldest arrays will
* be deallocated first, until the new array can in the cache.
*
* By default, the following parameters are used for the cache:
*
*
Maximum cache size: 0.25 x max memory, where:
*
*
CPU: max memory is determined using {@link Pointer#maxBytes()}
*
GPU: max memory is determined using GPU 0 total memory
*
*
Larger array max multiple: 2.0
*
*
This means: if an exact array size can't be provided from the cache, use the next smallest array with a buffer up to 2.0x larger than requested
*
If no cached arrays of size < 2x requested exists, allocate a new array
*
*
Small array threshold: 1024 elements
*
*
This means: the "larger array max multiple" doesn't apply below this level. For example, we might return a size 1 array backed by a size 1023 buffer
*
*
*
* @author Alex Black
*/
@Getter
public class ArrayCacheMemoryMgr extends AbstractMemoryMgr {
private final double maxMemFrac;
private final long smallArrayThreshold;
private final double largerArrayMaxMultiple;
private final long maxCacheBytes;
private final long totalMemBytes;
private long currentCacheSize = 0;
private Map arrayStores = new HashMap<>();
private LinkedHashSet lruCache = new LinkedHashSet<>();
private Map lruCacheValues = new HashMap<>();
/**
* Create an ArrayCacheMemoryMgr with default settings as per {@link ArrayCacheMemoryMgr}
*/
public ArrayCacheMemoryMgr() {
this(0.25, 1024, 2.0);
}
/**
* @param maxMemFrac Maximum memory fraciton to use as cache
* @param smallArrayThreshold Below this size (elements), don't apply the "largerArrayMaxMultiple" rule
* @param largerArrayMaxMultiple Maximum multiple of the requested size to return from the cache. If an array of size
* 1024 is requested, and largerArrayMaxMultiple is 2.0, then we'll return from the cache
* the array with the smallest data buffer up to 2.0*1024 elements; otherwise we'll return
* a new array
*/
public ArrayCacheMemoryMgr(double maxMemFrac, long smallArrayThreshold, double largerArrayMaxMultiple) {
Preconditions.checkArgument(maxMemFrac > 0 && maxMemFrac < 1, "Maximum memory fraction for cache must be between 0.0 and 1.0, got %s", maxMemFrac);
Preconditions.checkArgument(smallArrayThreshold >= 0, "Small array threshould must be >= 0, got %s", smallArrayThreshold);
Preconditions.checkArgument(largerArrayMaxMultiple >= 1.0, "Larger array max multiple must be >= 1.0, got %s", largerArrayMaxMultiple);
this.maxMemFrac = maxMemFrac;
this.smallArrayThreshold = smallArrayThreshold;
this.largerArrayMaxMultiple = largerArrayMaxMultiple;
if(isCpu()){
totalMemBytes = Pointer.maxBytes();
} else {
Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
List devList = (List) p.get("cuda.devicesInformation");
Map m = (Map) devList.get(0);
totalMemBytes = (Long)m.get("cuda.totalMemory");
}
maxCacheBytes = (long)(maxMemFrac * totalMemBytes);
}
private boolean isCpu(){
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
return !"CUDA".equalsIgnoreCase(backend);
}
@Override
public INDArray allocate(boolean detached, DataType dataType, long... shape) {
if (arrayStores.containsKey(dataType)) {
INDArray arr = arrayStores.get(dataType).get(shape);
if (arr != null) {
//Decrement cache size
currentCacheSize -= dataType.width() * arr.data().length();
return arr; //Allocated from cache
}
}
//Allocation failed, allocate new array
return Nd4j.createUninitializedDetached(dataType, shape);
}
@Override
public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) {
return allocate(detached, descriptor.dataType(), descriptor.getShape());
}
@Override
public void release(@NonNull INDArray array) {
//Check for multiple releases of the array
long id = array.getId();
Preconditions.checkState(!lruCache.contains(id), "Array was released multiple times: id=%s, shape=%ndShape", id, array);
DataType dt = array.dataType();
long thisBytes = array.data().length() * dt.width();
if(array.dataType() == DataType.UTF8) {
//Don't cache string arrays due to variable length buffers
if(array.closeable())
array.close();
} else if (currentCacheSize + thisBytes > maxCacheBytes) {
if(thisBytes > maxCacheBytes){
//Can't store even if we clear everything - too large
if(array.closeable())
array.close();
return;
}
//Need to deallocate some arrays to stay under limit - do in "oldest first" order
Iterator iter = lruCache.iterator();
while(currentCacheSize + thisBytes > maxCacheBytes){
long next = iter.next();
iter.remove();
INDArray nextOldest = lruCacheValues.remove(next);
DataType ndt = nextOldest.dataType();
long nextBytes = ndt.width() * nextOldest.data().length();
arrayStores.get(ndt).removeObject(nextOldest);
currentCacheSize -= nextBytes;
if(nextOldest.closeable())
nextOldest.close();
}
//After clearing space - can now cache
cacheArray(array);
} else {
//OK to cache
cacheArray(array);
}
//Store in LRU cache for "last used" removal if we exceed cache size
lruCache.add(array.getId());
lruCacheValues.put(array.getId(), array);
}
private void cacheArray(INDArray array){
DataType dt = array.dataType();
if (!arrayStores.containsKey(dt))
arrayStores.put(dt, new ArrayStore());
arrayStores.get(dt).add(array);
currentCacheSize += array.data().length() * dt.width();
lruCache.add(array.getId());
lruCacheValues.put(array.getId(), array);
}
@Override
public void close() {
for (ArrayStore as : arrayStores.values()) {
as.close();
}
}
@Getter
public class ArrayStore {
private INDArray[] sorted = new INDArray[1000]; //TODO resizing, don't hardcode
private long[] lengths = new long[1000];
private long lengthSum;
private long bytesSum;
private int size;
private void add(@NonNull INDArray array) {
//Resize arrays
if(size == sorted.length){
sorted = Arrays.copyOf(sorted, 2*sorted.length);
lengths = Arrays.copyOf(lengths, 2*lengths.length);
}
long length = array.data().length();
int idx = Arrays.binarySearch(lengths, 0, size, length);
if (idx < 0) {
idx = -idx - 1; //See binarySearch javadoc
}
for (int i = size - 1; i >= idx; i--) {
sorted[i + 1] = sorted[i];
lengths[i + 1] = lengths[i];
}
sorted[idx] = array;
lengths[idx] = length;
size++;
lengthSum += length;
bytesSum += length * array.dataType().width();
}
private INDArray get(long[] shape) {
if (size == 0)
return null;
long length = shape.length == 0 ? 1 : ArrayUtil.prod(shape);
int idx = Arrays.binarySearch(lengths, 0, size, length);
if (idx < 0) {
idx = -idx - 1;
if (idx >= size) {
//Largest array is smaller than required -> can't return from cache
return null;
}
INDArray nextSmallest = sorted[idx];
long nextSmallestLength = nextSmallest.data().length();
long nextSmallestLengthBytes = nextSmallestLength * nextSmallest.dataType().width();
boolean tooLarge = (length > (long) (nextSmallestLength * largerArrayMaxMultiple));
if (nextSmallestLengthBytes > smallArrayThreshold && tooLarge) {
return null;
} // If less than smallArrayThreshold, ok, return as is
}
//Remove
INDArray arr = removeIdx(idx);
lruCache.remove(arr.getId());
lruCacheValues.remove(arr.getId());
//Create a new array with the specified buffer. This is for 2 reasons:
//(a) the cached array and requested array sizes may differ (though this is easy to check for)
//(b) Some SameDiff array use tracking uses *object identity* - so we want different objects when reusing arrays
// to avoid issues there
return Nd4j.create(arr.data(), shape);
}
private void removeObject(INDArray array){
long length = array.data().length();
int idx = Arrays.binarySearch(lengths, 0, size, length);
Preconditions.checkState(idx > 0, "Cannot remove array from ArrayStore: no array with this length exists in the cache");
boolean found = false;
int i = 0;
while(!found && i <= size && lengths[i] == length){
found = sorted[i++] == array; //Object equality
}
Preconditions.checkState(found, "Cannot remove array: not found in ArrayCache");
removeIdx(i - 1);
}
private INDArray removeIdx(int idx){
INDArray arr = sorted[idx];
for (int i = idx; i < size; i++) {
sorted[i] = sorted[i + 1];
lengths[i] = lengths[i + 1];
}
sorted[size] = null;
lengths[size] = 0;
size--;
bytesSum -= (arr.data().length() * arr.dataType().width());
lengthSum -= arr.data().length();
return arr;
}
private void close() {
for (int i = 0; i < size; i++) {
if (sorted[i].closeable())
sorted[i].close();
lengths[i] = 0;
}
lengthSum = 0;
bytesSum = 0;
size = 0;
}
}
}