hex.api.Word2VecHandler Maven / Gradle / Ivy
package hex.api;
import hex.schemas.Word2VecSynonymsV3;
import hex.schemas.Word2VecTransformV3;
import hex.word2vec.Word2VecModel;
import water.DKV;
import water.api.Handler;
import water.api.schemas3.KeyV3;
import water.fvec.Frame;
import java.util.*;
public class Word2VecHandler extends Handler {
public Word2VecSynonymsV3 findSynonyms(int version, Word2VecSynonymsV3 args) {
Word2VecModel model = DKV.getGet(args.model.key());
if (model == null)
throw new IllegalArgumentException("missing source model " + args.model);
Map synonyms = model.findSynonyms(args.word, args.count);
List> result = new ArrayList<>(synonyms.entrySet());
Collections.sort(result, new Comparator>() {
@Override
public int compare(Map.Entry o1, Map.Entry o2) {
return o2.getValue().compareTo(o1.getValue()); // reverse sort
}
});
args.synonyms = new String[result.size()];
args.scores = new double[result.size()];
int i = 0;
for (Map.Entry entry : result) {
args.synonyms[i] = entry.getKey();
args.scores[i] = entry.getValue();
i++;
}
return args;
}
public Word2VecTransformV3 transform(int version, Word2VecTransformV3 args) {
Word2VecModel model = DKV.getGet(args.model.key());
if (model == null)
throw new IllegalArgumentException("missing source model " + args.model);
Frame words = DKV.getGet(args.words_frame.key());
if (words == null)
throw new IllegalArgumentException("missing words frame " + args.words_frame);
if (words.numCols() != 1) {
throw new IllegalArgumentException("words frame is expected to have a single string column, got" + words.numCols());
}
if (args.aggregate_method == null)
args.aggregate_method = Word2VecModel.AggregateMethod.NONE;
Frame vectors = model.transform(words.vec(0), args.aggregate_method);
args.vectors_frame = new KeyV3.FrameKeyV3(vectors._key);
return args;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy