All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.kmeans.KMeansGrid Maven / Gradle / Ivy

package hex.kmeans;

import hex.*;
import water.DKV;
import water.H2O;
import water.Key;
import water.fvec.Frame;

/** A Grid of Models
 *  Used to explore Model hyper-parameter space.  Lazily filled in, this object
 *  represents the potentially infinite variety of hyperparameters of a given
 *  model & dataset.
 *
 *  One subclass per kind of Model, e.g. KMeans or GLM or GBM or DL.  The Grid
 *  tracks Models and their hyperparameters, and will allow discovery of
 *  existing Models by hyperparameter, or building Models on demand by
 *  hyperparameter.  The Grid can manage a (simplistic) hyperparameter search
 *  space.
 *
 *  Hyperparameter values are limited to doubles in the API, but can be
 *  anything the subclass Grid desires internally.  E.g. the Grid for KMeans
 *  will convert the initial center selection Enum to and from a simple integer
 *  value internally.
 */
public class KMeansGrid extends Grid {

  public static final String MODEL_NAME = "KMeans";

  @Override protected String modelName() { return MODEL_NAME; }

  private static final String[] HYPER_NAMES    = new String[] {"_k", "_standardize",          "_init",                         "_seed" };
  private static final double[] HYPER_DEFAULTS = new double[] {  0,          1     , KMeans.Initialization.PlusPlus.ordinal(),123456789L};

  @Override protected String[] hyperNames() { return HYPER_NAMES; }

  @Override protected double[] hyperDefaults() { return HYPER_DEFAULTS; }

  @Override protected double suggestedNextHyperValue( int h, Model m, double[] hyperLimits ) {
    throw H2O.unimpl();
  }

  @Override
  protected ModelBuilder createBuilder(KMeansModel.KMeansParameters params) {
    return new KMeans(params);
  }

  @Override protected KMeansModel.KMeansParameters applyHypers(KMeansModel.KMeansParameters params, double[] hypers) {
    params._train = _fr._key;
    params._k = (int)hypers[0];
    params._standardize = hypers[1]!=0;
    params._init = KMeans.Initialization.values()[(int)hypers[2]];
    params._seed = (long)hypers[3];
    return params;
  }

  @Override public double[] getHypers( KMeansModel.KMeansParameters params ) {
    double[] ds = new double[HYPER_NAMES.length];
    ds[0] = params._k;
    ds[1] = params._standardize ? 1 : 0;
    ds[2] = params._init.ordinal();
    ds[3] = params._seed;
    return ds;
  }

  // Factory for returning a grid based on an algorithm flavor
  private KMeansGrid( Key key, Frame fr ) { super(key,fr); }
  public static KMeansGrid get( Frame fr ) { 
    Key k = Grid.keyName(MODEL_NAME, fr);
    KMeansGrid kmg = DKV.getGet(k);
    if( kmg != null ) return kmg;
    kmg = new KMeansGrid(k,fr);
    DKV.put(kmg);
    return kmg;
  }

  /** FIXME: Rest API requirement - do not call directly */
  public KMeansGrid() { super(null, null); }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy