org.nd4j.autodiff.samediff.internal.memory.CloseValidationMemoryMgr Maven / Gradle / Ivy
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.autodiff.samediff.internal.memory;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.DependencyList;
import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.common.primitives.Pair;
import java.util.*;
@Slf4j
public class CloseValidationMemoryMgr extends AbstractMemoryMgr implements SessionMemMgr {
private final SameDiff sd;
private final SessionMemMgr underlying;
private final Map released = new IdentityHashMap<>();
public CloseValidationMemoryMgr(SameDiff sd, SessionMemMgr underlying) {
this.sd = sd;
this.underlying = underlying;
}
@Override
public INDArray allocate(boolean detached, DataType dataType, long... shape) {
INDArray out = underlying.allocate(detached, dataType, shape);
released.put(out, false);
return out;
}
@Override
public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) {
INDArray out = underlying.allocate(detached, descriptor);
released.put(out, false);
return out;
}
@Override
public void release(INDArray array) {
Preconditions.checkState(released.containsKey(array), "Attempting to release an array that was not allocated by" +
" this memory manager: id=%s", array.getId());
if (released.get(array)) {
//Already released
InferenceSession is = sd.getSessions().get(Thread.currentThread().getId());
IdentityDependencyTracker arrayUseTracker = is.getArrayUseTracker();
DependencyList dl = arrayUseTracker.getDependencies(array);
System.out.println(dl);
if (dl.getDependencies() != null) {
for (InferenceSession.Dep d : dl.getDependencies()) {
System.out.println(d + ": " + arrayUseTracker.isSatisfied(d));
}
}
if (dl.getOrDependencies() != null) {
for (Pair p : dl.getOrDependencies()) {
System.out.println(p + " - (" + arrayUseTracker.isSatisfied(p.getFirst()) + "," + arrayUseTracker.isSatisfied(p.getSecond()));
}
}
}
Preconditions.checkState(!released.get(array), "Attempting to release an array that was already deallocated by" +
" an earlier release call to this memory manager: id=%s", array.getId());
log.trace("Released array: id = {}", array.getId());
released.put(array, true);
}
@Override
public void close() {
underlying.close();
}
/**
* Check that all arrays have been released (after an inference call) except for the specified arrays.
*
* @param except Arrays that should not have been closed (usually network outputs)
*/
public void assertAllReleasedExcept(@NonNull Collection except) {
Set allVarPhConst = null;
for (INDArray arr : except) {
if (!released.containsKey(arr)) {
//Check if constant, variable or placeholder - maybe user requested that out
if (allVarPhConst == null)
allVarPhConst = identitySetAllConstPhVar();
if (allVarPhConst.contains(arr))
continue; //OK - output is a constant, variable or placeholder, hence it's fine it's not allocated by the memory manager
throw new IllegalStateException("Array " + arr.getId() + " was not originally allocated by the memory manager");
}
boolean released = this.released.get(arr);
if (released) {
throw new IllegalStateException("Specified output array (id=" + arr.getId() + ") should not have been deallocated but was");
}
}
Set exceptSet = Collections.newSetFromMap(new IdentityHashMap());
exceptSet.addAll(except);
int numNotClosed = 0;
Set notReleased = Collections.newSetFromMap(new IdentityHashMap());
InferenceSession is = sd.getSessions().get(Thread.currentThread().getId());
IdentityDependencyTracker arrayUseTracker = is.getArrayUseTracker();
for (Map.Entry e : released.entrySet()) {
INDArray a = e.getKey();
if (!exceptSet.contains(a)) {
boolean b = e.getValue();
if (!b) {
notReleased.add(a);
numNotClosed++;
log.info("Not released: array id {}", a.getId());
DependencyList list = arrayUseTracker.getDependencies(a);
List l = list.getDependencies();
List> l2 = list.getOrDependencies();
if (l != null) {
for (InferenceSession.Dep d : l) {
if (!arrayUseTracker.isSatisfied(d)) {
log.info(" Not satisfied: {}", d);
}
}
}
if (l2 != null) {
for (Pair d : l2) {
if (!arrayUseTracker.isSatisfied(d.getFirst()) && !arrayUseTracker.isSatisfied(d.getSecond())) {
log.info(" Not satisfied: {}", d);
}
}
}
}
}
}
if (numNotClosed > 0) {
System.out.println(sd.summary());
throw new IllegalStateException(numNotClosed + " arrays were not released but should have been");
}
}
protected Set identitySetAllConstPhVar() {
Set set = Collections.newSetFromMap(new IdentityHashMap());
for (SDVariable v : sd.variables()) {
if (v.getVariableType() == VariableType.VARIABLE || v.getVariableType() == VariableType.CONSTANT || v.getVariableType() == VariableType.PLACEHOLDER) {
set.add(v.getArr());
}
}
return set;
}
}