All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.h2o.automl.modeling.StackedEnsembleStepsProvider Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package ai.h2o.automl.modeling;

import ai.h2o.automl.*;
import ai.h2o.automl.WorkAllocations.Work;
import ai.h2o.automl.events.EventLogEntry;
import ai.h2o.automl.preprocessing.PreprocessingConfig;
import ai.h2o.automl.preprocessing.TargetEncoding;
import hex.KeyValue;
import hex.Model;
import hex.ensemble.Metalearner;
import hex.ensemble.StackedEnsembleModel;
import hex.ensemble.StackedEnsembleModel.StackedEnsembleParameters;
import hex.glm.GLMModel;
import water.DKV;
import water.Job;
import water.Key;
import water.util.PojoUtils;

import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

public class StackedEnsembleStepsProvider
        implements ModelingStepsProvider
                 , ModelParametersProvider {

    public static class StackedEnsembleSteps extends ModelingSteps {

        static final String NAME = Algo.StackedEnsemble.name();

        static abstract class StackedEnsembleModelStep extends ModelingStep.ModelStep {
            
            protected final Metalearner.Algorithm _metalearnerAlgo;

            StackedEnsembleModelStep(String id, Metalearner.Algorithm algo, int priorityGroup, int weight, AutoML autoML) {
                super(NAME, Algo.StackedEnsemble, id, priorityGroup, weight, autoML);
                _metalearnerAlgo = algo;
                _ignoredConstraints = new AutoML.Constraint[] {AutoML.Constraint.MODEL_COUNT};
            }

            @Override
            protected void setCrossValidationParams(Model.Parameters params) {
                //added in the stack: we could probably move this here.
            }

            @Override
            protected void setWeightingParams(Model.Parameters params) {
                //Disabled: StackedEnsemble doesn't support weights in score0? 
            }

            @Override
            protected void setClassBalancingParams(Model.Parameters params) {
                //Disabled
            }

            @Override
            protected PreprocessingConfig getPreprocessingConfig() {
                //SE should not have TE applied, the base models already do it.
                PreprocessingConfig config = super.getPreprocessingConfig();
                config.put(TargetEncoding.CONFIG_ENABLED, false);
                return config;
            }

            @Override
            @SuppressWarnings("unchecked")
            public boolean canRun() {
                Key[] keys = getBaseModels();
                Work seWork = getAllocatedWork();
                if (!super.canRun()) {
                    aml().job().update(0, "Skipping this StackedEnsemble");
                    aml().eventLog().info(EventLogEntry.Stage.ModelTraining, String.format("Skipping StackedEnsemble '%s' due to the exclude_algos option or it is already trained.", _id));
                    return false;
                } else if (keys.length == 0) {
                    aml().job().update(seWork.consume(), "No base models; skipping this StackedEnsemble");
                    aml().eventLog().info(EventLogEntry.Stage.ModelTraining, String.format("No base models, due to timeouts or the exclude_algos option. Skipping StackedEnsemble '%s'.", _id));
                    return false;
                } else if (keys.length == 1) {
                    aml().job().update(seWork.consume(), "Only one base model; skipping this StackedEnsemble");
                    aml().eventLog().info(EventLogEntry.Stage.ModelTraining, String.format("Skipping StackedEnsemble '%s' since there is only one model to stack", _id));
                    return false;
                } else if (!isCVEnabled() && aml().getBlendingFrame() == null) {
                    aml().job().update(seWork.consume(), "Cross-validation disabled by the user and no blending frame provided; Skipping this StackedEnsemble");
                    aml().eventLog().info(EventLogEntry.Stage.ModelTraining, String.format("Cross-validation is disabled by the user and no blending frame was provided; skipping StackedEnsemble '%s'.", _id));
                    return false;
                }
                return !hasDoppelganger(keys);
            }
            
            @SuppressWarnings("unchecked")
            protected boolean hasDoppelganger(Key[] baseModelsKeys) {
                Key[] seModels = Arrays
                        .stream(getTrainedModelsKeys())
                        .filter(k -> isStackedEnsemble(k))
                        .toArray(Key[]::new);

                Set keySet = new HashSet<>(Arrays.asList(baseModelsKeys));
                for (Key seKey: seModels) {
                    StackedEnsembleModelStep seStep = (StackedEnsembleModelStep)aml().session().getModelingStep(seKey);
                    if (seStep._metalearnerAlgo != _metalearnerAlgo) continue;
                    
                    final StackedEnsembleParameters seParams = seKey.get()._parms;
                    final Key[] seBaseModels = seParams._base_models;
                    if (seBaseModels.length != baseModelsKeys.length) continue;
                    if (keySet.equals(new HashSet<>(Arrays.asList(seBaseModels))))
                        return true; // We already have a SE with the same base models
                }

                return false;
            }

            protected abstract Key[] getBaseModels();

            protected String getModelType(Key key) {
                String keyStr = key.toString();
                return keyStr.substring(0, keyStr.indexOf('_'));
            }

            protected boolean isStackedEnsemble(Key key) {
                ModelingStep step = aml().session().getModelingStep(key);
                return step != null && step.getAlgo() == Algo.StackedEnsemble;
            }

            @Override
            public StackedEnsembleParameters prepareModelParameters() {
                StackedEnsembleParameters params = new StackedEnsembleParameters();
                params._valid = (aml().getValidationFrame() == null ? null : aml().getValidationFrame()._key);
                params._blending = (aml().getBlendingFrame() == null ? null : aml().getBlendingFrame()._key);
                params._keep_levelone_frame = true; //TODO Why is this true? Can be optionally turned off
                return params;
            }
            
            protected void setMetalearnerParameters(StackedEnsembleParameters params) {
                AutoMLBuildSpec buildSpec = aml().getBuildSpec();
                params._metalearner_fold_column = buildSpec.input_spec.fold_column;
                params._metalearner_nfolds = buildSpec.build_control.nfolds;
                params.initMetalearnerParams(_metalearnerAlgo);
                params._metalearner_parameters._keep_cross_validation_models = buildSpec.build_control.keep_cross_validation_models;
                params._metalearner_parameters._keep_cross_validation_predictions = buildSpec.build_control.keep_cross_validation_predictions;
            }

            Job stack(String modelName, Key[] baseModels, boolean isLast) {
                StackedEnsembleParameters params = prepareModelParameters();
                params._base_models = baseModels;
                params._keep_base_model_predictions = !isLast; //avoids recomputing some base predictions for each SE
                
                setMetalearnerParameters(params);
                if (_metalearnerAlgo == Metalearner.Algorithm.AUTO) setAutoMetalearnerSEParameters(params);
                
                return stack(modelName, params);
            }
            
            Job stack(String modelName, StackedEnsembleParameters stackedEnsembleParameters) {
                Key modelKey = makeKey(modelName, true);
                return trainModel(modelKey, stackedEnsembleParameters);
            }

            protected void setAutoMetalearnerSEParameters(StackedEnsembleParameters stackedEnsembleParameters) {
                // add custom alpha in GLM metalearner
                GLMModel.GLMParameters metalearnerParams = (GLMModel.GLMParameters)stackedEnsembleParameters._metalearner_parameters;
                metalearnerParams._alpha = new double[]{0.5, 1.0};

                if (aml().getResponseColumn().isCategorical()) {
                    // Add logit transform
                    stackedEnsembleParameters._metalearner_transform = StackedEnsembleParameters.MetalearnerTransform.Logit;
                }
            } 

        }
        
        static class BestOfFamilySEModelStep extends StackedEnsembleModelStep {

            public BestOfFamilySEModelStep(String id, int priorityGroup, AutoML autoML) {
                this(id, Metalearner.Algorithm.AUTO, priorityGroup, autoML);
            }

            public BestOfFamilySEModelStep(String id, Metalearner.Algorithm algo, int priorityGroup, AutoML autoML) {
                this(id, algo, priorityGroup, DEFAULT_MODEL_TRAINING_WEIGHT, autoML);
            }

            public BestOfFamilySEModelStep(String id, Metalearner.Algorithm algo, int priorityGroup, int weight, AutoML autoML) {
                super((id == null ? "best_of_family_"+algo.name() : id), algo, priorityGroup, weight, autoML);
                _description = _description+" (built with "+algo.name()+" metalearner, using top model from each algorithm type)";
            }
            
            @Override
            @SuppressWarnings("unchecked")
            protected Key[] getBaseModels() {
                // Set aside List for best models per model type. Meaning best GLM, GBM, DRF, XRT, and DL (5 models).
                // This will give another ensemble that is smaller than the original which takes all models into consideration.
                List> bestModelsOfEachType = new ArrayList<>();
                Set typesOfGatheredModels = new HashSet<>();

                for (Key key : getTrainedModelsKeys()) {
                    // trained models are sorted (taken from leaderboard), so we only need to pick the first of each type (excluding other StackedEnsembles)
                    String type = getModelType(key);
                    if (isStackedEnsemble(key) || typesOfGatheredModels.contains(type)) continue;
                    typesOfGatheredModels.add(type);
                    bestModelsOfEachType.add(key);
                }
                return bestModelsOfEachType.toArray(new Key[0]);
            }

            @Override
            protected Job startJob() {
                return stack(_provider+"_BestOfFamily", getBaseModels(), false);
            }
        }

        static class BestNModelsSEModelStep extends StackedEnsembleModelStep {

            private final int _N;
            
            public BestNModelsSEModelStep(String id, int N, int priorityGroup, AutoML autoML) {
                this(id, Metalearner.Algorithm.AUTO, N, priorityGroup, DEFAULT_MODEL_TRAINING_WEIGHT, autoML);
            }

            public BestNModelsSEModelStep(String id, Metalearner.Algorithm algo, int N, int priorityGroup, int weight, AutoML autoML) {
                super((id == null ? "best_"+N+"_"+algo.name() : id), algo, priorityGroup, weight, autoML);
                _N = N;
                _description = _description+" (built with "+algo.name()+" metalearner, using best "+N+" non-SE models)";
            }
            
            @Override
            @SuppressWarnings("unchecked")
            protected Key[] getBaseModels() {
                return Stream.of(getTrainedModelsKeys())
                        .filter(k -> !isStackedEnsemble(k))
                        .limit(_N)
                        .toArray(Key[]::new);
            }

            @Override
            protected Job startJob() {
                return stack(_provider+"_Best"+_N, getBaseModels(), false);
            }
        }
        
        static class AllSEModelStep extends StackedEnsembleModelStep {
            public AllSEModelStep(String id, int priorityGroup, AutoML autoML) {
                this(id, Metalearner.Algorithm.AUTO, priorityGroup, autoML);
            }

            public AllSEModelStep(String id, Metalearner.Algorithm algo, int priorityGroup, AutoML autoML) {
                this(id, algo, priorityGroup, DEFAULT_MODEL_TRAINING_WEIGHT, autoML);
            }
            
            public AllSEModelStep(String id, Metalearner.Algorithm algo, int priorityGroup, int weight, AutoML autoML) {
                super((id == null ? "all_"+algo.name() : id), algo, priorityGroup, weight, autoML);
                _description = _description+" (built with "+algo.name()+" metalearner, using all AutoML models)";
            }
            
            @Override
            @SuppressWarnings("unchecked")
            protected Key[] getBaseModels() {
                return Stream.of(getTrainedModelsKeys())
                        .filter(k -> !isStackedEnsemble(k))
                        .toArray(Key[]::new);
            }

            @Override
            protected Job startJob() {
                return stack(_provider+"_AllModels", getBaseModels(), false);
            }
        }
        
        static class MonotonicSEModelStep extends StackedEnsembleModelStep {

            public MonotonicSEModelStep(String id, int priorityGroup, AutoML autoML) {
                this(id, Metalearner.Algorithm.AUTO, priorityGroup, DEFAULT_MODEL_TRAINING_WEIGHT, autoML);
            }
            
            public MonotonicSEModelStep(String id, Metalearner.Algorithm algo, int priorityGroup, int weight, AutoML autoML) {
                super((id == null ? "monotonic" : id), algo, priorityGroup, weight, autoML);
                _description = _description+" (built with "+algo.name()+" metalearner, using monotonically constrained AutoML models)";
            }
            
            boolean hasMonotoneConstrains(Key modelKey) {
                Model model = DKV.getGet(modelKey);
                try {
                    KeyValue[] mc = (KeyValue[]) PojoUtils.getFieldValue(
                            model._parms, "_monotone_constraints",
                            PojoUtils.FieldNaming.CONSISTENT);
                    return mc != null && mc.length > 0;
                } catch (IllegalArgumentException e) {
                    return false;
                }
            }

            @Override
            public boolean canRun() {
                boolean canRun = super.canRun();
                if (!canRun) return false;
                int monotoneModels=0;
                for (Key modelKey: getTrainedModelsKeys()) {
                    if (hasMonotoneConstrains(modelKey))
                        monotoneModels++;
                    if (monotoneModels >= 2)
                        return true;
                }
                if (monotoneModels == 1) {
                    aml().job().update(getAllocatedWork().consume(),
                            "Only one monotonic base model; skipping this StackedEnsemble");
                    aml().eventLog().info(EventLogEntry.Stage.ModelTraining,
                            String.format("Skipping StackedEnsemble '%s' since there is only one monotonic model to stack", _id));
                } else {
                    aml().job().update(getAllocatedWork().consume(),
                            "No monotonic base model; skipping this StackedEnsemble");
                    aml().eventLog().info(EventLogEntry.Stage.ModelTraining,
                            String.format("Skipping StackedEnsemble '%s' since there is no monotonic model to stack", _id));
                }
                return false;
            }

            @Override
            @SuppressWarnings("unchecked")
            protected Key[] getBaseModels() {
                return Stream.of(getTrainedModelsKeys())
                        .filter(k -> !isStackedEnsemble(k) && hasMonotoneConstrains(k))
                        .toArray(Key[]::new);
            }

            @Override
            protected Job startJob() {
                return stack(_provider + "_Monotonic", getBaseModels(), false);
            }
        }
        
        private final ModelingStep[] defaults;
        private final ModelingStep[] optionals;

        {
            // we're going to cheat a bit: ModelingSteps needs to instantiated by the AutoML instance 
            // to convert each StepDefinition into one or more ModelingStep(s)
            // so at that time, we have access to the entire modeling plan
            // and we can dynamically generate the modeling steps that we're going to need.
            StepDefinition[] modelingPlan = aml().getBuildSpec().build_models.modeling_plan;
            if (Stream.of(modelingPlan).noneMatch(sd -> sd.getName().equals(NAME))) {
                defaults = new ModelingStep[0];
                optionals = new ModelingStep[0];
            } else {
                List defaultSeSteps = new ArrayList<>();
                // starting to generate the SE for each "base" group
                // ie for each group with algo steps.
                Set defaultAlgoProviders = Stream.of(Algo.values())
                        .filter(a -> a != Algo.StackedEnsemble)
                        .map(Algo::name)
                        .collect(Collectors.toSet());
                int[] baseAlgoGroups = Stream.of(modelingPlan)
                        .filter(sd -> defaultAlgoProviders.contains(sd.getName()))
                        .flatMapToInt(sd -> 
                                sd.getAlias() == StepDefinition.Alias.defaults ? IntStream.of(ModelingStep.ModelStep.DEFAULT_MODEL_GROUP)
                                : sd.getAlias() == StepDefinition.Alias.grids ? IntStream.of(ModelingStep.GridStep.DEFAULT_GRID_GROUP)
                                : sd.getAlias() == StepDefinition.Alias.all ? IntStream.of(ModelingStep.ModelStep.DEFAULT_MODEL_GROUP, ModelingStep.GridStep.DEFAULT_GRID_GROUP)
                                : sd.getSteps().stream().flatMapToInt(s -> s.getGroup() == StepDefinition.Step.DEFAULT_GROUP 
                                        ? IntStream.of(ModelingStep.ModelStep.DEFAULT_MODEL_GROUP, ModelingStep.GridStep.DEFAULT_GRID_GROUP)
                                        : IntStream.of(s.getGroup())))
                        .distinct().sorted().toArray();
                
                for (int group : baseAlgoGroups) {
                    defaultSeSteps.add(new BestOfFamilySEModelStep("best_of_family_" + group, group, aml()));
                    defaultSeSteps.add(new AllSEModelStep("all_" + group, group, aml()));  // groups <=0 are ignored.
                }
                defaults = defaultSeSteps.toArray(new ModelingStep[0]);

                // now all the additional SEs are available as optionals (usually requested by id).
                int maxBaseGroup = IntStream.of(baseAlgoGroups).max().orElse(0);
                List optionalSeSteps = new ArrayList<>();
                if (maxBaseGroup > 0) {
                    int optionalGroup = maxBaseGroup+1;
                    optionalSeSteps.add(new MonotonicSEModelStep("monotonic", optionalGroup, aml()));
                    optionalSeSteps.add(new BestOfFamilySEModelStep("best_of_family", optionalGroup, aml()));
                    optionalSeSteps.add(new AllSEModelStep("all", optionalGroup, aml()));
                    if (Algo.XGBoost.enabled()) {
                        optionalSeSteps.add(new BestOfFamilySEModelStep("best_of_family_xgboost", Metalearner.Algorithm.xgboost, optionalGroup, aml()));
                        optionalSeSteps.add(new AllSEModelStep("all_xgboost", Metalearner.Algorithm.xgboost, optionalGroup, aml()));
                    }
                    optionalSeSteps.add(new BestOfFamilySEModelStep("best_of_family_gbm", Metalearner.Algorithm.gbm, optionalGroup, aml()));
                    optionalSeSteps.add(new AllSEModelStep("all_gbm", Metalearner.Algorithm.gbm, optionalGroup, aml()));
                    optionalSeSteps.add(new BestOfFamilySEModelStep("best_of_family_xglm", optionalGroup, aml()) {
                        @Override
                        protected boolean hasDoppelganger(Key[] baseModelsKeys) {
                            return false;
                        }

                        @Override
                        protected void setMetalearnerParameters(StackedEnsembleParameters params) {
                            super.setMetalearnerParameters(params);
                            GLMModel.GLMParameters metalearnerParams = (GLMModel.GLMParameters) params._metalearner_parameters;
                            metalearnerParams._lambda_search = true;
                        }
                    });
                    optionalSeSteps.add(new AllSEModelStep("all_xglm", optionalGroup, aml()) {
                        @Override
                        protected boolean hasDoppelganger(Key[] baseModelsKeys) {
                            Set modelTypes = new HashSet<>();
                            for (Key key : baseModelsKeys) {
                                String modelType = getModelType(key);
                                if (modelTypes.contains(modelType)) return false;
                                modelTypes.add(modelType);
                            }
                            return true;
                        }

                        @Override
                        protected void setMetalearnerParameters(StackedEnsembleParameters params) {
                            super.setMetalearnerParameters(params);
                            GLMModel.GLMParameters metalearnerParams = (GLMModel.GLMParameters) params._metalearner_parameters;
                            metalearnerParams._lambda_search = true;
                        }
                    });
//                    optionalSeSteps.add(new BestNModelsSEModelStep("best_20", 20, optionalGroup, aml()));
                    int card = aml().getResponseColumn().cardinality();
                    int maxModels = card <= 2 ? 1_000 : Math.max(100, 1_000 / (card - 1));
                    optionalSeSteps.add(new BestNModelsSEModelStep("best_N", maxModels, optionalGroup, aml()));
                }
                optionals = optionalSeSteps.toArray(new ModelingStep[0]);
            }
        }

        public StackedEnsembleSteps(AutoML autoML) {
            super(autoML);
        }

        @Override
        public String getProvider() {
            return NAME;
        }

        @Override
        protected ModelingStep[] getDefaultModels() {
            return defaults;
        }

        @Override
        protected ModelingStep[] getOptionals() {
            return optionals;
        }
    }

    @Override
    public String getName() {
        return StackedEnsembleSteps.NAME;
    }

    @Override
    public StackedEnsembleSteps newInstance(AutoML aml) {
        return new StackedEnsembleSteps(aml);
    }

    @Override
    public StackedEnsembleParameters newDefaultParameters() {
        return new StackedEnsembleParameters();
    }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy