
smile.classification.GradientTreeBoost Maven / Gradle / Ivy
/*
* Copyright (c) 2010-2025 Haifeng Li. All rights reserved.
*
* Smile is free software: you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Smile is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Smile. If not, see .
*/
package smile.classification;
import java.io.Serial;
import java.util.Arrays;
import java.util.Properties;
import java.util.stream.IntStream;
import smile.base.cart.*;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.DataTypes;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.ValueVector;
import smile.feature.importance.SHAP;
import smile.math.MathEx;
import smile.regression.RegressionTree;
import smile.util.IntSet;
import smile.util.IterativeAlgorithmController;
import smile.validation.ClassificationMetrics;
/**
* Gradient boosting for classification. Gradient boosting is typically used
* with decision trees (especially CART regression trees) of a fixed size as
* base learners. For this special case Friedman proposes a modification to
* gradient boosting method which improves the quality of fit of each base
* learner.
*
* Generic gradient boosting at the t-th step would fit a regression tree to
* pseudo-residuals. Let J be the number of its leaves. The tree partitions
* the input space into J disjoint regions and predicts a constant value in
* each region. The parameter J controls the maximum allowed
* level of interaction between variables in the model. With J = 2 (decision
* stumps), no interaction between variables is allowed. With J = 3 the model
* may include effects of the interaction between up to two variables, and
* so on. Hastie et al. comment that typically {@code 4 <= J <= 8} work well
* for boosting and results are fairly insensitive to the choice of in
* this range, J = 2 is insufficient for many applications, and {@code J > 10} is
* unlikely to be required.
*
* Fitting the training set too closely can lead to degradation of the model's
* generalization ability. Several so-called regularization techniques reduce
* this over-fitting effect by constraining the fitting procedure.
* One natural regularization parameter is the number of gradient boosting
* iterations T (i.e. the number of trees in the model when the base learner
* is a decision tree). Increasing T reduces the error on training set,
* but setting it too high may lead to over-fitting. An optimal value of T
* is often selected by monitoring prediction error on a separate validation
* data set.
*
* Another regularization approach is the shrinkage which times a parameter
* η (called the "learning rate") to update term.
* Empirically it has been found that using small learning rates (such as
* η {@code < 0.1}) yields dramatic improvements in model's generalization ability
* over gradient boosting without shrinking (η = 1). However, it comes at
* the price of increasing computational time both during training and
* prediction: lower learning rate requires more iterations.
*
* Soon after the introduction of gradient boosting Friedman proposed a
* minor modification to the algorithm, motivated by Breiman's bagging method.
* Specifically, he proposed that at each iteration of the algorithm, a base
* learner should be fit on a subsample of the training set drawn at random
* without replacement. Friedman observed a substantial improvement in
* gradient boosting's accuracy with this modification.
*
* Subsample size is some constant fraction f of the size of the training set.
* When f = 1, the algorithm is deterministic and identical to the one
* described above. Smaller values of f introduce randomness into the
* algorithm and help prevent over-fitting, acting as a kind of regularization.
* The algorithm also becomes faster, because regression trees have to be fit
* to smaller datasets at each iteration. Typically, f is set to 0.5, meaning
* that one half of the training set is used to build each base learner.
*
* Also, like in bagging, sub-sampling allows one to define an out-of-bag
* estimate of the prediction performance improvement by evaluating predictions
* on those observations which were not used in the building of the next
* base learner. Out-of-bag estimates help avoid the need for an independent
* validation dataset, but often underestimate actual performance improvement
* and the optimal number of iterations.
*
* Gradient tree boosting implementations often also use regularization by
* limiting the minimum number of observations in trees' terminal nodes.
* It's used in the tree building process by ignoring any splits that lead
* to nodes containing fewer than this number of training set instances.
* Imposing this limit helps to reduce variance in predictions at leaves.
*
*
References
*
* - J. H. Friedman. Greedy Function Approximation: A Gradient Boosting Machine, 1999.
* - J. H. Friedman. Stochastic Gradient Boosting, 1999.
*
*
* @author Haifeng Li
*/
public class GradientTreeBoost extends AbstractClassifier implements DataFrameClassifier, SHAP {
@Serial
private static final long serialVersionUID = 2L;
private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(GradientTreeBoost.class);
/**
* The model formula.
*/
private final Formula formula;
/**
* The number of classes.
*/
private final int k;
/**
* Forest of regression trees. Each row is the model associated with
* a class (OVR). For binary classification, it has only one row.
*/
private final RegressionTree[][] trees;
/**
* Variable importance. Every time a split of a node is made on variable
* the impurity criterion for the two descendent nodes is less than the
* parent node. Adding up the decreases for each individual variable over
* all trees in the forest gives a simple variable importance.
*/
private final double[] importance;
/**
* The intercept for binary classification.
*/
private final double b;
/**
* The shrinkage parameter in (0, 1] controls the learning rate of procedure.
*/
private final double shrinkage;
/**
* Constructor of binary class.
*
* @param formula a symbolic description of the model to be fitted.
* @param trees the regression trees.
* @param b the intercept
* @param shrinkage the shrinkage parameter in (0, 1] controls the learning rate of procedure.
* @param importance variable importance
*/
public GradientTreeBoost(Formula formula, RegressionTree[] trees, double b, double shrinkage, double[] importance) {
this(formula, trees, b, shrinkage, importance, IntSet.of(2));
}
/**
* Constructor of binary class.
*
* @param formula a symbolic description of the model to be fitted.
* @param trees the regression trees.
* @param b the intercept
* @param shrinkage the shrinkage parameter in (0, 1] controls the learning rate of procedure.
* @param importance variable importance
* @param labels class labels
*/
public GradientTreeBoost(Formula formula, RegressionTree[] trees, double b, double shrinkage, double[] importance, IntSet labels) {
super(labels);
this.formula = formula;
this.k = 2;
this.trees = new RegressionTree[][] { trees };
this.b = b;
this.shrinkage = shrinkage;
this.importance = importance;
}
/**
* Constructor of multi-class.
*
* @param formula a symbolic description of the model to be fitted.
* @param trees the regression trees, one row per class.
* @param shrinkage the shrinkage parameter in (0, 1] controls the learning rate of procedure.
* @param importance variable importance
*/
public GradientTreeBoost(Formula formula, RegressionTree[][] trees, double shrinkage, double[] importance) {
this(formula, trees, shrinkage, importance, IntSet.of(trees.length));
}
/**
* Constructor of multi-class.
*
* @param formula a symbolic description of the model to be fitted.
* @param trees the regression trees, one row per class.
* @param shrinkage the shrinkage parameter in (0, 1] controls the learning rate of procedure.
* @param importance variable importance
* @param labels the class label encoder.
*/
public GradientTreeBoost(Formula formula, RegressionTree[][] trees, double shrinkage, double[] importance, IntSet labels) {
super(labels);
this.formula = formula;
this.k = trees.length;
this.trees = trees;
this.b = 0.0;
this.shrinkage = shrinkage;
this.importance = importance;
}
/**
* Training status per tree.
* @param tree the tree index, starting at 1.
* @param loss the current loss function value.
* @param metrics the optional validation metrics if test data is provided.
*/
public record TrainingStatus(int tree, double loss, ClassificationMetrics metrics) {
}
/**
* Gradient tree boosting hyperparameters.
* @param ntrees the number of iterations (trees).
* @param maxDepth the maximum depth of the tree.
* @param maxNodes the maximum number of leaf nodes in the tree.
* @param nodeSize the minimum size of leaf nodes.
* Setting nodeSize = 5 generally gives good results.
* @param shrinkage the shrinkage parameter in (0, 1] controls the learning rate of procedure.
* @param subsample the sampling fraction for stochastic tree boosting.
* @param test the optional test data for validation per epoch.
* @param controller the optional training controller.
*/
public record Options(int ntrees, int maxDepth, int maxNodes, int nodeSize, double shrinkage, double subsample,
DataFrame test, IterativeAlgorithmController controller) {
/** Constructor. */
public Options {
if (ntrees < 1) {
throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
}
if (maxDepth < 2) {
throw new IllegalArgumentException("Invalid maximal tree depth: " + maxDepth);
}
if (maxNodes < 2) {
throw new IllegalArgumentException("Invalid maximum number of nodes: " + maxNodes);
}
if (nodeSize < 1) {
throw new IllegalArgumentException("Invalid node size: " + nodeSize);
}
if (shrinkage <= 0 || shrinkage > 1) {
throw new IllegalArgumentException("Invalid shrinkage: " + shrinkage);
}
if (subsample <= 0 || subsample > 1) {
throw new IllegalArgumentException("Invalid sampling fraction: " + subsample);
}
}
/**
* Constructor.
* @param ntrees the number of trees.
*/
public Options(int ntrees) {
this(ntrees, 20, 6, 5, 0.05, 0.7, null, null);
}
/**
* Returns the persistent set of hyperparameters.
* @return the persistent set.
*/
public Properties toProperties() {
Properties props = new Properties();
props.setProperty("smile.gradient_boost.trees", Integer.toString(ntrees));
props.setProperty("smile.gradient_boost.max_depth", Integer.toString(maxDepth));
props.setProperty("smile.gradient_boost.max_nodes", Integer.toString(maxNodes));
props.setProperty("smile.gradient_boost.node_size", Integer.toString(nodeSize));
props.setProperty("smile.gradient_boost.shrinkage", Double.toString(shrinkage));
props.setProperty("smile.gradient_boost.sampling_rate", Double.toString(subsample));
return props;
}
/**
* Returns the options from properties.
*
* @param props the hyperparameters.
* @return the options.
*/
public static Options of(Properties props) {
int ntrees = Integer.parseInt(props.getProperty("smile.gradient_boost.trees", "500"));
int maxDepth = Integer.parseInt(props.getProperty("smile.gradient_boost.max_depth", "20"));
int maxNodes = Integer.parseInt(props.getProperty("smile.gradient_boost.max_nodes", "6"));
int nodeSize = Integer.parseInt(props.getProperty("smile.gradient_boost.node_size", "5"));
double shrinkage = Double.parseDouble(props.getProperty("smile.gradient_boost.shrinkage", "0.05"));
double subsample = Double.parseDouble(props.getProperty("smile.gradient_boost.sampling_rate", "0.7"));
return new Options(ntrees, maxDepth, maxNodes, nodeSize, shrinkage, subsample, null, null);
}
}
/**
* Fits a gradient tree boosting for classification.
*
* @param formula a symbolic description of the model to be fitted.
* @param data the data frame of the explanatory and response variables.
* @return the model.
*/
public static GradientTreeBoost fit(Formula formula, DataFrame data) {
return fit(formula, data, new Options(500));
}
/**
* Fits a gradient tree boosting for classification.
*
* @param formula a symbolic description of the model to be fitted.
* @param data the data frame of the explanatory and response variables.
* @param options the hyperparameters.
* @return the model.
*/
public static GradientTreeBoost fit(Formula formula, DataFrame data, Options options) {
formula = formula.expand(data.schema());
DataFrame x = formula.x(data);
ValueVector y = formula.y(data);
int[][] order = CART.order(x);
ClassLabels codec = ClassLabels.fit(y);
if (codec.k == 2) {
return train2(formula, x, codec, order, options);
} else {
return traink(formula, x, codec, order, options);
}
}
@Override
public Formula formula() {
return formula;
}
@Override
public StructType schema() {
return trees[0][0].schema();
}
/**
* Returns the variable importance. Every time a split of a node is made
* on variable the impurity criterion for the two descendent nodes is less
* than the parent node. Adding up the decreases for each individual
* variable over all trees in the forest gives a simple measure of variable
* importance.
*
* @return the variable importance
*/
public double[] importance() {
return importance;
}
/**
* Train L2 tree boost.
*/
private static GradientTreeBoost train2(Formula formula, DataFrame x, ClassLabels codec, int[][] order, Options options) {
long startTime = System.nanoTime();
int n = x.nrow();
int p = x.ncol();
int k = codec.k;
int[] y = codec.y;
int[] nc = new int[k];
for (int i = 0; i < n; i++) nc[y[i]]++;
Loss loss = Loss.logistic(y);
double b = loss.intercept(null);
double[] h = loss.residual(); // this is actually the output of boost trees.
StructField field = new StructField("residual", DataTypes.DoubleType);
DataFrame testx = null;
int[] testy = null;
int[] prediction = null;
double[] logit = null;
double[] probability = null;
if (options.test != null) {
testx = formula.x(options.test);
testy = codec.indexOf(formula.y(options.test).toIntArray());
prediction = new int[testy.length];
logit = new double[testy.length];
probability = new double[testy.length];
Arrays.fill(logit, b);
}
int ntrees = options.ntrees;
double shrinkage = options.shrinkage;
RegressionTree[] trees = new RegressionTree[ntrees];
int[] permutation = IntStream.range(0, n).toArray();
int[] samples = new int[n];
for (int t = 0; t < ntrees; t++) {
sampling(samples, permutation, nc, y, options.subsample);
RegressionTree tree = new RegressionTree(x, loss, field, options.maxDepth, options.maxNodes, options.nodeSize, p, samples, order);
trees[t] = tree;
for (int i = 0; i < n; i++) {
h[i] += shrinkage * tree.predict(x.get(i));
}
double lossValue = loss.value();
logger.info("Tree {}: loss = {}", t+1, lossValue);
double fitTime = (System.nanoTime() - startTime) / 1E6;
ClassificationMetrics metrics = null;
if (options.test != null) {
long testStartTime = System.nanoTime();
for (int i = 0; i < testy.length; i++) {
logit[i] += shrinkage * tree.predict(testx.get(i));
prediction[i] = logit[i] > 0 ? 1 : 0;
probability[i] = 1 - 1.0 / (1.0 + Math.exp(2 * logit[i]));
}
double scoreTime = (System.nanoTime() - testStartTime) / 1E6;
metrics = ClassificationMetrics.binary(fitTime, scoreTime, testy, prediction, probability);
logger.info("Validation metrics = {} ", metrics);
}
if (options.controller != null) {
options.controller.submit(new TrainingStatus(t+1, lossValue, metrics));
if (options.controller.isInterrupted()) {
trees = Arrays.copyOf(trees, t);
break;
}
}
}
double[] importance = new double[p];
for (RegressionTree tree : trees) {
double[] imp = tree.importance();
for (int i = 0; i < imp.length; i++) {
importance[i] += imp[i];
}
}
return new GradientTreeBoost(formula, trees, b, shrinkage, importance, codec.classes);
}
/**
* Train L-k tree boost.
*/
private static GradientTreeBoost traink(Formula formula, DataFrame x, ClassLabels codec, int[][] order, Options options) {
long startTime = System.nanoTime();
int n = x.size();
int p = x.ncol();
int k = codec.k;
int[] y = codec.y;
int[] nc = new int[k];
for (int i = 0; i < n; i++) nc[y[i]]++;
DataFrame testx = null;
int[] testy = null;
int[] prediction = null;
double[][] logit = null;
double[][] probability = null;
if (options.test != null) {
testx = formula.x(options.test);
testy = codec.indexOf(formula.y(options.test).toIntArray());
prediction = new int[testy.length];
logit = new double[testy.length][k];
probability = new double[testy.length][k];
}
int ntrees = options.ntrees;
double shrinkage = options.shrinkage;
StructField field = new StructField("residual", DataTypes.DoubleType);
RegressionTree[][] forest = new RegressionTree[k][ntrees];
double[][] prob = new double[n][k]; // posteriori probabilities.
double[][] h = new double[k][]; // boost tree output.
Loss[] loss = new Loss[k];
for (int i = 0; i < k; i++) {
loss[i] = Loss.logistic(i, k, y, prob);
h[i] = loss[i].residual();
}
int[] permutation = IntStream.range(0, n).toArray();
int[] samples = new int[n];
for (int t = 0; t < ntrees; t++) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < k; j++) {
prob[i][j] = h[j][i];
}
MathEx.softmax(prob[i]);
}
for (int j = 0; j < k; j++) {
sampling(samples, permutation, nc, y, options.subsample);
RegressionTree tree = new RegressionTree(x, loss[j], field, options.maxDepth, options.maxNodes, options.nodeSize, p, samples, order);
forest[j][t] = tree;
double[] hj = h[j];
for (int i = 0; i < n; i++) {
hj[i] += shrinkage * tree.predict(x.get(i));
}
}
double lossValue = loss[0].value();
logger.info("Tree {}: loss = {}", t+1, lossValue);
double fitTime = (System.nanoTime() - startTime) / 1E6;
ClassificationMetrics metrics = null;
if (options.test != null) {
long testStartTime = System.nanoTime();
for (int i = 0; i < testy.length; i++) {
var xt = testx.get(i);
for (int j = 0; j < k; j++) {
logit[i][j] += shrinkage * forest[j][t].predict(xt);
}
prediction[i] = MathEx.whichMax(logit[i]);
double max = logit[i][prediction[i]];
double Z = 0.0;
for (int j = 0; j < k; j++) {
probability[i][j] = Math.exp(logit[i][j] - max);
Z += probability[i][j];
}
for (int j = 0; j < k; j++) {
probability[i][j] /= Z;
}
}
double scoreTime = (System.nanoTime() - testStartTime) / 1E6;
metrics = ClassificationMetrics.of(fitTime, scoreTime, testy, prediction, probability);
logger.info("Validation metrics = {} ", metrics);
}
if (options.controller != null) {
options.controller.submit(new TrainingStatus(t+1, lossValue, metrics));
if (options.controller.isInterrupted()) {
for (int j = 0; j < k; j++) {
forest[j] = Arrays.copyOf(forest[j], t);
}
break;
}
}
}
double[] importance = new double[p];
for (RegressionTree[] grove : forest) {
for (RegressionTree tree : grove) {
double[] imp = tree.importance();
for (int i = 0; i < imp.length; i++) {
importance[i] += imp[i];
}
}
}
return new GradientTreeBoost(formula, forest, shrinkage, importance, codec.classes);
}
/**
* Stratified sampling.
*/
private static void sampling(int[] samples, int[] permutation, int[] nc, int[] y, double subsample) {
int n = samples.length;
int k = nc.length;
Arrays.fill(samples, 0);
MathEx.permutate(permutation);
for (int j = 0; j < k; j++) {
int subj = (int) Math.round(nc[j] * subsample);
for (int i = 0, nj = 0; i < n && nj < subj; i++) {
int xi = permutation[i];
if (y[xi] == j) {
samples[xi] = 1;
nj++;
}
}
}
}
/**
* Returns the number of trees in the model.
*
* @return the number of trees in the model.
*/
public int size() {
return trees().length;
}
/**
* Returns the regression trees.
* @return the regression trees.
*/
public RegressionTree[][] trees() {
return trees;
}
/**
* Trims the tree model set to a smaller size in case of over-fitting.
* Or if extra decision trees in the model don't improve the performance,
* we may remove them to reduce the model size and also improve the speed of
* prediction.
*
* @param ntrees the new (smaller) size of tree model set.
* @return the trimmed model.
*/
public GradientTreeBoost trim(int ntrees) {
if (ntrees < 1) {
throw new IllegalArgumentException("Invalid new model size: " + ntrees);
}
if (ntrees > trees[0].length) {
throw new IllegalArgumentException("The new model size is larger than the current one.");
}
if (k == 2) {
return new GradientTreeBoost(formula, Arrays.copyOf(trees[0], ntrees), b, shrinkage, importance, classes);
} else {
RegressionTree[][] forest = new RegressionTree[k][];
for (int i = 0; i < k; i++) {
forest[i] = Arrays.copyOf(trees[i], ntrees);
}
return new GradientTreeBoost(formula, forest, shrinkage, importance, classes);
}
}
@Override
public int predict(Tuple x) {
Tuple xt = formula.x(x);
if (k == 2) {
double y = b;
for (RegressionTree tree : trees[0]) {
y += shrinkage * tree.predict(xt);
}
return classes.valueOf(y > 0 ? 1 : 0);
} else {
double max = Double.NEGATIVE_INFINITY;
int y = -1;
for (int j = 0; j < k; j++) {
double yj = 0.0;
for (RegressionTree tree : trees[j]) {
yj += shrinkage * tree.predict(xt);
}
if (yj > max) {
max = yj;
y = j;
}
}
return classes.valueOf(y);
}
}
@Override
public boolean soft() {
return true;
}
@Override
public int predict(Tuple x, double[] posteriori) {
if (posteriori.length != k) {
throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, k));
}
Tuple xt = formula.x(x);
if (k == 2) {
double y = b;
for (RegressionTree tree : trees[0]) {
y += shrinkage * tree.predict(xt);
}
posteriori[0] = 1.0 / (1.0 + Math.exp(2 * y));
posteriori[1] = 1.0 - posteriori[0];
return classes.valueOf(y > 0 ? 1 : 0);
} else {
double max = Double.NEGATIVE_INFINITY;
int y = -1;
for (int j = 0; j < k; j++) {
posteriori[j] = 0.0;
for (RegressionTree tree : trees[j]) {
posteriori[j] += shrinkage * tree.predict(xt);
}
if (posteriori[j] > max) {
max = posteriori[j];
y = j;
}
}
double Z = 0.0;
for (int i = 0; i < k; i++) {
posteriori[i] = Math.exp(posteriori[i] - max);
Z += posteriori[i];
}
for (int i = 0; i < k; i++) {
posteriori[i] /= Z;
}
return classes.valueOf(y);
}
}
/**
* Test the model on a validation dataset.
*
* @param data the test data set.
* @return the predictions with first 1, 2, ..., decision trees.
*/
public int[][] test(DataFrame data) {
DataFrame x = formula.x(data);
int n = x.size();
int ntrees = trees[0].length;
int[][] prediction = new int[ntrees][n];
if (k == 2) {
for (int j = 0; j < n; j++) {
Tuple xj = x.get(j);
double base = 0;
for (int i = 0; i < ntrees; i++) {
base += shrinkage * trees[0][i].predict(xj);
prediction[i][j] = base > 0 ? 1 : 0;
}
}
} else {
double[] p = new double[k];
for (int j = 0; j < n; j++) {
Tuple xj = x.get(j);
Arrays.fill(p, 0);
for (int i = 0; i < ntrees; i++) {
for (int l = 0; l < k; l++) {
p[l] += shrinkage * trees[l][i].predict(xj);
}
prediction[i][j] = MathEx.whichMax(p);
}
}
}
return prediction;
}
/**
* Returns the average of absolute SHAP values over a data frame.
* @param data the data set.
* @return the average of absolute SHAP values.
*/
public double[] shap(DataFrame data) {
// Binds the formula to the data frame's schema in case that
// it is different from that of training data.
formula.bind(data.schema());
return shap(data.stream().parallel());
}
@Override
public double[] shap(Tuple x) {
Tuple xt = formula.x(x);
int p = xt.length();
double[] phi = new double[p * k];
int ntrees = trees[0].length;
if (k == 2) {
for (RegressionTree tree : trees[0]) {
double[] phii = tree.shap(xt);
for (int i = 0; i < p; i++) {
phi[2*i] += phii[i];
phi[2*i+1] += phii[i];
}
}
} else {
for (int i = 0; i < k; i++) {
for (RegressionTree tree : trees[i]) {
double[] phii = tree.shap(xt);
for (int j = 0; j < p; j++) {
phi[j*k + i] += phii[j];
}
}
}
}
for (int i = 0; i < phi.length; i++) {
phi[i] /= ntrees;
}
return phi;
}
}