org.nd4j.linalg.workspace.BaseWorkspaceMgr Maven / Gradle / Ivy
package org.nd4j.linalg.workspace;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
/**
* A standard/baseline {@link WorkspaceMgr} implementation
*
* @param Array type
* @author Alex Black
*/
@Slf4j
public abstract class BaseWorkspaceMgr> implements WorkspaceMgr {
private static final boolean DISABLE_LEVERAGE = false; //Mainly for debugging/optimization purposes
protected final Set scopeOutOfWs;
protected final Map configMap;
protected final Map workspaceNames;
protected BaseWorkspaceMgr(Set scopeOutOfWs, Map configMap,
Map workspaceNames){
this.scopeOutOfWs = scopeOutOfWs;
this.configMap = configMap;
this.workspaceNames = workspaceNames;
}
protected BaseWorkspaceMgr(){
scopeOutOfWs = new HashSet<>();
configMap = new HashMap<>();
workspaceNames = new HashMap<>();
}
@Override
public void setConfiguration(@NonNull T arrayType, WorkspaceConfiguration configuration) {
configMap.put(arrayType, configuration);
}
@Override
public WorkspaceConfiguration getConfiguration(@NonNull T arrayType) {
return configMap.get(arrayType);
}
@Override
public void setScopedOutFor(@NonNull T arrayType) {
scopeOutOfWs.add(arrayType);
configMap.remove(arrayType);
workspaceNames.remove(arrayType);
}
@Override
public boolean isScopedOut(@NonNull T arrayType) {
return scopeOutOfWs.contains(arrayType);
}
@Override
public boolean hasConfiguration(@NonNull T arrayType){
return scopeOutOfWs.contains(arrayType) || workspaceNames.containsKey(arrayType);
}
@Override
public MemoryWorkspace notifyScopeEntered(@NonNull T arrayType) {
validateConfig(arrayType);
if(isScopedOut(arrayType)){
return Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
} else {
MemoryWorkspace ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
getConfiguration(arrayType), getWorkspaceName(arrayType));
return ws.notifyScopeEntered();
}
}
@Override
public WorkspacesCloseable notifyScopeEntered(@NonNull T... arrayTypes) {
MemoryWorkspace[] ws = new MemoryWorkspace[arrayTypes.length];
for(int i=0; i