All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.ensemble.StackedEnsemble Maven / Gradle / Ivy

package hex.ensemble;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.Grid;
import jsr166y.CountedCompleter;
import water.DKV;
import water.Job;
import water.Key;
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;

import java.util.*;
import java.util.stream.Stream;

import static hex.genmodel.utils.DistributionFamily.*;


/**
 * An ensemble of other models, created by stacking with the SuperLearner algorithm or a variation.
 */
public class StackedEnsemble extends ModelBuilder {
  StackedEnsembleDriver _driver;
  // The in-progress model being built
  protected StackedEnsembleModel _model;

  public StackedEnsemble(StackedEnsembleModel.StackedEnsembleParameters parms) {
    super(parms);
    init(false);
  }

  public StackedEnsemble(boolean startup_once) {
    super(new StackedEnsembleModel.StackedEnsembleParameters(), startup_once);
  }

  @Override
  public ModelCategory[] can_build() {
    return new ModelCategory[]{
            ModelCategory.Regression,
            ModelCategory.Binomial,
            ModelCategory.Multinomial
    };
  }

  @Override
  public BuilderVisibility builderVisibility() {
    return BuilderVisibility.Stable;
  }

  @Override
  public boolean isSupervised() {
    return true;
  }

  @Override
  protected void ignoreBadColumns(int npredictors, boolean expensive){
    HashSet usedColumns = new HashSet();

    for(Key k: _parms._base_models) {
      Model model = (Model) DKV.getGet(k);
      usedColumns.add(model._parms._response_column);
      usedColumns.addAll(Arrays.asList(model._parms.getNonPredictors()));
      if (model._output._origNames != null)
        usedColumns.addAll(Arrays.asList(model._output._origNames));
      else
        usedColumns.addAll(Arrays.asList(model._output._names));
    }

    usedColumns.addAll(Arrays.asList(_parms.getNonPredictors()));

    // FilterCols(n=0) because there is no guarantee that non-predictors are
    // at the end of the frame, e.g., `metalearner_fold` column can be anywhere,
    // and `usedColumns` contain all used columns even the non-predictor ones
    new FilterCols(0) {
      @Override protected boolean filter(Vec v, String name) {
        return !usedColumns.contains(name);
      }
    }.doIt(_train,"Dropping unused columns: ",expensive);
  }

  @Override
  protected StackedEnsembleDriver trainModelImpl() {
    return _driver = _parms._blending == null ? new StackedEnsembleCVStackingDriver() : new StackedEnsembleBlendingDriver();
  }

  @Override
  public boolean haveMojo() {
    return true;
  }

  @Override
  public int nclasses() {
    if (_parms._metalearner_parameters != null) {
      DistributionFamily distribution = _parms._metalearner_parameters.getDistributionFamily();
      if (Arrays.asList(multinomial, ordinal, AUTO).contains(distribution))
        return _nclass;
      if (Arrays.asList(bernoulli, quasibinomial, fractionalbinomial).contains(distribution))
        return 2;
      return 1;
    }
    return super.nclasses();
  }

  @Override
  public void init(boolean expensive) {
    expandBaseModels();
    super.init(expensive);

    if (_parms._distribution != DistributionFamily.AUTO) {
      throw new H2OIllegalArgumentException("Setting \"distribution\" to StackedEnsemble is unsupported. Please set it in \"metalearner_parameters\".");
    }

    checkColumnPresent("fold", _parms._metalearner_fold_column, train(), valid(), _parms.blending());
    checkColumnPresent("weights", _parms._weights_column, train(), valid(), _parms.blending());
    checkColumnPresent("offset", _parms._offset_column, train(), valid(), _parms.blending());
    validateBaseModels();
  }

  /**
   * Expand base models - if a grid is provided instead of a model it gets expanded in to individual models.
   */
  private void expandBaseModels() {
    // H2O Flow initializes SE with no base_models
    if (_parms._base_models == null) return;

    List baseModels = new ArrayList();
    for (Key baseModelKey : _parms._base_models) {
      Object retrievedObject = DKV.getGet(baseModelKey);
      if (retrievedObject instanceof Model) {
        baseModels.add(baseModelKey);
      } else if (retrievedObject instanceof Grid) {
        Grid grid = (Grid) retrievedObject;
        Collections.addAll(baseModels, grid.getModelKeys());
      } else if (retrievedObject == null) {
        throw new IllegalArgumentException(String.format("Specified id \"%s\" does not exist.", baseModelKey));
      } else {
        throw new IllegalArgumentException(String.format("Unsupported type \"%s\" as a base model.", retrievedObject.getClass().toString()));
      }
    }
    _parms._base_models = baseModels.toArray(new Key[0]);
  }

  /**
   * Validates base models.
   */
  private void validateBaseModels() {
    // H2O Flow initializes SE with no base_models
    if (_parms._base_models == null) return;

    boolean warnSameWeightsColumns = true;
    String referenceWeightsColumn = null;
    for (int i = 0; i < _parms._base_models.length; i++) {
      Model baseModel = DKV.getGet(_parms._base_models[i]);

      if (i == 0) {
        if ((_parms._offset_column == null))
          _parms._offset_column = baseModel._parms._offset_column;
        referenceWeightsColumn = baseModel._parms._weights_column;
        warnSameWeightsColumns = referenceWeightsColumn != null; // We don't want to warn if no weights are set
      }

      if (!Objects.equals(referenceWeightsColumn, baseModel._parms._weights_column)) {
        warnSameWeightsColumns = false;
      }

      if (!Objects.equals(_parms._offset_column, baseModel._parms._offset_column))
        throw new IllegalArgumentException("All base models must have the same offset_column!");
    }

    if (_parms._weights_column == null && warnSameWeightsColumns && _parms._base_models.length > 0) {
      warn("_weights_column", "All base models use weights_column=\"" + referenceWeightsColumn +
              "\" but Stacked Ensemble does not. If you want to use the same " +
              "weights_column for the meta learner, please specify it as an argument " +
              "in the h2o.stackedEnsemble call.");
    }
  }


  /**
   * Checks for presence of a column in given {@link Frame}s. Null column means no checks are done.
   *
   * @param columnName     Name of the column, such as fold, weight, etc.
   * @param columnId       Actual column name in the frame. Null means no column has been specified.
   * @param frames         A list of frames to check the presence of fold column in
   */
  private static void checkColumnPresent(final String columnName, final String columnId, final Frame... frames) {
    if (columnId == null) return; // Unspecified column implies no checks are needs on provided frames

    for (Frame frame : frames) {
      if (frame == null) continue; // No frame provided, no checks required
      if (frame.vec(columnId) == null) {
        throw new IllegalArgumentException(String.format("Specified %s column '%s' not found in one of the supplied data frames. Available column names are: %s",
                columnName, columnId, Arrays.toString(frame.names())));
      }
    }
  }

  static void addModelPredictionsToLevelOneFrame(Model aModel, Frame aModelsPredictions, Frame levelOneFrame) {
    if (aModel._output.isBinomialClassifier()) {
      // GLM uses a different column name than the other algos
      Vec preds = aModelsPredictions.vec(2); // Predictions column names have been changed...
      levelOneFrame.add(aModel._key.toString(), preds);
    } else if (aModel._output.isMultinomialClassifier()) { //Multinomial
      //Need to remove 'predict' column from multinomial since it contains outcome
      Frame probabilities = aModelsPredictions.subframe(ArrayUtils.remove(aModelsPredictions.names(), "predict"));
      probabilities.setNames(
              Stream.of(probabilities.names())
                      .map((name) -> aModel._key.toString().concat("/").concat(name))
                      .toArray(String[]::new)
      );
      levelOneFrame.add(probabilities);
    } else if (aModel._output.isAutoencoder()) {
      throw new H2OIllegalArgumentException("Don't yet know how to stack autoencoders: " + aModel._key);
    } else if (!aModel._output.isSupervised()) {
      throw new H2OIllegalArgumentException("Don't yet know how to stack unsupervised models: " + aModel._key);
    } else {
      Vec preds = aModelsPredictions.vec("predict");
      levelOneFrame.add(aModel._key.toString(), preds);
    }
  }

  /**
   * Add non predictor columns to levelOneFrame, i.e., all but those generated by base models. For example:
   * response_column, metalearner_fold_column, weights_column
   *
   * @param parms           StackedEnsembleParameters
   * @param fr
   * @param levelOneFrame
   * @param training        Used to determine which columns are necessary to add
   */
  static void addNonPredictorsToLevelOneFrame(final StackedEnsembleModel.StackedEnsembleParameters parms, Frame fr, Frame levelOneFrame, boolean training) {
    if (training) {
      if (parms._metalearner_fold_column != null)
        levelOneFrame.add(parms._metalearner_fold_column, fr.vec(parms._metalearner_fold_column));
    }

    if (parms._weights_column != null)
      levelOneFrame.add(parms._weights_column, fr.vec(parms._weights_column));

    if (parms._offset_column != null)
      levelOneFrame.add(parms._offset_column, fr.vec(parms._offset_column));

    levelOneFrame.add(parms._response_column, fr.vec(parms._response_column));
  }

  private abstract class StackedEnsembleDriver extends Driver {

    /**
     * Prepare a "level one" frame for a given set of models, predictions-frames and actuals.  Used for preparing
     * training and validation frames for the metalearning step, and could also be used for bulk predictions for
     * a StackedEnsemble.
     */
    private Frame prepareLevelOneFrame(String levelOneKey, Model[] baseModels, Frame[] baseModelPredictions, Frame actuals) {
      if (null == baseModels) throw new H2OIllegalArgumentException("Base models array is null.");
      if (null == baseModelPredictions) throw new H2OIllegalArgumentException("Base model predictions array is null.");
      if (baseModels.length == 0) throw new H2OIllegalArgumentException("Base models array is empty.");
      if (baseModelPredictions.length == 0)
        throw new H2OIllegalArgumentException("Base model predictions array is empty.");
      if (baseModels.length != baseModelPredictions.length)
        throw new H2OIllegalArgumentException("Base models and prediction arrays are different lengths.");
      final StackedEnsembleModel.StackedEnsembleParameters.MetalearnerTransform transform;
      if (_parms._metalearner_transform != null && _parms._metalearner_transform != StackedEnsembleModel.StackedEnsembleParameters.MetalearnerTransform.NONE) {
        if (!(_model._output.isBinomialClassifier() || _model._output.isMultinomialClassifier()))
          throw new H2OIllegalArgumentException("Metalearner transform is supported only for classification!");
        transform = _parms._metalearner_transform;
      } else {
        transform = null;
      }

      if (null == levelOneKey) levelOneKey = "levelone_" + _model._key.toString() + "_" + _parms._metalearner_transform.toString();

      // TODO: what if we're running multiple in parallel and have a name collision?
      Frame old = DKV.getGet(levelOneKey);
      if (old != null && old instanceof Frame) {
        Frame oldFrame = (Frame) old;
        oldFrame.write_lock(_job);
        // Remove ALL the columns, so we don't delete them in remove_impl.  Their
        // lifetime is controlled by their model.
        oldFrame.removeAll();
        oldFrame.update(_job);
        oldFrame.unlock(_job);
      }

      Frame levelOneFrame = transform == null ?
              new Frame(Key.make(levelOneKey))  // no tranform -> this will be the final frame 
              :
              new Frame();                      // tranform -> this is only an intermediate result

      for (int i = 0; i < baseModels.length; i++) {
        Model baseModel = baseModels[i];
        Frame baseModelPreds = baseModelPredictions[i];

        if (null == baseModel) {
          Log.warn("Failed to find base model; skipping: " + baseModels[i]);
          continue;
        }
        if (null == baseModelPreds) {
          Log.warn("Failed to find base model " + baseModel + " predictions; skipping: " + baseModelPreds._key);
          continue;
        }
        StackedEnsemble.addModelPredictionsToLevelOneFrame(baseModel, baseModelPreds, levelOneFrame);
        Scope.untrack(baseModelPredictions);
      }

      if (transform != null) {
        levelOneFrame = _parms._metalearner_transform.transform(_model, levelOneFrame, Key.make(levelOneKey));
      }

      // Add metalearner fold column, weights column to level one frame if it exists
      addNonPredictorsToLevelOneFrame(_model._parms, actuals, levelOneFrame, true);

      Log.info("Finished creating \"level one\" frame for stacking: " + levelOneFrame.toString());
      DKV.put(levelOneFrame);
      return levelOneFrame;
    }

    /**
     * Prepare a "level one" frame for a given set of models and actuals.
     * Used for preparing validation frames for the metalearning step, and could also be used for bulk predictions for a StackedEnsemble.
     */
    private Frame prepareLevelOneFrame(String levelOneKey, Key[] baseModelKeys, Frame actuals, boolean isTraining) {
      List baseModels = new ArrayList<>();
      List baseModelPredictions = new ArrayList<>();

      for (Key k : baseModelKeys) {
        if (_model._output._metalearner == null || _model.isUsefulBaseModel(k)) {
          Model aModel = DKV.getGet(k);
          if (null == aModel)
            throw new H2OIllegalArgumentException("Failed to find base model: " + k);

          Frame predictions = getPredictionsForBaseModel(aModel, actuals, isTraining);
          baseModels.add(aModel);
          baseModelPredictions.add(predictions);
        }
      }
      boolean keepLevelOneFrame = isTraining && _parms._keep_levelone_frame;
      Frame levelOneFrame = prepareLevelOneFrame(levelOneKey, baseModels.toArray(new Model[0]), baseModelPredictions.toArray(new Frame[0]), actuals);
      if (keepLevelOneFrame) {
        levelOneFrame = levelOneFrame.deepCopy(levelOneFrame._key.toString());
        levelOneFrame.write_lock(_job);
        levelOneFrame.update(_job);
        levelOneFrame.unlock(_job);
        Scope.untrack(levelOneFrame.keysList());
      }
      return levelOneFrame;
    }

    @Override
    public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) {
      if (_model != null) _model.delete();
      return super.onExceptionalCompletion(ex, caller);
    }

    protected Frame buildPredictionsForBaseModel(Model model, Frame frame) {
      Key predsKey = buildPredsKey(model, frame);
      Frame preds = DKV.getGet(predsKey);
      if (preds == null) {
        preds =  model.score(frame, predsKey.toString(), null, false);  // no need for metrics here (leaks in client mode)
        Scope.untrack(preds.keysList());
      }
      if (_model._output._base_model_predictions_keys == null)
        _model._output._base_model_predictions_keys = new Key[0];

      if (!ArrayUtils.contains(_model._output._base_model_predictions_keys, predsKey)){
        _model._output._base_model_predictions_keys = ArrayUtils.append(_model._output._base_model_predictions_keys, predsKey);
      }
      //predictions are cleaned up by metalearner if necessary
      return preds;
    }

    protected abstract StackedEnsembleModel.StackingStrategy strategy();

    /**
     * @RETURN THE FRAME THAT IS USED TO COMPUTE THE PREDICTIONS FOR THE LEVEL-ONE TRAINING FRAME.
     */
    protected abstract Frame getActualTrainingFrame();

    protected abstract Frame getPredictionsForBaseModel(Model model, Frame actualsFrame, boolean isTrainingFrame);

    private Key buildPredsKey(Key model_key, long model_checksum, Key frame_key, long frame_checksum) {
      return Key.make("preds_" + model_checksum + "_on_" + frame_checksum);
    }

    protected Key buildPredsKey(Model model, Frame frame) {
      return frame == null || model == null ? null : buildPredsKey(model._key, model.checksum(), frame._key, frame.checksum());
    }

    public void computeImpl() {
      init(true);
      if (error_count() > 0) throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(StackedEnsemble.this);

      _model = new StackedEnsembleModel(dest(), _parms, new StackedEnsembleModel.StackedEnsembleOutput(StackedEnsemble.this));
      _model._output._stacking_strategy = strategy();
      try {
        _model.delete_and_lock(_job); // and clear & write-lock it (smashing any prior)
        _model.checkAndInheritModelProperties();
        _model.update(_job);
      } finally {
        _model.unlock(_job);
      }

      String levelOneTrainKey = "levelone_training_" + _model._key.toString();
      Frame levelOneTrainingFrame = prepareLevelOneFrame(levelOneTrainKey, _model._parms._base_models, getActualTrainingFrame(), true);
      Frame levelOneValidationFrame = null;
      if (_model._parms.valid() != null) {
        String levelOneValidKey = "levelone_validation_" + _model._key.toString();
        levelOneValidationFrame = prepareLevelOneFrame(levelOneValidKey, _model._parms._base_models, _model._parms.valid(), false);
      }

      Metalearner.Algorithm metalearnerAlgoSpec = _model._parms._metalearner_algorithm;
      Metalearner.Algorithm metalearnerAlgoImpl = Metalearners.getActualMetalearnerAlgo(metalearnerAlgoSpec);

      // Compute metalearner
      if(metalearnerAlgoImpl != null) {
        Key metalearnerKey = Key.make("metalearner_" + metalearnerAlgoSpec + "_" + _model._key);

        Job metalearnerJob = new Job<>(metalearnerKey, ModelBuilder.javaName(metalearnerAlgoImpl.toString()),
                "StackingEnsemble metalearner (" + metalearnerAlgoSpec + ")");
        //Check if metalearner_params are passed in
        boolean hasMetaLearnerParams = _model._parms._metalearner_parameters != null;
        long metalearnerSeed = _model._parms._seed;

        Metalearner metalearner = Metalearners.createInstance(metalearnerAlgoSpec.name());
        metalearner.init(
                levelOneTrainingFrame,
                levelOneValidationFrame,
                _model._parms._metalearner_parameters,
                _model,
                _job,
                metalearnerKey,
                metalearnerJob,
                _parms,
                hasMetaLearnerParams,
                metalearnerSeed,
                _parms._max_runtime_secs == 0 ? 0 : Math.max(remainingTimeSecs(), 1)
        );
        metalearner.compute();
      } else {
        throw new H2OIllegalArgumentException("Invalid `metalearner_algorithm`. Passed in " + metalearnerAlgoSpec +
                " but must be one of " + Arrays.toString(Metalearner.Algorithm.values()));
      }
      if (_model.evalAutoParamsEnabled && _model._parms._metalearner_algorithm == Metalearner.Algorithm.AUTO)
        _model._parms._metalearner_algorithm = metalearnerAlgoImpl;
    } // computeImpl
  }

  private class StackedEnsembleCVStackingDriver extends StackedEnsembleDriver {

    @Override
    protected StackedEnsembleModel.StackingStrategy strategy() {
      return StackedEnsembleModel.StackingStrategy.cross_validation;
    }

    @Override
    protected Frame getActualTrainingFrame() {
      return _model._parms.train();
    }

    @Override
    protected Frame getPredictionsForBaseModel(Model model, Frame actualsFrame, boolean isTraining) {
      Frame fr;
      if (isTraining) {
        // for training, retrieve predictions from cv holdout predictions frame as all base models are required to get built with keep_cross_validation_frame=true
        if (null == model._output._cross_validation_holdout_predictions_frame_id)
          throw new H2OIllegalArgumentException("Failed to find the xval predictions frame id. . .  Looks like keep_cross_validation_predictions wasn't set when building the models.");

        fr = DKV.getGet(model._output._cross_validation_holdout_predictions_frame_id);

        if (null == fr)
          throw new H2OIllegalArgumentException("Failed to find the xval predictions frame. . .  Looks like keep_cross_validation_predictions wasn't set when building the models, or the frame was deleted.");

      } else {
        fr = buildPredictionsForBaseModel(model, actualsFrame);
      }
      return fr;
    }

  }

  private class StackedEnsembleBlendingDriver extends StackedEnsembleDriver {

    @Override
    protected StackedEnsembleModel.StackingStrategy strategy() {
      return StackedEnsembleModel.StackingStrategy.blending;
    }

    @Override
    protected Frame getActualTrainingFrame() {
      return _model._parms.blending();
    }

    @Override
    protected Frame getPredictionsForBaseModel(Model model, Frame actualsFrame, boolean isTrainingFrame) {
      // if training we can stop prematurely due to a timeout but computing validation scores should be allowed to finish
      if (stop_requested() && isTrainingFrame) {
        throw new Job.JobCancelledException();
      }
      return buildPredictionsForBaseModel(model, actualsFrame);
    }
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy