All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.example.Example Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.example;

import hex.ModelBuilder;
import hex.ModelCategory;
import hex.example.ExampleModel.ExampleOutput;
import hex.example.ExampleModel.ExampleParameters;
import hex.schemas.ExampleV3;
import hex.schemas.ModelBuilderSchema;
import water.H2O.H2OCountedCompleter;
import water.MRTask;
import water.Scope;
import water.fvec.Chunk;
import water.util.Log;

import java.util.Arrays;

/**
 *  Example model builder... building a trivial ExampleModel
 */
public class Example extends ModelBuilder {
  @Override
  public ModelCategory[] can_build() {
    return new ModelCategory[]{
            ModelCategory.Unknown,
    };
  }

  @Override public BuilderVisibility builderVisibility() { return BuilderVisibility.Experimental; };

  // Called from Nano thread; start the Example Job on a F/J thread
  public Example( ExampleModel.ExampleParameters parms ) { super("Example",parms); init(false); }

  public ModelBuilderSchema schema() { return new ExampleV3(); }

  @Override public Example trainModel() {
    return (Example)start(new ExampleDriver(), _parms._max_iterations);
  }

  /** Initialize the ModelBuilder, validating all arguments and preparing the
   *  training frame.  This call is expected to be overridden in the subclasses
   *  and each subclass will start with "super.init();".  This call is made
   *  by the front-end whenever the GUI is clicked, and needs to be fast;
   *  heavy-weight prep needs to wait for the trainModel() call.
   *
   *  Validate the max_iterations. */
  @Override public void init(boolean expensive) {
    super.init(expensive);
    if( _parms._max_iterations < 1 || _parms._max_iterations > 9999999 )
      error("max_iterations", "must be between 1 and 10 million");
  }

  // ----------------------
  private class ExampleDriver extends H2OCountedCompleter {

    @Override protected void compute2() {
      ExampleModel model = null;
      try {
        Scope.enter();
        _parms.read_lock_frames(Example.this); // Fetch & read-lock source frame
        init(true);

        // The model to be built
        model = new ExampleModel(dest(), _parms, new ExampleModel.ExampleOutput(Example.this));
        model.delete_and_lock(_key);

        // ---
        // Run the main Example Loop
        // Stop after enough iterations
        for( ; model._output._iterations < _parms._max_iterations; model._output._iterations++ ) {
          if( !isRunning() ) break; // Stopped/cancelled

          double[] maxs = new Max().doAll(_parms.train())._maxs;

          // Fill in the model
          model._output._maxs = maxs;
          model.update(_key); // Update model in K/V store
          update(1);          // One unit of work

          StringBuilder sb = new StringBuilder();
          sb.append("Example: iter: ").append(model._output._iterations);
          Log.info(sb);
        }
        done();                 // Job done!
      } catch( Throwable t ) {
        if (_state == JobState.CANCELLED) {
          Log.info("Job cancelled by user.");
        } else {
          t.printStackTrace();
          failed(t);
          throw t;
        }
      } finally {
        if( model != null ) model.unlock(_key);
        _parms.read_unlock_frames(Example.this);
        Scope.exit(model._key);
      }
      tryComplete();
    }
  }


  // -------------------------------------------------------------------------
  // Find max per-column
  private static class Max extends MRTask {
    // IN

    // OUT
    double[] _maxs;

    @Override public void map(Chunk[] cs) {
      _maxs = new double[cs.length];
      Arrays.fill(_maxs,-Double.MAX_VALUE);
      for( int col = 0; col < cs.length; col++ )
        for( int row = 0; row < cs[col]._len; row++ )
          _maxs[col] = Math.max(_maxs[col],cs[col].atd(row));
    }

    @Override public void reduce(Max that) {
      for( int col = 0; col < _maxs.length; col++ )
        _maxs[col] = Math.max(_maxs[col],that._maxs[col]);
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy