All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.samediff.internal.memory.CloseValidationMemoryMgr Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.autodiff.samediff.internal.memory;

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.DependencyList;
import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.primitives.Pair;

import java.util.*;

/**
 * A {@link SessionMemMgr} that wraps an existing memory manager, to ensure that:
* - All arrays that are supposed to be closed, have been closed
* - Arrays are only passed to the close method exactly one (unless they are requested outputs)
* - Arrays that are passed to the close method were originally allocated by the session memory manager
*
* How to use:
* 1. Perform an inference or training iteration, as normal
* 2. Call {@link #assertAllReleasedExcept(Collection)} with the output arrays
*

* NOTE: This is intended for debugging and testing only * * @author Alex Black */ @Slf4j public class CloseValidationMemoryMgr extends AbstractMemoryMgr implements SessionMemMgr { private final SameDiff sd; private final SessionMemMgr underlying; private final Map released = new IdentityHashMap<>(); public CloseValidationMemoryMgr(SameDiff sd, SessionMemMgr underlying) { this.sd = sd; this.underlying = underlying; } @Override public INDArray allocate(boolean detached, DataType dataType, long... shape) { INDArray out = underlying.allocate(detached, dataType, shape); released.put(out, false); return out; } @Override public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) { INDArray out = underlying.allocate(detached, descriptor); released.put(out, false); return out; } @Override public void release(INDArray array) { Preconditions.checkState(released.containsKey(array), "Attempting to release an array that was not allocated by" + " this memory manager: id=%s", array.getId()); if (released.get(array)) { //Already released InferenceSession is = sd.getSessions().get(Thread.currentThread().getId()); IdentityDependencyTracker arrayUseTracker = is.getArrayUseTracker(); DependencyList dl = arrayUseTracker.getDependencies(array); System.out.println(dl); if (dl.getDependencies() != null) { for (InferenceSession.Dep d : dl.getDependencies()) { System.out.println(d + ": " + arrayUseTracker.isSatisfied(d)); } } if (dl.getOrDependencies() != null) { for (Pair p : dl.getOrDependencies()) { System.out.println(p + " - (" + arrayUseTracker.isSatisfied(p.getFirst()) + "," + arrayUseTracker.isSatisfied(p.getSecond())); } } } Preconditions.checkState(!released.get(array), "Attempting to release an array that was already deallocated by" + " an earlier release call to this memory manager: id=%s", array.getId()); log.trace("Released array: id = {}", array.getId()); released.put(array, true); } @Override public void close() { underlying.close(); } /** * Check that all arrays have been released (after an inference call) except for the specified arrays. * * @param except Arrays that should not have been closed (usually network outputs) */ public void assertAllReleasedExcept(@NonNull Collection except) { Set allVarPhConst = null; for (INDArray arr : except) { if (!released.containsKey(arr)) { //Check if constant, variable or placeholder - maybe user requested that out if (allVarPhConst == null) allVarPhConst = identitySetAllConstPhVar(); if (allVarPhConst.contains(arr)) continue; //OK - output is a constant, variable or placeholder, hence it's fine it's not allocated by the memory manager throw new IllegalStateException("Array " + arr.getId() + " was not originally allocated by the memory manager"); } boolean released = this.released.get(arr); if (released) { throw new IllegalStateException("Specified output array (id=" + arr.getId() + ") should not have been deallocated but was"); } } Set exceptSet = Collections.newSetFromMap(new IdentityHashMap()); exceptSet.addAll(except); int numNotClosed = 0; Set notReleased = Collections.newSetFromMap(new IdentityHashMap()); InferenceSession is = sd.getSessions().get(Thread.currentThread().getId()); IdentityDependencyTracker arrayUseTracker = is.getArrayUseTracker(); for (Map.Entry e : released.entrySet()) { INDArray a = e.getKey(); if (!exceptSet.contains(a)) { boolean b = e.getValue(); if (!b) { notReleased.add(a); numNotClosed++; log.info("Not released: array id {}", a.getId()); DependencyList list = arrayUseTracker.getDependencies(a); List l = list.getDependencies(); List> l2 = list.getOrDependencies(); if (l != null) { for (InferenceSession.Dep d : l) { if (!arrayUseTracker.isSatisfied(d)) { log.info(" Not satisfied: {}", d); } } } if (l2 != null) { for (Pair d : l2) { if (!arrayUseTracker.isSatisfied(d.getFirst()) && !arrayUseTracker.isSatisfied(d.getSecond())) { log.info(" Not satisfied: {}", d); } } } } } } if (numNotClosed > 0) { System.out.println(sd.summary()); throw new IllegalStateException(numNotClosed + " arrays were not released but should have been"); } } protected Set identitySetAllConstPhVar() { Set set = Collections.newSetFromMap(new IdentityHashMap()); for (SDVariable v : sd.variables()) { if (v.getVariableType() == VariableType.VARIABLE || v.getVariableType() == VariableType.CONSTANT || v.getVariableType() == VariableType.PLACEHOLDER) { set.add(v.getArr()); } } return set; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy