hex.example.Example Maven / Gradle / Ivy
package hex.example;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.example.ExampleModel.ExampleOutput;
import hex.example.ExampleModel.ExampleParameters;
import water.MRTask;
import water.fvec.Chunk;
import water.util.Log;
import java.util.Arrays;
/**
* Example model builder... building a trivial ExampleModel
*/
public class Example extends ModelBuilder {
@Override public ModelCategory[] can_build() { return new ModelCategory[]{ ModelCategory.Unknown, }; }
@Override public BuilderVisibility builderVisibility() { return BuilderVisibility.Experimental; }
@Override public boolean isSupervised() { return false; }
// Called from Nano thread; start the Example Job on a F/J thread
public Example( ExampleModel.ExampleParameters parms ) { super(parms); init(false); }
@Override protected ExampleDriver trainModelImpl() { return new ExampleDriver(); }
/** Initialize the ModelBuilder, validating all arguments and preparing the
* training frame. This call is expected to be overridden in the subclasses
* and each subclass will start with "super.init();". This call is made
* by the front-end whenever the GUI is clicked, and needs to be fast;
* heavy-weight prep needs to wait for the trainModel() call.
*
* Validate the max_iterations. */
@Override public void init(boolean expensive) {
super.init(expensive);
if( _parms._max_iterations < 1 || _parms._max_iterations > 9999999 )
error("max_iterations", "must be between 1 and 10 million");
}
// ----------------------
private class ExampleDriver extends Driver {
@Override public void computeImpl() {
ExampleModel model = null;
try {
init(true);
// The model to be built
model = new ExampleModel(_job._result, _parms, new ExampleModel.ExampleOutput(Example.this));
model.delete_and_lock(_job);
// ---
// Run the main Example Loop
// Stop after enough iterations
for( ; model._output._iterations < _parms._max_iterations; model._output._iterations++ ) {
if( stop_requested() ) break; // Stopped/cancelled
double[] maxs = new Max().doAll(_parms.train())._maxs;
// Fill in the model
model._output._maxs = maxs;
model.update(_job); // Update model in K/V store
_job.update(1); // One unit of work
StringBuilder sb = new StringBuilder();
sb.append("Example: iter: ").append(model._output._iterations);
Log.info(sb);
}
} finally {
if( model != null ) model.unlock(_job);
}
}
}
// -------------------------------------------------------------------------
// Find max per-column
private static class Max extends MRTask {
// IN
// OUT
double[] _maxs;
@Override public void map(Chunk[] cs) {
_maxs = new double[cs.length];
Arrays.fill(_maxs,-Double.MAX_VALUE);
for( int col = 0; col < cs.length; col++ )
for( int row = 0; row < cs[col]._len; row++ )
_maxs[col] = Math.max(_maxs[col],cs[col].atd(row));
}
@Override public void reduce(Max that) {
for( int col = 0; col < _maxs.length; col++ )
_maxs[col] = Math.max(_maxs[col],that._maxs[col]);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy