All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.ensemble.Metalearner Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.ensemble;

import hex.Model;
import hex.ModelBuilder;
import hex.ensemble.StackedEnsembleModel.StackedEnsembleParameters;

import water.DKV;
import water.Job;
import water.Key;
import water.fvec.Frame;
import water.util.Log;

public abstract class Metalearner, M extends Model, P extends Model.Parameters> {

  /**
   * Using an enum to list possible algos is not the greatest idea here
   * as it forces us to hardcode supported algos and creates a dependency to metalearners provided in extensions (XGBoost).
   * Also, it prevents us from loading custom metalearners.
   */
  public enum Algorithm {
    AUTO,
    deeplearning,
    drf,
    gbm,
    glm,
    naivebayes,
    xgboost,
  }

  protected Frame _levelOneTrainingFrame;
  protected Frame _levelOneValidationFrame;
  protected StackedEnsembleModel _model;
  protected StackedEnsembleParameters _parms;
  protected Job _job;
  protected Key _metalearnerKey;
  protected Job _metalearnerJob;
  protected P _metalearner_parameters;
  protected boolean _hasMetalearnerParams;
  protected long _metalearnerSeed;
  protected long _maxRuntimeSecs;

  void init(Frame levelOneTrainingFrame,
            Frame levelOneValidationFrame,
            P metalearner_parameters,
            StackedEnsembleModel model,
            Job StackedEnsembleJob,
            Key metalearnerKey,
            Job metalearnerJob,
            StackedEnsembleParameters parms,
            boolean hasMetalearnerParams,
            long metalearnerSeed,
            long maxRuntimeSecs) {

    _levelOneTrainingFrame = levelOneTrainingFrame;
    _levelOneValidationFrame = levelOneValidationFrame;
    _metalearner_parameters = metalearner_parameters;
    _model = model;
    _job = StackedEnsembleJob;
    _metalearnerKey = metalearnerKey;
    _metalearnerJob = metalearnerJob;
    _parms = parms;
    _hasMetalearnerParams = hasMetalearnerParams;
    _metalearnerSeed = metalearnerSeed;
    _maxRuntimeSecs = maxRuntimeSecs;
  }

  void compute() {
    try {
      _model.write_lock(_job);
      B builder = createBuilder();

      if (_hasMetalearnerParams) {
        builder._parms = _metalearner_parameters;
      }
      setCommonParams(builder._parms);
      setCrossValidationParams(builder._parms);
      setCustomParams(builder._parms);

      validateParams(builder._parms);

      builder.init(false);
      Job j = builder.trainModel();
      while (j.isRunning()) {
        try {
          _job.update(j.getWork(), "training metalearner(" + _model._parms._metalearner_algorithm + ")");
          Thread.sleep(100);
        } catch (InterruptedException ignored) {
        }
      }
      Log.info("Finished training metalearner model(" + _model._parms._metalearner_algorithm + ").");

      _model._output._metalearner = builder.get();
      _model._dist = _model._output._metalearner._dist;

      _model.doScoreOrCopyMetrics(_job);

      if (_parms._keep_levelone_frame) {
        _model._output._levelone_frame_id = _levelOneTrainingFrame; //Keep Level One Training Frame in Stacked Ensemble model object
      }
    } finally {
      cleanup();
      _model.update(_job);
      _model.unlock(_job);
    }
  }

  abstract B createBuilder();

  protected void setCommonParams(P parms) {
    if (parms._seed == -1) { //use _metalearnerSeed only as legacy fallback if not set on metalearner_parameters 
      parms._seed = _metalearnerSeed;
    }
    parms._train = _levelOneTrainingFrame._key;
    parms._valid = (_levelOneValidationFrame == null ? null : _levelOneValidationFrame._key);
    parms._response_column = _model.responseColumn;
    parms._max_runtime_secs = _maxRuntimeSecs;
    parms._weights_column = _model._parms._weights_column;
    parms._offset_column = _model._parms._offset_column;
    parms._main_model_time_budget_factor = _model._parms._main_model_time_budget_factor;
  }

  protected void setCrossValidationParams(P parms) {
    if (_model._parms._metalearner_fold_column == null) {
      parms._nfolds = _model._parms._metalearner_nfolds;
      if (_model._parms._metalearner_nfolds > 1) {
        if (_model._parms._metalearner_fold_assignment == null) {
          parms._fold_assignment = Model.Parameters.FoldAssignmentScheme.AUTO;
        } else {
          parms._fold_assignment = _model._parms._metalearner_fold_assignment;
        }
      }
    } else {
      parms._fold_column = _model._parms._metalearner_fold_column;
    }
  }

  protected void setCustomParams(P parms) { }

  protected void validateParams(P parms) { }

  protected void cleanup() {
    if (!_parms._keep_base_model_predictions) {
      _model.deleteBaseModelPredictions();
    }
    if (!_parms._keep_levelone_frame) {
      DKV.remove(_levelOneTrainingFrame._key); //Remove Level One Training Frame from DKV
    }
    if (null != _levelOneValidationFrame) {
      DKV.remove(_levelOneValidationFrame._key); //Remove Level One Validation Frame from DKV
    }
  }
}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy