All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.example.Example Maven / Gradle / Ivy

package hex.example;

import hex.ModelBuilder;
import hex.ModelCategory;
import hex.example.ExampleModel.ExampleOutput;
import hex.example.ExampleModel.ExampleParameters;
import water.MRTask;
import water.fvec.Chunk;
import water.util.Log;

import java.util.Arrays;

/**
 *  Example model builder... building a trivial ExampleModel
 */
public class Example extends ModelBuilder {
  @Override public ModelCategory[] can_build() { return new ModelCategory[]{ ModelCategory.Unknown, }; }
  @Override public BuilderVisibility builderVisibility() { return BuilderVisibility.Experimental; }
  @Override public boolean isSupervised() { return false; }

  // Called from Nano thread; start the Example Job on a F/J thread
  public Example( ExampleModel.ExampleParameters parms ) { super(parms); init(false); }
  @Override protected ExampleDriver trainModelImpl() { return new ExampleDriver(); }

  /** Initialize the ModelBuilder, validating all arguments and preparing the
   *  training frame.  This call is expected to be overridden in the subclasses
   *  and each subclass will start with "super.init();".  This call is made
   *  by the front-end whenever the GUI is clicked, and needs to be fast;
   *  heavy-weight prep needs to wait for the trainModel() call.
   *
   *  Validate the max_iterations. */
  @Override public void init(boolean expensive) {
    super.init(expensive);
    if( _parms._max_iterations < 1 || _parms._max_iterations > 9999999 )
      error("max_iterations", "must be between 1 and 10 million");
  }

  // ----------------------
  private class ExampleDriver extends Driver {
    @Override public void computeImpl() {
      ExampleModel model = null;
      try {
        init(true);

        // The model to be built
        model = new ExampleModel(_job._result, _parms, new ExampleModel.ExampleOutput(Example.this));
        model.delete_and_lock(_job);

        // ---
        // Run the main Example Loop
        // Stop after enough iterations
        for( ; model._output._iterations < _parms._max_iterations; model._output._iterations++ ) {
          if( stop_requested() ) break; // Stopped/cancelled

          double[] maxs = new Max().doAll(_parms.train())._maxs;

          // Fill in the model
          model._output._maxs = maxs;
          model.update(_job);   // Update model in K/V store
          _job.update(1);       // One unit of work

          StringBuilder sb = new StringBuilder();
          sb.append("Example: iter: ").append(model._output._iterations);
          Log.info(sb);
        }
      } finally {
        if( model != null ) model.unlock(_job);
      }
    }
  }


  // -------------------------------------------------------------------------
  // Find max per-column
  private static class Max extends MRTask {
    // IN

    // OUT
    double[] _maxs;

    @Override public void map(Chunk[] cs) {
      _maxs = new double[cs.length];
      Arrays.fill(_maxs,-Double.MAX_VALUE);
      for( int col = 0; col < cs.length; col++ )
        for( int row = 0; row < cs[col]._len; row++ )
          _maxs[col] = Math.max(_maxs[col],cs[col].atd(row));
    }

    @Override public void reduce(Max that) {
      for( int col = 0; col < _maxs.length; col++ )
        _maxs[col] = Math.max(_maxs[col],that._maxs[col]);
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy