
hex.grid.GridSearch Maven / Gradle / Ivy
package hex.grid;
import java.util.Map;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelParametersBuilderFactory;
import hex.grid.HyperSpaceWalker.CartesianWalker;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.util.Log;
import water.util.PojoUtils;
/**
* Grid search job.
*
* This job represents a generic interface to launch "any" hyper space search. It triggers sub-jobs
* for each point in hyper space. It produces Grid
object which contains a list of
* build models. A triggered model builder job can fail!
*
* Grid search is parametrized by: - model factory ({@link hex.grid.ModelFactory}) defines
* model build process
- hyper space walk strategy ({@link hex.grid.HyperSpaceWalker} defines
* how the space of hyper parameters is traversed
*
* The job is started by the startGridSearch
method which create a new grid search, put
* representation of Grid into distributed KV store, and for each parameter in hyper space of
* possible parameters, it launches a separated model building job. The launch of jobs is sequential
* and blocking. So after finish the last model, whole grid search job is done as well.
*
* By default, the grid search invokes cartezian grid search, but it can be modified by passing
* explicit hyper space walk strategy via the {@link #startGridSearch(Key, ModelFactory,
* HyperSpaceWalker)} method.
*
* If any of forked jobs fails then the failure is ignored, and grid search normally continue in
* traversing the hyper space.
*
* Typical usage from Java is:
* {@code
* // Create initial parameters and fill them by references to data
* GBMModel.GBMParameters params = new GBMModel.GBMParameters();
* params._train = fr._key;
* params._response_column = "cylinders";
*
* // Define hyper-space to search
* HashMap hyperParms = new HashMap<>();
* hyperParms.put("_ntrees", new Integer[]{1, 2});
* hyperParms.put("_distribution",new Distribution.Family[] {Distribution.Family.multinomial});
* hyperParms.put("_max_depth",new Integer[]{1,2,5});
* hyperParms.put("_learn_rate",new Float[]{0.01f,0.1f,0.3f});
*
* // Launch grid search job creating GBM models
* GridSearch gridSearchJob = GridSearch.startGridSearch(params, hyperParms, GBM_MODEL_FACTORY);
*
* // Block till the end of the job and get result
* Grid grid = gridSearchJob.get()
*
* // Get built models
* Model[] models = grid.getModels()
* }
*
* @see hex.grid.ModelFactory
* @see hex.grid.HyperSpaceWalker
* @see #startGridSearch(Key, ModelFactory, HyperSpaceWalker)
*/
// FIXME: this class should be driver which is passed to job as H2OCountedCompleter. Will be
// FIXME: refactored as part of Job refactoring.
public final class GridSearch extends Job {
/**
* Produces a new model builder for given parameters.
*/
private final transient ModelFactory _modelFactory;
/**
* Walks hyper space and for each point produces model parameters. It is used only locally to fire
* new model builders via ModelFactory.
*/
private final transient HyperSpaceWalker _hyperSpaceWalker;
private GridSearch(Key gkey,
ModelFactory modelFactory,
HyperSpaceWalker hyperSpaceWalker) {
super(gkey, modelFactory.getModelName() + " Grid Search");
assert modelFactory != null : "Grid search needs to know how to build a new model!";
assert hyperSpaceWalker != null : "Grid search needs to know to how walk around hyper space!";
//_paramsBuilderFactory = paramsBuilderFactory;
_modelFactory = modelFactory;
_hyperSpaceWalker = hyperSpaceWalker;
// Note: do not validate parameters of created model builders here!
// Leave it to launch time, and just mark the corresponding model builder job as failed.
}
GridSearch start() {
final int gridSize = _hyperSpaceWalker.getHyperSpaceSize();
Log.info("Starting gridsearch: estimated size of search space = " + gridSize);
// Create grid object and lock it
// Creation is done here, since we would like make sure that after leaving
// this function the grid object is in DKV and accessible.
Grid grid = DKV.getGet(dest());
if (grid != null) {
Frame specTrainFrame = _hyperSpaceWalker.getParams().train();
Frame oldTrainFrame = grid.getTrainingFrame();
if (!specTrainFrame._key.equals(oldTrainFrame._key) ||
specTrainFrame.checksum() != oldTrainFrame.checksum()) {
throw new H2OIllegalArgumentException("training_frame", "grid", "Cannot append new models"
+ " to a grid with different training input");
}
grid.write_lock(jobKey());
} else {
grid =
new Grid<>(dest(),
_hyperSpaceWalker.getParams(),
_hyperSpaceWalker.getHyperParamNames(),
_modelFactory.getModelName(),
_hyperSpaceWalker.getParametersBuilderFactory().getFieldNamingStrategy());
grid.delete_and_lock(jobKey());
}
// Java trick
final Grid gridToExpose = grid;
// Install this as job functions
start(new H2O.H2OCountedCompleter() {
@Override
public void compute2() {
gridSearch(gridToExpose);
tryComplete();
}
}, gridSize, true);
return this;
}
/**
* Returns expected number of models in resulting Grid object.
*
* The number can differ from final number of models due to visiting duplicate points in hyper
* space.
*
* @return expected number of models produced by this grid search
*/
public int getModelCount() {
return _hyperSpaceWalker.getHyperSpaceSize();
}
/**
* @return the set of models covered by this grid search, some may be null if the search is in
* progress or otherwise incomplete. It can also contain duplicate entries if input grid search
* specification includes them.
*
* FIXME: cannot iterate over space again since it is already used below in model building
*/
/*
public Model[] models() {
MP paramsPrototype = _params;
Model[] ms = new Model[_totalModels];
int mcnt = 0;
Object[] hypers = new Object[_hyperParamNames.length];
for (int[] hidx = new int[_hyperParamNames.length]; hidx != null; hidx = nextModel(hidx)) {
MP params = getModelParams((MP) paramsPrototype.clone(), hypers(hidx, hypers));
ms[mcnt++] = model(params).get();
}
return ms;
} */
/**
* Invokes grid search based on specified hyper space walk strategy.
*
* It updates passed grid object in distributed store.
*
* @param grid grid object to save results
*/
private void gridSearch(Grid grid) {
Model model = null;
try {
HyperSpaceWalker.HyperSpaceIterator it = _hyperSpaceWalker.iterator();
while (it.hasNext(model)) {
// Handle end-user cancel request
if (!isRunning()) {
// FIXME: propagate cancellation event to sub jobs, block till they are cancelled
cancel();
return;
}
MP params = null;
try {
// Get parameters for next model
params = it.nextModelParameters(model);
// Sequential model building, should never propagate
// exception up, just mark combination of model parameters as wrong
try {
model = buildModel(params, grid);
} catch (RuntimeException e) { // Catch everything
Log.warn("Grid search: model builder for parameters " + params + " failed! Exception: ", e);
grid.appendFailedModelParameters(params, e);
}
} catch (IllegalArgumentException e) {
Log.warn("Grid search: construction of model parameters failed! Exception: ", e);
// Model parameters cannot be constructed for some reason
Object[] rawParams = it.getCurrentRawParameters();
grid.appendFailedModelParameters(rawParams, e);
} finally {
// Update progress by 1 increment
this.update(1L);
// Always update grid in DKV after model building attempt
grid.update(jobKey());
}
}
// Grid search is done
done();
} catch(Throwable e) {
// Something wrong happened during hyper-space walking
// So cancel this job
// FIXME: should I delete grid here? it failed but user can be interested in partial result
Job thisJob = DKV.getGet(jobKey());
if (thisJob._state == JobState.CANCELLED) {
Log.info("Job " + jobKey() + " cancelled by user.");
} else {
// Mark job as failed
failed(e);
// And propagate unknown exception up
throw e;
}
} finally {
// Unlock grid object
grid.unlock(jobKey());
}
}
/**
* Build a model based on specified parameters and save it to resulting Grid object.
*
* Returns a model run with these parameters, typically built on demand and cached - expected to
* be an expensive operation. If the model in question is "in progress", a 2nd build will NOT be
* kicked off. This is a blocking call.
*
* If a new model is created, then the Grid object is updated in distributed store. If a model for
* given parameters already exists, it is directly returned without updating the Grid object. If
* model building fails then the Grid object is not updated and the method returns
* null
.
*
* @param params parameters for a new model
* @return return a new model if it does not exist
*/
private Model buildModel(final MP params, Grid grid) {
// Make sure that the model is not yet built (can be case of duplicated hyper parameters)
// FIXME: get checksum here since model builder will modify instance of params!!!
long checksum = params.checksum();
Key key = grid.getModelKey(checksum);
// It was already built
if (key != null) {
return key.get();
}
// Build a new model
// THIS IS BLOCKING call since we do not have enough information about free resources
// FIXME: we should allow here any launching strategy (not only sequential)
Model m = (Model) (startBuildModel(params, grid).get());
grid.putModel(checksum, m._key);
return m;
}
/**
* Triggers model building process but do not block on it.
*
* @param params parameters for a new model
* @param grid resulting grid object
* @return A Future of a model run with these parameters, typically built on demand and not cached
* - expected to be an expensive operation. If the model in question is "in progress", a 2nd
* build will NOT be kicked off. This is a non-blocking call.
*/
private ModelBuilder startBuildModel(MP params, Grid grid) {
if (grid.getModel(params) != null) {
return null;
}
ModelBuilder mb = _modelFactory.buildModel(params);
mb.trainModel();
return mb;
}
/**
* Defines a key for a new Grid object holding results of grid search.
*
* @return a grid key for a particular modeling class and frame.
* @throws java.lang.IllegalArgumentException if frame is not saved to distributed store.
*/
protected static Key gridKeyName(String modelName, Frame fr) {
if (fr._key == null) {
throw new IllegalArgumentException("The frame being grid-searched over must have a Key");
}
return Key.make("Grid_" + modelName + "_" + fr._key.toString() + H2O.calcNextUniqueModelId(""));
}
/**
* Start a new grid search job.
*
* This method launches "classical" grid search traversing cartezian grid of parameters
* point-by-point.
*
* @param destKey A key to store result of grid search under.
* @param params Default parameters for model builder. This object is used to create
* a specific model parameters for a combination of hyper parameters.
* @param hyperParams A set of arrays of hyper parameter values, used to specify a simple
* fully-filled-in grid search.
* @param modelFactory defines a strategy for creating new model builders
* @param paramsBuilderFactory defines a strategy for creating a new model parameters based on
* common parameters and list of hyper-parameters
* @return GridSearch Job, with models run with these parameters, built as needed - expected to be
* an expensive operation. If the models in question are "in progress", a 2nd build will NOT be
* kicked off. This is a non-blocking call.
*/
public static GridSearch startGridSearch(
final Key destKey,
final MP params,
final Map hyperParams,
final ModelFactory modelFactory,
final ModelParametersBuilderFactory paramsBuilderFactory) {
// Create a walker to traverse hyper space of model parameters
CartesianWalker
hyperSpaceWalker =
new CartesianWalker<>(params, hyperParams, paramsBuilderFactory);
return startGridSearch(destKey, modelFactory, hyperSpaceWalker);
}
/**
* Start a new grid search job.
*
* This method launches "classical" grid search traversing cartezian grid of parameters
* point-by-point.
*
* @param destKey A key to store result of grid search under.
* @param params Default parameters for model builder. This object is used to create a
* specific model parameters for a combination of hyper parameters.
* @param hyperParams A set of arrays of hyper parameter values, used to specify a simple
* fully-filled-in grid search.
* @param modelFactory defines a strategy for creating new model builders
* @return GridSearch Job, with models run with these parameters, built as needed - expected to be
* an expensive operation. If the models in question are "in progress", a 2nd build will NOT be
* kicked off. This is a non-blocking call.
*/
public static GridSearch startGridSearch(final Key destKey,
final MP params,
final Map hyperParams,
final ModelFactory modelFactory) {
return startGridSearch(destKey, params, hyperParams, modelFactory,
new SimpleParametersBuilderFactory());
}
public static GridSearch startGridSearch(final MP params,
final Map hyperParams,
final ModelFactory modelFactory) {
return startGridSearch(null, params, hyperParams, modelFactory);
}
/**
* Start a new grid search job. This method launches any grid search traversing space of hyper
* parameters based on specified strategy.
*
* @param destKey A key to store result of grid search under.
* @param modelFactory defines a strategy for creating new model builders
* @param hyperSpaceWalker defines a strategy for traversing a hyper space. The object itself
* holds definition of hyper space.
* @return GridSearch Job, with models run with these parameters, built as needed - expected to be
* an expensive operation. If the models in question are "in progress", a 2nd build will NOT be
* kicked off. This is a non-blocking call.
*/
public static GridSearch startGridSearch(
final Key destKey,
final ModelFactory modelFactory,
final HyperSpaceWalker hyperSpaceWalker) {
// Compute key for destination object representing grid
Key
gridKey =
destKey != null ? destKey : gridKeyName(modelFactory.getModelName(),
hyperSpaceWalker.getParams().train());
// Start the search
return new GridSearch(gridKey, modelFactory, hyperSpaceWalker).start();
}
/**
* The factory is producing a parameters builder which uses reflection to setup field values.
*
* @param type of model parameters object
*/
static class SimpleParametersBuilderFactory
implements ModelParametersBuilderFactory {
@Override
public ModelParametersBuilder get(MP initialParams) {
return new SimpleParamsBuilder<>(initialParams);
}
@Override
public PojoUtils.FieldNaming getFieldNamingStrategy() {
return PojoUtils.FieldNaming.CONSISTENT;
}
/**
* The builder modifies initial model parameters directly by reflection.
*
* Usage:
* {@code
* GBMModel.GBMParameters params =
* new SimpleParamsBuilder(initialParams)
* .set("_ntrees", 30).set("_learn_rate", 0.01).build()
* }
*
* @param type of model parameters object
*/
public static class SimpleParamsBuilder
implements ModelParametersBuilder {
final private MP params;
public SimpleParamsBuilder(MP initialParams) {
params = initialParams;
}
@Override
public ModelParametersBuilder set(String name, Object value) {
PojoUtils.setField(params, name, value, PojoUtils.FieldNaming.CONSISTENT);
return this;
}
@Override
public MP build() {
return params;
}
}
}
}