All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.word2vec.Word2Vec Maven / Gradle / Ivy

package hex.word2vec;

import hex.ModelBuilder;
import hex.ModelCategory;
import hex.word2vec.Word2VecModel.*;
import water.Job;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;
import water.util.StringUtils;

import java.util.LinkedList;
import java.util.List;

public class Word2Vec extends ModelBuilder {
  public enum WordModel { SkipGram, CBOW }
  public enum NormModel { HSM }

  @Override public ModelCategory[] can_build() { return new ModelCategory[]{ ModelCategory.WordEmbedding, }; }
  @Override public BuilderVisibility builderVisibility() { return BuilderVisibility.Stable; }
  @Override public boolean isSupervised() { return false; }

  public Word2Vec(boolean startup_once) {
    super(new Word2VecParameters(), startup_once);
  }

  public Word2Vec(Word2VecModel.Word2VecParameters parms) {
    super(parms);
    init(false);
  }

  @Override protected Word2VecDriver trainModelImpl() { return new Word2VecDriver(); }

  /** Initialize the ModelBuilder, validating all arguments and preparing the
   *  training frame.  This call is expected to be overridden in the subclasses
   *  and each subclass will start with "super.init();".
   *
   *  Verify that at the first column contains strings. Validate _vec_size, _window_size,
   *  _sent_sample_rate, _init_learning_rate, and epochs for values within range.
   */
  @Override public void init(boolean expensive) {
    super.init(expensive);
    if (_parms._train != null) { // Can be called without an existing frame, but when present check for a string col
      if (_parms.train().vecs().length == 0 || ! _parms.trainVec().isString())
        error("_train", "The first column of the training input frame has to be column of Strings.");
    }
    if (_parms._vec_size > Word2VecParameters.MAX_VEC_SIZE) error("_vec_size", "Requested vector size of "+_parms._vec_size +" in Word2Vec, exceeds limit of "+Word2VecParameters.MAX_VEC_SIZE+".");
    if (_parms._vec_size < 1) error("_vec_size", "Requested vector size of " + _parms._vec_size + " in Word2Vec, is not allowed.");
    if (_parms._window_size < 1) error("_window_size", "Negative window size not allowed for Word2Vec.  Expected value > 0, received " + _parms._window_size);
    if (_parms._sent_sample_rate < 0.0) error("_sent_sample_rate", "Negative sentence sample rate not allowed for Word2Vec.  Expected a value > 0.0, received " + _parms._sent_sample_rate);
    if (_parms._init_learning_rate < 0.0) error("_init_learning_rate", "Negative learning rate not allowed for Word2Vec.  Expected a value > 0.0, received " + _parms._init_learning_rate);
    if (_parms._epochs < 1) error("_epochs", "Negative epoch count not allowed for Word2Vec.  Expected value > 0, received " + _parms._epochs);
  }

  @Override
  protected void ignoreBadColumns(int npredictors, boolean expensive) {
    // Do not remove String columns - these are the ones we need!
  }

  @Override
  public boolean haveMojo() { return true; }

  private class Word2VecDriver extends Driver {
    @Override public void computeImpl() {
      Word2VecModel model = null;
      try {
        init(! _parms.isPreTrained()); // expensive == true IFF the model is not pre-trained

        // The model to be built
        model = new Word2VecModel(_job._result, _parms, new Word2VecOutput(Word2Vec.this));
        model.delete_and_lock(_job);

        if (_parms.isPreTrained())
          convertToModel(_parms._pre_trained.get(), model);
        else
          trainModel(model);
      } finally {
        if (model != null) model.unlock(_job);
      }
    }
    private void trainModel(Word2VecModel model) {
      Log.info("Word2Vec: Initializing model training.");
      Word2VecModelInfo modelInfo = Word2VecModelInfo.createInitialModelInfo(_parms);

      // main loop
      Log.info("Word2Vec: Starting to train model, " + _parms._epochs + " epochs.");
      long tstart = System.currentTimeMillis();
      for (int i = 0; i < _parms._epochs; i++) {
        long start = System.currentTimeMillis();
        WordVectorTrainer trainer = new WordVectorTrainer(_job, modelInfo).doAll(_parms.trainVec());
        long stop = System.currentTimeMillis();
        long actProcessedWords = trainer._processedWords;
        long estProcessedWords = trainer._nodeProcessedWords._val;
        if (estProcessedWords < 0.95 * actProcessedWords)
          Log.warn("Estimated number processed words " + estProcessedWords +
                  " is significantly lower than actual number processed words " + actProcessedWords);
        trainer.updateModelInfo(modelInfo);
        model.update(_job); // Early version of model is visible
        double duration = (stop - start) / 1000.0;
        Log.info("Epoch " + i + " took "  + duration + "s; Words trained/s: " + actProcessedWords / duration);
        model._output._epochs=i;

        if (stop_requested()) { // do at least one iteration to avoid null model being returned and all hell will break loose
          break;
        }
      }
      long tstop  = System.currentTimeMillis();
      Log.info("Total time: " + (tstop - tstart) / 1000.0);
      Log.info("Finished training the Word2Vec model.");
      model.buildModelOutput(modelInfo);
    }
    private void convertToModel(Frame preTrained, Word2VecModel model) {
      if (_parms._vec_size != preTrained.numCols() - 1) {
        throw new IllegalStateException("Frame with pre-trained model doesn't conform to the specified vector length.");
      }
      WordVectorConverter result = new WordVectorConverter(_job, _parms._vec_size, (int) preTrained.numRows()).doAll(preTrained);
      model.buildModelOutput(result._words, result._syn0);
    }
  }

  public static Job fromPretrainedModel(Frame model) {
    if (model == null || model.numCols() < 2) {
      throw new IllegalArgumentException("Frame representing an external word2vec needs to have at least 2 columns.");
    }
    if (model.vec(0).get_type() != Vec.T_STR) {
      throw new IllegalArgumentException("First column is expected to contain the dictionary words and be represented as String, " +
              "instead got " + model.vec(0).get_type_str());
    }
    List colErrors = new LinkedList<>();
    for (int i = 1; i < model.numCols(); i++) {
      if (model.vec(i).get_type() != Vec.T_NUM) {
        colErrors.add(model.name(i) + " (type " + model.vec(i).get_type_str() + ")");
      }
    }
    if (! colErrors.isEmpty()) {
      throw new IllegalArgumentException("All components of word2vec mapping are expected to be numeric. Invalid columns: " +
              StringUtils.join(", ", colErrors));
    }

    Word2VecModel.Word2VecParameters p = new Word2VecModel.Word2VecParameters();
    p._vec_size = model.numCols() - 1;
    p._pre_trained = model._key;
    return new Word2Vec(p).trainModel();
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy