
hex.grid.HyperSpaceWalker Maven / Gradle / Ivy
package hex.grid;
import java.util.Map;
import java.util.NoSuchElementException;
import hex.Model;
import hex.ModelParametersBuilderFactory;
public interface HyperSpaceWalker {
interface HyperSpaceIterator {
/**
* Get next model parameters.
*
* It should return model parameters for next point in hyper space.
* Throws {@link java.util.NoSuchElementException} if there is no remaining point in space
* to explore.
*
* The method can optimize based on previousModel, but should be
* able to handle null-value.
*
* @param previousModel model generated for the previous point in hyper space, can be null.
*
* @return model parameters for next point in hyper space or null if there is no such point.
*
* @throws IllegalArgumentException when model parameters cannot be constructed
* @throws java.util.NoSuchElementException if the iteration has no more elements
*/
MP nextModelParameters(Model previousModel);
/**
* Returns true if the iterator can continue.
* @param previousModel optional parameter which helps to determine next step, can be null
* @return true if the iterator can produce one more model parameters configuration.
*/
boolean hasNext(Model previousModel);
/**
* Returns current "raw" state of iterator.
*
* The state is represented by a permutation of values of grid parameters.
*
* @return array of "untyped" values representing configuration of grid parameters
*/
Object[] getCurrentRawParameters();
}
/**
* Returns an iterator to traverse this hyper-space.
*
* @return an iterator
*/
HyperSpaceIterator iterator();
/**
* Returns hyper parameters names which are used for walking the hyper parameters space.
*
* The names have to match the names of attributes in model parameters MP.
*
* @return names of used hyper parameters
*/
String[] getHyperParamNames();
/**
* Return estimated size of hyperspace.
*
* Can return -1 if estimate is not available.
*
* @return size of hyper space to explore
*/
int getHyperSpaceSize();
/**
* Return initial model parameters for search.
* @return return model parameters
*/
MP getParams();
ModelParametersBuilderFactory getParametersBuilderFactory();
/**
*
* The external Grid API uses a HashMap to describe a set of hyperparameter values,
* where the String is a valid field name in the corresponding Model.Parameter, and the Object is
* the field value (boxed as needed).
*/
class CartesianWalker implements HyperSpaceWalker {
/**
* Parameters builder factory to create new instance of parameters.
*/
final transient ModelParametersBuilderFactory _paramsBuilderFactory;
/**
* Used "based" model parameters for this grid search.
* The object is used as a prototype to create model parameters
* for each point in hyper space.
*/
final MP _params;
/**
* Hyper space description - in this case only dimension and possible values.
*/
final private Map _hyperParams;
/**
* Cached names of used hyper parameters.
*/
final private String[] _hyperParamNames;
/**
* Compute size of hyper space to walk. Includes duplicates (point in space specified multiple
* times)
*/
final private int _hyperSpaceSize;
/**
*
* @param paramsBuilderFactory
* @param hyperParams
*/
public CartesianWalker(MP params,
Map hyperParams,
ModelParametersBuilderFactory paramsBuilderFactory) {
_params = params;
_hyperParams = hyperParams;
_paramsBuilderFactory = paramsBuilderFactory;
_hyperParamNames = hyperParams.keySet().toArray(new String[0]);
_hyperSpaceSize = computeSizeOfHyperSpace();
}
@Override
public HyperSpaceIterator iterator() {
return new HyperSpaceIterator() {
/** Hyper params permutation.
*/
private int[] _hidx = null;
@Override
public MP nextModelParameters(Model previousModel) {
_hidx = _hidx != null ? nextModel(_hidx) : new int[_hyperParamNames.length];
if (_hidx != null) {
// Fill array of hyper-values
Object[] hypers = hypers(_hidx, new Object[_hyperParamNames.length]);
// Get clone of parameters
MP commonModelParams = (MP) _params.clone();
// Fill model parameters
MP params = getModelParams(commonModelParams, hypers);
return params;
} else {
throw new NoSuchElementException("No more elements to explore in hyper-space!");
}
}
@Override
public boolean hasNext(Model previousModel) {
if (_hidx == null) {
return true;
}
int[] hidx = _hidx;
for (int i = 0; i < hidx.length; i++) {
if (hidx[i] + 1 < _hyperParams.get(_hyperParamNames[i]).length) {
return true;
}
}
return false;
}
@Override
public Object[] getCurrentRawParameters() {
Object[] hyperValues = new Object[_hyperParamNames.length];
return hypers(_hidx, hyperValues);
}
};
}
@Override
public String[] getHyperParamNames() {
return _hyperParamNames;
}
@Override
public int getHyperSpaceSize() {
return _hyperSpaceSize;
}
@Override
public MP getParams() {
return _params;
}
@Override
public ModelParametersBuilderFactory getParametersBuilderFactory() {
return _paramsBuilderFactory;
}
// Dumb iteration over the hyper-parameter space.
// Return NULL at end
private int[] nextModel(int[] hidx) {
// Find the next parm to flip
int i;
for (i = 0; i < hidx.length; i++) {
if (hidx[i] + 1 < _hyperParams.get(_hyperParamNames[i]).length) {
break;
}
}
if (i == hidx.length) {
return null; // All done, report null
}
// Flip indices
for (int j = 0; j < i; j++) {
hidx[j] = 0;
}
hidx[i]++;
return hidx;
}
private MP getModelParams(MP params, Object[] hyperParams) {
ModelParametersBuilderFactory.ModelParametersBuilder
paramsBuilder = _paramsBuilderFactory.get(params);
for (int i = 0; i < _hyperParamNames.length; i++) {
String paramName = _hyperParamNames[i];
Object paramValue = hyperParams[i];
paramsBuilder.set(paramName, paramValue);
}
return paramsBuilder.build();
}
protected int computeSizeOfHyperSpace() {
int work = 1;
for (Map.Entry p : _hyperParams.entrySet()) {
if (p.getValue() != null) {
work *= p.getValue().length;
}
}
return work;
}
private Object[] hypers(int[] hidx, Object[] hypers) {
for (int i = 0; i < hidx.length; i++) {
hypers[i] = _hyperParams.get(_hyperParamNames[i])[hidx[i]];
}
return hypers;
}
}
/**
* FIXME : finish random walk
*/
abstract public static class RandomWalker
implements HyperSpaceWalker {
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy