org.deeplearning4j.models.glove.count.CountMap Maven / Gradle / Ivy
package org.deeplearning4j.models.glove.count;
import com.google.common.util.concurrent.AtomicDouble;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.nd4j.linalg.primitives.Pair;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* Drop-in replacement for CounterMap
*
* WORK IN PROGRESS, PLEASE DO NOT USE
*
* @author [email protected]
*/
public class CountMap {
private volatile Map, AtomicDouble> backingMap = new ConcurrentHashMap<>();
public CountMap() {
// placeholder
}
public void incrementCount(T element1, T element2, double weight) {
Pair tempEntry = new Pair<>(element1, element2);
if (backingMap.containsKey(tempEntry)) {
backingMap.get(tempEntry).addAndGet(weight);
} else {
backingMap.put(tempEntry, new AtomicDouble(weight));
}
}
public void removePair(T element1, T element2) {
Pair tempEntry = new Pair<>(element1, element2);
backingMap.remove(tempEntry);
}
public void removePair(Pair pair) {
backingMap.remove(pair);
}
public double getCount(T element1, T element2) {
Pair tempEntry = new Pair<>(element1, element2);
if (backingMap.containsKey(tempEntry)) {
return backingMap.get(tempEntry).get();
} else
return 0;
}
public double getCount(Pair pair) {
if (backingMap.containsKey(pair)) {
return backingMap.get(pair).get();
} else
return 0;
}
public Iterator> getPairIterator() {
return new Iterator>() {
private Iterator> iterator = backingMap.keySet().iterator();
@Override
public boolean hasNext() {
return iterator.hasNext();
}
@Override
public Pair next() {
//MapEntry entry = iterator.next();
return iterator.next(); //new Pair<>(entry.getElement1(), entry.getElement2());
}
@Override
public void remove() {
throw new UnsupportedOperationException("remove() isn't supported here");
}
};
}
public int size() {
return backingMap.size();
}
}