All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.classification.GradientTreeBoost Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2025 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile. If not, see .
 */
package smile.classification;

import java.io.Serial;
import java.util.Arrays;
import java.util.Properties;
import java.util.stream.IntStream;
import smile.base.cart.*;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.DataTypes;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.ValueVector;
import smile.feature.importance.SHAP;
import smile.math.MathEx;
import smile.regression.RegressionTree;
import smile.util.IntSet;
import smile.util.IterativeAlgorithmController;
import smile.validation.ClassificationMetrics;

/**
 * Gradient boosting for classification. Gradient boosting is typically used
 * with decision trees (especially CART regression trees) of a fixed size as
 * base learners. For this special case Friedman proposes a modification to
 * gradient boosting method which improves the quality of fit of each base
 * learner.
 * 

* Generic gradient boosting at the t-th step would fit a regression tree to * pseudo-residuals. Let J be the number of its leaves. The tree partitions * the input space into J disjoint regions and predicts a constant value in * each region. The parameter J controls the maximum allowed * level of interaction between variables in the model. With J = 2 (decision * stumps), no interaction between variables is allowed. With J = 3 the model * may include effects of the interaction between up to two variables, and * so on. Hastie et al. comment that typically {@code 4 <= J <= 8} work well * for boosting and results are fairly insensitive to the choice of in * this range, J = 2 is insufficient for many applications, and {@code J > 10} is * unlikely to be required. *

* Fitting the training set too closely can lead to degradation of the model's * generalization ability. Several so-called regularization techniques reduce * this over-fitting effect by constraining the fitting procedure. * One natural regularization parameter is the number of gradient boosting * iterations T (i.e. the number of trees in the model when the base learner * is a decision tree). Increasing T reduces the error on training set, * but setting it too high may lead to over-fitting. An optimal value of T * is often selected by monitoring prediction error on a separate validation * data set. *

* Another regularization approach is the shrinkage which times a parameter * η (called the "learning rate") to update term. * Empirically it has been found that using small learning rates (such as * η {@code < 0.1}) yields dramatic improvements in model's generalization ability * over gradient boosting without shrinking (η = 1). However, it comes at * the price of increasing computational time both during training and * prediction: lower learning rate requires more iterations. *

* Soon after the introduction of gradient boosting Friedman proposed a * minor modification to the algorithm, motivated by Breiman's bagging method. * Specifically, he proposed that at each iteration of the algorithm, a base * learner should be fit on a subsample of the training set drawn at random * without replacement. Friedman observed a substantial improvement in * gradient boosting's accuracy with this modification. *

* Subsample size is some constant fraction f of the size of the training set. * When f = 1, the algorithm is deterministic and identical to the one * described above. Smaller values of f introduce randomness into the * algorithm and help prevent over-fitting, acting as a kind of regularization. * The algorithm also becomes faster, because regression trees have to be fit * to smaller datasets at each iteration. Typically, f is set to 0.5, meaning * that one half of the training set is used to build each base learner. *

* Also, like in bagging, sub-sampling allows one to define an out-of-bag * estimate of the prediction performance improvement by evaluating predictions * on those observations which were not used in the building of the next * base learner. Out-of-bag estimates help avoid the need for an independent * validation dataset, but often underestimate actual performance improvement * and the optimal number of iterations. *

* Gradient tree boosting implementations often also use regularization by * limiting the minimum number of observations in trees' terminal nodes. * It's used in the tree building process by ignoring any splits that lead * to nodes containing fewer than this number of training set instances. * Imposing this limit helps to reduce variance in predictions at leaves. * *

References

*
    *
  1. J. H. Friedman. Greedy Function Approximation: A Gradient Boosting Machine, 1999.
  2. *
  3. J. H. Friedman. Stochastic Gradient Boosting, 1999.
  4. *
* * @author Haifeng Li */ public class GradientTreeBoost extends AbstractClassifier implements DataFrameClassifier, SHAP { @Serial private static final long serialVersionUID = 2L; private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(GradientTreeBoost.class); /** * The model formula. */ private final Formula formula; /** * The number of classes. */ private final int k; /** * Forest of regression trees. Each row is the model associated with * a class (OVR). For binary classification, it has only one row. */ private final RegressionTree[][] trees; /** * Variable importance. Every time a split of a node is made on variable * the impurity criterion for the two descendent nodes is less than the * parent node. Adding up the decreases for each individual variable over * all trees in the forest gives a simple variable importance. */ private final double[] importance; /** * The intercept for binary classification. */ private final double b; /** * The shrinkage parameter in (0, 1] controls the learning rate of procedure. */ private final double shrinkage; /** * Constructor of binary class. * * @param formula a symbolic description of the model to be fitted. * @param trees the regression trees. * @param b the intercept * @param shrinkage the shrinkage parameter in (0, 1] controls the learning rate of procedure. * @param importance variable importance */ public GradientTreeBoost(Formula formula, RegressionTree[] trees, double b, double shrinkage, double[] importance) { this(formula, trees, b, shrinkage, importance, IntSet.of(2)); } /** * Constructor of binary class. * * @param formula a symbolic description of the model to be fitted. * @param trees the regression trees. * @param b the intercept * @param shrinkage the shrinkage parameter in (0, 1] controls the learning rate of procedure. * @param importance variable importance * @param labels class labels */ public GradientTreeBoost(Formula formula, RegressionTree[] trees, double b, double shrinkage, double[] importance, IntSet labels) { super(labels); this.formula = formula; this.k = 2; this.trees = new RegressionTree[][] { trees }; this.b = b; this.shrinkage = shrinkage; this.importance = importance; } /** * Constructor of multi-class. * * @param formula a symbolic description of the model to be fitted. * @param trees the regression trees, one row per class. * @param shrinkage the shrinkage parameter in (0, 1] controls the learning rate of procedure. * @param importance variable importance */ public GradientTreeBoost(Formula formula, RegressionTree[][] trees, double shrinkage, double[] importance) { this(formula, trees, shrinkage, importance, IntSet.of(trees.length)); } /** * Constructor of multi-class. * * @param formula a symbolic description of the model to be fitted. * @param trees the regression trees, one row per class. * @param shrinkage the shrinkage parameter in (0, 1] controls the learning rate of procedure. * @param importance variable importance * @param labels the class label encoder. */ public GradientTreeBoost(Formula formula, RegressionTree[][] trees, double shrinkage, double[] importance, IntSet labels) { super(labels); this.formula = formula; this.k = trees.length; this.trees = trees; this.b = 0.0; this.shrinkage = shrinkage; this.importance = importance; } /** * Training status per tree. * @param tree the tree index, starting at 1. * @param loss the current loss function value. * @param metrics the optional validation metrics if test data is provided. */ public record TrainingStatus(int tree, double loss, ClassificationMetrics metrics) { } /** * Gradient tree boosting hyperparameters. * @param ntrees the number of iterations (trees). * @param maxDepth the maximum depth of the tree. * @param maxNodes the maximum number of leaf nodes in the tree. * @param nodeSize the minimum size of leaf nodes. * Setting nodeSize = 5 generally gives good results. * @param shrinkage the shrinkage parameter in (0, 1] controls the learning rate of procedure. * @param subsample the sampling fraction for stochastic tree boosting. * @param test the optional test data for validation per epoch. * @param controller the optional training controller. */ public record Options(int ntrees, int maxDepth, int maxNodes, int nodeSize, double shrinkage, double subsample, DataFrame test, IterativeAlgorithmController controller) { /** Constructor. */ public Options { if (ntrees < 1) { throw new IllegalArgumentException("Invalid number of trees: " + ntrees); } if (maxDepth < 2) { throw new IllegalArgumentException("Invalid maximal tree depth: " + maxDepth); } if (maxNodes < 2) { throw new IllegalArgumentException("Invalid maximum number of nodes: " + maxNodes); } if (nodeSize < 1) { throw new IllegalArgumentException("Invalid node size: " + nodeSize); } if (shrinkage <= 0 || shrinkage > 1) { throw new IllegalArgumentException("Invalid shrinkage: " + shrinkage); } if (subsample <= 0 || subsample > 1) { throw new IllegalArgumentException("Invalid sampling fraction: " + subsample); } } /** * Constructor. * @param ntrees the number of trees. */ public Options(int ntrees) { this(ntrees, 20, 6, 5, 0.05, 0.7, null, null); } /** * Returns the persistent set of hyperparameters. * @return the persistent set. */ public Properties toProperties() { Properties props = new Properties(); props.setProperty("smile.gradient_boost.trees", Integer.toString(ntrees)); props.setProperty("smile.gradient_boost.max_depth", Integer.toString(maxDepth)); props.setProperty("smile.gradient_boost.max_nodes", Integer.toString(maxNodes)); props.setProperty("smile.gradient_boost.node_size", Integer.toString(nodeSize)); props.setProperty("smile.gradient_boost.shrinkage", Double.toString(shrinkage)); props.setProperty("smile.gradient_boost.sampling_rate", Double.toString(subsample)); return props; } /** * Returns the options from properties. * * @param props the hyperparameters. * @return the options. */ public static Options of(Properties props) { int ntrees = Integer.parseInt(props.getProperty("smile.gradient_boost.trees", "500")); int maxDepth = Integer.parseInt(props.getProperty("smile.gradient_boost.max_depth", "20")); int maxNodes = Integer.parseInt(props.getProperty("smile.gradient_boost.max_nodes", "6")); int nodeSize = Integer.parseInt(props.getProperty("smile.gradient_boost.node_size", "5")); double shrinkage = Double.parseDouble(props.getProperty("smile.gradient_boost.shrinkage", "0.05")); double subsample = Double.parseDouble(props.getProperty("smile.gradient_boost.sampling_rate", "0.7")); return new Options(ntrees, maxDepth, maxNodes, nodeSize, shrinkage, subsample, null, null); } } /** * Fits a gradient tree boosting for classification. * * @param formula a symbolic description of the model to be fitted. * @param data the data frame of the explanatory and response variables. * @return the model. */ public static GradientTreeBoost fit(Formula formula, DataFrame data) { return fit(formula, data, new Options(500)); } /** * Fits a gradient tree boosting for classification. * * @param formula a symbolic description of the model to be fitted. * @param data the data frame of the explanatory and response variables. * @param options the hyperparameters. * @return the model. */ public static GradientTreeBoost fit(Formula formula, DataFrame data, Options options) { formula = formula.expand(data.schema()); DataFrame x = formula.x(data); ValueVector y = formula.y(data); int[][] order = CART.order(x); ClassLabels codec = ClassLabels.fit(y); if (codec.k == 2) { return train2(formula, x, codec, order, options); } else { return traink(formula, x, codec, order, options); } } @Override public Formula formula() { return formula; } @Override public StructType schema() { return trees[0][0].schema(); } /** * Returns the variable importance. Every time a split of a node is made * on variable the impurity criterion for the two descendent nodes is less * than the parent node. Adding up the decreases for each individual * variable over all trees in the forest gives a simple measure of variable * importance. * * @return the variable importance */ public double[] importance() { return importance; } /** * Train L2 tree boost. */ private static GradientTreeBoost train2(Formula formula, DataFrame x, ClassLabels codec, int[][] order, Options options) { long startTime = System.nanoTime(); int n = x.nrow(); int p = x.ncol(); int k = codec.k; int[] y = codec.y; int[] nc = new int[k]; for (int i = 0; i < n; i++) nc[y[i]]++; Loss loss = Loss.logistic(y); double b = loss.intercept(null); double[] h = loss.residual(); // this is actually the output of boost trees. StructField field = new StructField("residual", DataTypes.DoubleType); DataFrame testx = null; int[] testy = null; int[] prediction = null; double[] logit = null; double[] probability = null; if (options.test != null) { testx = formula.x(options.test); testy = codec.indexOf(formula.y(options.test).toIntArray()); prediction = new int[testy.length]; logit = new double[testy.length]; probability = new double[testy.length]; Arrays.fill(logit, b); } int ntrees = options.ntrees; double shrinkage = options.shrinkage; RegressionTree[] trees = new RegressionTree[ntrees]; int[] permutation = IntStream.range(0, n).toArray(); int[] samples = new int[n]; for (int t = 0; t < ntrees; t++) { sampling(samples, permutation, nc, y, options.subsample); RegressionTree tree = new RegressionTree(x, loss, field, options.maxDepth, options.maxNodes, options.nodeSize, p, samples, order); trees[t] = tree; for (int i = 0; i < n; i++) { h[i] += shrinkage * tree.predict(x.get(i)); } double lossValue = loss.value(); logger.info("Tree {}: loss = {}", t+1, lossValue); double fitTime = (System.nanoTime() - startTime) / 1E6; ClassificationMetrics metrics = null; if (options.test != null) { long testStartTime = System.nanoTime(); for (int i = 0; i < testy.length; i++) { logit[i] += shrinkage * tree.predict(testx.get(i)); prediction[i] = logit[i] > 0 ? 1 : 0; probability[i] = 1 - 1.0 / (1.0 + Math.exp(2 * logit[i])); } double scoreTime = (System.nanoTime() - testStartTime) / 1E6; metrics = ClassificationMetrics.binary(fitTime, scoreTime, testy, prediction, probability); logger.info("Validation metrics = {} ", metrics); } if (options.controller != null) { options.controller.submit(new TrainingStatus(t+1, lossValue, metrics)); if (options.controller.isInterrupted()) { trees = Arrays.copyOf(trees, t); break; } } } double[] importance = new double[p]; for (RegressionTree tree : trees) { double[] imp = tree.importance(); for (int i = 0; i < imp.length; i++) { importance[i] += imp[i]; } } return new GradientTreeBoost(formula, trees, b, shrinkage, importance, codec.classes); } /** * Train L-k tree boost. */ private static GradientTreeBoost traink(Formula formula, DataFrame x, ClassLabels codec, int[][] order, Options options) { long startTime = System.nanoTime(); int n = x.size(); int p = x.ncol(); int k = codec.k; int[] y = codec.y; int[] nc = new int[k]; for (int i = 0; i < n; i++) nc[y[i]]++; DataFrame testx = null; int[] testy = null; int[] prediction = null; double[][] logit = null; double[][] probability = null; if (options.test != null) { testx = formula.x(options.test); testy = codec.indexOf(formula.y(options.test).toIntArray()); prediction = new int[testy.length]; logit = new double[testy.length][k]; probability = new double[testy.length][k]; } int ntrees = options.ntrees; double shrinkage = options.shrinkage; StructField field = new StructField("residual", DataTypes.DoubleType); RegressionTree[][] forest = new RegressionTree[k][ntrees]; double[][] prob = new double[n][k]; // posteriori probabilities. double[][] h = new double[k][]; // boost tree output. Loss[] loss = new Loss[k]; for (int i = 0; i < k; i++) { loss[i] = Loss.logistic(i, k, y, prob); h[i] = loss[i].residual(); } int[] permutation = IntStream.range(0, n).toArray(); int[] samples = new int[n]; for (int t = 0; t < ntrees; t++) { for (int i = 0; i < n; i++) { for (int j = 0; j < k; j++) { prob[i][j] = h[j][i]; } MathEx.softmax(prob[i]); } for (int j = 0; j < k; j++) { sampling(samples, permutation, nc, y, options.subsample); RegressionTree tree = new RegressionTree(x, loss[j], field, options.maxDepth, options.maxNodes, options.nodeSize, p, samples, order); forest[j][t] = tree; double[] hj = h[j]; for (int i = 0; i < n; i++) { hj[i] += shrinkage * tree.predict(x.get(i)); } } double lossValue = loss[0].value(); logger.info("Tree {}: loss = {}", t+1, lossValue); double fitTime = (System.nanoTime() - startTime) / 1E6; ClassificationMetrics metrics = null; if (options.test != null) { long testStartTime = System.nanoTime(); for (int i = 0; i < testy.length; i++) { var xt = testx.get(i); for (int j = 0; j < k; j++) { logit[i][j] += shrinkage * forest[j][t].predict(xt); } prediction[i] = MathEx.whichMax(logit[i]); double max = logit[i][prediction[i]]; double Z = 0.0; for (int j = 0; j < k; j++) { probability[i][j] = Math.exp(logit[i][j] - max); Z += probability[i][j]; } for (int j = 0; j < k; j++) { probability[i][j] /= Z; } } double scoreTime = (System.nanoTime() - testStartTime) / 1E6; metrics = ClassificationMetrics.of(fitTime, scoreTime, testy, prediction, probability); logger.info("Validation metrics = {} ", metrics); } if (options.controller != null) { options.controller.submit(new TrainingStatus(t+1, lossValue, metrics)); if (options.controller.isInterrupted()) { for (int j = 0; j < k; j++) { forest[j] = Arrays.copyOf(forest[j], t); } break; } } } double[] importance = new double[p]; for (RegressionTree[] grove : forest) { for (RegressionTree tree : grove) { double[] imp = tree.importance(); for (int i = 0; i < imp.length; i++) { importance[i] += imp[i]; } } } return new GradientTreeBoost(formula, forest, shrinkage, importance, codec.classes); } /** * Stratified sampling. */ private static void sampling(int[] samples, int[] permutation, int[] nc, int[] y, double subsample) { int n = samples.length; int k = nc.length; Arrays.fill(samples, 0); MathEx.permutate(permutation); for (int j = 0; j < k; j++) { int subj = (int) Math.round(nc[j] * subsample); for (int i = 0, nj = 0; i < n && nj < subj; i++) { int xi = permutation[i]; if (y[xi] == j) { samples[xi] = 1; nj++; } } } } /** * Returns the number of trees in the model. * * @return the number of trees in the model. */ public int size() { return trees().length; } /** * Returns the regression trees. * @return the regression trees. */ public RegressionTree[][] trees() { return trees; } /** * Trims the tree model set to a smaller size in case of over-fitting. * Or if extra decision trees in the model don't improve the performance, * we may remove them to reduce the model size and also improve the speed of * prediction. * * @param ntrees the new (smaller) size of tree model set. * @return the trimmed model. */ public GradientTreeBoost trim(int ntrees) { if (ntrees < 1) { throw new IllegalArgumentException("Invalid new model size: " + ntrees); } if (ntrees > trees[0].length) { throw new IllegalArgumentException("The new model size is larger than the current one."); } if (k == 2) { return new GradientTreeBoost(formula, Arrays.copyOf(trees[0], ntrees), b, shrinkage, importance, classes); } else { RegressionTree[][] forest = new RegressionTree[k][]; for (int i = 0; i < k; i++) { forest[i] = Arrays.copyOf(trees[i], ntrees); } return new GradientTreeBoost(formula, forest, shrinkage, importance, classes); } } @Override public int predict(Tuple x) { Tuple xt = formula.x(x); if (k == 2) { double y = b; for (RegressionTree tree : trees[0]) { y += shrinkage * tree.predict(xt); } return classes.valueOf(y > 0 ? 1 : 0); } else { double max = Double.NEGATIVE_INFINITY; int y = -1; for (int j = 0; j < k; j++) { double yj = 0.0; for (RegressionTree tree : trees[j]) { yj += shrinkage * tree.predict(xt); } if (yj > max) { max = yj; y = j; } } return classes.valueOf(y); } } @Override public boolean soft() { return true; } @Override public int predict(Tuple x, double[] posteriori) { if (posteriori.length != k) { throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, k)); } Tuple xt = formula.x(x); if (k == 2) { double y = b; for (RegressionTree tree : trees[0]) { y += shrinkage * tree.predict(xt); } posteriori[0] = 1.0 / (1.0 + Math.exp(2 * y)); posteriori[1] = 1.0 - posteriori[0]; return classes.valueOf(y > 0 ? 1 : 0); } else { double max = Double.NEGATIVE_INFINITY; int y = -1; for (int j = 0; j < k; j++) { posteriori[j] = 0.0; for (RegressionTree tree : trees[j]) { posteriori[j] += shrinkage * tree.predict(xt); } if (posteriori[j] > max) { max = posteriori[j]; y = j; } } double Z = 0.0; for (int i = 0; i < k; i++) { posteriori[i] = Math.exp(posteriori[i] - max); Z += posteriori[i]; } for (int i = 0; i < k; i++) { posteriori[i] /= Z; } return classes.valueOf(y); } } /** * Test the model on a validation dataset. * * @param data the test data set. * @return the predictions with first 1, 2, ..., decision trees. */ public int[][] test(DataFrame data) { DataFrame x = formula.x(data); int n = x.size(); int ntrees = trees[0].length; int[][] prediction = new int[ntrees][n]; if (k == 2) { for (int j = 0; j < n; j++) { Tuple xj = x.get(j); double base = 0; for (int i = 0; i < ntrees; i++) { base += shrinkage * trees[0][i].predict(xj); prediction[i][j] = base > 0 ? 1 : 0; } } } else { double[] p = new double[k]; for (int j = 0; j < n; j++) { Tuple xj = x.get(j); Arrays.fill(p, 0); for (int i = 0; i < ntrees; i++) { for (int l = 0; l < k; l++) { p[l] += shrinkage * trees[l][i].predict(xj); } prediction[i][j] = MathEx.whichMax(p); } } } return prediction; } /** * Returns the average of absolute SHAP values over a data frame. * @param data the data set. * @return the average of absolute SHAP values. */ public double[] shap(DataFrame data) { // Binds the formula to the data frame's schema in case that // it is different from that of training data. formula.bind(data.schema()); return shap(data.stream().parallel()); } @Override public double[] shap(Tuple x) { Tuple xt = formula.x(x); int p = xt.length(); double[] phi = new double[p * k]; int ntrees = trees[0].length; if (k == 2) { for (RegressionTree tree : trees[0]) { double[] phii = tree.shap(xt); for (int i = 0; i < p; i++) { phi[2*i] += phii[i]; phi[2*i+1] += phii[i]; } } } else { for (int i = 0; i < k; i++) { for (RegressionTree tree : trees[i]) { double[] phii = tree.shap(xt); for (int j = 0; j < p; j++) { phi[j*k + i] += phii[j]; } } } } for (int i = 0; i < phi.length; i++) { phi[i] /= ntrees; } return phi; } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy