All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.modelselection.ModelSelectionModel Maven / Gradle / Ivy

package hex.modelselection;

import hex.*;
import hex.deeplearning.DeepLearningModel;
import hex.glm.GLM;
import hex.glm.GLMModel;
import water.*;
import water.fvec.Frame;
import water.fvec.Vec;
import water.udf.CFuncRef;
import water.util.TwoDimTable;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static hex.glm.GLMModel.GLMParameters.Family.AUTO;
import static hex.modelselection.ModelSelectionUtils.*;

public class ModelSelectionModel extends Model {
    
    public ModelSelectionModel(Key selfKey, ModelSelectionParameters parms, ModelSelectionModelOutput output) {
        super(selfKey, parms, output);
    }

    @Override
    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        assert domain == null;
        switch (_output.getModelCategory()) {
            case Regression:
                return new ModelMetricsRegression.MetricBuilderRegression();
            default:
                throw H2O.unimpl("Invalid ModelCategory " + _output.getModelCategory());
        }
    }

    @Override
    protected double[] score0(double[] data, double[] preds) {
        throw new UnsupportedOperationException("ModelSelection does not support scoring on data.  It only provide " +
                "information on predictor relevance");
    }

    @Override
    public Frame score(Frame fr, String destination_key, Job j, boolean computeMetrics, CFuncRef customMetricFunc) {
        throw new UnsupportedOperationException("AnovaGLM does not support scoring on data.  It only provide " +
                "information on predictor relevance");
    }

    @Override
    public Frame result() {
        return _output.generateResultFrame();
    }
    
    public static class ModelSelectionParameters extends Model.Parameters {
        public double[] _alpha;
        public boolean _standardize = true;
        public boolean _intercept = true;
        GLMModel.GLMParameters.Family _family = AUTO;
        public boolean _lambda_search;
        public GLMModel.GLMParameters.Link _link = GLMModel.GLMParameters.Link.family_default;
        public GLMModel.GLMParameters.Solver _solver = GLMModel.GLMParameters.Solver.IRLSM;
        public String[] _interactions=null;
        public Serializable _missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.MeanImputation;
        public boolean _compute_p_values = false;
        public boolean _remove_collinear_columns = false;
        public int _nfolds = 0; // disable cross-validation
        public Key _plug_values = null;
        public int _max_predictor_number = 1;
        public int _min_predictor_number = 1;
        public int _nparallelism = 0;
        public double _p_values_threshold = 0;
        public double _tweedie_variance_power;
        public double _tweedie_link_power;
        public Mode _mode = Mode.maxr;  // mode chosen to perform model selection
        public double _beta_epsilon = 1e-4;
        public double _objective_epsilon = -1;  // -1 to use default setting
        public double _gradient_epsilon = -1;   // -1 to use default setting
        public double _obj_reg = -1.0;
        public double[] _lambda = new double[]{0.0};
        public boolean _use_all_factor_levels = false;
        public boolean _build_glm_model = false;
        public GLMModel.GLMParameters.Influence _influence;  // if set to dfbetas will calculate the difference of betas obtained from including and excluding a data row
        public boolean _multinode_mode = false; // for maxrsweep only, if true will run on multiple nodes in cluster
        
        public enum Mode {
            allsubsets, // use combinatorial, exponential runtime
            maxr, // use sequential replacement but calls GLM to build all models, slow but can use cross-validation and validation dataset to build more robust results
            maxrsweep,  // perform incremental maxrsweep without using sweeping vectors, only on CPM.
            backward // use backward selection
        }
        @Override
        public String algoName() {
            return "ModelSelection";
        }

        @Override
        public String fullName() {
            return "Model Selection";
        }

        @Override
        public String javaName() {
            return ModelSelectionModel.class.getName();
        }

        @Override
        public long progressUnits() {
            return 1;
        }

        public GLMModel.GLMParameters.MissingValuesHandling missingValuesHandling() {
            if (_missing_values_handling instanceof GLMModel.GLMParameters.MissingValuesHandling)
                return (GLMModel.GLMParameters.MissingValuesHandling) _missing_values_handling;
            assert _missing_values_handling instanceof DeepLearningModel.DeepLearningParameters.MissingValuesHandling;
            switch ((DeepLearningModel.DeepLearningParameters.MissingValuesHandling) _missing_values_handling) {
                case MeanImputation:
                    return GLMModel.GLMParameters.MissingValuesHandling.MeanImputation;
                case Skip:
                    return GLMModel.GLMParameters.MissingValuesHandling.Skip;
                default:
                    throw new IllegalStateException("Unsupported missing values handling value: " + _missing_values_handling);
            }
        }

        public boolean imputeMissing() {
            return missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.MeanImputation ||
                    missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.PlugValues;
        }

        public DataInfo.Imputer makeImputer() {
            if (missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.PlugValues) {
                if (_plug_values == null || _plug_values.get() == null) {
                    throw new IllegalStateException("Plug values frame needs to be specified when Missing Value " +
                            "Handling = PlugValues.");
                }
                return new GLM.PlugValuesImputer(_plug_values.get());
            } else { // mean/mode imputation and skip (even skip needs an imputer right now! PUBDEV-6809)
                return new DataInfo.MeanImputer();
            }
        }
    }
    
    public static class ModelSelectionModelOutput extends Model.Output {
        GLMModel.GLMParameters.Family _family;
        DataInfo _dinfo;
        String[][] _coefficient_names; // store for each predictor number, the best model predictors
        double[] _best_r2_values;  // store the best R2 values of the best models with fix number of predictors
        String[][] _predictors_added_per_step;
        String[][] _predictors_removed_per_step;
        public Key[] _best_model_ids;
        double[][] _coef_p_values;
        double[][] _coefficient_values;   // store best predictor subset coefficient values
        double[][] _coefficient_values_normalized;   // store best predictor subset coefficient values
        double[][] _z_values;
        public ModelSelectionParameters.Mode _mode;
        String[][] _best_predictors_subset; // predictor names for subset of each size
        
        public ModelSelectionModelOutput(hex.modelselection.ModelSelection b, DataInfo dinfo) {
            super(b, dinfo._adaptedFrame);
            _dinfo = dinfo;
        }

        public String[][] coefficientNames() {
            return _coefficient_names;
        }
        
        public double[][] beta() {
            int numModel = _best_model_ids.length;
            double[][] coeffs = new double[numModel][];
            for (int index=0; index < numModel; index++) {
                GLMModel oneModel = DKV.getGet(_best_model_ids[index]);
                coeffs[index] = oneModel._output.beta().clone();
            }
            return coeffs;
        }
        
        public double[][] getNormBeta() {
            int numModel = _best_model_ids.length;
            double[][] coeffs = new double[numModel][];
            for (int index=0; index < numModel; index++) {
                GLMModel oneModel = DKV.getGet(_best_model_ids[index]);
                coeffs[index] = oneModel._output.getNormBeta().clone();
            }
            return coeffs;
        }
        
        @Override
        public ModelCategory getModelCategory() {
            return ModelCategory.Regression;
        }
        
        private Frame generateResultFrame() {
            int numRows = _coefficient_names.length;
            String[] modelNames = new String[numRows];
            String[] coefNames = new String[numRows];
            String[] predNames = new String[numRows];
            String[] modelIds = _best_model_ids == null ? null : Stream.of(_best_model_ids).map(Key::toString).toArray(String[]::new);
            String[] zvalues = new String[numRows];
            String[] pvalues = new String[numRows];
            String[] predAddedNames = new String[numRows];
            String[] predRemovedNames = new String[numRows];
            boolean backwardMode = _z_values!=null;
            // generate model names and predictor names
            for (int index=0; index < numRows; index++) {
                int numPred = _best_predictors_subset[index].length;
                modelNames[index] = "best "+numPred+" predictors model";
                coefNames[index] = backwardMode ? String.join(", ", _coefficient_names[index])
                        :String.join(", ", _coefficient_names[index]);
                predAddedNames[index] = backwardMode ? "" : String.join(", ", _predictors_added_per_step[index]);
                predRemovedNames[index] = _predictors_removed_per_step[index] == null ? "" : 
                        String.join(", ", _predictors_removed_per_step[index]);
                predNames[index] = String.join(", ", _best_predictors_subset[index]);
                if (backwardMode) {
                    zvalues[index] = joinDouble(_z_values[index]);
                    pvalues[index] = joinDouble(_coef_p_values[index]);
                }
            }
            // generate vectors before forming frame
            Vec.VectorGroup vg = Vec.VectorGroup.VG_LEN1;
            Vec modNames = Vec.makeVec(modelNames, vg.addVec());
            Vec modelIDV = modelIds == null ? null : Vec.makeVec(modelIds, vg.addVec());
            Vec r2=null;
            Vec zval=null;
            Vec pval=null;
            Vec predAdded=null;
            Vec predRemoved;
            if (backwardMode) {
                zval = Vec.makeVec(zvalues, vg.addVec());
                pval = Vec.makeVec(pvalues, vg.addVec());
            } else {
                r2 = Vec.makeVec(_best_r2_values, vg.addVec());
                predAdded = Vec.makeVec(predAddedNames, vg.addVec());
            }
            predRemoved = Vec.makeVec(predRemovedNames, vg.addVec());
            Vec coefN = Vec.makeVec(coefNames, vg.addVec());
            Vec predN = Vec.makeVec(predNames, vg.addVec());
            
            if (backwardMode) {
                String[] colNames = new String[]{"model_name", "model_id", "z_values", "p_values",
                        "coefficient_names", "predictor_names", "predictors_removed"};
                return new Frame(Key.make(), colNames, new Vec[]{modNames, modelIDV, zval, pval, coefN, predN, predRemoved});
            } else {
                if (modelIds == null) {
                    String[] colNames = new String[]{"model_name", "best_r2_value", "coefficient_names", "predictor_names",
                            "predictors_removed", "predictors_added"};
                    return new Frame(Key.make(), colNames, new Vec[]{modNames, r2, coefN, predN, predRemoved, predAdded});
                } else {
                    String[] colNames = new String[]{"model_name", "model_id", "best_r2_value", "coefficient_names", "predictor_names",
                            "predictors_removed", "predictors_added"};
                    return new Frame(Key.make(), colNames, new Vec[]{modNames, modelIDV, r2, coefN, predN, predRemoved, predAdded});
                }
            }
        }
        
        public void shrinkArrays(int numModelsBuilt) {
            if (_coefficient_names.length > numModelsBuilt) {
                _coefficient_names = shrinkStringArray(_coefficient_names, numModelsBuilt);
                _best_predictors_subset = shrinkStringArray(_best_predictors_subset, numModelsBuilt);
                _coefficient_names = shrinkStringArray(_coefficient_names, numModelsBuilt);
                _z_values = shrinkDoubleArray(_z_values, numModelsBuilt);
                _coef_p_values = shrinkDoubleArray(_coef_p_values, numModelsBuilt);
                _best_model_ids = shrinkKeyArray(_best_model_ids, numModelsBuilt);
                _predictors_removed_per_step = shrinkStringArray(_predictors_removed_per_step, numModelsBuilt);
            }
        }
        
        public void generateSummary() {
            int numModels = _best_r2_values.length;
            String[] names = new String[]{"best_r2_value", "coefficient_names", "predictor_names", 
                    "predictors_removed", "predictors_added"};
            String[] types = new String[]{"double", "String", "String", "String", "String"};
            String[] formats = new String[]{"%d", "%s", "%s", "%s", "%s"};
            String[] rowHeaders = new String[numModels];
            for (int index=1; index<=numModels; index++)
                rowHeaders[index-1] = "with "+_best_predictors_subset[index-1].length+" predictors";
            
            _model_summary = new TwoDimTable("ModelSelection Model Summary", "summary", 
                    rowHeaders, names, types, formats, "");
            for (int rIndex=0; rIndex < numModels; rIndex++) {
                int colInd = 0;
                _model_summary.set(rIndex, colInd++, _best_r2_values[rIndex]);
                _model_summary.set(rIndex, colInd++, String.join(", ", _coefficient_names[rIndex]));
                _model_summary.set(rIndex, colInd++, String.join(", ", _best_predictors_subset[rIndex]));
                if (_predictors_removed_per_step[rIndex] != null)
                    _model_summary.set(rIndex, colInd++, String.join(", ", _predictors_removed_per_step[rIndex]));
                else
                    _model_summary.set(rIndex, colInd++, "");
                _model_summary.set(rIndex, colInd++, String.join(", ", _predictors_added_per_step[rIndex]));
            }
        }
        
        // for backward model only
        public void generateSummary(int numModels) {
            String[] names = new String[]{"coefficient_names", "predictor_names", "z_values", "p_values", "predictors_removed"};
            String[] types = new String[]{"string", "string", "string", "string", "string"};
            String[] formats = new String[]{"%s", "%s", "%s", "%s", "%s"};
            String[] rowHeaders = new String[numModels];
            for (int index=0; index < numModels; index++) {
                rowHeaders[index] = "with "+_best_predictors_subset[index].length+" predictors";
            }
            _model_summary = new TwoDimTable("ModelSelection Model Summary", "summary", 
                    rowHeaders, names, types, formats, "");
            for (int rIndex=0; rIndex < numModels; rIndex++) {
                int colInd = 0;
                String coeffNames = String.join(", ", _coefficient_names[rIndex]);
                String predNames = String.join(", ", _best_predictors_subset[rIndex]);
                String pValue = joinDouble(_coef_p_values[rIndex]);
                String zValue = joinDouble(_z_values[rIndex]);
                _model_summary.set(rIndex, colInd++, coeffNames);
                _model_summary.set(rIndex, colInd++, predNames);
                _model_summary.set(rIndex, colInd++, zValue);
                _model_summary.set(rIndex, colInd++, pValue);
                _model_summary.set(rIndex, colInd, _predictors_removed_per_step[rIndex][0]);
            }
        }
        
        void updateBestModels(GLMModel bestModel, int index) {
            _best_model_ids[index] = bestModel.getKey();
            if (bestModel._parms._nfolds > 0) {
                int r2Index = Arrays.asList(bestModel._output._cross_validation_metrics_summary.getRowHeaders()).indexOf("r2");
                Float tempR2 = (Float) bestModel._output._cross_validation_metrics_summary.get(r2Index, 0);
                _best_r2_values[index] = tempR2.doubleValue();
            } else {
                _best_r2_values[index] = bestModel.r2();
            }
            extractCoeffs(bestModel, index);
            updateAddedRemovedPredictors(index);
        }

        void extractCoeffs(GLMModel model, int index) {
            _coefficient_names[index] = model._output.coefficientNames().clone(); // all coefficients
            ArrayList coeffNames = new ArrayList<>(Arrays.asList(model._output.coefficientNames()));
            _coefficient_names[index] = coeffNames.toArray(new String[0]); // without intercept
            List predNames = Stream.of(model.names()).collect(Collectors.toList());
            predNames.remove(model._parms._response_column);
            _best_predictors_subset[index] = predNames.stream().toArray(String[]::new);
        }

        void updateBestModels(String[] predictorNames, List allCoefNames, int index, boolean hasIntercept, 
                              int actualCPMSize, int[] predsubset, double[][] lastCPM, double r2Scale, 
                              CoeffNormalization coeffN, int[][] pred2CPMIndex, DataInfo dinfo) {
            int lastCPMIndex = actualCPMSize-1;
            if (lastCPM[lastCPMIndex][lastCPMIndex] == Double.MAX_VALUE)
                _best_r2_values[index] = -1;
            else
                _best_r2_values[index] = 1-r2Scale * lastCPM[lastCPMIndex][lastCPMIndex];
            extractCoeffs(predictorNames, allCoefNames, lastCPM, index, hasIntercept, actualCPMSize, predsubset, coeffN,
                    pred2CPMIndex, dinfo);
            updateAddedRemovedPredictors(index);
        }

        void extractCoeffs(String[] predNames, List allCoefNames, double[][] cpm, int index, boolean hasIntercept,
                           int actualCPMSize, int[] predSubset, CoeffNormalization coeffN, int[][] predsubset2CPMIndices,
                           DataInfo dinfo) {
            _best_predictors_subset[index] = extractPredsFromPredIndices(predNames, predSubset);
            _coefficient_names[index] = extractCoefsFromPred(allCoefNames, hasIntercept, dinfo, predSubset);
            extractCoefsValues(cpm, _coefficient_names[index].length, hasIntercept, actualCPMSize, coeffN, index, 
                    predSubset, predsubset2CPMIndices);
        }
        
        public void extractCoefsValues(double[][] cpm, int coefValLen, boolean hasIntercept, int actualCPMSize, 
                                           CoeffNormalization coeffN, int predIndex, int[] predSubset, int[][] pred2CPMIndices) {
            _coefficient_values[predIndex] = new double[coefValLen];
            _coefficient_values_normalized[predIndex] = new double[coefValLen];
            int lastCPMIndex = actualCPMSize-1;
            int cpmIndexOffset = hasIntercept?1:0;
            boolean standardize = coeffN._standardize;
            double[] sigmaOrOneOSigma = coeffN._sigmaOrOneOSigma;
            double[] meanOverSigma = coeffN._meanOverSigma;
            double sumBetaMeanOverSigma = 0;
            int numIndexStart = _dinfo._cats;
            int offset =0;
            int predSubsetLen = predSubset.length;
            int cpmInd, coefIndex;
            for (int pIndex = 0; pIndex < predSubsetLen; pIndex++) {
                int predictor = predSubset[pIndex];
                if (predictor >= numIndexStart) {   // numerical columns
                    coefIndex = pIndex+offset;
                    cpmInd = cpmIndexOffset+pIndex;
                    if (standardize) {
                        _coefficient_values[predIndex][coefIndex] = cpm[cpmInd][lastCPMIndex]*sigmaOrOneOSigma[predictor-numIndexStart];
                        _coefficient_values_normalized[predIndex][coefIndex] = cpm[cpmInd][lastCPMIndex];
                    } else {
                        _coefficient_values[predIndex][coefIndex] = cpm[cpmInd][lastCPMIndex];
                        _coefficient_values_normalized[predIndex][coefIndex] = cpm[cpmInd][lastCPMIndex]*sigmaOrOneOSigma[predictor-numIndexStart];
                    }
                    sumBetaMeanOverSigma += _coefficient_values_normalized[predIndex][coefIndex]*meanOverSigma[predictor-numIndexStart];
                } else {    // categorical columns
                    int cpmLen = pred2CPMIndices[predictor].length; // indices of cpm to grab for coefficients info
                    for (int cpmIndex = 0; cpmIndex < cpmLen; cpmIndex++) {
                        coefIndex = offset + cpmIndex + pIndex;
                        cpmInd = cpmIndexOffset + cpmIndex + pIndex;
                        _coefficient_values[predIndex][coefIndex] = cpm[cpmInd][lastCPMIndex];
                        _coefficient_values_normalized[predIndex][coefIndex] = cpm[cpmInd][lastCPMIndex];
                    }
                    offset += cpmLen-1;
                    cpmIndexOffset += cpmLen-1;
                }
            }

            if (hasIntercept) { // extract intercept value
                int lastCoefInd = _coefficient_values[predIndex].length-1;
                if (coeffN._standardize) {
                    _coefficient_values_normalized[predIndex][lastCoefInd] = cpm[0][lastCPMIndex];
                    _coefficient_values[predIndex][lastCoefInd] = cpm[0][lastCPMIndex]-sumBetaMeanOverSigma;
                } else {
                    _coefficient_values_normalized[predIndex][lastCoefInd] = cpm[0][lastCPMIndex]+sumBetaMeanOverSigma;
                    _coefficient_values[predIndex][lastCoefInd] = cpm[0][lastCPMIndex];
                }
            }
        }
        
        public static String[] extractCoefsFromPred(List allCoefList, boolean hasIntercept,
                                                    DataInfo dinfo, int[] predSubset) {
            List coefNames = new ArrayList<>();
            int numPred = predSubset.length;
            int predIndex;
            int numCats = dinfo._cats;
            int catOffsets = dinfo._catOffsets[dinfo._catOffsets.length-1];
            int numCatLevel;
            for (int index=0; index coeffs = IntStream.range(0, numCatLevel).mapToObj(x -> allCoefList.get(x+dinfo._catOffsets[predictorInd])).collect(Collectors.toList());
                    coefNames.addAll(coeffs);
                } else {    // numerical columns
                    coefNames.add(allCoefList.get(predIndex+catOffsets-numCats));
                }
            }
            if (hasIntercept)
                coefNames.add("Intercept");
            return coefNames.toArray(new String[0]);
        }
        public static String[] extractPredsFromPredIndices(String[] allPreds, int[] predSubset) {
            int numPreds = predSubset.length;
            String[] predSubsetNames = new String[numPreds];
            for (int index=0; index newSet = Stream.of(_coefficient_names[index]).collect(Collectors.toList());
            if (index > 0) {
                final List oldSet = Stream.of(_coefficient_names[index - 1]).collect(Collectors.toList());
                List predDeleted = oldSet.stream().filter(x -> (!newSet.contains(x) && 
                        !"Intercept".equals(x))).collect(Collectors.toList());
                _predictors_removed_per_step[index] = predDeleted == null || predDeleted.size()==0 ? new String[]{""} :
                        predDeleted.toArray(new String[predDeleted.size()]);
                if (!ModelSelectionParameters.Mode.backward.equals(_mode)) {
                    List predAdded = newSet.stream().filter(x -> (!oldSet.contains(x) &&
                            !"Intercept".equals(x))).collect(Collectors.toList());
                    _predictors_added_per_step[index] = predAdded.toArray(new String[0]);
                }
                return;
            } else if (!ModelSelectionParameters.Mode.backward.equals(_mode)) {
                _predictors_added_per_step[index] = new String[]{_coefficient_names[index][0]};
                _predictors_removed_per_step[index] = new String[]{""};
                return;
            }
            _predictors_removed_per_step[index] = new String[]{""};
            _predictors_added_per_step[index] = new String[]{""};
        }

        /**
         * Method to remove redundant predictors at the beginning of backward method.
         */
        void resetCoeffs(GLMModel model, List predNames, List numPredNames, List catPredNames) {
            final String[] coeffName = model._output.coefficientNames();
            int[] idxs = model._output.bestSubmodel().idxs;
            if (idxs == null) // no redundant predictors
                return;
            List coeffNames = Arrays.stream(idxs).mapToObj(x -> coeffName[x]).collect(Collectors.toList());
            resetAllPreds(predNames, catPredNames, numPredNames, model, coeffNames); // remove redundant preds
        }
        
        void resetAllPreds(List predNames, List catPredNames, List numPredNames, 
                           GLMModel model, List coeffNames) {
            if (model._output.bestSubmodel().idxs.length == model.coefficients().size())  // no redundant predictors
                return;
            resetNumPredNames(numPredNames, coeffNames);
            resetCatPredNames(model.dinfo(), model._output.bestSubmodel().idxs, catPredNames);
            if (predNames.size() > (numPredNames.size() + catPredNames.size())) {
                predNames.clear();
                predNames.addAll(catPredNames);
                predNames.addAll(numPredNames);
            }
        }
        
        public void resetNumPredNames(List numPredNames, List coeffNames) {
            List newNumPredNames = numPredNames.stream().filter(x -> coeffNames.contains(x)).collect(Collectors.toList());
            numPredNames.clear();
            numPredNames.addAll(newNumPredNames);
        }
        
        public void resetCatPredNames(DataInfo dinfo, int[] idxs, List catPredNames) {
            List newCatPredNames = new ArrayList<>();
            List idxsList = Arrays.stream(idxs).boxed().collect(Collectors.toList());
            int[] catOffset = dinfo._catOffsets;
            int catIndex = catOffset.length;
            int maxCatOffset = catOffset[catIndex-1];
            for (int index=1; index currCatList = IntStream.range(catOffset[offsetedIndex], catOffset[index]).boxed().collect(Collectors.toList());
                if (currCatList.stream().filter(x -> idxsList.contains(x)).count() > 0 && currCatList.get(currCatList.size()-1) < maxCatOffset) {
                    newCatPredNames.add(catPredNames.get(offsetedIndex));
                }
            }
            if (newCatPredNames.size() < catPredNames.size()) {
                catPredNames.clear();
                catPredNames.addAll(newCatPredNames);
            }
        }

        /***
         * Eliminate predictors with lowest z-value (z-score) magnitude as described in III of 
         * ModelSelectionTutorial.pdf in https://h2oai.atlassian.net/browse/PUBDEV-8428
         */
        void extractPredictors4NextModel(GLMModel model, int index, List predNames, List numPredNames, 
                                         List catPredNames) {
            boolean firstRun = (index+1) == predNames.size();
            List oldPredNames = firstRun ? new ArrayList<>(predNames) : null;
            extractCoeffs(model, index);
            int predIndex2Remove = findMinZValue(model, numPredNames, catPredNames, predNames);
            String pred2Remove = predNames.get(predIndex2Remove);
            if (firstRun) // remove redundant predictors if present
                resetCoeffs(model, predNames, numPredNames, catPredNames);
            List redundantPred = firstRun ? 
                    oldPredNames.stream().filter(x -> !predNames.contains(x)).collect(Collectors.toList()) : null;
            _best_model_ids[index] = model.getKey();

            if (redundantPred != null && redundantPred.size() > 0) {
                redundantPred = redundantPred.stream().map(x -> x+"(redundant_predictor)").collect(Collectors.toList());
                redundantPred.add(pred2Remove);
                _predictors_removed_per_step[index] = redundantPred.stream().toArray(String[]::new);
            } else {
                _predictors_removed_per_step[index] = new String[]{pred2Remove};
            }

            _z_values[index] = model._output.zValues().clone();
            _coef_p_values[index] = model._output.pValues().clone();
            predNames.remove(pred2Remove);
            if (catPredNames.contains(pred2Remove))
                catPredNames.remove(pred2Remove);
            else
                numPredNames.remove(pred2Remove);      
        }
    }

    @Override
    protected Futures remove_impl(Futures fs, boolean cascade) {
        super.remove_impl(fs, cascade);
        if (cascade && _output._best_model_ids != null && _output._best_model_ids.length > 0) {
            for (Key oneModelID : _output._best_model_ids)
                if (null != oneModelID)
                    Keyed.remove(oneModelID, fs, cascade);   // remove model key
        }
        return fs;
    }

    @Override
    protected AutoBuffer writeAll_impl(AutoBuffer ab) {
        if (_output._best_model_ids != null && _output._best_model_ids.length > 0) {
            for (Key oneModelID : _output._best_model_ids)
                if (null != oneModelID)
                    ab.putKey(oneModelID);  // add GLM model key
        }
        return super.writeAll_impl(ab);
    }

    @Override
    protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
        if (_output._best_model_ids != null && _output._best_model_ids.length > 0) {
            for (Key oneModelID : _output._best_model_ids) {
                if (null != oneModelID)
                    ab.getKey(oneModelID, fs);  // add GLM model key
            }
        }
        return super.readAll_impl(ab, fs);
    }

    public HashMap[] coefficients() {
        return coefficients(false);
    }
    
    public HashMap[] coefficients(boolean standardize) {
        int numModel = _output._best_model_ids.length;
        HashMap[] coeffs = new HashMap[numModel];
        for (int index=0; index < numModel; index++) {
            coeffs[index] = coefficients(index+1, standardize);
        }
        return coeffs;
    }

    public HashMap coefficients(int predictorSize) {
        return coefficients(predictorSize, false);
    }
    
    public HashMap coefficients(int predictorSize, boolean standardize) {
        int numModel = _output._best_model_ids.length;
        if (predictorSize <= 0 || predictorSize > numModel)
            throw new IllegalArgumentException("predictorSize must be between 1 and maximum size of predictor subset" +
                    " size.");
        GLMModel oneModel = DKV.getGet(_output._best_model_ids[predictorSize-1]);
        return oneModel.coefficients(standardize);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy