All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.drf.DRFGrid Maven / Gradle / Ivy

package hex.tree.drf;

import hex.*;
import hex.tree.SharedTreeGrid;
import water.DKV;
import water.util.ArrayUtils;
import water.H2O;
import water.Key;
import water.fvec.Frame;

/** A Grid of Models
 *  Used to explore Model hyper-parameter space.  Lazily filled in, this object
 *  represents the potentially infinite variety of hyperparameters of a given
 *  model & dataset.
 *
 *  One subclass per kind of Model, e.g. DRF or GLM or DRF or DL.  The Grid
 *  tracks Models and their hyperparameters, and will allow discovery of
 *  existing Models by hyperparameter, or building Models on demand by
 *  hyperparameter.  The Grid can manage a (simplistic) hyperparameter search
 *  space.
 *
 *  Hyperparameter values are limited to doubles in the API, but can be
 *  anything the subclass Grid desires internally.  E.g. the Grid for DRF
 *  will convert the initial center selection Enum to and from a simple integer
 *  value internally.
 */
public class DRFGrid extends SharedTreeGrid {

  public static final String MODEL_NAME = "DRF";
  /** @return Model name */
  @Override protected String modelName() { return MODEL_NAME; }

  private static final String[] HYPER_NAMES    = ArrayUtils.append(SharedTreeGrid.HYPER_NAMES   ,new String[] { "_mtries", "_sample_rate"});
  private static final double[] HYPER_DEFAULTS = ArrayUtils.append(SharedTreeGrid.HYPER_DEFAULTS,new double[] {    -1    ,     2f/3f     });

  @Override protected String[] hyperNames() { return HYPER_NAMES; }

  @Override protected double[] hyperDefaults() { return HYPER_DEFAULTS; }

  @Override protected double suggestedNextHyperValue( int h, Model m, double[] hyperLimits ) {
    throw H2O.unimpl();
  }

  @Override
  protected ModelBuilder createBuilder(DRFModel.DRFParameters params) {
    return new DRF(params);
  }

  @Override protected DRFModel.DRFParameters applyHypers(DRFModel.DRFParameters params, double[] hypers) {
    DRFModel.DRFParameters p = super.applyHypers(params, hypers);
    int slen = SharedTreeGrid.HYPER_NAMES.length;
    p._mtries      = (int)  hypers[slen  ];
    p._sample_rate = (float)hypers[slen+1];
    return p;
  }

  @Override public double[] getHypers(DRFModel.DRFParameters params) {
    double[] hypers = new double[HYPER_NAMES.length];
    super.getHypers(params,hypers);
    int slen = SharedTreeGrid.HYPER_NAMES.length;
    hypers[slen  ] = params._mtries;
    hypers[slen+1] = params._sample_rate;
    return hypers;
  }

  // Factory for returning a grid based on an algorithm flavor
  private DRFGrid( Key key, Frame fr ) { super(key,fr); }
  public static DRFGrid get( Frame fr ) { 
    Key k = Grid.keyName(MODEL_NAME, fr);
    DRFGrid kmg = DKV.getGet(k);
    if( kmg != null ) return kmg;
    kmg = new DRFGrid(k,fr);
    DKV.put(kmg);
    return kmg;
  }

  /** FIXME: Rest API requirement - do not call directly */
  public DRFGrid() { super(null, null); }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy