org.deeplearning4j.models.embeddings.reader.impl.TreeModelUtils Maven / Gradle / Ivy
package org.deeplearning4j.models.embeddings.reader.impl;
import lombok.NonNull;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.util.SetUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
/**
* This is VPTree-based implementation for wordsNearest method, suited for multiple consequent calls.
* Please note: VPTree will take some memory, dependant on your model size.
*
* @author [email protected]
*/
public class TreeModelUtils extends BasicModelUtils {
protected VPTree vpTree;
@Override
public void init(@NonNull WeightLookupTable lookupTable) {
super.init(lookupTable);
vpTree = null;
}
protected synchronized void checkTree() {
// build new tree if it wasn't created before
if (vpTree == null) {
List points = new ArrayList<>();
for (String word : vocabCache.words()) {
points.add(new DataPoint(vocabCache.indexOf(word), lookupTable.vector(word)));
}
vpTree = new VPTree(points);
}
}
/**
* This method returns nearest words for target word, based on tree structure.
* This method is recommended to use if you're going to call for nearest words multiple times.
* VPTree will be built upon firt call to this method
*
* @param label label of element we're looking nearest words to
* @param n number of nearest elements to return
* @return
*/
@Override
public Collection wordsNearest(String label, int n) {
if (!vocabCache.hasToken(label))
return new ArrayList<>();
Collection collection = wordsNearest(Arrays.asList(label), new ArrayList(), n + 1);
if (collection.contains(label))
collection.remove(label);
return collection;
}
@Override
public Collection wordsNearest(Collection positive, Collection negative, int top) {
// Check every word is in the model
for (String p : SetUtils.union(new HashSet<>(positive), new HashSet<>(negative))) {
if (!vocabCache.containsWord(p)) {
return new ArrayList<>();
}
}
INDArray words = Nd4j.create(positive.size() + negative.size(), lookupTable.layerSize());
int row = 0;
for (String s : positive) {
words.putRow(row++, lookupTable.vector(s));
}
for (String s : negative) {
words.putRow(row++, lookupTable.vector(s).mul(-1));
}
INDArray mean = words.isMatrix() ? words.mean(0) : words;
return wordsNearest(mean, top);
}
@Override
public Collection wordsNearest(INDArray words, int top) {
checkTree();
List add = new ArrayList<>();
List distances = new ArrayList<>();
// we need n+1 to address original datapoint removal
vpTree.search(new DataPoint(0, words), top, add, distances);
Collection ret = new ArrayList<>();
for (DataPoint e : add) {
String word = vocabCache.wordAtIndex(e.getIndex());
ret.add(word);
}
return super.wordsNearest(words, top);
}
}