All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.word2vec.Word2Vec Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.word2vec;

import hex.ModelCategory;
import water.DKV;
import water.Job;
import water.H2O;
import water.fvec.Vec;
import water.util.Log;

import hex.ModelBuilder;
import hex.schemas.Word2VecV3;
import hex.schemas.ModelBuilderSchema;
import hex.word2vec.Word2VecModel.*;

public class Word2Vec extends ModelBuilder {
  @Override
  public ModelCategory[] can_build() {
    return new ModelCategory[]{
            ModelCategory.Unknown,
    };
  }

  @Override public BuilderVisibility builderVisibility() { return BuilderVisibility.Experimental; };

  public enum WordModel { SkipGram, CBOW }
  public enum NormModel { HSM, NegSampling }

  public Word2Vec(Word2VecModel.Word2VecParameters parms) { super("Word2Vec", parms); }

  public ModelBuilderSchema schema() { return new Word2VecV3(); }

  /** Start the KMeans training Job on an F/J thread. */
  @Override public Job trainModel() { return start(new Word2VecDriver(), _parms._epochs); }

  /** Initialize the ModelBuilder, validating all arguments and preparing the
   *  training frame.  This call is expected to be overridden in the subclasses
   *  and each subclass will start with "super.init();".
   *
   *  Verify that at least one column contains strings.  Validate _vecSize, windowSize,
   *  sentSampleRate, initLearningRate, and epochs for values within range.
   */
  @Override public void init(boolean expensive) {
    super.init(expensive);
    if (_parms._train != null) { //Can be called without an existing frame, but when present check for a string col
      Boolean useableCol = false;
      for (Vec v : _parms.train().vecs()) if (v.isString()) useableCol = true;
      if (!useableCol) error("_train", "Training input frame lacks any string columns for Word2Vec to analyze.");
    }
    if (_parms._vecSize > Word2VecParameters.MAX_VEC_SIZE) error("_vecSize", "Requested vector size of "+_parms._vecSize+" in Word2Vec, exceeds limit of "+Word2VecParameters.MAX_VEC_SIZE+".");
    if (_parms._vecSize < 1) error("_vecSize", "Requested vector size of " + _parms._vecSize + " in Word2Vec, is not allowed.");
    if (_parms._windowSize < 1) error("_windowSize", "Negative window size not allowed for Word2Vec.  Expected value > 0, received " + _parms._windowSize);
    if (_parms._sentSampleRate < 0.0) error("_sentSampleRate", "Negative sentence sample rate not allowed for Word2Vec.  Expected a value > 0.0, received " + _parms._sentSampleRate);
    if (_parms._initLearningRate < 0.0) error("_initLearningRate", "Negative learning rate not allowed for Word2Vec.  Expected a value > 0.0, received " + _parms._initLearningRate);
    if (_parms._epochs < 1) error("_epochs", "Negative epoch count not allowed for Word2Vec.  Expected value > 0, received " + _parms._epochs);
  }

  private class Word2VecDriver extends H2O.H2OCountedCompleter {
    @Override
    protected void compute2() {
      Word2VecModel model = null;
      long start, stop, lastCnt=0;
      long tstart, tstop;
      float tDiff;

      try {
        _parms.read_lock_frames(Word2Vec.this);
        init(true);

        //The model to be built
        model = new Word2VecModel(dest(), _parms, new Word2VecOutput(Word2Vec.this));
        model.delete_and_lock(_key);

        // main loop
        Log.info("Word2Vec: Starting to train model.");
        tstart = System.currentTimeMillis();
        for (int i = 0; i < _parms._epochs; i++) {
          start = System.currentTimeMillis();
          model.setModelInfo(new WordVectorTrainer(model.getModelInfo()).doAll(_parms.train()).getModelInfo());
          stop = System.currentTimeMillis();
          model.getModelInfo().updateLearningRate();
          model.update(_key); // Early version of model is visible
          Job.update(1, _key);
          tDiff = (float)(stop-start)/1000;
          Log.info("Epoch "+i+" "+tDiff+"s  Words trained/s: "+ (model.getModelInfo().getTotalProcessed()-lastCnt)/tDiff);
          lastCnt = model.getModelInfo().getTotalProcessed();
        }
        tstop  = System.currentTimeMillis();
        Log.info("Total time :" + ((float)(tstop-tstart))/1000f);
        Log.info("Finished training the Word2Vec model.");
        model.buildModelOutput();
        done();                 // Job done!
      } catch (Throwable t) {
        Job thisJob = DKV.getGet(_key);
        if (thisJob._state == JobState.CANCELLED) {
          Log.info("Job cancelled by user.");
        } else {
          t.printStackTrace();
          failed(t);
          throw t;
        }
      } finally {
        if( model != null ) model.unlock(_key);
        _parms.read_unlock_frames(Word2Vec.this);
      }
      tryComplete();
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy