All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.samediff.internal.DependencyMapIdentity Maven / Gradle / Ivy

The newest version!
package org.nd4j.autodiff.samediff.internal; 
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.function.Predicate;

public class DependencyMapIdentity implements IDependencyMap {
    //IDependeeGroup will act as dummy interface and will be ignored

    private IdentityHashMap> map = new IdentityHashMap>();  
    @Override
    public void clear() {
        map.clear();
    }

    @Override
    public void add(K dependeeGroup, V element) {
      HashSet s = map.get(dependeeGroup);
      if(s==null){
        s= new HashSet ();
        map.put(dependeeGroup, s);
      }
       s.add(element);
    }

    @Override
    public Iterable getDependantsForEach(K dependeeGroup) {
        return map.get(dependeeGroup);
    }

    @Override
    public Iterable getDependantsForGroup(K dependeeGroup) {
        return map.get(dependeeGroup);
    }

    @Override
    public boolean containsAny(K dependeeGroup) {
        return map.containsKey(dependeeGroup);
    }

    @Override
    public boolean containsAnyForGroup(K dependeeGroup) {
        return map.containsKey(dependeeGroup);
    }

    @Override
    public boolean isEmpty() {
        return map.isEmpty();
    }

    @Override
    public void removeGroup(K dependeeGroup) {
        map.remove(dependeeGroup);
    }

    @Override
    public Iterable removeGroupReturn(K dependeeGroup) {
        return map.remove(dependeeGroup);
    }

    @Override
    public void removeForEach(K dependeeGroup) {
          map.remove(dependeeGroup);
    }

    @Override
    public Iterable removeForEachResult(K dependeeGroup) {
        return map.remove(dependeeGroup);
    }

    @Override
    public Iterable removeGroupReturn(K dependeeGroup, Predicate predicate) {
        HashSet s= new HashSet ();
        HashSet ret = map.get(dependeeGroup);
        if(ret!=null){
            long prevSize = ret.size();
            for (V v : ret) {
                if(predicate.test(v)) s.add(v);
            }
            for (V v : s) {
                ret.remove(s);
            }
            //remove the key as well
            if(prevSize == s.size()){
                //remove the key
                //as we are testing containsAny using key
                map.remove(dependeeGroup);
            }
        }
        return s;
    }
    
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy