
smile.regression.RandomForest Maven / Gradle / Ivy
The newest version!
/*
* Copyright (c) 2010-2025 Haifeng Li. All rights reserved.
*
* Smile is free software: you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Smile is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Smile. If not, see .
*/
package smile.regression;
import java.io.Serial;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Properties;
import java.util.stream.IntStream;
import smile.base.cart.CART;
import smile.base.cart.Loss;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.ValueVector;
import smile.feature.importance.TreeSHAP;
import smile.math.MathEx;
import smile.util.IterativeAlgorithmController;
import smile.validation.RegressionMetrics;
/**
* Random forest for regression. Random forest is an ensemble method that
* consists of many regression trees and outputs the average of individual
* trees. The method combines bagging idea and the random selection of features.
*
* Each tree is constructed using the following algorithm:
*
* - If the number of cases in the training set is N, randomly sample N cases
* with replacement from the original data. This sample will
* be the training set for growing the tree.
*
- If there are M input variables, a number {@code m << M} is specified such
* that at each node, m variables are selected at random out of the M and
* the best split on these m is used to split the node. The value of m is
* held constant during the forest growing.
*
- Each tree is grown to the largest extent possible. There is no pruning.
*
* The advantages of random forest are:
*
* - For many data sets, it produces a highly accurate model.
*
- It runs efficiently on large data sets.
*
- It can handle thousands of input variables without variable deletion.
*
- It gives estimates of what variables are important in the classification.
*
- It generates an internal unbiased estimate of the generalization error
* as the forest building progresses.
*
- It has an effective method for estimating missing data and maintains
* accuracy when a large proportion of the data are missing.
*
* The disadvantages are
*
* - Random forests are prone to over-fitting for some datasets. This is
* even more pronounced in noisy classification/regression tasks.
*
- For data including categorical variables with different number of
* levels, random forests are biased in favor of those attributes with more
* levels. Therefore, the variable importance scores from random forest are
* not reliable for this type of data.
*
*
* @author Haifeng Li
*/
public class RandomForest implements DataFrameRegression, TreeSHAP {
@Serial
private static final long serialVersionUID = 2L;
private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(RandomForest.class);
/**
* The base model.
* @param tree The regression tree.
* @param metrics The validation metrics on out-of-bag samples.
*/
public record Model(RegressionTree tree, RegressionMetrics metrics) implements Serializable, Comparable {
@Override
public int compareTo(Model o) {
return Double.compare(metrics.rmse(), o.metrics.rmse());
}
}
/**
* The model formula.
*/
private final Formula formula;
/**
* Forest of regression trees.
*/
private final Model[] models;
/**
* The overall out-of-bag metrics, which are quite accurate given that
* enough trees have been grown (otherwise the OOB error estimate can
* bias upward).
*/
private final RegressionMetrics metrics;
/**
* Variable importance. Every time a split of a node is made on variable
* the impurity criterion for the two descendent nodes is less than the
* parent node. Adding up the decreases for each individual variable over
* all trees in the forest gives a fast variable importance that is often
* very consistent with the permutation importance measure.
*/
private final double[] importance;
/**
* Constructor.
* @param formula a symbolic description of the model to be fitted.
* @param models the base models.
* @param metrics the overall out-of-bag metric estimations.
* @param importance the feature importance.
*/
public RandomForest(Formula formula, Model[] models, RegressionMetrics metrics, double[] importance) {
this.formula = formula;
this.models = models;
this.metrics = metrics;
this.importance = importance;
}
/**
* Training status per tree.
* @param tree the tree index, starting at 1.
* @param metrics the validation metrics on out-of-bag samples.
*/
public record TrainingStatus(int tree, RegressionMetrics metrics) {
}
/**
* Random forest hyperparameters.
* @param ntrees the number of trees.
* @param mtry the number of input variables to be used to determine the
* decision at a node of the tree. p/3 generally give good
* performance, where p is the number of variables.
* @param maxDepth the maximum depth of the tree.
* @param maxNodes the maximum number of leaf nodes in the tree.
* @param nodeSize the minimum size of leaf nodes.
* Setting nodeSize = 5 generally gives good results.
* @param subsample the sampling rate for training tree. 1.0 means sampling with
* replacement. {@code < 1.0} means sampling without replacement.
* @param seeds optional RNG seeds for each regression tree.
* @param controller the optional training controller.
*/
public record Options(int ntrees, int mtry, int maxDepth, int maxNodes, int nodeSize, double subsample,
long[] seeds, IterativeAlgorithmController controller) {
/** Constructor. */
public Options {
if (ntrees < 1) {
throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
}
if (maxDepth < 2) {
throw new IllegalArgumentException("Invalid maximal tree depth: " + maxDepth);
}
if (nodeSize < 1) {
throw new IllegalArgumentException("Invalid node size: " + nodeSize);
}
if (subsample <= 0 || subsample > 1) {
throw new IllegalArgumentException("Invalid sampling rate: " + subsample);
}
if (seeds != null && seeds.length < ntrees) {
throw new IllegalArgumentException("The number of RNG seeds is fewer than that of trees: " + seeds.length);
}
}
/**
* Constructor.
* @param ntrees the number of trees.
*/
public Options(int ntrees) {
this(ntrees, 0);
}
/**
* Constructor.
* @param ntrees the number of trees.
* @param mtry the number of input variables to be used to determine the
* decision at a node of the tree. p/3 generally give good
* performance, where p is the number of variables.
*/
public Options(int ntrees, int mtry) {
this(ntrees, mtry, 20, 0, 5, 1.0, null, null);
}
/**
* Returns the persistent set of hyperparameters.
* @return the persistent set.
*/
public Properties toProperties() {
Properties props = new Properties();
props.setProperty("smile.random_forest.trees", Integer.toString(ntrees));
props.setProperty("smile.random_forest.mtry", Integer.toString(mtry));
props.setProperty("smile.random_forest.max_depth", Integer.toString(maxDepth));
props.setProperty("smile.random_forest.max_nodes", Integer.toString(maxNodes));
props.setProperty("smile.random_forest.node_size", Integer.toString(nodeSize));
props.setProperty("smile.random_forest.sampling_rate", Double.toString(subsample));
return props;
}
/**
* Returns the options from properties.
*
* @param props the hyperparameters.
* @return the options.
*/
public static Options of(Properties props) {
int ntrees = Integer.parseInt(props.getProperty("smile.random_forest.trees", "500"));
int mtry = Integer.parseInt(props.getProperty("smile.random_forest.mtry", "0"));
int maxDepth = Integer.parseInt(props.getProperty("smile.random_forest.max_depth", "20"));
int maxNodes = Integer.parseInt(props.getProperty("smile.random_forest.max_nodes", "0"));
int nodeSize = Integer.parseInt(props.getProperty("smile.random_forest.node_size", "5"));
double subsample = Double.parseDouble(props.getProperty("smile.random_forest.sampling_rate", "1.0"));
return new Options(ntrees, mtry, maxDepth, maxNodes, nodeSize, subsample, null, null);
}
}
/**
* Fits a random forest for regression.
*
* @param formula a symbolic description of the model to be fitted.
* @param data the data frame of the explanatory and response variables.
* @return the model.
*/
public static RandomForest fit(Formula formula, DataFrame data) {
return fit(formula, data, new Options(500));
}
/**
* Fits a random forest for regression.
*
* @param formula a symbolic description of the model to be fitted.
* @param data the data frame of the explanatory and response variables.
* @param options the hyperparameters.
* @return the model.
*/
public static RandomForest fit(Formula formula, DataFrame data, Options options) {
formula = formula.expand(data.schema());
DataFrame x = formula.x(data);
ValueVector response = formula.y(data);
StructField field = response.field();
double[] y = response.toDoubleArray();
if (options.mtry > x.ncol()) {
throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + options.mtry);
}
int mtry = options.mtry > 0 ? options.mtry : Math.max(x.ncol()/3, 1);
int maxNodes = options.maxNodes > 0 ? options.maxNodes :Math.max(2, data.size() / 5);
int ntrees = options.ntrees;
var subsample = options.subsample;
final int n = x.size();
double[] prediction = new double[n];
int[] oob = new int[n];
final int[][] order = CART.order(x);
// train trees with parallel stream
Model[] models = IntStream.range(0, ntrees).parallel().mapToObj(t -> {
// set RNG seed for the tree
if (options.seeds != null) MathEx.setSeed(options.seeds[t]);
final int[] samples = new int[n];
if (subsample == 1.0) {
// Training samples draw with replacement.
for (int i = 0; i < n; i++) {
samples[MathEx.randomInt(n)]++;
}
} else {
// Training samples draw without replacement.
int[] permutation = MathEx.permutate(n);
int N = (int) Math.round(n * subsample);
for (int i = 0; i < N; i++) {
samples[permutation[i]] = 1;
}
}
long start = System.nanoTime();
RegressionTree tree = new RegressionTree(x, Loss.ls(y), field, options.maxDepth, maxNodes, options.nodeSize, mtry, samples, order);
double fitTime = (System.nanoTime() - start) / 1E6;
// estimate OOB metrics
start = System.nanoTime();
int noob = 0;
for (int i = 0; i < n; i++) {
if (samples[i] == 0) {
noob++;
}
}
double[] truth = new double[noob];
double[] predict = new double[noob];
for (int i = 0, j = 0; i < n; i++) {
if (samples[i] == 0) {
truth[j] = y[i];
double yi = tree.predict(x.get(i));
predict[j] = yi;
oob[i]++;
prediction[i] += yi;
j++;
}
}
double scoreTime = (System.nanoTime() - start) / 1E6;
var metrics = RegressionMetrics.of(fitTime, scoreTime, truth, predict);
logger.info("Tree {}: OOB = {}, R2 = {}%", t+1, noob, String.format("%.2f", 100*metrics.r2()));
if (options.controller != null) {
options.controller.submit(new TrainingStatus(t+1, metrics));
}
return new Model(tree, metrics);
}).toArray(Model[]::new);
double fitTime = 0.0, scoreTime = 0.0;
for (Model model : models) {
fitTime += model.metrics.fitTime();
scoreTime += model.metrics.scoreTime();
}
for (int i = 0; i < n; i++) {
if (oob[i] > 0) {
prediction[i] /= oob[i];
}
}
var metrics = RegressionMetrics.of(fitTime, scoreTime, y, prediction);
double[] importance = calculateImportance(models);
return new RandomForest(formula, models, metrics, importance);
}
/** Returns the sum of importance of all trees. */
private static double[] calculateImportance(Model[] models) {
double[] importance = new double[models[0].tree.importance().length];
for (Model model : models) {
double[] imp = model.tree.importance();
for (int i = 0; i < imp.length; i++) {
importance[i] += imp[i];
}
}
return importance;
}
@Override
public Formula formula() {
return formula;
}
@Override
public StructType schema() {
return models[0].tree.schema();
}
/**
* Returns the overall out-of-bag metric estimations. The OOB estimate is
* quite accurate given that enough trees have been grown. Otherwise, the
* OOB error estimate can bias upward.
*
* @return the overall out-of-bag metric estimations.
*/
public RegressionMetrics metrics() {
return metrics;
}
/**
* Returns the variable importance. Every time a split of a node is made
* on variable the impurity criterion for the two descendent nodes is less
* than the parent node. Adding up the decreases for each individual
* variable over all trees in the forest gives a fast measure of variable
* importance that is often very consistent with the permutation importance
* measure.
*
* @return the variable importance
*/
public double[] importance() {
return importance;
}
/**
* Returns the number of trees in the model.
*
* @return the number of trees in the model
*/
public int size() {
return models.length;
}
/**
* Returns the base models.
* @return the base models.
*/
public Model[] models() {
return models;
}
@Override
public RegressionTree[] trees() {
return Arrays.stream(models).map(model -> model.tree).toArray(RegressionTree[]::new);
}
/**
* Trims the tree model set to a smaller size in case of over-fitting.
* Or if extra decision trees in the model don't improve the performance,
* we may remove them to reduce the model size and also improve the speed of
* prediction.
*
* @param ntrees the new (smaller) size of tree model set.
* @return the trimmed model.
*/
public RandomForest trim(int ntrees) {
if (ntrees > models.length) {
throw new IllegalArgumentException("The new model size is larger than the current size.");
}
if (ntrees <= 0) {
throw new IllegalArgumentException("Invalid new model size: " + ntrees);
}
Arrays.sort(models);
return new RandomForest(formula, Arrays.copyOf(models, ntrees), metrics, importance);
}
/**
* Merges two random forests.
*
* @param other the model to merge with.
* @return the merged model.
*/
public RandomForest merge(RandomForest other) {
if (!formula.equals(other.formula)) {
throw new IllegalArgumentException("RandomForest have different model formula");
}
Model[] forest = new Model[models.length + other.models.length];
System.arraycopy(models, 0, forest, 0, models.length);
System.arraycopy(other.models, 0, forest, models.length, other.models.length);
// rough estimation
RegressionMetrics mergedMetrics = new RegressionMetrics(
metrics.fitTime() + other.metrics.fitTime(),
metrics.scoreTime() + other.metrics.scoreTime(),
metrics.size(),
(metrics.rss() + other.metrics.rss()) / 2,
(metrics.mse() + other.metrics.mse()) / 2,
(metrics.rmse() + other.metrics.rmse()) / 2,
(metrics.mad() + other.metrics.mad()) / 2,
(metrics.r2() + other.metrics.r2()) / 2
);
double[] mergedImportance = importance.clone();
for (int i = 0; i < importance.length; i++) {
mergedImportance[i] += other.importance[i];
}
return new RandomForest(formula, forest, mergedMetrics, mergedImportance);
}
@Override
public double predict(Tuple x) {
Tuple xt = formula.x(x);
double y = 0;
for (Model model : models) {
y += model.tree.predict(xt);
}
return y / models.length;
}
/**
* Test the model on a validation dataset.
*
* @param data the test data set.
* @return the predictions with first 1, 2, ..., regression trees.
*/
public double[][] test(DataFrame data) {
DataFrame x = formula.x(data);
int n = x.size();
int ntrees = models.length;
double[][] prediction = new double[ntrees][n];
for (int j = 0; j < n; j++) {
Tuple xj = x.get(j);
double base = 0;
for (int i = 0; i < ntrees; i++) {
base = base + models[i].tree.predict(xj);
prediction[i][j] = base / (i+1);
}
}
return prediction;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy