All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.word2vec.Word2Vec Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.word2vec;

import water.Job;
import water.H2O;
import water.util.Log;

import hex.ModelBuilder;
import hex.schemas.Word2VecV2;
import hex.schemas.ModelBuilderSchema;
import hex.word2vec.Word2VecModel.*;

public class Word2Vec extends ModelBuilder {
  public enum WordModel { SkipGram, CBOW }
  public enum NormModel { HSM, NegSampling }

  public Word2Vec(Word2VecModel.Word2VecParameters parms) { super("Word2Vec", parms); }

  public ModelBuilderSchema schema() { return new Word2VecV2(); }

  /** Start the KMeans training Job on an F/J thread. */
  @Override public Job trainModel() { return start(new Word2VecDriver(), _parms._epochs); }

  /** Initialize the ModelBuilder, validating all arguments and preparing the
   *  training frame.  This call is expected to be overridden in the subclasses
   *  and each subclass will start with "super.init();".
   *
   *  Validate K, max_iters and the number of rows.  Precompute the number of
   *  categorical columns. */
  @Override public void init(boolean expensive) {
    super.init(expensive);
    if (_parms._vecSize > Word2VecParameters.MAX_VEC_SIZE) error("vecSize", "Requested vector size of "+_parms._vecSize+" in Word2Vec, exceeds limit of "+Word2VecParameters.MAX_VEC_SIZE+".");
    if (_parms._vecSize < 1) error("vecSize", "Requested vector size of " + _parms._vecSize + " in Word2Vec, is not allowed.");
    if (_parms._windowSize < 1) error("windowSize", "Negative window size not allowed for Word2Vec.  Expected value > 0, received " + _parms._windowSize);
    if (_parms._sentSampleRate < 0.0) error("sentSampleRate", "Negative sentence sample rate not allowed for Word2Vec.  Expected a value > 0.0, received " + _parms._sentSampleRate);
    if (_parms._initLearningRate < 0.0) error("initLearningRate", "Negative learning rate not allowed for Word2Vec.  Expected a value > 0.0, received " + _parms._initLearningRate);
    if (_parms._epochs < 1) error("epochs", "Negative epoch count not allowed for Word2Vec.  Expected value > 0, received " + _parms._epochs);
  }

  private class Word2VecDriver extends H2O.H2OCountedCompleter {
    @Override
    protected void compute2() {
      Word2VecModel model = null;
      long start, stop, lastCnt=0;
      long tstart, tstop;
      float tDiff;

      try {
        init(true);
        _parms.lock_frames(Word2Vec.this);

        //The model to be built
        model = new Word2VecModel(dest(), _parms, new Word2VecOutput(Word2Vec.this));
        model.delete_and_lock(_key);

        // main loop
        Log.info("Word2Vec: Starting to train model.");
        tstart = System.currentTimeMillis();
        for (int i = 0; i < _parms._epochs; i++) {
          start = System.currentTimeMillis();
          model.setModelInfo(new WordVectorTrainer(model.getModelInfo()).doAll(_parms.train()).getModelInfo());
          stop = System.currentTimeMillis();
          model.getModelInfo().updateLearningRate();
          model.update(_key); // Early version of model is visible
          Job.update(1, _key);
          tDiff = (float)(stop-start)/1000;
          Log.info("Epoch "+i+" "+tDiff+"s  Words trained/s: "+ (model.getModelInfo().getTotalProcessed()-lastCnt)/tDiff);
          lastCnt = model.getModelInfo().getTotalProcessed();
        }
        tstop  = System.currentTimeMillis();
        Log.info("Total time :" + ((float)(tstop-tstart))/1000f);
        Log.info("Finished training the Word2Vec model.");
        model.buildModelOutput();

      } catch (Throwable t) {
        //model = DKV.get(dest()).get();
        //_state = JobState.CANCELLED; //for JSON REST response
        Log.info("Word2Vec model building was cancelled.");
        t.printStackTrace();
        cancel2(t);
        throw t;
      } finally {
        if( model != null ) model.unlock(_key);
        _parms.unlock_frames(Word2Vec.this);
        done();                 // Job done!
      }
      tryComplete();
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy