org.nd4j.linalg.memory.BasicMemoryManager Maven / Gradle / Ivy
package org.nd4j.linalg.memory;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.memory.abstracts.DummyWorkspace;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
/**
* @author [email protected]
*/
public abstract class BasicMemoryManager implements MemoryManager {
protected AtomicInteger frequency = new AtomicInteger(0);
protected AtomicLong freqCounter = new AtomicLong(0);
protected AtomicLong lastGcTime = new AtomicLong(System.currentTimeMillis());
protected AtomicBoolean periodicEnabled = new AtomicBoolean(true);
protected AtomicInteger averageLoopTime = new AtomicInteger(0);
protected AtomicInteger noGcWindow = new AtomicInteger(100);
protected AtomicBoolean averagingEnabled = new AtomicBoolean(false);
protected static final int intervalTail = 100;
protected Queue intervals = new ConcurrentLinkedQueue<>();
private ThreadLocal workspace = new ThreadLocal<>();
private ThreadLocal tempWorkspace = new ThreadLocal<>();
/**
* This method returns
* PLEASE NOTE: Cache options
* depend on specific implementations
*
* @param bytes
* @param kind
* @param initialize
*/
@Override
public Pointer allocate(long bytes, MemoryKind kind, boolean initialize) {
throw new UnsupportedOperationException("This method isn't available for this backend");
}
/**
* This method detaches off-heap memory from passed INDArray instances, and optionally stores them in cache for future reuse
* PLEASE NOTE: Cache options depend on specific implementations
*
* @param arrays
*/
@Override
public void collect(INDArray... arrays) {
throw new UnsupportedOperationException("This method isn't implemented yet");
}
@Override
public void toggleAveraging(boolean enabled) {
averagingEnabled.set(enabled);
}
/**
* This method purges all cached memory chunks
*
*/
@Override
public void purgeCaches() {
throw new UnsupportedOperationException("This method isn't implemented yet");
}
@Override
public void memcpy(DataBuffer dstBuffer, DataBuffer srcBuffer) {
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
Pointer.memcpy(dstBuffer.addressPointer(), srcBuffer.addressPointer(),
srcBuffer.length() * srcBuffer.getElementSize());
PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, srcBuffer.length() * srcBuffer.getElementSize(), MemcpyDirection.HOST_TO_HOST);
}
@Override
public void notifyScopeEntered() {
// TODO: to be implemented
}
@Override
public void notifyScopeLeft() {
// TODO: to be implemented
}
@Override
public void invokeGcOccasionally() {
long currentTime = System.currentTimeMillis();
if (averagingEnabled.get())
intervals.add((int) (currentTime - lastGcTime.get()));
// not sure if we want to conform autoGcWindow here...
if (frequency.get() > 0)
if (freqCounter.incrementAndGet() % frequency.get() == 0
&& currentTime > getLastGcTime() + getAutoGcWindow()) {
System.gc();
lastGcTime.set(System.currentTimeMillis());
}
if (averagingEnabled.get())
if (intervals.size() > intervalTail)
intervals.remove();
}
@Override
public void invokeGc() {
System.gc();
lastGcTime.set(System.currentTimeMillis());
}
@Override
public boolean isPeriodicGcActive() {
return periodicEnabled.get();
}
@Override
public void setOccasionalGcFrequency(int frequency) {
this.frequency.set(frequency);
}
@Override
public void setAutoGcWindow(int windowMillis) {
noGcWindow.set(windowMillis);
}
@Override
public int getAutoGcWindow() {
return noGcWindow.get();
}
@Override
public int getOccasionalGcFrequency() {
return frequency.get();
}
@Override
public long getLastGcTime() {
return lastGcTime.get();
}
@Override
public void togglePeriodicGc(boolean enabled) {
periodicEnabled.set(enabled);
}
@Override
public int getAverageLoopTime() {
if (averagingEnabled.get()) {
int cnt = 0;
for (Integer value : intervals) {
cnt += value;
}
cnt /= intervals.size();
return cnt;
} else
return 0;
}
@Override
public MemoryWorkspace getCurrentWorkspace() {
return workspace.get();
}
@Override
public void setCurrentWorkspace(MemoryWorkspace workspace) {
this.workspace.set(workspace);
}
@Override
public MemoryWorkspace scopeOutOfWorkspaces() {
MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
if (workspace == null)
return new DummyWorkspace();
else {
//Nd4j.getMemoryManager().setCurrentWorkspace(null);
return new DummyWorkspace().notifyScopeEntered();//workspace.tagOutOfScopeUse();
}
}
}