All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.ensemble.StackedEnsemble Maven / Gradle / Ivy

package hex.ensemble;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.StackedEnsembleModel;
import hex.glm.GLM;
import hex.glm.GLMModel;
import water.*;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;

/**
 * An ensemble of other models, created by stacking with the SuperLearner algorithm or a variation.
 */
public class StackedEnsemble extends ModelBuilder {
  StackedEnsembleDriver _driver;
  // The in-progress model being built
  protected StackedEnsembleModel _model;


  public StackedEnsemble(boolean startup_once) { super(new StackedEnsembleModel.StackedEnsembleParameters(),startup_once); }

  /*
  public StackedEnsemble(Key selfKey, StackedEnsembleModel.StackedEnsembleParameters parms, StackedEnsemble job) {
    super(selfKey, parms, job == null ?
            new StackedEnsembleModel.StackedEnsembleOutput():
            new StackedEnsembleModel.StackedEnsembleOutput(job));
  }
  */

  @Override public ModelCategory[] can_build() {
    return new ModelCategory[]{
            ModelCategory.Regression,
            ModelCategory.Binomial,
         //   ModelCategory.Multinomial, // TODO
    };
  }

  @Override public BuilderVisibility builderVisibility() { return BuilderVisibility.Experimental; }

  @Override public boolean isSupervised() { return true; }

  @Override protected StackedEnsembleDriver trainModelImpl() { return _driver = new StackedEnsembleDriver(); }

  public static void addModelPredictionsToLevelOneFrame(Model aModel, Frame aModelsPredictions, Frame levelOneFrame) {
    if (aModel._output.isBinomialClassifier()) {
      // GLM uses a different column name than the other algos, yay!
      Vec preds = aModelsPredictions.vec(2); // Predictions column names have been changed. . .

      levelOneFrame.add(aModel._key.toString(), preds);
    } else if (aModel._output.isClassifier()) {
      throw new H2OIllegalArgumentException("Don't yet know how to stack multinomial classifiers: " + aModel._key);
    } else if (aModel._output.isAutoencoder()) {
      throw new H2OIllegalArgumentException("Don't yet know how to stack autoencoders: " + aModel._key);
    } else if (!aModel._output.isSupervised()) {
      throw new H2OIllegalArgumentException("Don't yet know how to stack unsupervised models: " + aModel._key);
    } else {
      levelOneFrame.add(aModel._key.toString(), aModelsPredictions.vec("predict"));
    }
  }

  private class StackedEnsembleDriver extends Driver {

    private Frame prepareLevelOneFrame(StackedEnsembleModel.StackedEnsembleParameters parms) {
      // TODO: allow the user to name the level one frame
      Frame levelOneFrame = new Frame(Key.make("levelone_" + _model._key.toString()));
      for (Key k : _parms._base_models) {
        Model aModel = DKV.getGet(k);
        if (null == aModel) {
          Log.warn("Failed to find base model; skipping: " + k);
          continue;
        }

        if (null == aModel._output._cross_validation_holdout_predictions_frame_id)
          throw new H2OIllegalArgumentException("Failed to find the xval predictions frame id. . .  Looks like keep_cross_validation_predictions wasn't set when building the models.");

        // add the predictions for aModel to levelOneFrame
        // TODO: multinomial classification:
        Frame aModelsPredictions = aModel._output._cross_validation_holdout_predictions_frame_id.get();
        StackedEnsemble.addModelPredictionsToLevelOneFrame(aModel, aModelsPredictions, levelOneFrame);
      } // for all base_models

      levelOneFrame.add(_model.responseColumn, _model._parms.train().vec(_model.responseColumn));

      // TODO: what if we're running multiple in parallel and have a name collision?

      Frame old = DKV.getGet(levelOneFrame._key);
      if (old != null && old instanceof Frame) {
        Frame oldFrame = (Frame)old;
        // Remove ALL the columns so we don't delete them in remove_impl.  Their
        // lifetime is controlled by their model.
        oldFrame.removeAll();
        oldFrame.write_lock(_job);
        oldFrame.update(_job);
        oldFrame.unlock(_job);
      }

      levelOneFrame.delete_and_lock(_job);
      levelOneFrame.unlock(_job);
      Log.info("Finished creating \"level one\" frame for stacking: " + levelOneFrame.toString());
      return levelOneFrame;
    }

    public void computeImpl() {
      init(true);
      if (error_count() > 0)
        throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(StackedEnsemble.this);

      _model = new StackedEnsembleModel(dest(), _parms, new StackedEnsembleModel.StackedEnsembleOutput(StackedEnsemble.this));
      _model.delete_and_lock(_job); // and clear & write-lock it (smashing any prior)

      _model.checkAndInheritModelProperties();

      Frame levelOneFrame = prepareLevelOneFrame(_parms);

      // train the metalearner model
      // TODO: allow types other than GLM
      // Default Job for just this training
      Key metalearnerKey = Key.make("metalearner_" + _model._key);
      Job job = new Job<>(metalearnerKey, ModelBuilder.javaName("glm"), "StackingEnsemble metalearner (GLM)");
      GLM metaBuilder = ModelBuilder.make("GLM", job, metalearnerKey);
      metaBuilder._parms._non_negative = true;
      metaBuilder._parms._train = levelOneFrame._key;
      metaBuilder._parms._response_column = _model.responseColumn;
      // TODO: multinomial
      // TODO: support other families for regression
      metaBuilder._parms._family = _model.modelCategory == ModelCategory.Regression ? GLMModel.GLMParameters.Family.gaussian : GLMModel.GLMParameters.Family.binomial;

      metaBuilder.init(false);

      Job j = metaBuilder.trainModel();

      while (j.isRunning()) {
        try {
          _job.update(j._work, "training metalearner");
          Thread.sleep(100);
        }
        catch (InterruptedException e) {}
      }

      Log.info("Finished training metalearner model.");

      _model._output._metalearner = metaBuilder.get();
      _model.doScoreMetrics(_job);
      // _model._output._model_summary = createModelSummaryTable(model._output);
      _model.update(_job);
      _model.unlock(_job);

    } // computeImpl
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy