All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.word2vec.WordVectorTrainer Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.word2vec;

import water.DKV;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.parser.BufferedString;
import hex.word2vec.Word2VecModel.*;
import water.util.ArrayUtils;
import water.util.IcedHashMap;
import water.util.IcedHashMapGeneric;
import water.util.IcedLong;

import java.util.Iterator;

public class WordVectorTrainer extends MRTask {
  private static final int MAX_SENTENCE_LEN = 1000;
  private static final int EXP_TABLE_SIZE = 1000;
  private static final int MAX_EXP = 6;
  private static final float[] _expTable = calcExpTable();
  private static final float LEARNING_RATE_MIN_FACTOR = 0.0001F; // learning rate stops decreasing at (initLearningRate * this factor)

  // Job
  private final Job _job;

  // Params
  private final Word2Vec.WordModel _wordModel;
  private final int _wordVecSize, _windowSize, _epochs;
  private final float _initLearningRate;
  private final float _sentSampleRate;
  private final long _vocabWordCount;

  // Model IN
  private final Key _vocabKey;
  private final Key _wordCountsKey;
  private final Key _treeKey;
  private final long _prevTotalProcessedWords;

  // Model IN & OUT
  // _syn0 represents the matrix of synaptic weights connecting the input layer of the NN to the hidden layer,
  // similarly _syn1 corresponds to the weight matrix of the synapses connecting the hidden layer to the output layer
  // both matrices are represented in a 1D array, where M[i,j] == array[i * VEC_SIZE + j]
  float[] _syn0, _syn1;
  long _processedWords = 0L;

  // Node-Local (Shared)
  IcedLong _nodeProcessedWords; // mutable long, approximates the total number of words processed by this node
  private transient IcedHashMapGeneric _vocab;
  private transient IcedHashMap _wordCounts;
  private transient int[][] _HBWTCode;
  private transient int[][] _HBWTPoint;

  private float _curLearningRate;
  private long _seed = System.nanoTime();

  public WordVectorTrainer(Job job, Word2VecModelInfo input) {
    super(null);
    _job = job;

    _treeKey = input._treeKey;
    _vocabKey = input._vocabKey;
    _wordCountsKey = input._wordCountsKey;

    // Params
    _wordModel = input.getParams()._word_model;
    _wordVecSize = input.getParams()._vec_size;
    _windowSize = input.getParams()._window_size;
    _sentSampleRate = input.getParams()._sent_sample_rate;
    _epochs = input.getParams()._epochs;
    _initLearningRate = input.getParams()._init_learning_rate;

    _vocabWordCount = input._vocabWordCount;
    _prevTotalProcessedWords = input._totalProcessedWords;

    _syn0 = input._syn0;
    _syn1 = input._syn1;
    _curLearningRate = calcLearningRate(_initLearningRate, _epochs, _prevTotalProcessedWords, _vocabWordCount);
  }

  @Override
  protected void setupLocal() {
    _vocab = ((Vocabulary) DKV.getGet(_vocabKey))._data;
    _wordCounts = ((WordCounts) DKV.getGet(_wordCountsKey))._data;
    HBWTree t = DKV.getGet(_treeKey);
    _HBWTCode = t._code;
    _HBWTPoint = t._point;
    _nodeProcessedWords = new IcedLong(0L);
  }

  // Precompute the exp() table
  private static float[] calcExpTable() {
    float[] expTable = new float[EXP_TABLE_SIZE];
    for (int i = 0; i < EXP_TABLE_SIZE; i++) {
      expTable[i] = (float) Math.exp((i / (float) EXP_TABLE_SIZE * 2 - 1) * MAX_EXP);
      expTable[i] = expTable[i] / (expTable[i] + 1);  // Precompute f(x) = x / (x + 1)
    }
    return expTable;
  }

  @Override public void map(Chunk chk) {
    final int winSize = _windowSize, vecSize = _wordVecSize;
    float[] neu1 = new float[vecSize];
    float[] neu1e = new float[vecSize];
    ChunkSentenceIterator sentIter = new ChunkSentenceIterator(chk);

    int wordCount = 0;
    while (sentIter.hasNext()) {
      int sentLen = sentIter.nextLength();
      int[] sentence = sentIter.next();
      for (int sentIdx = 0; sentIdx < sentLen; sentIdx++) {
        int curWord = sentence[sentIdx];
        int bagSize = 0;
        if (_wordModel == Word2Vec.WordModel.CBOW) {
          for (int j = 0; j < vecSize; j++) neu1[j] = 0;
          for (int j = 0; j < vecSize; j++) neu1e[j] = 0;
        }
        
        // for each item in the window (except curWord), update neu1 vals
        int winSizeMod = cheapRandInt(winSize);
        for (int winIdx = winSizeMod; winIdx < winSize * 2 + 1 - winSizeMod; winIdx++) {
          if (winIdx != winSize) { // skips curWord in sentence
            int winWordSentIdx = sentIdx - winSize + winIdx;
            if (winWordSentIdx < 0 || winWordSentIdx >= sentLen) continue;
            int winWord = sentence[winWordSentIdx];
            if (_wordModel == Word2Vec.WordModel.SkipGram)
              skipGram(curWord, winWord, neu1e);
            else { // CBOW
              for (int j = 0; j < vecSize; j++) neu1[j] += _syn0[j + winWord * vecSize];
              bagSize++;
            }
          }
        } // end for each item in the window
        if (_wordModel == Word2Vec.WordModel.CBOW && bagSize > 0) {
          CBOW(curWord, sentence, sentIdx, sentLen, winSizeMod, bagSize, neu1, neu1e);
        }

        wordCount++;
        // update learning rate
        if (wordCount % 10000 == 0) {
          _nodeProcessedWords._val += 10000;
          long totalProcessedWordsEst = _prevTotalProcessedWords + _nodeProcessedWords._val;
          _curLearningRate = calcLearningRate(_initLearningRate, _epochs, totalProcessedWordsEst, _vocabWordCount);
        }
      } // for each item in the sentence
    } // while more sentences
    _processedWords = wordCount;
    _nodeProcessedWords._val += wordCount % 10000;
    _job.update(1);
  }

  @Override public void reduce(WordVectorTrainer other) {
    _processedWords += other._processedWords;
    if (_syn0 != other._syn0) { // other task worked on a different syn0
      float c = (float) other._processedWords / _processedWords;
      ArrayUtils.add(1.0f - c, _syn0, c, other._syn0);
      ArrayUtils.add(1.0f - c, _syn1, c, other._syn1);
      // for diagnostics only
      _nodeProcessedWords._val += other._nodeProcessedWords._val;
    }
  }

  private void skipGram(int curWord, int winWord, float[] neu1e) {
    final int vecSize = _wordVecSize;
    final int l1 = winWord * vecSize;
    for (int i = 0; i < vecSize; i++) neu1e[i] = 0;

    hierarchicalSoftmaxSG(curWord, l1, neu1e);

    // Learned weights input -> hidden
    for (int i = 0; i < vecSize; i++) _syn0[i + l1] += neu1e[i];
  }

  private void hierarchicalSoftmaxSG(final int targetWord, final int l1, float[] neu1e) {
    final int vecSize = _wordVecSize, tWrdCodeLen = _HBWTCode[targetWord].length;
    final float alpha = _curLearningRate;

    for (int i = 0; i < tWrdCodeLen; i++) {
      int l2 = _HBWTPoint[targetWord][i] * vecSize;

      float f = 0;
      // Propagate hidden -> output (calc sigmoid)
      for (int j = 0; j < vecSize; j++) f += _syn0[j + l1] * _syn1[j + l2];

      if (f <= -MAX_EXP) continue;
      else if (f >= MAX_EXP) continue;
      else f = _expTable[(int) ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))];

      float gradient = (1 - _HBWTCode[targetWord][i] - f) * alpha;
      // Propagate errors output -> hidden
      for (int j = 0; j < vecSize; j++) neu1e[j] += gradient * _syn1[j + l2];
      // Learn weights hidden -> output
      for (int j = 0; j < vecSize; j++) _syn1[j + l2] += gradient * _syn0[j + l1];
    }
  }

  private void CBOW(
      int curWord,
      int[] sentence,
      int sentIdx,
      int sentLen,
      int winSizeMod,
      int bagSize,
      float[] neu1,
      float[] neu1e
  ) {
    int winWordSentIdx, winWord;
    final int vecSize = _wordVecSize, winSize = _windowSize;
    final int curWinSize = winSize * 2 + 1 - winSize;

    for (int i = 0; i < vecSize; i++) neu1[i] /= bagSize;
    hierarchicalSoftmaxCBOW(curWord, neu1, neu1e);

    // hidden -> in
    for (int winIdx = winSizeMod; winIdx < curWinSize; winIdx++) {
      if (winIdx != winSize) {
        winWordSentIdx = sentIdx - winSize + winIdx;
        if (winWordSentIdx < 0 || winWordSentIdx >= sentLen) continue;
        winWord = sentence[winWordSentIdx];
        for (int i = 0; i < vecSize; i++) _syn0[i + winWord * vecSize] += neu1e[i];
      }
    }
  }

  private void hierarchicalSoftmaxCBOW(final int targetWord, float[] neu1, float[] neu1e) {
    final int vecSize = _wordVecSize, tWrdCodeLen = _HBWTCode[targetWord].length;
    final float alpha = _curLearningRate;
    float gradient, f = 0;
    int l2;

    for (int i = 0; i < tWrdCodeLen; i++, f = 0) {
      l2 = _HBWTPoint[targetWord][i] * vecSize;

      // Propagate hidden -> output (calc sigmoid)
      for (int j = 0; j < vecSize; j++) f += neu1[j] * _syn1[j + l2];

      if (f <= -MAX_EXP) continue;
      else if (f >= MAX_EXP) continue;
      else f = _expTable[(int) ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))];

      gradient = (1 - _HBWTCode[targetWord][i] - f) * alpha;
      // Propagate errors output -> hidden
      for (int j = 0; j < vecSize; j++) neu1e[j] += gradient * _syn1[j + l2];
      // Learn weights hidden -> output
      for (int j = 0; j < vecSize; j++) _syn1[j + l2] += gradient * neu1[j];
    }
  }

  /**
   * Calculates a new global learning rate for the next round
   * of map/reduce calls.
   * The learning rate is a coefficient that controls the amount that
   * newly learned information affects current learned information.
   */
  private static float calcLearningRate(float initLearningRate, int epochs, long totalProcessed, long vocabWordCount) {
    float rate = initLearningRate * (1 - totalProcessed / (float) (epochs * vocabWordCount + 1));
    if (rate < initLearningRate * LEARNING_RATE_MIN_FACTOR) rate = initLearningRate * LEARNING_RATE_MIN_FACTOR;
    return rate;
  }

  public void updateModelInfo(Word2VecModelInfo modelInfo) {
    modelInfo._syn0 = _syn0;
    modelInfo._syn1 = _syn1;
    modelInfo._totalProcessedWords += _processedWords;
  }

  /**
    * This is cheap and moderate in quality.
    *
    * @param max - Upper range limit.
    * @return int between 0-(max-1).
    */
  private int cheapRandInt(int max) {
    _seed ^= ( _seed << 21);
    _seed ^= ( _seed >>> 35);
    _seed ^= ( _seed << 4);
    int r = (int) _seed % max;
    return r > 0 ? r : -r;
  }

  private class ChunkSentenceIterator implements Iterator {

    private Chunk _chk;
    private int _pos = 0;

    private int _len = -1;
    private int[] _sent = new int[MAX_SENTENCE_LEN + 1];

    private ChunkSentenceIterator(Chunk chk) { _chk = chk; }

    @Override
    public boolean hasNext() {
      return nextLength() >= 0;
    }

    private int nextLength() {
      if (_len >= 0)
        return _len;
      if (_pos >= _chk._len)
        return -1;
      _len = 0;
      BufferedString tmp = new BufferedString();
      for (; _pos < _chk._len && ! _chk.isNA(_pos) && _len < MAX_SENTENCE_LEN; _pos++) {
        BufferedString str = _chk.atStr(tmp, _pos);
        if (! _vocab.containsKey(str)) continue; // not in the vocab, skip
        if (_sentSampleRate > 0) {  // sub-sampling while creating a sentence
          long count = _wordCounts.get(str)._val;
          float ran = (float) ((Math.sqrt(count / (_sentSampleRate * _vocabWordCount)) + 1) * (_sentSampleRate * _vocabWordCount) / count);
          if (ran * 65536 < cheapRandInt(0xFFFF)) continue;
        }
        _sent[_len++] = _vocab.get(tmp);
      }
      _sent[_len] = -1;
      _pos++;
      return _len;
    }

    @Override
    public int[] next() {
      if (hasNext()) {
        _len = -1;
        return _sent;
      }
      else
        return null;
    }

    @Override
    public void remove() { throw new UnsupportedOperationException("Remove is not supported"); } // should never be called
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy