All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.ensemble.StackedEnsembleModel Maven / Gradle / Ivy

package hex.ensemble;

import hex.*;
import hex.genmodel.utils.DistributionFamily;
import hex.genmodel.utils.LinkFunctionType;
import hex.glm.GLMModel;
import hex.tree.drf.DRFModel;
import water.*;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.udf.CFuncRef;
import water.util.Log;
import water.util.MRUtils;
import water.util.ReflectionUtils;
import water.util.TwoDimTable;

import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.HashSet;
import java.util.stream.Stream;

import static hex.Model.Parameters.FoldAssignmentScheme.AUTO;
import static hex.Model.Parameters.FoldAssignmentScheme.Random;
import static hex.util.DistributionUtils.familyToDistribution;

/**
 * An ensemble of other models, created by stacking with the SuperLearner algorithm or a variation.
 */
public class StackedEnsembleModel extends Model {

  // common parameters for the base models (keeping them public for backwards compatibility, although it's nonsense)
  public ModelCategory modelCategory;
  public long trainingFrameRows = -1;

  public String responseColumn = null;

  public enum StackingStrategy {
    cross_validation,
    blending
  }

  // TODO: add a separate holdout dataset for the ensemble
  // TODO: add a separate overall cross-validation for the ensemble, including _fold_column and FoldAssignmentScheme / _fold_assignment

  public StackedEnsembleModel(Key selfKey, StackedEnsembleParameters parms, StackedEnsembleOutput output) {
    super(selfKey, parms, output);
  }

  @Override
  public void initActualParamValues() {
    super.initActualParamValues();
    if (_parms._metalearner_fold_assignment == AUTO) {
      _parms._metalearner_fold_assignment = Random;
    }
  }

  @Override
  public boolean haveMojo() {
    return super.haveMojo() 
            && Stream.of(_parms._base_models)
                     .filter(this::isUsefulBaseModel)
                     .map(DKV::getGet)
                     .allMatch(Model::haveMojo);
  }

  public static class StackedEnsembleParameters extends Model.Parameters {
    public String algoName() { return "StackedEnsemble"; }
    public String fullName() { return "Stacked Ensemble"; }
    public String javaName() { return StackedEnsembleModel.class.getName(); }
    @Override public long progressUnits() { return 1; }  // TODO

    // base_models is a list of base model keys to ensemble (must have been cross-validated)
    public Key _base_models[] = new Key[0];
    // Should we keep the level-one frame of cv preds + response col?
    public boolean _keep_levelone_frame = false;
    // internal flag if we want to avoid having the base model predictions rescored multiple times, esp. for blending.
    public boolean _keep_base_model_predictions = false;

    // Metalearner params
    //for stacking using cross-validation
    public int _metalearner_nfolds;
    public Parameters.FoldAssignmentScheme _metalearner_fold_assignment;
    public String _metalearner_fold_column;
    //the training frame used for blending (from which predictions columns are computed)
    public Key _blending;

    public enum MetalearnerTransform {
      NONE,
      Logit;

      public Frame transform(StackedEnsembleModel model, Frame frame, Key destKey) {
        if (this == Logit) {
          return new MRTask() {
            @Override
            public void map(Chunk[] cs, NewChunk[] ncs) {
              LinkFunction logitLink = LinkFunctionFactory.getLinkFunction(LinkFunctionType.logit);
              for (int c = 0; c < cs.length; c++) {
                for (int i = 0; i < cs[c]._len; i++) {
                  final double p = Math.min(1 - 1e-9, Math.max(cs[c].atd(i), 1e-9)); // 0 and 1 don't work well with logit
                  ncs[c].addNum(logitLink.link(p));
                }
              }
            }
          }.doAll(frame.numCols(), Vec.T_NUM, frame)
                  .outputFrame(destKey, frame._names, null);
        } else {
          throw new RuntimeException();
        }
      }
    }

    public MetalearnerTransform _metalearner_transform = MetalearnerTransform.NONE;

    public Metalearner.Algorithm _metalearner_algorithm = Metalearner.Algorithm.AUTO;
    public String _metalearner_params = new String(); //used for clients code-gen only.
    public Model.Parameters _metalearner_parameters;
    public long _score_training_samples = 10_000;

    /**
     * initialize {@link #_metalearner_parameters} with default parameters for the current {@link #_metalearner_algorithm}.
     */
    public void initMetalearnerParams() {
      initMetalearnerParams(_metalearner_algorithm);
    }

    /**
     * initialize {@link #_metalearner_parameters} with default parameters for the given algorithm
     * @param algo the metalearner algorithm we want to use and for which parameters are initialized.
     */
    public void initMetalearnerParams(Metalearner.Algorithm algo) {
      _metalearner_algorithm = algo;
      _metalearner_parameters = Metalearners.createParameters(algo.name());
    }
    
    public final Frame blending() { return _blending == null ? null : _blending.get(); }

    @Override
    public String[] getNonPredictors() {
      HashSet nonPredictors = new HashSet<>();
      nonPredictors.addAll(Arrays.asList(super.getNonPredictors()));
      if (null != _metalearner_fold_column)
        nonPredictors.add(_metalearner_fold_column);
      return nonPredictors.toArray(new String[0]);
    }

    @Override
    public DistributionFamily getDistributionFamily() {
      if (_metalearner_parameters != null)
        return _metalearner_parameters.getDistributionFamily();
      return super.getDistributionFamily();
    }

    @Override
    public void setDistributionFamily(DistributionFamily distributionFamily) {
      assert _metalearner_parameters != null;
      _metalearner_parameters.setDistributionFamily(distributionFamily);
    }
  }

  public static class StackedEnsembleOutput extends Model.Output {
    public StackedEnsembleOutput() { super(); }
    public StackedEnsembleOutput(StackedEnsemble b) { super(b); }

    public StackedEnsembleOutput(Job job) { _job = job; }
    // The metalearner model (e.g., a GLM that has a coefficient for each of the base_learners).
    public Model _metalearner;
    public Frame _levelone_frame_id; //set only if StackedEnsembleParameters#_keep_levelone_frame=true
    public StackingStrategy _stacking_strategy;
    
    //Set of base model predictions that have been cached in DKV to avoid scoring the same model multiple times,
    // it is then the responsibility of the client code to delete those frames from DKV.
    //This especially useful when building SE models incrementally (e.g. in AutoML).
    //The Set is instantiated and filled only if StackedEnsembleParameters#_keep_base_model_predictions=true.
    public Key[] _base_model_predictions_keys;

    @Override
    public int nfeatures() {
      return super.nfeatures() - (_metalearner._parms._fold_column == null ? 0 : 1);
    }
  }

  /**
   * For StackedEnsemble we call score on all the base_models and then combine the results
   * with the metalearner to create the final predictions frame.
   *
   * @see Model#predictScoreImpl(Frame, Frame, String, Job, boolean, CFuncRef)
   * @param adaptFrm Already adapted frame
   * @param computeMetrics
   * @return A Frame containing the prediction column, and class distribution
   */
  @Override
  protected PredictScoreResult predictScoreImpl(Frame fr, Frame adaptFrm, String destination_key, Job j, boolean computeMetrics, CFuncRef customMetricFunc) {
    final StackedEnsembleParameters.MetalearnerTransform transform; 
    if (_parms._metalearner_transform != null && _parms._metalearner_transform != StackedEnsembleParameters.MetalearnerTransform.NONE) {
      if (!(_output.isBinomialClassifier() || _output.isMultinomialClassifier()))
        throw new H2OIllegalArgumentException("Metalearner transform is supported only for classification!");
      transform = _parms._metalearner_transform;
    } else {
      transform = null;
    }
    final String seKey = this._key.toString();
    final Key levelOneFrameKey = Key.make("preds_levelone_" + seKey + fr._key);
    Frame levelOneFrame = transform == null ?
            new Frame(levelOneFrameKey)  // no transform -> this will be the final frame
            :
            new Frame();        // transform -> this is only an intermediate result

    Model[] usefulBaseModels = Stream.of(_parms._base_models)
            .filter(this::isUsefulBaseModel)
            .map(Key::get)
            .toArray(Model[]::new);

    if (usefulBaseModels.length > 0) {
      Frame[] baseModelPredictions = new Frame[usefulBaseModels.length];

      // Run scoring of base models in parallel
      H2O.submitTask(new LocalMR(new MrFun() {
        @Override
        protected void map(int id) {
          baseModelPredictions[id] = usefulBaseModels[id].score(
                  fr,
                  "preds_base_" + seKey + usefulBaseModels[id]._key + fr._key,
                  j,
                  false
          );
        }
      }, usefulBaseModels.length)).join();

      for (int i = 0; i < usefulBaseModels.length; i++) {
        StackedEnsemble.addModelPredictionsToLevelOneFrame(usefulBaseModels[i], baseModelPredictions[i], levelOneFrame);
        DKV.remove(baseModelPredictions[i]._key); //Cleanup
        Frame.deleteTempFrameAndItsNonSharedVecs(baseModelPredictions[i], levelOneFrame);
      }
    }

    if (transform != null) {
      Frame oldLOF = levelOneFrame;
      levelOneFrame = transform.transform(this, levelOneFrame, levelOneFrameKey);
      oldLOF.remove();
    }
    // Add response column, weights columns to level one frame
    StackedEnsemble.addNonPredictorsToLevelOneFrame(_parms, adaptFrm, levelOneFrame, false);
    // TODO: what if we're running multiple in parallel and have a name collision?
    Log.info("Finished creating \"level one\" frame for scoring: " + levelOneFrame.toString());

    // Score the dataset, building the class distribution & predictions

    Model metalearner = this._output._metalearner;
    Frame predictFr = metalearner.score(
        levelOneFrame, 
        destination_key, 
        j, 
        computeMetrics, 
        CFuncRef.from(_parms._custom_metric_func)
    );
    ModelMetrics mmStackedEnsemble = null;
    if (computeMetrics) {
      // #score has just stored a ModelMetrics object for the (metalearner, preds_levelone) Model/Frame pair.
      // We need to be able to look it up by the (this, fr) pair.
      // The ModelMetrics object for the metalearner will be removed when the metalearner is removed.
      Key[] mms = metalearner._output.getModelMetrics();
      ModelMetrics lastComputedMetric = mms[mms.length - 1].get();
      mmStackedEnsemble = lastComputedMetric.deepCloneWithDifferentModelAndFrame(this, fr);
      this.addModelMetrics(mmStackedEnsemble);
      //now that we have the metric set on the SE model, removing the one we just computed on metalearner (otherwise it leaks in client mode)
      for (Key mm : metalearner._output.clearModelMetrics(true)) {
        DKV.remove(mm);
      }
    }
    Frame.deleteTempFrameAndItsNonSharedVecs(levelOneFrame, adaptFrm);
    return new StackedEnsemblePredictScoreResult(predictFr, mmStackedEnsemble);
  }

  private class StackedEnsemblePredictScoreResult extends PredictScoreResult {
    private final ModelMetrics _modelMetrics;

    public StackedEnsemblePredictScoreResult(Frame preds, ModelMetrics modelMetrics) {
      super(null, preds, preds);
      _modelMetrics = modelMetrics;
    }

    @Override
    public ModelMetrics makeModelMetrics(Frame fr, Frame adaptFrm) {
      return _modelMetrics;
    }

    @Override
    public ModelMetrics.MetricBuilder getMetricBuilder() {
      throw new UnsupportedOperationException("Stacked Ensemble model doesn't implement MetricBuilder infrastructure code, " +
              "retrieve your metrics by calling getOrMakeMetrics method.");
    }
  }
  
  /**
   * Is the baseModel's prediction used in the metalearner?
   *
   * @param baseModelKey
   */
  boolean isUsefulBaseModel(Key baseModelKey) {
    Model metalearner = _output._metalearner;
    assert metalearner != null : "can't use isUsefulBaseModel during training";
    if (modelCategory == ModelCategory.Multinomial) {
      // Multinomial models output multiple columns and a base model
      // might be useful just for one category...
      for (String feature : metalearner._output._names) {
        if (feature.startsWith(baseModelKey.toString().concat("/"))){
          if (metalearner.isFeatureUsedInPredict(feature)) {
            return true;
          }
        }
      }
      return false;
    } else {
      return metalearner.isFeatureUsedInPredict(baseModelKey.toString());
    }
  }

  /**
   * Should never be called: the code paths that normally go here should call predictScoreImpl().
   * @see Model#score0(double[], double[])
   */
  @Override
  protected double[] score0(double data[/*ncols*/], double preds[/*nclasses+1*/]) {
    throw new UnsupportedOperationException("StackedEnsembleModel.score0() should never be called: the code paths that normally go here should call predictScoreImpl().");
  }

  @Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
    throw new UnsupportedOperationException("StackedEnsembleModel.makeMetricBuilder should never be called!");
  }

  private ModelMetrics doScoreTrainingMetrics(Frame frame, Job job) {
    Frame scoredFrame = (_parms._score_training_samples > 0 && _parms._score_training_samples < frame.numRows())
            ? MRUtils.sampleFrame(frame, _parms._score_training_samples, _parms._seed)
            : frame;
    try {
      Frame adaptedFrame = new Frame(scoredFrame);
      PredictScoreResult result = predictScoreImpl(scoredFrame, adaptedFrame, null, job, true, CFuncRef.from(_parms._custom_metric_func));
      result.getPredictions().delete();
      return result.makeModelMetrics(scoredFrame, adaptedFrame);
    } finally {
      if (scoredFrame != frame) scoredFrame.delete();
    }
  }

  void doScoreOrCopyMetrics(Job job) {
    // To get ensemble training metrics, the training frame needs to be re-scored since
    // training metrics from metalearner are not equal to ensemble training metrics.
    // The training metrics for the metalearner do not reflect true ensemble training metrics because
    // the metalearner was trained on cv preds, not training preds.  So, rather than clone the metalearner
    // training metrics, we have to re-score the training frame on all the base models, then send these
    // biased preds through to the metalearner, and then compute the metrics there.
    //
    // Job set to null since `stop_requested()` due to timeout would invalidate the whole SE at this point
    // which would be unfortunate since this is the last step of SE training and it also should be relatively fast.
    this._output._training_metrics = doScoreTrainingMetrics(this._parms.train(), null);

    // Validation metrics can be copied from metalearner (may be null).
    // Validation frame was already piped through so there's no need to re-do that to get the same results.
    this._output._validation_metrics = this._output._metalearner._output._validation_metrics;
    
    // Cross-validation metrics can be copied from metalearner (may be null).
    // For cross-validation metrics, we use metalearner cross-validation metrics as a proxy for the ensemble
    // cross-validation metrics -- the true k-fold cv metrics for the ensemble would require training k sets of
    // cross-validated base models (rather than a single set of cross-validated base models), which is extremely
    // computationally expensive and awkward from the standpoint of the existing Stacked Ensemble API.
    // More info: https://0xdata.atlassian.net/browse/PUBDEV-3971
    // Need to do DeepClone because otherwise framekey is incorrect (metalearner train is levelone not train)
    if (null != this._output._metalearner._output._cross_validation_metrics) {
      this._output._cross_validation_metrics = this._output._metalearner._output._cross_validation_metrics
              .deepCloneWithDifferentModelAndFrame(this, this._output._metalearner._parms.train());
      this._output._cross_validation_metrics_summary = (TwoDimTable) this._output._metalearner._output._cross_validation_metrics_summary.clone();
    }
  }



  
  public void deleteBaseModelPredictions() {
    if (_output._base_model_predictions_keys != null) {
      for (Key key : _output._base_model_predictions_keys) {
        if (_output._levelone_frame_id != null && key.get() != null)
          Frame.deleteTempFrameAndItsNonSharedVecs(key.get(), _output._levelone_frame_id);
        else
          Keyed.remove(key);
      }
      _output._base_model_predictions_keys = null;
    }
  }

  @Override protected Futures remove_impl(Futures fs, boolean cascade) {
    deleteBaseModelPredictions(); 
    if (_output._metalearner != null)
      _output._metalearner.remove(fs);
    if (_output._levelone_frame_id != null)
      _output._levelone_frame_id.remove(fs);

    return super.remove_impl(fs, cascade);
  }

  /** Write out models (base + metalearner) */
  @Override protected AutoBuffer writeAll_impl(AutoBuffer ab) {
    //Metalearner
    ab.putKey(_output._metalearner._key);
    //Base Models
    for (Key ks : _parms._base_models)
        ab.putKey(ks);
    return super.writeAll_impl(ab);
  }

  /** Read in models (base + metalearner) */
  @Override protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
    //Metalearner
    ab.getKey(_output._metalearner._key,fs);
    //Base Models
    for (Key ks : _parms._base_models)
      ab.getKey(ks,fs);
    return super.readAll_impl(ab,fs);
  }

  @Override
  public StackedEnsembleMojoWriter getMojo() {
    return new StackedEnsembleMojoWriter(this);
  }

  @Override
  public void deleteCrossValidationModels() {
    if (_output._metalearner != null)
      _output._metalearner.deleteCrossValidationModels();
  }

  @Override
  public void deleteCrossValidationPreds() {
    if (_output._metalearner != null)
      _output._metalearner.deleteCrossValidationPreds();
  }

  @Override
  public void deleteCrossValidationFoldAssignment() {
    if (_output._metalearner != null)
      _output._metalearner.deleteCrossValidationFoldAssignment();
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy