All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.ensemble.StackedEnsembleModel Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.ensemble;

import hex.*;
import hex.genmodel.utils.DistributionFamily;
import hex.genmodel.utils.LinkFunctionType;
import hex.glm.GLMModel;
import hex.tree.drf.DRFModel;
import water.*;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.udf.CFuncRef;
import water.util.Log;
import water.util.MRUtils;
import water.util.ReflectionUtils;
import water.util.TwoDimTable;

import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.HashSet;
import java.util.stream.Stream;

import static hex.Model.Parameters.FoldAssignmentScheme.AUTO;
import static hex.Model.Parameters.FoldAssignmentScheme.Random;
import static hex.util.DistributionUtils.familyToDistribution;

/**
 * An ensemble of other models, created by stacking with the SuperLearner algorithm or a variation.
 */
public class StackedEnsembleModel extends Model {

  // common parameters for the base models (keeping them public for backwards compatibility, although it's nonsense)
  public ModelCategory modelCategory;
  public long trainingFrameRows = -1;

  public String responseColumn = null;

  public enum StackingStrategy {
    cross_validation,
    blending
  }

  // TODO: add a separate holdout dataset for the ensemble
  // TODO: add a separate overall cross-validation for the ensemble, including _fold_column and FoldAssignmentScheme / _fold_assignment

  public StackedEnsembleModel(Key selfKey, StackedEnsembleParameters parms, StackedEnsembleOutput output) {
    super(selfKey, parms, output);
  }

  @Override
  public void initActualParamValues() {
    super.initActualParamValues();
    if (_parms._metalearner_fold_assignment == AUTO) {
      _parms._metalearner_fold_assignment = Random;
    }
  }

  @Override
  public boolean haveMojo() {
    return super.haveMojo() 
            && Stream.of(_parms._base_models)
                     .filter(this::isUsefulBaseModel)
                     .map(DKV::getGet)
                     .allMatch(Model::haveMojo);
  }

  public static class StackedEnsembleParameters extends Model.Parameters {
    public String algoName() { return "StackedEnsemble"; }
    public String fullName() { return "Stacked Ensemble"; }
    public String javaName() { return StackedEnsembleModel.class.getName(); }
    @Override public long progressUnits() { return 1; }  // TODO

    // base_models is a list of base model keys to ensemble (must have been cross-validated)
    public Key _base_models[] = new Key[0];
    // Should we keep the level-one frame of cv preds + response col?
    public boolean _keep_levelone_frame = false;
    // internal flag if we want to avoid having the base model predictions rescored multiple times, esp. for blending.
    public boolean _keep_base_model_predictions = false;

    // Metalearner params
    //for stacking using cross-validation
    public int _metalearner_nfolds;
    public Parameters.FoldAssignmentScheme _metalearner_fold_assignment;
    public String _metalearner_fold_column;
    //the training frame used for blending (from which predictions columns are computed)
    public Key _blending;

    public enum MetalearnerTransform {
      NONE,
      Logit;

      public Frame transform(StackedEnsembleModel model, Frame frame, Key destKey) {
        if (this == Logit) {
          return new MRTask() {
            @Override
            public void map(Chunk[] cs, NewChunk[] ncs) {
              LinkFunction logitLink = LinkFunctionFactory.getLinkFunction(LinkFunctionType.logit);
              for (int c = 0; c < cs.length; c++) {
                for (int i = 0; i < cs[c]._len; i++) {
                  final double p = Math.min(1 - 1e-9, Math.max(cs[c].atd(i), 1e-9)); // 0 and 1 don't work well with logit
                  ncs[c].addNum(logitLink.link(p));
                }
              }
            }
          }.doAll(frame.numCols(), Vec.T_NUM, frame)
                  .outputFrame(destKey, frame._names, null);
        } else {
          throw new RuntimeException();
        }
      }
    }

    public MetalearnerTransform _metalearner_transform = MetalearnerTransform.NONE;

    public Metalearner.Algorithm _metalearner_algorithm = Metalearner.Algorithm.AUTO;
    public String _metalearner_params = new String(); //used for clients code-gen only.
    public Model.Parameters _metalearner_parameters;
    public long _score_training_samples = 10_000;

    /**
     * initialize {@link #_metalearner_parameters} with default parameters for the current {@link #_metalearner_algorithm}.
     */
    public void initMetalearnerParams() {
      initMetalearnerParams(_metalearner_algorithm);
    }

    /**
     * initialize {@link #_metalearner_parameters} with default parameters for the given algorithm
     * @param algo the metalearner algorithm we want to use and for which parameters are initialized.
     */
    public void initMetalearnerParams(Metalearner.Algorithm algo) {
      _metalearner_algorithm = algo;
      _metalearner_parameters = Metalearners.createParameters(algo.name());
    }
    
    public final Frame blending() { return _blending == null ? null : _blending.get(); }

    @Override
    public String[] getNonPredictors() {
      HashSet nonPredictors = new HashSet<>();
      nonPredictors.addAll(Arrays.asList(super.getNonPredictors()));
      if (null != _metalearner_fold_column)
        nonPredictors.add(_metalearner_fold_column);
      return nonPredictors.toArray(new String[0]);
    }

    @Override
    public DistributionFamily getDistributionFamily() {
      if (_metalearner_parameters != null)
        return _metalearner_parameters.getDistributionFamily();
      return super.getDistributionFamily();
    }

    @Override
    public void setDistributionFamily(DistributionFamily distributionFamily) {
      assert _metalearner_parameters != null;
      _metalearner_parameters.setDistributionFamily(distributionFamily);
    }
  }

  public static class StackedEnsembleOutput extends Model.Output {
    public StackedEnsembleOutput() { super(); }
    public StackedEnsembleOutput(StackedEnsemble b) { super(b); }

    public StackedEnsembleOutput(Job job) { _job = job; }
    // The metalearner model (e.g., a GLM that has a coefficient for each of the base_learners).
    public Model _metalearner;
    public Frame _levelone_frame_id; //set only if StackedEnsembleParameters#_keep_levelone_frame=true
    public StackingStrategy _stacking_strategy;
    
    //Set of base model predictions that have been cached in DKV to avoid scoring the same model multiple times,
    // it is then the responsibility of the client code to delete those frames from DKV.
    //This especially useful when building SE models incrementally (e.g. in AutoML).
    //The Set is instantiated and filled only if StackedEnsembleParameters#_keep_base_model_predictions=true.
    public Key[] _base_model_predictions_keys;

    @Override
    public int nfeatures() {
      return super.nfeatures() - (_metalearner._parms._fold_column == null ? 0 : 1);
    }
  }

  /**
   * For StackedEnsemble we call score on all the base_models and then combine the results
   * with the metalearner to create the final predictions frame.
   *
   * @see Model#predictScoreImpl(Frame, Frame, String, Job, boolean, CFuncRef)
   * @param adaptFrm Already adapted frame
   * @param computeMetrics
   * @return A Frame containing the prediction column, and class distribution
   */
  @Override
  protected PredictScoreResult predictScoreImpl(Frame fr, Frame adaptFrm, String destination_key, Job j, boolean computeMetrics, CFuncRef customMetricFunc) {
    final StackedEnsembleParameters.MetalearnerTransform transform; 
    if (_parms._metalearner_transform != null && _parms._metalearner_transform != StackedEnsembleParameters.MetalearnerTransform.NONE) {
      if (!(_output.isBinomialClassifier() || _output.isMultinomialClassifier()))
        throw new H2OIllegalArgumentException("Metalearner transform is supported only for classification!");
      transform = _parms._metalearner_transform;
    } else {
      transform = null;
    }
    final String seKey = this._key.toString();
    final Key levelOneFrameKey = Key.make("preds_levelone_" + seKey + fr._key);
    Frame levelOneFrame = transform == null ?
            new Frame(levelOneFrameKey)  // no transform -> this will be the final frame
            :
            new Frame();        // transform -> this is only an intermediate result

    Model[] usefulBaseModels = Stream.of(_parms._base_models)
            .filter(this::isUsefulBaseModel)
            .map(Key::get)
            .toArray(Model[]::new);

    if (usefulBaseModels.length > 0) {
      Frame[] baseModelPredictions = new Frame[usefulBaseModels.length];

      // Run scoring of base models in parallel
      H2O.submitTask(new LocalMR(new MrFun() {
        @Override
        protected void map(int id) {
          baseModelPredictions[id] = usefulBaseModels[id].score(
                  fr,
                  "preds_base_" + seKey + usefulBaseModels[id]._key + fr._key,
                  j,
                  false
          );
        }
      }, usefulBaseModels.length)).join();

      for (int i = 0; i < usefulBaseModels.length; i++) {
        StackedEnsemble.addModelPredictionsToLevelOneFrame(usefulBaseModels[i], baseModelPredictions[i], levelOneFrame);
        DKV.remove(baseModelPredictions[i]._key); //Cleanup
        Frame.deleteTempFrameAndItsNonSharedVecs(baseModelPredictions[i], levelOneFrame);
      }
    }

    if (transform != null) {
      Frame oldLOF = levelOneFrame;
      levelOneFrame = transform.transform(this, levelOneFrame, levelOneFrameKey);
      oldLOF.remove();
    }
    // Add response column, weights columns to level one frame
    StackedEnsemble.addNonPredictorsToLevelOneFrame(_parms, adaptFrm, levelOneFrame, false);
    // TODO: what if we're running multiple in parallel and have a name collision?
    Log.info("Finished creating \"level one\" frame for scoring: " + levelOneFrame.toString());

    // Score the dataset, building the class distribution & predictions

    Model metalearner = this._output._metalearner;
    Frame predictFr = metalearner.score(
        levelOneFrame, 
        destination_key, 
        j, 
        computeMetrics, 
        CFuncRef.from(_parms._custom_metric_func)
    );
    ModelMetrics mmStackedEnsemble = null;
    if (computeMetrics) {
      // #score has just stored a ModelMetrics object for the (metalearner, preds_levelone) Model/Frame pair.
      // We need to be able to look it up by the (this, fr) pair.
      // The ModelMetrics object for the metalearner will be removed when the metalearner is removed.
      Key[] mms = metalearner._output.getModelMetrics();
      ModelMetrics lastComputedMetric = mms[mms.length - 1].get();
      mmStackedEnsemble = lastComputedMetric.deepCloneWithDifferentModelAndFrame(this, fr);
      this.addModelMetrics(mmStackedEnsemble);
      //now that we have the metric set on the SE model, removing the one we just computed on metalearner (otherwise it leaks in client mode)
      for (Key mm : metalearner._output.clearModelMetrics(true)) {
        DKV.remove(mm);
      }
    }
    Frame.deleteTempFrameAndItsNonSharedVecs(levelOneFrame, adaptFrm);
    return new StackedEnsemblePredictScoreResult(predictFr, mmStackedEnsemble);
  }

  private class StackedEnsemblePredictScoreResult extends PredictScoreResult {
    private final ModelMetrics _modelMetrics;

    public StackedEnsemblePredictScoreResult(Frame preds, ModelMetrics modelMetrics) {
      super(null, preds, preds);
      _modelMetrics = modelMetrics;
    }

    @Override
    public ModelMetrics makeModelMetrics(Frame fr, Frame adaptFrm) {
      return _modelMetrics;
    }

    @Override
    public ModelMetrics.MetricBuilder getMetricBuilder() {
      throw new UnsupportedOperationException("Stacked Ensemble model doesn't implement MetricBuilder infrastructure code, " +
              "retrieve your metrics by calling getOrMakeMetrics method.");
    }
  }
  
  /**
   * Is the baseModel's prediction used in the metalearner?
   *
   * @param baseModelKey
   */
  boolean isUsefulBaseModel(Key baseModelKey) {
    Model metalearner = _output._metalearner;
    assert metalearner != null : "can't use isUsefulBaseModel during training";
    if (modelCategory == ModelCategory.Multinomial) {
      // Multinomial models output multiple columns and a base model
      // might be useful just for one category...
      for (String feature : metalearner._output._names) {
        if (feature.startsWith(baseModelKey.toString().concat("/"))){
          if (metalearner.isFeatureUsedInPredict(feature)) {
            return true;
          }
        }
      }
      return false;
    } else {
      return metalearner.isFeatureUsedInPredict(baseModelKey.toString());
    }
  }

  /**
   * Should never be called: the code paths that normally go here should call predictScoreImpl().
   * @see Model#score0(double[], double[])
   */
  @Override
  protected double[] score0(double data[/*ncols*/], double preds[/*nclasses+1*/]) {
    throw new UnsupportedOperationException("StackedEnsembleModel.score0() should never be called: the code paths that normally go here should call predictScoreImpl().");
  }

  @Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
    throw new UnsupportedOperationException("StackedEnsembleModel.makeMetricBuilder should never be called!");
  }

  private ModelMetrics doScoreTrainingMetrics(Frame frame, Job job) {
    Frame scoredFrame = (_parms._score_training_samples > 0 && _parms._score_training_samples < frame.numRows())
            ? MRUtils.sampleFrame(frame, _parms._score_training_samples, _parms._seed)
            : frame;
    try {
      Frame adaptedFrame = new Frame(scoredFrame);
      PredictScoreResult result = predictScoreImpl(scoredFrame, adaptedFrame, null, job, true, CFuncRef.from(_parms._custom_metric_func));
      result.getPredictions().delete();
      return result.makeModelMetrics(scoredFrame, adaptedFrame);
    } finally {
      if (scoredFrame != frame) scoredFrame.delete();
    }
  }

  void doScoreOrCopyMetrics(Job job) {
    // To get ensemble training metrics, the training frame needs to be re-scored since
    // training metrics from metalearner are not equal to ensemble training metrics.
    // The training metrics for the metalearner do not reflect true ensemble training metrics because
    // the metalearner was trained on cv preds, not training preds.  So, rather than clone the metalearner
    // training metrics, we have to re-score the training frame on all the base models, then send these
    // biased preds through to the metalearner, and then compute the metrics there.
    //
    // Job set to null since `stop_requested()` due to timeout would invalidate the whole SE at this point
    // which would be unfortunate since this is the last step of SE training and it also should be relatively fast.
    this._output._training_metrics = doScoreTrainingMetrics(this._parms.train(), null);

    // Validation metrics can be copied from metalearner (may be null).
    // Validation frame was already piped through so there's no need to re-do that to get the same results.
    this._output._validation_metrics = this._output._metalearner._output._validation_metrics;
    
    // Cross-validation metrics can be copied from metalearner (may be null).
    // For cross-validation metrics, we use metalearner cross-validation metrics as a proxy for the ensemble
    // cross-validation metrics -- the true k-fold cv metrics for the ensemble would require training k sets of
    // cross-validated base models (rather than a single set of cross-validated base models), which is extremely
    // computationally expensive and awkward from the standpoint of the existing Stacked Ensemble API.
    // More info: https://0xdata.atlassian.net/browse/PUBDEV-3971
    // Need to do DeepClone because otherwise framekey is incorrect (metalearner train is levelone not train)
    if (null != this._output._metalearner._output._cross_validation_metrics) {
      this._output._cross_validation_metrics = this._output._metalearner._output._cross_validation_metrics
              .deepCloneWithDifferentModelAndFrame(this, this._output._metalearner._parms.train());
      this._output._cross_validation_metrics_summary = (TwoDimTable) this._output._metalearner._output._cross_validation_metrics_summary.clone();
    }
  }

  private DistributionFamily distributionFamily(Model aModel) {
    // TODO: hack alert: In DRF, _parms._distribution is always set to multinomial.  Yay.
    if (aModel instanceof DRFModel)
      if (aModel._output.isBinomialClassifier())
        return DistributionFamily.bernoulli;
      else if (aModel._output.isClassifier())
        return DistributionFamily.multinomial;
      else
        return DistributionFamily.gaussian;

    if (aModel instanceof StackedEnsembleModel) {
      StackedEnsembleModel seModel = (StackedEnsembleModel) aModel;
      if (Metalearners.getActualMetalearnerAlgo(seModel._parms._metalearner_algorithm) == Metalearner.Algorithm.glm) {
        return familyToDistribution(((GLMModel.GLMParameters) seModel._parms._metalearner_parameters)._family);
      }
      if (seModel._parms._metalearner_parameters._distribution != DistributionFamily.AUTO) {
        return seModel._parms._metalearner_parameters._distribution;
      }
    }

    try {
      Field familyField = ReflectionUtils.findNamedField(aModel._parms, "_family");
      Field distributionField = (familyField != null ? null : ReflectionUtils.findNamedField(aModel, "_dist"));
      if (null != familyField) {
        // GLM only, for now
        GLMModel.GLMParameters.Family thisFamily = (GLMModel.GLMParameters.Family) familyField.get(aModel._parms);
        return familyToDistribution(thisFamily);
      }

      if (null != distributionField) {
        Distribution distribution = ((Distribution)distributionField.get(aModel));
        DistributionFamily distributionFamily;
        if (null != distribution)
          distributionFamily = distribution._family;
        else
          distributionFamily = aModel._parms._distribution;

        // NOTE: If the algo does smart guessing of the distribution family we need to duplicate the logic here.
        if (distributionFamily == DistributionFamily.AUTO) {
          if (aModel._output.isBinomialClassifier())
            distributionFamily = DistributionFamily.bernoulli;
          else if (aModel._output.isClassifier())
            distributionFamily = DistributionFamily.multinomial;
          else
            distributionFamily = DistributionFamily.gaussian;
        } // DistributionFamily.AUTO

        return distributionFamily;
      }

      throw new H2OIllegalArgumentException("Don't know how to stack models that have neither a distribution hyperparameter nor a family hyperparameter.");
    }
    catch (Exception e) {
      throw new H2OIllegalArgumentException(e.toString(), e.toString());
    }
  }



  /**
   * Inherit distribution and its parameters
   * @param baseModelParms
   */
  private void inheritDistributionAndParms(Model.Parameters baseModelParms) {
    if (baseModelParms instanceof GLMModel.GLMParameters) {
      try {
        _parms._metalearner_parameters.setDistributionFamily(familyToDistribution(((GLMModel.GLMParameters) baseModelParms)._family));
      } catch (IllegalArgumentException e) {
        Log.warn("Stacked Ensemble is not able to inherit distribution from GLM's family " + ((GLMModel.GLMParameters) baseModelParms)._family + ".");
      }
    } else if (baseModelParms instanceof DRFModel.DRFParameters) {
      inferBasicDistribution();
    } else {
      _parms._metalearner_parameters.setDistributionFamily(baseModelParms._distribution);
    }
    // deal with parameterized distributions
    switch (baseModelParms._distribution) {
      case custom:
          _parms._metalearner_parameters._custom_distribution_func = baseModelParms._custom_distribution_func;
        break;
      case huber:
          _parms._metalearner_parameters._huber_alpha = baseModelParms._huber_alpha;
        break;
      case tweedie:
        _parms._metalearner_parameters._tweedie_power = baseModelParms._tweedie_power;
        break;
      case quantile:
        _parms._metalearner_parameters._quantile_alpha = baseModelParms._quantile_alpha;
        break;
    }
  }

  /**
   * Inherit family and its parameters
   * @param baseModelParms
   */
  private void inheritFamilyAndParms(Model.Parameters baseModelParms) {
    GLMModel.GLMParameters metaParams = (GLMModel.GLMParameters) _parms._metalearner_parameters;
    if (baseModelParms instanceof GLMModel.GLMParameters) {
      GLMModel.GLMParameters glmParams = (GLMModel.GLMParameters) baseModelParms;
      metaParams._family = glmParams._family;
      metaParams._link = glmParams._link;
    } else if (baseModelParms instanceof DRFModel.DRFParameters) {
      inferBasicDistribution();
    } else {
      try {
        metaParams.setDistributionFamily(baseModelParms._distribution);
      } catch (H2OIllegalArgumentException e) {
        Log.warn("Stacked Ensemble is not able to inherit family from a distribution " + baseModelParms._distribution + ".");
        inferBasicDistribution();
      }
    }
    // deal with parameterized distributions
    if (metaParams._family == GLMModel.GLMParameters.Family.tweedie) {
      _parms._metalearner_parameters._tweedie_power = baseModelParms._tweedie_power;
    }
  }

  /**
   * Infers distribution/family from a model
   * @param aModel
   * @return True if the distribution or family was inferred from a model
   */
  boolean inferDistributionOrFamily(Model aModel) {
    if (Metalearners.getActualMetalearnerAlgo(_parms._metalearner_algorithm) == Metalearner.Algorithm.glm) { //use family
      if (((GLMModel.GLMParameters)_parms._metalearner_parameters)._family != GLMModel.GLMParameters.Family.AUTO) {
        return false; // User specified family - no need to infer one; Link will be also used properly if it is specified
      }
      inheritFamilyAndParms(aModel._parms);
    } else { // use distribution
      if (_parms._metalearner_parameters._distribution != DistributionFamily.AUTO) {
        return false; // User specified distribution; no need to infer one
      }
      inheritDistributionAndParms(aModel._parms);
    }
    return true;
  }

  void inferBasicDistribution() {
      if (this._output.isBinomialClassifier()) {
        _parms._metalearner_parameters.setDistributionFamily(DistributionFamily.bernoulli);
      } else if (this._output.isClassifier()) {
        _parms._metalearner_parameters.setDistributionFamily(DistributionFamily.multinomial);
      } else {
        _parms._metalearner_parameters.setDistributionFamily(DistributionFamily.gaussian);
      }
  }

  void checkAndInheritModelProperties() {
    if (null == _parms._base_models || 0 == _parms._base_models.length)
      throw new H2OIllegalArgumentException("When creating a StackedEnsemble you must specify one or more models; found 0.");

    if (null != _parms._metalearner_fold_column && 0 != _parms._metalearner_nfolds)
      throw new H2OIllegalArgumentException("Cannot specify fold_column and nfolds at the same time.");

    Model aModel = null;
    boolean retrievedFirstModelParams = false;
    boolean inferredDistributionFromFirstModel = false;
    GLMModel firstGLM = null;
    boolean blending_mode = _parms._blending != null;
    boolean cv_required_on_base_model = !blending_mode;
    boolean require_consistent_training_frames = !blending_mode && !_parms._is_cv_model;
    
    //following variables are collected from the 1st base model (should be identical across base models), i.e. when beenHere=false
    int basemodel_nfolds = -1;
    Parameters.FoldAssignmentScheme basemodel_fold_assignment = null;
    String basemodel_fold_column = null;
    long seed = -1;
    //end 1st model collected fields

    // Make sure we can set metalearner's family and link if needed
    if (_parms._metalearner_parameters == null) {
      _parms.initMetalearnerParams();
    }

    for (Key k : _parms._base_models) {
      aModel = DKV.getGet(k);
      if (null == aModel) {
        Log.warn("Failed to find base model; skipping: "+k);
        continue;
      }
      Log.debug("Checking properties for model "+k);
      if (!aModel.isSupervised()) {
        throw new H2OIllegalArgumentException("Base model is not supervised: "+aModel._key.toString());
      }

      if (retrievedFirstModelParams) {
        // check that the base models are all consistent with first based model

        if (modelCategory != aModel._output.getModelCategory())
          throw new H2OIllegalArgumentException("Base models are inconsistent: "
                  +"there is a mix of different categories of models among "+Arrays.toString(_parms._base_models));

        if (! responseColumn.equals(aModel._parms._response_column))
          throw new H2OIllegalArgumentException("Base models are inconsistent: they use different response columns."
                  +" Found: "+responseColumn+" (StackedEnsemble) and "+aModel._parms._response_column+" (model "+k+").");

        if (require_consistent_training_frames) {
          if (trainingFrameRows < 0) trainingFrameRows = _parms.train().numRows();
          long numOfRowsUsedToTrain = aModel._parms.train() == null ?
                  aModel._output._cross_validation_holdout_predictions_frame_id.get().numRows() :
                  aModel._parms.train().numRows();
          if (trainingFrameRows != numOfRowsUsedToTrain)
            throw new H2OIllegalArgumentException("Base models are inconsistent: they use different size (number of rows) training frames."
                    +" Found: "+trainingFrameRows+" (StackedEnsemble) and "+numOfRowsUsedToTrain+" (model "+k+").");
        }

        if (cv_required_on_base_model) {

          if (aModel._parms._fold_assignment != basemodel_fold_assignment
                  && !(aModel._parms._fold_assignment == AUTO && basemodel_fold_assignment == Random)
          ) {
            throw new H2OIllegalArgumentException("Base models are inconsistent: they use different fold_assignments.");
          }

          if (aModel._parms._fold_column == null) {
            // If we don't have a fold_column require:
            // nfolds > 1
            // nfolds consistent across base models
            if (aModel._parms._nfolds < 2)
              throw new H2OIllegalArgumentException("Base model does not use cross-validation: "+aModel._parms._nfolds);
            if (basemodel_nfolds != aModel._parms._nfolds)
              throw new H2OIllegalArgumentException("Base models are inconsistent: they use different values for nfolds.");

            if (basemodel_fold_assignment == Random && aModel._parms._seed != seed)
              throw new H2OIllegalArgumentException("Base models are inconsistent: they use random-seeded k-fold cross-validation but have different seeds.");

          } else {
            if (!aModel._parms._fold_column.equals(basemodel_fold_column))
              throw new H2OIllegalArgumentException("Base models are inconsistent: they use different fold_columns.");
          }
          if (! aModel._parms._keep_cross_validation_predictions)
            throw new H2OIllegalArgumentException("Base model does not keep cross-validation predictions: "+aModel._parms._nfolds);
        }

        if (inferredDistributionFromFirstModel) {
          // Check inferred params and if they differ fallback to basic distribution of model category
          if (!(aModel instanceof DRFModel) && distributionFamily(aModel) == distributionFamily(this)) {
            boolean sameParams = true;
            switch (_parms._metalearner_parameters._distribution) {
              case custom:
                sameParams = _parms._metalearner_parameters._custom_distribution_func
                        .equals(aModel._parms._custom_distribution_func);
                break;
              case huber:
                sameParams = _parms._metalearner_parameters._huber_alpha == aModel._parms._huber_alpha;
                break;
              case tweedie:
                sameParams = _parms._metalearner_parameters._tweedie_power == aModel._parms._tweedie_power;
                break;
              case quantile:
                sameParams = _parms._metalearner_parameters._quantile_alpha == aModel._parms._quantile_alpha;
                break;
            }

            if ((aModel instanceof GLMModel) && (Metalearners.getActualMetalearnerAlgo(_parms._metalearner_algorithm) == Metalearner.Algorithm.glm)) {
              if (firstGLM == null) {
                firstGLM = (GLMModel) aModel;
                inheritFamilyAndParms(firstGLM._parms);
              } else {
                sameParams = ((GLMModel.GLMParameters) _parms._metalearner_parameters)._link.equals(((GLMModel) aModel)._parms._link);
              }
            }

            if (!sameParams) {
              Log.warn("Base models are inconsistent; they use same distribution but different parameters of " +
                      "the distribution. Reverting to default distribution.");
              inferBasicDistribution();
              inferredDistributionFromFirstModel = false;
            }
          } else {
            if (distributionFamily(aModel) != distributionFamily(this)) {
              // Distribution of base models differ
              Log.warn("Base models are inconsistent; they use different distributions: "
                      + distributionFamily(this) + " and: " + distributionFamily(aModel) +
                      ". Reverting to default distribution.");
            } // else the first model was DRF/XRT so we don't want to warn
            inferBasicDistribution();
            inferredDistributionFromFirstModel = false;
          }
        }
      } else {
        // !retrievedFirstModelParams: this is the first base_model
        this.modelCategory = aModel._output.getModelCategory();
        inferredDistributionFromFirstModel = inferDistributionOrFamily(aModel);
        firstGLM = aModel instanceof GLMModel && inferredDistributionFromFirstModel ? (GLMModel) aModel : null;
        responseColumn = aModel._parms._response_column;

        if (! _parms._response_column.equals(responseColumn))  // _params._response_column can't be null, validated by ModelBuilder
          throw new H2OIllegalArgumentException("StackedModel response_column must match the response_column of each base model."
                  +" Found: "+_parms._response_column+"(StackedEnsemble) and: "+responseColumn+" (model "+k+").");

        basemodel_nfolds = aModel._parms._nfolds;
        basemodel_fold_assignment = aModel._parms._fold_assignment;
        if (basemodel_fold_assignment == AUTO) basemodel_fold_assignment = Random;
        basemodel_fold_column = aModel._parms._fold_column;
        seed = aModel._parms._seed;
        retrievedFirstModelParams = true;
      }

    } // for all base_models

    if (null == aModel)
      throw new H2OIllegalArgumentException("When creating a StackedEnsemble you must specify one or more models; "
              +_parms._base_models.length+" were specified but none of those were found: "+Arrays.toString(_parms._base_models));
  }
  
  public void deleteBaseModelPredictions() {
    if (_output._base_model_predictions_keys != null) {
      for (Key key : _output._base_model_predictions_keys) {
        if (_output._levelone_frame_id != null && key.get() != null)
          Frame.deleteTempFrameAndItsNonSharedVecs(key.get(), _output._levelone_frame_id);
        else
          Keyed.remove(key);
      }
      _output._base_model_predictions_keys = null;
    }
  }

  @Override protected Futures remove_impl(Futures fs, boolean cascade) {
    deleteBaseModelPredictions(); 
    if (_output._metalearner != null)
      _output._metalearner.remove(fs);
    if (_output._levelone_frame_id != null)
      _output._levelone_frame_id.remove(fs);

    return super.remove_impl(fs, cascade);
  }

  /** Write out models (base + metalearner) */
  @Override protected AutoBuffer writeAll_impl(AutoBuffer ab) {
    //Metalearner
    ab.putKey(_output._metalearner._key);
    //Base Models
    for (Key ks : _parms._base_models)
        ab.putKey(ks);
    return super.writeAll_impl(ab);
  }

  /** Read in models (base + metalearner) */
  @Override protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
    //Metalearner
    ab.getKey(_output._metalearner._key,fs);
    //Base Models
    for (Key ks : _parms._base_models)
      ab.getKey(ks,fs);
    return super.readAll_impl(ab,fs);
  }

  @Override
  public StackedEnsembleMojoWriter getMojo() {
    return new StackedEnsembleMojoWriter(this);
  }

  @Override
  public void deleteCrossValidationModels() {
    if (_output._metalearner != null)
      _output._metalearner.deleteCrossValidationModels();
  }

  @Override
  public void deleteCrossValidationPreds() {
    if (_output._metalearner != null)
      _output._metalearner.deleteCrossValidationPreds();
  }

  @Override
  public void deleteCrossValidationFoldAssignment() {
    if (_output._metalearner != null)
      _output._metalearner.deleteCrossValidationFoldAssignment();
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy