All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.api.Word2VecHandler Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.api;

import hex.schemas.Word2VecSynonymsV3;
import hex.schemas.Word2VecTransformV3;
import hex.word2vec.Word2VecModel;
import water.DKV;
import water.api.Handler;
import water.api.schemas3.KeyV3;
import water.fvec.Frame;

import java.util.*;

public class Word2VecHandler extends Handler {

  public Word2VecSynonymsV3 findSynonyms(int version, Word2VecSynonymsV3 args) {
    Word2VecModel model = DKV.getGet(args.model.key());
    if (model == null)
      throw new IllegalArgumentException("missing source model " + args.model);

    Map synonyms = model.findSynonyms(args.word, args.count);

    List> result = new ArrayList<>(synonyms.entrySet());
    Collections.sort(result, new Comparator>() {
      @Override
      public int compare(Map.Entry o1, Map.Entry o2) {
        return o2.getValue().compareTo(o1.getValue()); // reverse sort
      }
    });
    args.synonyms = new String[result.size()];
    args.scores = new double[result.size()];
    int i = 0;
    for (Map.Entry entry : result) {
      args.synonyms[i] = entry.getKey();
      args.scores[i] = entry.getValue();
      i++;
    }
    return args;
  }

  public Word2VecTransformV3 transform(int version, Word2VecTransformV3 args) {
    Word2VecModel model = DKV.getGet(args.model.key());
    if (model == null)
      throw new IllegalArgumentException("missing source model " + args.model);

    Frame words = DKV.getGet(args.words_frame.key());
    if (words == null)
      throw new IllegalArgumentException("missing words frame " + args.words_frame);

    if (words.numCols() != 1) {
      throw new IllegalArgumentException("words frame is expected to have a single string column, got" + words.numCols());
    }

    if (args.aggregate_method == null)
      args.aggregate_method = Word2VecModel.AggregateMethod.NONE;

    Frame vectors = model.transform(words.vec(0), args.aggregate_method);
    args.vectors_frame = new KeyV3.FrameKeyV3(vectors._key);
    return args;
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy