All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.h2o.automl.modeling.GLMStepsProvider Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package ai.h2o.automl.modeling;

import ai.h2o.automl.*;
import ai.h2o.automl.preprocessing.PreprocessingConfig;
import ai.h2o.automl.preprocessing.TargetEncoding;
import hex.Model;
import hex.glm.GLMModel;
import hex.glm.GLMModel.GLMParameters;


public class GLMStepsProvider
        implements ModelingStepsProvider
                 , ModelParametersProvider {

    public static class GLMSteps extends ModelingSteps {

        static final String NAME = Algo.GLM.name();
        
        static abstract class GLMModelStep extends ModelingStep.ModelStep {

            GLMModelStep(String id, AutoML autoML) {
                super(NAME, Algo.GLM, id, autoML);
            }

            @Override
            protected void setStoppingCriteria(Model.Parameters parms, Model.Parameters defaults) {
                // disabled as we're using lambda search
            }

            public GLMParameters prepareModelParameters() {
                GLMParameters params = new GLMParameters();
                params._lambda_search = true;
                params._family =
                        aml().getResponseColumn().isBinary() && !(aml().getResponseColumn().isNumeric()) ? GLMParameters.Family.binomial
                                : aml().getResponseColumn().isCategorical() ? GLMParameters.Family.multinomial
                                : GLMParameters.Family.gaussian;  // TODO: other continuous distributions!
                return params;
            }
            
            @Override
            protected PreprocessingConfig getPreprocessingConfig() {
                //GLM (the exception as usual) doesn't support targetencoding if CV is enabled
                // because it is initializing its lambdas + other params before CV (preventing changes in train frame during CV).
                PreprocessingConfig config = super.getPreprocessingConfig();
                config.put(TargetEncoding.CONFIG_PREPARE_CV_ONLY, aml().isCVEnabled()); 
                return config;
            }
        }


        private final ModelingStep[] defaults = new GLMModelStep[] {
                new GLMModelStep("def_1", aml()) {
                    @Override
                    public GLMParameters prepareModelParameters() {
                        GLMParameters params = super.prepareModelParameters();
                        params._alpha = new double[] {0.0, 0.2, 0.4, 0.6, 0.8, 1.0};
                        params._missing_values_handling = GLMParameters.MissingValuesHandling.MeanImputation;
                        return params;
                    }
                },
        };

        public GLMSteps(AutoML autoML) {
            super(autoML);
        }

        @Override
        public String getProvider() {
            return NAME;
        }

        @Override
        protected ModelingStep[] getDefaultModels() {
            return defaults;
        }
    }

    @Override
    public String getName() {
        return GLMSteps.NAME;
    }

    @Override
    public GLMSteps newInstance(AutoML aml) {
        return new GLMSteps(aml);
    }

    @Override
    public GLMParameters newDefaultParameters() {
        return new GLMParameters();
    }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy