All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.word2vec.Word2VecModel Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.word2vec;

import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.word2vec.Word2VecModel.Word2VecOutput;
import hex.word2vec.Word2VecModel.Word2VecParameters;
import water.*;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.util.IcedHashMap;
import water.util.IcedHashMapGeneric;
import water.util.IcedLong;
import water.util.RandomUtils;

import java.util.*;

public class Word2VecModel extends Model {

  public Word2VecModel(Key selfKey, Word2VecParameters params, Word2VecOutput output) {
    super(selfKey, params, output);
    assert(Arrays.equals(_key._kb, selfKey._kb));
  }

  @Override
  public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
    throw H2O.unimpl("No Model Metrics for Word2Vec.");
  }

  @Override
  public double[] score0(Chunk[] cs, int foo, double data[], double preds[]) {
    throw H2O.unimpl();
  }

  @Override
  protected double[] score0(double data[], double preds[]) {
    throw H2O.unimpl();
  }

  @Override
  public Word2VecMojoWriter getMojo() {
    return new Word2VecMojoWriter(this);
  }

  /**
   * Converts this word2vec model to a Frame.
   * @return Frame made of columns: Word, V1, .., Vn. Word column holds the vocabulary associated
   * with this word2vec model, and columns V1, .., Vn represent the embeddings in the n-dimensional space.
   */
  public Frame toFrame() {
    Vec zeroVec = null;
    try {
      zeroVec = Vec.makeZero(_output._words.length);
      byte[] types = new byte[1 + _output._vecSize];
      Arrays.fill(types, Vec.T_NUM);
      types[0] = Vec.T_STR;
      String[] colNames = new String[types.length];
      colNames[0] = "Word";
      for (int i = 1; i < colNames.length; i++)
        colNames[i] = "V" + i;
      return new ConvertToFrameTask(this).doAll(types, zeroVec).outputFrame(colNames, null);
    } finally {
      if (zeroVec != null) zeroVec.remove();
    }
  }

  private static class ConvertToFrameTask extends MRTask {
    private Key _modelKey;
    private transient Word2VecModel _model;
    public ConvertToFrameTask(Word2VecModel model) { _modelKey = model._key; }
    @Override
    protected void setupLocal() { _model = DKV.getGet(_modelKey); }
    @Override
    public void map(Chunk[] cs, NewChunk[] ncs) {
      assert cs.length == 1;
      assert ncs.length == _model._output._vecSize + 1;
      Chunk chk = cs[0];
      int wordOffset = (int) chk.start();
      int vecPos = _model._output._vecSize * wordOffset;
      for (int i = 0; i < chk._len; i++) {
        ncs[0].addStr(_model._output._words[wordOffset + i]);
        for (int j = 1; j < ncs.length; j++)
          ncs[j].addNum(_model._output._vecs[vecPos++]);
      }
    }
  }

  /**
   * Takes an input string can return the word vector for that word.
   *
   * @param target - String of desired word
   * @return float array containing the word vector values or null if
   *  the word isn't present in the vocabulary.
   */
  public float[] transform(String target) {
    return transform(new BufferedString(target));
  }

  private float[] transform(BufferedString word) {
    if (! _output._vocab.containsKey(word))
      return null;
    int wordIdx = _output._vocab.get(word);
    return Arrays.copyOfRange(_output._vecs, wordIdx * _output._vecSize, (wordIdx + 1) * _output._vecSize);
  }

  public enum AggregateMethod { NONE, AVERAGE }

  public Frame transform(Vec wordVec, AggregateMethod aggregateMethod) {
    if (wordVec.get_type() != Vec.T_STR) {
      throw new IllegalArgumentException("Expected a string vector, got " + wordVec.get_type_str() + " vector.");
    }
    byte[] types = new byte[_output._vecSize];
    Arrays.fill(types, Vec.T_NUM);
    MRTask transformTask = aggregateMethod == AggregateMethod.AVERAGE ?
            new Word2VecAggregateTask(this) : new Word2VecTransformTask(this);
    return transformTask.doAll(types, wordVec).outputFrame(Key.make(), null, null);
  }

  private static class Word2VecTransformTask extends MRTask {
    private Word2VecModel _model;
    public Word2VecTransformTask(Word2VecModel model) { _model = model; }
    @Override
    public void map(Chunk[] cs, NewChunk[] ncs) {
      assert cs.length == 1;
      Chunk chk = cs[0];
      BufferedString tmp = new BufferedString();
      for (int i = 0; i < chk._len; i++) {
        if (chk.isNA(i)) {
          for (NewChunk nc : ncs) nc.addNA();
        } else {
          BufferedString word = chk.atStr(tmp, i);
          float[] vs = _model.transform(word);
          if (vs == null)
            for (NewChunk nc : ncs) nc.addNA();
          else
            for (int j = 0; j < ncs.length; j++)
              ncs[j].addNum(vs[j]);
        }
      }
    }
  }

  private static class Word2VecAggregateTask extends MRTask {
    private Word2VecModel _model;
    public Word2VecAggregateTask(Word2VecModel model) { _model = model; }
    @Override
    public void map(Chunk[] cs, NewChunk[] ncs) {
      assert cs.length == 1;
      Chunk chk = cs[0];
      // skip words that belong to a sequence started in a previous chunk
      int offset = 0;
      if (chk.cidx() > 0) { // first chunk doesn't have an offset
        int naPos = findNA(chk);
        if (naPos < 0)
          return; // chunk doesn't contain an end of sequence and should not be processed
        offset = naPos + 1;
      }
      // process this chunk, if the last sequence is not terminated in this chunk, roll-over to the next chunk
      float[] aggregated = new float[ncs.length];
      int seqLength = 0;
      boolean seqOpen = false;
      BufferedString tmp = new BufferedString();
      chunkLoop: do {
        for (int i = offset; i < chk._len; i++) {
          if (chk.isNA(i)) {
            writeAggregate(seqLength, aggregated, ncs);
            Arrays.fill(aggregated, 0.0f);
            seqLength = 0;
            seqOpen = false;
            if (chk != cs[0])
              break chunkLoop; // we just closed a sequence that was left open in one of the previous chunks
          } else {
            BufferedString word = chk.atStr(tmp, i);
            float[] vs = _model.transform(word);
            if (vs != null) {
              for (int j = 0; j < ncs.length; j++)
                aggregated[j] += vs[j];
              seqLength++;
            }
            seqOpen = true;
          }
        }
        offset = 0;
      } while ((chk = chk.nextChunk()) != null);
      // last sequence doesn't have to be terminated by NA
      if (seqOpen)
        writeAggregate(seqLength, aggregated, ncs);
    }
    private void writeAggregate(int seqLength, float[] aggregated, NewChunk[] ncs) {
      if (seqLength == 0)
        for (NewChunk nc : ncs) nc.addNA();
      else
        for (int j = 0; j < ncs.length; j++)
          ncs[j].addNum(aggregated[j] / seqLength);
    }
    private int findNA(Chunk chk) {
      for (int i = 0; i < chk._len; i++)
        if (chk.isNA(i)) return i;
      return -1;
    }
  }

  /**
   * Find synonyms (i.e. word-vectors with the highest cosine similarity)
   *
   * @param target String of desired word
   * @param cnt Number of synonyms to find
   */
  public Map findSynonyms(String target, int cnt) {
    float[] vec = transform(target);

    if ((vec == null) || (cnt == 0))
      return Collections.emptyMap();

    int[] synonyms = new int[cnt];
    float[] scores = new float[cnt];


    int min = 0;
    for (int i = 0; i < cnt; i++) {
      synonyms[i] = i;
      scores[i] = cosineSimilarity(vec, i * vec.length, _output._vecs);
      if (scores[i] < scores[min])
        min = i;
    }

    final int vocabSize = _output._vocab.size();
    for (int i = cnt; i < vocabSize; i++) {
      float score = cosineSimilarity(vec, i * vec.length, _output._vecs);
      if ((score <= scores[min]) || (score >= 0.999999))
        continue;
      synonyms[min] = i;
      scores[min] = score;
      // find a new min
      min = 0;
      for (int j = 1; j < cnt; j++)
        if (scores[j] < scores[min])
          min = j;
    }

    Map result = new HashMap<>(cnt);
    for (int i = 0; i < cnt; i++)
      result.put(_output._words[synonyms[i]].toString(), scores[i]);
    return result;
  }

  /**
   * Basic calculation of cosine similarity
   * @param target - a word vector
   * @param pos - position in vecs
   * @param vecs - learned word vectors
   * @return cosine similarity between the two word vectors
   */
  private float cosineSimilarity(float[] target, int pos, float[] vecs) {
    float dotProd = 0, tsqr = 0, csqr = 0;
    for(int i = 0; i < target.length; i++) {
      dotProd += target[i] * vecs[pos + i];
      tsqr += Math.pow(target[i], 2);
      csqr += Math.pow(vecs[pos + i], 2);
    }
    return (float) (dotProd / (Math.sqrt(tsqr) * Math.sqrt(csqr)));
  }

  void buildModelOutput(Word2VecModelInfo modelInfo) {
    IcedHashMapGeneric vocab = ((Vocabulary) DKV.getGet(modelInfo._vocabKey))._data;
    BufferedString[] words = new BufferedString[vocab.size()];
    for (BufferedString str : vocab.keySet())
      words[vocab.get(str)] = str;

    _output._vecSize = _parms._vec_size;
    _output._vecs = modelInfo._syn0;
    _output._words = words;
    _output._vocab = vocab;
  }

  void buildModelOutput(BufferedString[] words, float[] syn0) {
    IcedHashMapGeneric vocab = new IcedHashMapGeneric<>();
    for (int i = 0; i < words.length; i++)
      vocab.put(words[i], i);

    _output._vecSize = _parms._vec_size;
    _output._vecs = syn0;
    _output._words = words;
    _output._vocab = vocab;
  }

  public static class Word2VecParameters extends Model.Parameters {
    public String algoName() { return "Word2Vec"; }
    public String fullName() { return "Word2Vec"; }
    public String javaName() { return Word2VecModel.class.getName(); }
    @Override public long progressUnits() {
      return isPreTrained() ? _pre_trained.get().anyVec().nChunks() : train().vec(0).nChunks() * _epochs;
    }
    static final int MAX_VEC_SIZE = 10000;

    public Word2Vec.WordModel _word_model = Word2Vec.WordModel.SkipGram;
    public Word2Vec.NormModel _norm_model = Word2Vec.NormModel.HSM;
    public int _min_word_freq = 5;
    public int _vec_size = 100;
    public int _window_size = 5;
    public int _epochs = 5;
    public float _init_learning_rate = 0.025f;
    public float _sent_sample_rate = 1e-3f;
    public Key _pre_trained;  // key of a frame that contains a pre-trained word2vec model
    boolean isPreTrained() { return _pre_trained != null; }
    Vec trainVec() { return train().vec(0); }

  }

  public static class Word2VecOutput extends Model.Output {
    public int _vecSize;
    public int _epochs;
    public Word2VecOutput(Word2Vec b) { super(b); }

    public BufferedString[] _words;
    public float[] _vecs;
    public IcedHashMapGeneric _vocab;

    @Override public ModelCategory getModelCategory() {
      return ModelCategory.WordEmbedding;
    }
  }

  public static class Word2VecModelInfo extends Iced {

    long _vocabWordCount;
    long _totalProcessedWords = 0L;

    float[] _syn0, _syn1;
    Key _treeKey;
    Key _vocabKey;
    Key _wordCountsKey;

    private Word2VecParameters _parameters;
    public final Word2VecParameters getParams() { return _parameters; }

    public Word2VecModelInfo() {}

    private Word2VecModelInfo(Word2VecParameters params, WordCounts wordCounts) {
      _parameters = params;

      long vocabWordCount = 0L;
      List> wordCountList = new ArrayList<>(wordCounts._data.size());
      for (Map.Entry wc : wordCounts._data.entrySet()) {
        if (wc.getValue()._val >= _parameters._min_word_freq) {
          wordCountList.add(wc);
          vocabWordCount += wc.getValue()._val;
        }
      }
      Collections.sort(wordCountList, new Comparator>() {
        @Override
        public int compare(Map.Entry o1, Map.Entry o2) {
          long x = o1.getValue()._val; long y = o2.getValue()._val;
          return (x < y) ? -1 : ((x == y) ? 0 : 1);
        }
      });

      int vocabSize = wordCountList.size();
      long[] countAry = new long[vocabSize];
      Vocabulary vocab = new Vocabulary(new IcedHashMapGeneric());
      int idx = 0;
      for (Map.Entry wc : wordCountList) {
        countAry[idx] = wc.getValue()._val;
        vocab._data.put(wc.getKey(), idx++);
      }
      HBWTree t = HBWTree.buildHuffmanBinaryWordTree(countAry);

      _vocabWordCount = vocabWordCount;
      _treeKey = publish(t);
      _vocabKey = publish(vocab);
      _wordCountsKey = publish(wordCounts);

      //initialize weights to random values
      Random rand = RandomUtils.getRNG(0xDECAF, 0xDA7A);
      _syn1 = MemoryManager.malloc4f(_parameters._vec_size * vocabSize);
      _syn0 = MemoryManager.malloc4f(_parameters._vec_size * vocabSize);
      for (int i = 0; i < _parameters._vec_size * vocabSize; i++) _syn0[i] = (rand.nextFloat() - 0.5f) / _parameters._vec_size;
    }

    public static Word2VecModelInfo createInitialModelInfo(Word2VecParameters params) {
      Vec v = params.trainVec();
      WordCounts wordCounts = new WordCounts(new WordCountTask().doAll(v)._counts);
      return new Word2VecModelInfo(params, wordCounts);
    }

    private static > Key publish(T keyed) {
      Scope.track_generic(keyed);
      DKV.put(keyed);
      return keyed._key;
    }

  }

  // wraps Vocabulary map into a Keyed object
  public static class Vocabulary extends Keyed {
    IcedHashMapGeneric _data;
    Vocabulary(IcedHashMapGeneric data) {
      super(Key.make());
      _data = data;
    }
  }

  // wraps Word-Count map into a Keyed object
  public static class WordCounts extends Keyed {
    IcedHashMap _data;
    WordCounts(IcedHashMap data) {
      super(Key.make());
      _data = data;
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy