All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.rulefit.RuleFit Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.rulefit;

import hex.*;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.tree.SharedTree;
import hex.tree.SharedTreeModel;
import hex.tree.TreeStats;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;

import org.apache.log4j.Logger;
import water.*;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.TwoDimTable;

import java.util.*;
import java.util.stream.Collectors;

import static hex.genmodel.utils.ArrayUtils.difference;
import static hex.genmodel.utils.ArrayUtils.signum;
import static hex.rulefit.RuleFitUtils.*;
import static hex.util.DistributionUtils.distributionToFamily;


/**
 * Rule Fit
* http://statweb.stanford.edu/~jhf/ftp/RuleFit.pdf * https://github.com/h2oai/h2o-tutorials/blob/8df6b492afa172095e2595922f0b67f8d715d1e0/best-practices/explainable-models/rulefit.py */ public class RuleFit extends ModelBuilder { private static final Logger LOG = Logger.getLogger(RuleFit.class); protected static final long WORK_TOTAL = 1000000; private SharedTreeModel.SharedTreeParameters treeParameters = null; private GLMModel.GLMParameters glmParameters = null; @Override public ModelCategory[] can_build() { return new ModelCategory[]{ ModelCategory.Regression, ModelCategory.Binomial, ModelCategory.Multinomial }; } @Override public boolean isSupervised() { return true; } /** * Start the KMeans training Job on an F/J thread. */ @Override protected RuleFitDriver trainModelImpl() { return new RuleFitDriver(); } // Called from an http request public RuleFit(RuleFitModel.RuleFitParameters parms) { super(parms); init(false); } public RuleFit(boolean startup_once) { super(new RuleFitModel.RuleFitParameters(), startup_once); } @Override public void init(boolean expensive) { super.init(expensive); if (expensive) { _parms.validate(this); if (_parms._fold_column != null) { _train.remove(_parms._fold_column); } if (_parms._algorithm == RuleFitModel.Algorithm.AUTO) { _parms._algorithm = RuleFitModel.Algorithm.DRF; } initTreeParameters(); initGLMParameters(); ignoreBadColumns(separateFeatureVecs(), true); } // if (_train == null) return; // if (expensive && error_count() == 0) checkMemoryFootPrint(); } private void initTreeParameters() { if (_parms._algorithm == RuleFitModel.Algorithm.GBM) { treeParameters = new GBMModel.GBMParameters(); } else if (_parms._algorithm == RuleFitModel.Algorithm.DRF) { treeParameters = new DRFModel.DRFParameters(); } else { throw new RuntimeException("Unsupported algorithm for tree building: " + _parms._algorithm); } treeParameters._response_column = _parms._response_column; treeParameters._train = _parms._train; treeParameters._ignored_columns = _parms._ignored_columns; treeParameters._seed = _parms._seed; treeParameters._weights_column = _parms._weights_column; treeParameters._distribution = _parms._distribution; treeParameters._ntrees = _parms._rule_generation_ntrees; treeParameters._max_categorical_levels = _parms._max_categorical_levels; treeParameters._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.EnumLimited; } private void initGLMParameters() { if (_parms._distribution == DistributionFamily.AUTO) { if (_nclass < 2) { // regression glmParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gaussian); } else if (_nclass == 2) { // binomial classification glmParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.binomial); } else { // multinomial classification glmParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.multinomial); } } else { glmParameters = new GLMModel.GLMParameters(distributionToFamily(_parms._distribution)); } if (RuleFitModel.ModelType.RULES_AND_LINEAR.equals(_parms._model_type) && _parms._ignored_columns != null) { glmParameters._ignored_columns = _parms._ignored_columns; } glmParameters._response_column = "linear." + _parms._response_column; glmParameters._seed = _parms._seed; // alpha ignored - set to 1 by rulefit (Lasso) glmParameters._alpha = new double[]{1}; if (_parms._weights_column != null) { glmParameters._weights_column = "linear." + _parms._weights_column; } glmParameters._auc_type = _parms._auc_type; if (_parms._lambda != null) { glmParameters._lambda = _parms._lambda; } glmParameters._ignore_const_cols = false; } private final class RuleFitDriver extends Driver { // Main worker thread @Override public void computeImpl() { String[] dataFromRulesCodes = null; RuleFitModel model = null; GLMModel glmModel; List rulesList; RuleEnsemble ruleEnsemble = null; int ntrees = 0; TreeStats overallTreeStats = new TreeStats(); String[] classNames = null; init(true); if (error_count() > 0) throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(RuleFit.this); try { // linearTrain = frame to be used as _train for GLM in 2., will be filled in 1. Frame linearTrain = new Frame(Key.make("paths_frame" + _result)); Frame linearValid = (_valid != null ? new Frame(Key.make("valid_paths_frame" + _result)) : null); // store train frame without bad columns to pass it to tree model builders Frame trainAdapted = new Frame(_train); // 1. Rule generation // get paths from tree models int[] depths = range(_parms._min_rule_length, _parms._max_rule_length); // prepare rules if (RuleFitModel.ModelType.RULES_AND_LINEAR.equals(_parms._model_type) || RuleFitModel.ModelType.RULES.equals(_parms._model_type)) { DKV.put(trainAdapted._key, trainAdapted); treeParameters._train = trainAdapted._key; long startAllTreesTime = System.nanoTime(); SharedTree[] builders = ModelBuilderHelper.trainModelsParallel( makeTreeModelBuilders(_parms._algorithm, depths), nTreeEnsemblesInParallel(depths.length)); rulesList = new ArrayList<>(); for (int modelId = 0; modelId < builders.length; modelId++) { long startModelTime = System.nanoTime(); SharedTreeModel treeModel = builders[modelId].get(); long endModelTime = System.nanoTime() - startModelTime; LOG.info("Tree model n." + modelId + " trained in " + ((double)endModelTime) / 1E9 + "s."); rulesList.addAll(Rule.extractRulesListFromModel(treeModel, modelId, nclasses())); overallTreeStats.mergeWith(treeModel._output._treeStats); ntrees += treeModel._output._ntrees; if (classNames == null) { classNames = treeModel._output.classNames(); } treeModel.delete(); } long endAllTreesTime = System.nanoTime() - startAllTreesTime; LOG.info("All tree models trained in " + ((double)endAllTreesTime) / 1E9 + "s."); LOG.info("Extracting rules from trees..."); ruleEnsemble = new RuleEnsemble(rulesList.toArray(new Rule[] {})); linearTrain.add(ruleEnsemble.createGLMTrainFrame(_train, depths.length, treeParameters._ntrees, classNames, _parms._weights_column, true)); if (_valid != null) linearValid.add(ruleEnsemble.createGLMTrainFrame(_valid, depths.length, treeParameters._ntrees, classNames, _parms._weights_column, false)); } // prepare linear terms if (RuleFitModel.ModelType.RULES_AND_LINEAR.equals(_parms._model_type) || RuleFitModel.ModelType.LINEAR.equals(_parms._model_type)) { String[] names = _train._names; linearTrain.add(RuleFitUtils.getLinearNames(names.length, names), _train.vecs(names)); if (_valid != null) linearValid.add(RuleFitUtils.getLinearNames(names.length, names), _valid.vecs(names)); } else { linearTrain.add(glmParameters._response_column, _train.vec(_parms._response_column)); if (_valid != null) linearValid.add(glmParameters._response_column, _valid.vec(_parms._response_column)); if (_parms._weights_column != null) { linearTrain.add(glmParameters._weights_column, _train.vec(_parms._weights_column)); if (_valid != null) linearValid.add(glmParameters._weights_column, _valid.vec(_parms._weights_column)); } } dataFromRulesCodes = linearTrain.names(); DKV.put(linearTrain); if (_valid != null) { DKV.put(linearValid); glmParameters._valid = linearValid._key; } // 2. Sparse linear model with Lasso glmParameters._train = linearTrain._key; if (_parms._max_num_rules > 0) { glmParameters._max_active_predictors = _parms._max_num_rules + 1; if (_parms._distribution != DistributionFamily.multinomial) { glmParameters._solver = GLMModel.GLMParameters.Solver.COORDINATE_DESCENT; } } else { if (glmParameters._lambda != null) glmParameters._lambda = getOptimalLambda(); } LOG.info("Training GLM..."); long startGLMTime = System.nanoTime(); GLM job = new GLM(glmParameters); glmModel = job.trainModel().get(); long endGLMTime = System.nanoTime() - startGLMTime; LOG.info("GLM trained in " + ((double)endGLMTime) / 1E9 + "s."); DKV.put(glmModel); model = new RuleFitModel(dest(), _parms, new RuleFitModel.RuleFitOutput(RuleFit.this), glmModel, ruleEnsemble); model._output.glmModelKey = glmModel._key; model._output._linear_names = linearTrain.names(); DKV.remove(linearTrain._key); if (linearValid != null) DKV.remove(linearValid._key); DKV.remove(trainAdapted._key); // 3. Step 3 (optional): Feature importance model._output._intercept = getIntercept(glmModel); // TODO: add here coverage_count and coverage percent model._output._rule_importance = convertRulesToTable(sortRules(deduplicateRules(RuleFitUtils.getRules(glmModel.coefficients(), ruleEnsemble, model._output.classNames(), nclasses()), _parms._remove_duplicates)), isClassifier() && nclasses() > 2, false); model._output._model_summary = generateSummary(glmModel, ruleEnsemble != null ? ruleEnsemble.size() : 0, overallTreeStats, ntrees); model._output._dataFromRulesCodes = dataFromRulesCodes; fillModelMetrics(model, glmModel); model.delete_and_lock(_job); model.update(_job); } finally { if (model != null) model.unlock(_job); } } void fillModelMetrics(RuleFitModel model, GLMModel glmModel) { model._output._validation_metrics = glmModel._output._validation_metrics; model._output._training_metrics = glmModel._output._training_metrics; model._output._cross_validation_metrics = glmModel._output._cross_validation_metrics; model._output._cross_validation_metrics_summary = glmModel._output._cross_validation_metrics_summary; Frame inputTrain = model._parms._train.get(); for (Key modelMetricsKey : glmModel._output.getModelMetrics()) { model.addModelMetrics(modelMetricsKey.get().deepCloneWithDifferentModelAndFrame(model, inputTrain)); } } int[] range(int min, int max) { int[] array = new int[max - min + 1]; for (int i = min, j = 0; i <= max; i++, j++) { array[j] = i; } return array; } SharedTree makeTreeModelBuilder(RuleFitModel.Algorithm algorithm, int maxDepth) { SharedTreeModel.SharedTreeParameters p = (SharedTreeModel.SharedTreeParameters) treeParameters.clone(); p._max_depth = maxDepth; final SharedTree builder; if (algorithm.equals(RuleFitModel.Algorithm.DRF)) { builder = new DRF((DRFModel.DRFParameters) p); } else if (algorithm.equals(RuleFitModel.Algorithm.GBM)) { builder = new GBM((GBMModel.GBMParameters) p); } else { // TODO XGB throw new RuntimeException("Unsupported algorithm for tree building: " + _parms._algorithm); } return builder; } SharedTree[] makeTreeModelBuilders(RuleFitModel.Algorithm algorithm, int[] depths) { SharedTree[] builders = new SharedTree[depths.length]; for (int i = 0; i < depths.length; i++) { builders[i] = makeTreeModelBuilder(algorithm, depths[i]); } return builders; } double[] getOptimalLambda() { glmParameters._lambda_search = true; GLM job = new GLM(glmParameters); GLMModel lambdaModel = job.trainModel().get(); glmParameters._lambda_search = false; GLMModel.RegularizationPath regularizationPath = lambdaModel.getRegularizationPath(); double[] deviance = regularizationPath._explained_deviance_train; double[] lambdas = regularizationPath._lambdas; int bestLambdaIndex; if (deviance.length < 5) { bestLambdaIndex = deviance.length - 1; } else { bestLambdaIndex = getBestLambdaIndex(deviance); if (bestLambdaIndex >= lambdas.length) { bestLambdaIndex = getBestLambdaIndexCornerCase(deviance, lambdas); } } lambdaModel.remove(); return new double[]{lambdas[bestLambdaIndex]}; } int getBestLambdaIndex(double[] deviance) { int bestLambdaIndex = deviance.length - 1; if (deviance.length >= 5) { double[] array = difference(signum(difference(difference(deviance)))); for (int i = 0; i < array.length; i++) { if (array[i] != 0 && i > 0) { bestLambdaIndex = 3 * i; break; } } } return bestLambdaIndex; } int getBestLambdaIndexCornerCase(double[] deviance, double[] lambdas) { double[] leftUpPoint = new double[] {deviance[0], lambdas[0]}; double[] rightLowPoint = new double[] {deviance[deviance.length - 1], lambdas[lambdas.length - 1]}; double[] leftActualPoint = new double[2]; double[] rightActualPoint = new double[2]; double leftVolume, rightVolume; int leftActualId = 0; int rightActualId = deviance.length - 1; while (leftActualId < deviance.length && rightActualId < deviance.length) { if (leftActualId >= rightActualId) // volumes overlap break; // leftVolume leftActualPoint[0] = deviance[leftActualId]; leftActualPoint[1] = lambdas[leftActualId]; leftVolume = (leftUpPoint[1] - leftActualPoint[1]) * (leftActualPoint[0] - leftUpPoint[0]); // rightVolume rightActualPoint[0] = deviance[rightActualId]; rightActualPoint[1] = lambdas[rightActualId]; rightVolume = (rightActualPoint[1] - rightLowPoint[1]) * (rightLowPoint[0] - rightActualPoint[0]); if (Math.abs(leftVolume) > Math.abs(rightVolume)) { rightActualId--; // add point to rightvolume } else { leftActualId++; // add point to leftvolume } } return rightActualId; } double[] getIntercept(GLMModel glmModel) { HashMap glmCoefficients = glmModel.coefficients(); double[] intercept = nclasses() > 2 ? new double[nclasses()] : new double[1]; int i = 0; for (Map.Entry coefficient : glmCoefficients.entrySet()) { if ("Intercept".equals(coefficient.getKey()) || coefficient.getKey().contains("Intercept_")) { intercept[i] = coefficient.getValue(); i++; } } return intercept; } } protected int nTreeEnsemblesInParallel(int numDepths) { if (_parms._algorithm == RuleFitModel.Algorithm.GBM) { return nModelsInParallel(numDepths, 2); } else { return nModelsInParallel(numDepths, 1); } } TwoDimTable generateSummary(GLMModel glmModel, int ruleEnsembleSize, TreeStats overallTreeStats, int ntrees) { List colHeaders = new ArrayList<>(); List colTypes = new ArrayList<>(); List colFormats = new ArrayList<>(); TwoDimTable glmModelSummary = glmModel._output._model_summary; String[] glmColHeaders = glmModelSummary.getColHeaders(); String[] glmColTypes = glmModelSummary.getColTypes(); String[] glmColFormats = glmModelSummary.getColFormats(); // linear model info for (int i = 0; i < glmModelSummary.getColDim(); i++) { if (!"Training Frame".equals(glmColHeaders[i])) { colHeaders.add(glmColHeaders[i]); colTypes.add(glmColTypes[i]); colFormats.add(glmColFormats[i]); } } // rule ensemble info colHeaders.add("Rule Ensemble Size"); colTypes.add("long"); colFormats.add("%d"); // trees info colHeaders.add("Number of Trees"); colTypes.add("long"); colFormats.add("%d"); colHeaders.add("Number of Internal Trees"); colTypes.add("long"); colFormats.add("%d"); colHeaders.add("Min. Depth"); colTypes.add("long"); colFormats.add("%d"); colHeaders.add("Max. Depth"); colTypes.add("long"); colFormats.add("%d"); colHeaders.add("Mean Depth"); colTypes.add("double"); colFormats.add("%.5f"); colHeaders.add("Min. Leaves"); colTypes.add("long"); colFormats.add("%d"); colHeaders.add("Max. Leaves"); colTypes.add("long"); colFormats.add("%d"); colHeaders.add("Mean Leaves"); colTypes.add("double"); colFormats.add("%.5f"); final int rows = 1; TwoDimTable summary = new TwoDimTable( "Rulefit Model Summary", null, new String[rows], colHeaders.toArray(new String[0]), colTypes.toArray(new String[0]), colFormats.toArray(new String[0]), ""); int col = 0, row = 0; for (int i = 0; i < glmModelSummary.getColDim(); i++) { if (!"Training Frame".equals(glmColHeaders[i])) { summary.set(row, col++, glmModelSummary.get(row, i)); } } summary.set(row, col++, ruleEnsembleSize); summary.set(row, col++, ntrees); summary.set(row, col++, overallTreeStats._num_trees); //internal number of trees (more for multinomial) summary.set(row, col++, overallTreeStats._min_depth); summary.set(row, col++, overallTreeStats._max_depth); summary.set(row, col++, overallTreeStats._mean_depth); summary.set(row, col++, overallTreeStats._min_leaves); summary.set(row, col++, overallTreeStats._max_leaves); summary.set(row, col++, overallTreeStats._mean_leaves); return summary; } @Override public boolean haveMojo() { return true; } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy