All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.regression.RandomForest Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2025 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile. If not, see .
 */
package smile.regression;

import java.io.Serial;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Properties;
import java.util.stream.IntStream;
import smile.base.cart.CART;
import smile.base.cart.Loss;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.ValueVector;
import smile.feature.importance.TreeSHAP;
import smile.math.MathEx;
import smile.util.IterativeAlgorithmController;
import smile.validation.RegressionMetrics;

/**
 * Random forest for regression. Random forest is an ensemble method that
 * consists of many regression trees and outputs the average of individual
 * trees. The method combines bagging idea and the random selection of features.
 * 

* Each tree is constructed using the following algorithm: *

    *
  1. If the number of cases in the training set is N, randomly sample N cases * with replacement from the original data. This sample will * be the training set for growing the tree. *
  2. If there are M input variables, a number {@code m << M} is specified such * that at each node, m variables are selected at random out of the M and * the best split on these m is used to split the node. The value of m is * held constant during the forest growing. *
  3. Each tree is grown to the largest extent possible. There is no pruning. *
* The advantages of random forest are: *
    *
  • For many data sets, it produces a highly accurate model. *
  • It runs efficiently on large data sets. *
  • It can handle thousands of input variables without variable deletion. *
  • It gives estimates of what variables are important in the classification. *
  • It generates an internal unbiased estimate of the generalization error * as the forest building progresses. *
  • It has an effective method for estimating missing data and maintains * accuracy when a large proportion of the data are missing. *
* The disadvantages are *
    *
  • Random forests are prone to over-fitting for some datasets. This is * even more pronounced in noisy classification/regression tasks. *
  • For data including categorical variables with different number of * levels, random forests are biased in favor of those attributes with more * levels. Therefore, the variable importance scores from random forest are * not reliable for this type of data. *
* * @author Haifeng Li */ public class RandomForest implements DataFrameRegression, TreeSHAP { @Serial private static final long serialVersionUID = 2L; private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(RandomForest.class); /** * The base model. * @param tree The regression tree. * @param metrics The validation metrics on out-of-bag samples. */ public record Model(RegressionTree tree, RegressionMetrics metrics) implements Serializable, Comparable { @Override public int compareTo(Model o) { return Double.compare(metrics.rmse(), o.metrics.rmse()); } } /** * The model formula. */ private final Formula formula; /** * Forest of regression trees. */ private final Model[] models; /** * The overall out-of-bag metrics, which are quite accurate given that * enough trees have been grown (otherwise the OOB error estimate can * bias upward). */ private final RegressionMetrics metrics; /** * Variable importance. Every time a split of a node is made on variable * the impurity criterion for the two descendent nodes is less than the * parent node. Adding up the decreases for each individual variable over * all trees in the forest gives a fast variable importance that is often * very consistent with the permutation importance measure. */ private final double[] importance; /** * Constructor. * @param formula a symbolic description of the model to be fitted. * @param models the base models. * @param metrics the overall out-of-bag metric estimations. * @param importance the feature importance. */ public RandomForest(Formula formula, Model[] models, RegressionMetrics metrics, double[] importance) { this.formula = formula; this.models = models; this.metrics = metrics; this.importance = importance; } /** * Training status per tree. * @param tree the tree index, starting at 1. * @param metrics the validation metrics on out-of-bag samples. */ public record TrainingStatus(int tree, RegressionMetrics metrics) { } /** * Random forest hyperparameters. * @param ntrees the number of trees. * @param mtry the number of input variables to be used to determine the * decision at a node of the tree. p/3 generally give good * performance, where p is the number of variables. * @param maxDepth the maximum depth of the tree. * @param maxNodes the maximum number of leaf nodes in the tree. * @param nodeSize the minimum size of leaf nodes. * Setting nodeSize = 5 generally gives good results. * @param subsample the sampling rate for training tree. 1.0 means sampling with * replacement. {@code < 1.0} means sampling without replacement. * @param seeds optional RNG seeds for each regression tree. * @param controller the optional training controller. */ public record Options(int ntrees, int mtry, int maxDepth, int maxNodes, int nodeSize, double subsample, long[] seeds, IterativeAlgorithmController controller) { /** Constructor. */ public Options { if (ntrees < 1) { throw new IllegalArgumentException("Invalid number of trees: " + ntrees); } if (maxDepth < 2) { throw new IllegalArgumentException("Invalid maximal tree depth: " + maxDepth); } if (nodeSize < 1) { throw new IllegalArgumentException("Invalid node size: " + nodeSize); } if (subsample <= 0 || subsample > 1) { throw new IllegalArgumentException("Invalid sampling rate: " + subsample); } if (seeds != null && seeds.length < ntrees) { throw new IllegalArgumentException("The number of RNG seeds is fewer than that of trees: " + seeds.length); } } /** * Constructor. * @param ntrees the number of trees. */ public Options(int ntrees) { this(ntrees, 0); } /** * Constructor. * @param ntrees the number of trees. * @param mtry the number of input variables to be used to determine the * decision at a node of the tree. p/3 generally give good * performance, where p is the number of variables. */ public Options(int ntrees, int mtry) { this(ntrees, mtry, 20, 0, 5, 1.0, null, null); } /** * Returns the persistent set of hyperparameters. * @return the persistent set. */ public Properties toProperties() { Properties props = new Properties(); props.setProperty("smile.random_forest.trees", Integer.toString(ntrees)); props.setProperty("smile.random_forest.mtry", Integer.toString(mtry)); props.setProperty("smile.random_forest.max_depth", Integer.toString(maxDepth)); props.setProperty("smile.random_forest.max_nodes", Integer.toString(maxNodes)); props.setProperty("smile.random_forest.node_size", Integer.toString(nodeSize)); props.setProperty("smile.random_forest.sampling_rate", Double.toString(subsample)); return props; } /** * Returns the options from properties. * * @param props the hyperparameters. * @return the options. */ public static Options of(Properties props) { int ntrees = Integer.parseInt(props.getProperty("smile.random_forest.trees", "500")); int mtry = Integer.parseInt(props.getProperty("smile.random_forest.mtry", "0")); int maxDepth = Integer.parseInt(props.getProperty("smile.random_forest.max_depth", "20")); int maxNodes = Integer.parseInt(props.getProperty("smile.random_forest.max_nodes", "0")); int nodeSize = Integer.parseInt(props.getProperty("smile.random_forest.node_size", "5")); double subsample = Double.parseDouble(props.getProperty("smile.random_forest.sampling_rate", "1.0")); return new Options(ntrees, mtry, maxDepth, maxNodes, nodeSize, subsample, null, null); } } /** * Fits a random forest for regression. * * @param formula a symbolic description of the model to be fitted. * @param data the data frame of the explanatory and response variables. * @return the model. */ public static RandomForest fit(Formula formula, DataFrame data) { return fit(formula, data, new Options(500)); } /** * Fits a random forest for regression. * * @param formula a symbolic description of the model to be fitted. * @param data the data frame of the explanatory and response variables. * @param options the hyperparameters. * @return the model. */ public static RandomForest fit(Formula formula, DataFrame data, Options options) { formula = formula.expand(data.schema()); DataFrame x = formula.x(data); ValueVector response = formula.y(data); StructField field = response.field(); double[] y = response.toDoubleArray(); if (options.mtry > x.ncol()) { throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + options.mtry); } int mtry = options.mtry > 0 ? options.mtry : Math.max(x.ncol()/3, 1); int maxNodes = options.maxNodes > 0 ? options.maxNodes :Math.max(2, data.size() / 5); int ntrees = options.ntrees; var subsample = options.subsample; final int n = x.size(); double[] prediction = new double[n]; int[] oob = new int[n]; final int[][] order = CART.order(x); // train trees with parallel stream Model[] models = IntStream.range(0, ntrees).parallel().mapToObj(t -> { // set RNG seed for the tree if (options.seeds != null) MathEx.setSeed(options.seeds[t]); final int[] samples = new int[n]; if (subsample == 1.0) { // Training samples draw with replacement. for (int i = 0; i < n; i++) { samples[MathEx.randomInt(n)]++; } } else { // Training samples draw without replacement. int[] permutation = MathEx.permutate(n); int N = (int) Math.round(n * subsample); for (int i = 0; i < N; i++) { samples[permutation[i]] = 1; } } long start = System.nanoTime(); RegressionTree tree = new RegressionTree(x, Loss.ls(y), field, options.maxDepth, maxNodes, options.nodeSize, mtry, samples, order); double fitTime = (System.nanoTime() - start) / 1E6; // estimate OOB metrics start = System.nanoTime(); int noob = 0; for (int i = 0; i < n; i++) { if (samples[i] == 0) { noob++; } } double[] truth = new double[noob]; double[] predict = new double[noob]; for (int i = 0, j = 0; i < n; i++) { if (samples[i] == 0) { truth[j] = y[i]; double yi = tree.predict(x.get(i)); predict[j] = yi; oob[i]++; prediction[i] += yi; j++; } } double scoreTime = (System.nanoTime() - start) / 1E6; var metrics = RegressionMetrics.of(fitTime, scoreTime, truth, predict); logger.info("Tree {}: OOB = {}, R2 = {}%", t+1, noob, String.format("%.2f", 100*metrics.r2())); if (options.controller != null) { options.controller.submit(new TrainingStatus(t+1, metrics)); } return new Model(tree, metrics); }).toArray(Model[]::new); double fitTime = 0.0, scoreTime = 0.0; for (Model model : models) { fitTime += model.metrics.fitTime(); scoreTime += model.metrics.scoreTime(); } for (int i = 0; i < n; i++) { if (oob[i] > 0) { prediction[i] /= oob[i]; } } var metrics = RegressionMetrics.of(fitTime, scoreTime, y, prediction); double[] importance = calculateImportance(models); return new RandomForest(formula, models, metrics, importance); } /** Returns the sum of importance of all trees. */ private static double[] calculateImportance(Model[] models) { double[] importance = new double[models[0].tree.importance().length]; for (Model model : models) { double[] imp = model.tree.importance(); for (int i = 0; i < imp.length; i++) { importance[i] += imp[i]; } } return importance; } @Override public Formula formula() { return formula; } @Override public StructType schema() { return models[0].tree.schema(); } /** * Returns the overall out-of-bag metric estimations. The OOB estimate is * quite accurate given that enough trees have been grown. Otherwise, the * OOB error estimate can bias upward. * * @return the overall out-of-bag metric estimations. */ public RegressionMetrics metrics() { return metrics; } /** * Returns the variable importance. Every time a split of a node is made * on variable the impurity criterion for the two descendent nodes is less * than the parent node. Adding up the decreases for each individual * variable over all trees in the forest gives a fast measure of variable * importance that is often very consistent with the permutation importance * measure. * * @return the variable importance */ public double[] importance() { return importance; } /** * Returns the number of trees in the model. * * @return the number of trees in the model */ public int size() { return models.length; } /** * Returns the base models. * @return the base models. */ public Model[] models() { return models; } @Override public RegressionTree[] trees() { return Arrays.stream(models).map(model -> model.tree).toArray(RegressionTree[]::new); } /** * Trims the tree model set to a smaller size in case of over-fitting. * Or if extra decision trees in the model don't improve the performance, * we may remove them to reduce the model size and also improve the speed of * prediction. * * @param ntrees the new (smaller) size of tree model set. * @return the trimmed model. */ public RandomForest trim(int ntrees) { if (ntrees > models.length) { throw new IllegalArgumentException("The new model size is larger than the current size."); } if (ntrees <= 0) { throw new IllegalArgumentException("Invalid new model size: " + ntrees); } Arrays.sort(models); return new RandomForest(formula, Arrays.copyOf(models, ntrees), metrics, importance); } /** * Merges two random forests. * * @param other the model to merge with. * @return the merged model. */ public RandomForest merge(RandomForest other) { if (!formula.equals(other.formula)) { throw new IllegalArgumentException("RandomForest have different model formula"); } Model[] forest = new Model[models.length + other.models.length]; System.arraycopy(models, 0, forest, 0, models.length); System.arraycopy(other.models, 0, forest, models.length, other.models.length); // rough estimation RegressionMetrics mergedMetrics = new RegressionMetrics( metrics.fitTime() + other.metrics.fitTime(), metrics.scoreTime() + other.metrics.scoreTime(), metrics.size(), (metrics.rss() + other.metrics.rss()) / 2, (metrics.mse() + other.metrics.mse()) / 2, (metrics.rmse() + other.metrics.rmse()) / 2, (metrics.mad() + other.metrics.mad()) / 2, (metrics.r2() + other.metrics.r2()) / 2 ); double[] mergedImportance = importance.clone(); for (int i = 0; i < importance.length; i++) { mergedImportance[i] += other.importance[i]; } return new RandomForest(formula, forest, mergedMetrics, mergedImportance); } @Override public double predict(Tuple x) { Tuple xt = formula.x(x); double y = 0; for (Model model : models) { y += model.tree.predict(xt); } return y / models.length; } /** * Test the model on a validation dataset. * * @param data the test data set. * @return the predictions with first 1, 2, ..., regression trees. */ public double[][] test(DataFrame data) { DataFrame x = formula.x(data); int n = x.size(); int ntrees = models.length; double[][] prediction = new double[ntrees][n]; for (int j = 0; j < n; j++) { Tuple xj = x.get(j); double base = 0; for (int i = 0; i < ntrees; i++) { base = base + models[i].tree.predict(xj); prediction[i][j] = base / (i+1); } } return prediction; } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy