All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.classification.RandomForest Maven / Gradle / Ivy

/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.classification;

import java.io.Serial;
import java.io.Serializable;
import java.util.*;
import java.util.stream.LongStream;
import smile.base.cart.CART;
import smile.base.cart.SplitRule;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.feature.importance.TreeSHAP;
import smile.math.MathEx;
import smile.util.IntSet;
import smile.util.Strings;
import smile.validation.ClassificationMetrics;
import smile.validation.metric.*;
import smile.validation.metric.Error;

/**
 * Random forest for classification. Random forest is an ensemble classifier
 * that consists of many decision trees and outputs the majority vote of
 * individual trees. The method combines bagging idea and the random
 * selection of features.
 * 

* Each tree is constructed using the following algorithm: *

    *
  1. If the number of cases in the training set is N, randomly sample N cases * with replacement from the original data. This sample will * be the training set for growing the tree. *
  2. If there are M input variables, a number {@code m << M} is specified such * that at each node, m variables are selected at random out of the M and * the best split on these m is used to split the node. The value of m is * held constant during the forest growing. *
  3. Each tree is grown to the largest extent possible. There is no pruning. *
* The advantages of random forest are: *
    *
  • For many data sets, it produces a highly accurate classifier. *
  • It runs efficiently on large data sets. *
  • It can handle thousands of input variables without variable deletion. *
  • It gives estimates of what variables are important in the classification. *
  • It generates an internal unbiased estimate of the generalization error * as the forest building progresses. *
  • It has an effective method for estimating missing data and maintains * accuracy when a large proportion of the data are missing. *
* The disadvantages are *
    *
  • Random forests are prone to over-fitting for some datasets. This is * even more pronounced on noisy data. *
  • For data including categorical variables with different number of * levels, random forests are biased in favor of those attributes with more * levels. Therefore, the variable importance scores from random forest are * not reliable for this type of data. *
* * @author Haifeng Li */ public class RandomForest extends AbstractClassifier implements DataFrameClassifier, TreeSHAP { @Serial private static final long serialVersionUID = 2L; private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(RandomForest.class); /** * The base model. */ public static class Model implements Serializable { /** The decision tree. */ public final DecisionTree tree; /** The performance metrics on out-of-bag samples. */ public final ClassificationMetrics metrics; /** The weight of tree, which can be used when aggregating tree votes. */ public final double weight; /** Constructor. */ Model(DecisionTree tree, ClassificationMetrics metrics) { this.tree = tree; this.metrics = metrics; this.weight = metrics.accuracy(); } } /** * The model formula. */ private final Formula formula; /** * Forest of decision trees. The second value is the accuracy of * tree on the OOB samples, which can be used a weight when aggregating * tree votes. */ private final Model[] models; /** * The number of classes. */ private final int k; /** * The overall out-of-bag metrics, which are quite accurate given that * enough trees have been grown (otherwise the OOB error estimate can * bias upward). */ private final ClassificationMetrics metrics; /** * Variable importance. Every time a split of a node is made on variable * the (GINI, information gain, etc.) impurity criterion for the two * descendent nodes is less than the parent node. Adding up the decreases * for each individual variable over all trees in the forest gives a fast * variable importance that is often very consistent with the permutation * importance measure. */ private final double[] importance; /** * Constructor. * * @param formula a symbolic description of the model to be fitted. * @param k the number of classes. * @param models forest of decision trees. * @param metrics the overall out-of-bag metric estimation. * @param importance the feature importance. */ public RandomForest(Formula formula, int k, Model[] models, ClassificationMetrics metrics, double[] importance) { this(formula, k, models, metrics, importance, IntSet.of(k)); } /** * Constructor. * * @param formula a symbolic description of the model to be fitted. * @param k the number of classes. * @param models the base models. * @param metrics the overall out-of-bag metric estimation. * @param importance the feature importance. * @param labels the class label encoder. */ public RandomForest(Formula formula, int k, Model[] models, ClassificationMetrics metrics, double[] importance, IntSet labels) { super(labels); this.formula = formula; this.k = k; this.models = models; this.metrics = metrics; this.importance = importance; } /** * Fits a random forest for classification. * * @param formula a symbolic description of the model to be fitted. * @param data the data frame of the explanatory and response variables. * @return the model. */ public static RandomForest fit(Formula formula, DataFrame data) { return fit(formula, data, new Properties()); } /** * Fits a random forest for classification. * * @param formula a symbolic description of the model to be fitted. * @param data the data frame of the explanatory and response variables. * @param params the hyperparameters. * @return the model. */ public static RandomForest fit(Formula formula, DataFrame data, Properties params) { int ntrees = Integer.parseInt(params.getProperty("smile.random_forest.trees", "500")); int mtry = Integer.parseInt(params.getProperty("smile.random_forest.mtry", "0")); SplitRule rule = SplitRule.valueOf(params.getProperty("smile.random_forest.split_rule", "GINI")); int maxDepth = Integer.parseInt(params.getProperty("smile.random_forest.max_depth", "20")); int maxNodes = Integer.parseInt(params.getProperty("smile.random_forest.max_nodes", String.valueOf(Math.max(2, data.size() / 5)))); int nodeSize = Integer.parseInt(params.getProperty("smile.random_forest.node_size", "5")); double subsample = Double.parseDouble(params.getProperty("smile.random_forest.sampling_rate", "1.0")); int[] classWeight = Strings.parseIntArray(params.getProperty("smile.random_forest.class_weight")); return fit(formula, data, ntrees, mtry, rule, maxDepth, maxNodes, nodeSize, subsample, classWeight, null); } /** * Fits a random forest for classification. * * @param formula a symbolic description of the model to be fitted. * @param data the data frame of the explanatory and response variables. * @param ntrees the number of trees. * @param mtry the number of input variables to be used to determine the * decision at a node of the tree. floor(sqrt(p)) generally * gives good performance, where p is the number of variables. * @param rule Decision tree split rule. * @param maxDepth the maximum depth of the tree. * @param maxNodes the maximum number of leaf nodes in the tree. * @param nodeSize the number of instances in a node below which the tree * will not split, nodeSize = 5 generally gives good * results. * @param subsample the sampling rate for training tree. 1.0 means sampling * with replacement. {@code < 1.0} means sampling without * replacement. * @return the model. */ public static RandomForest fit(Formula formula, DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample) { return fit(formula, data, ntrees, mtry, rule, maxDepth, maxNodes, nodeSize, subsample, null); } /** * Fits a random forest for regression. * * @param formula a symbolic description of the model to be fitted. * @param data the data frame of the explanatory and response variables. * @param ntrees the number of trees. * @param mtry the number of input variables to be used to determine the * decision at a node of the tree. floor(sqrt(p)) generally * gives good performance, where p is the number of variables * @param rule Decision tree split rule. * @param maxDepth the maximum depth of the tree. * @param maxNodes the maximum number of leaf nodes in the tree. * @param nodeSize the number of instances in a node below which the tree * will not split, nodeSize = 5 generally gives good * results. * @param subsample the sampling rate for training tree. 1.0 means sampling * with replacement. {@code < 1.0} means sampling without * replacement. * @param classWeight Priors of the classes. The weight of each class * is roughly the ratio of samples in each class. * For example, if there are 400 positive samples * and 100 negative samples, the classWeight should * be [1, 4] (assuming label 0 is of negative, label * 1 is of positive). * @return the model. */ public static RandomForest fit(Formula formula, DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample, int[] classWeight) { return fit(formula, data, ntrees, mtry, rule, maxDepth, maxNodes, nodeSize, subsample, classWeight, null); } /** * Fits a random forest for classification. * * @param formula a symbolic description of the model to be fitted. * @param data the data frame of the explanatory and response variables. * @param ntrees the number of trees. * @param mtry the number of input variables to be used to determine the * decision at a node of the tree. floor(sqrt(p)) generally * gives good performance, where p is the number of variables. * @param rule Decision tree split rule. * @param maxDepth the maximum depth of the tree. * @param maxNodes the maximum number of leaf nodes in the tree. * @param nodeSize the number of instances in a node below which the tree * will not split, nodeSize = 5 generally gives good * results. * @param subsample the sampling rate for training tree. 1.0 means sampling * with replacement. {@code < 1.0} means sampling without * replacement. * @param classWeight Priors of the classes. The weight of each class * is roughly the ratio of samples in each class. * For example, if there are 400 positive samples * and 100 negative samples, the classWeight should * be [1, 4] (assuming label 0 is of negative, label 1 is of * positive). * @param seeds optional RNG seeds for each regression tree. * @return the model. */ public static RandomForest fit(Formula formula, DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample, int[] classWeight, LongStream seeds) { if (ntrees < 1) { throw new IllegalArgumentException("Invalid number of trees: " + ntrees); } if (subsample <= 0 || subsample > 1) { throw new IllegalArgumentException("Invalid sampling rating: " + subsample); } formula = formula.expand(data.schema()); DataFrame x = formula.x(data); BaseVector y = formula.y(data); if (mtry > x.ncol()) { throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + mtry); } int mtryFinal = mtry > 0 ? mtry : (int) Math.sqrt(x.ncol()); ClassLabels codec = ClassLabels.fit(y); final int k = codec.k; final int n = x.nrow(); final int[] weight = classWeight != null ? classWeight : Collections.nCopies(k, 1).stream().mapToInt(i -> i).toArray(); final int[][] order = CART.order(x); final int[][] prediction = new int[n][k]; // out-of-bag prediction // generate seeds with sequential stream long[] seedArray = (seeds != null ? seeds : LongStream.range(-ntrees, 0)).sequential().distinct().limit(ntrees).toArray(); if (seedArray.length != ntrees) { throw new IllegalArgumentException(String.format("seed stream has only %d distinct values, expected %d", seedArray.length, ntrees)); } // # of samples in each class int[] count = new int[k]; for (int i = 0; i < n; i++) { count[codec.y[i]]++; } // samples in each class int[][] yi = new int[k][]; for (int i = 0; i < k; i++) { yi[i] = new int[count[i]]; } int[] idx = new int[k]; for (int i = 0; i < n; i++) { int j = codec.y[i]; yi[j][idx[j]++] = i; } Model[] models = Arrays.stream(seedArray).parallel().mapToObj(seed -> { // set RNG seed for the tree if (seed > 1) MathEx.setSeed(seed); final int[] samples = new int[n]; // Stratified sampling in case that class is unbalanced. // That is, we sample each class separately. if (subsample == 1.0) { // Training samples draw with replacement. for (int i = 0; i < k; i++) { // We used to do up sampling. // But we switch to down sampling, which seems producing better AUC. int ni = count[i]; int size = ni / weight[i]; int[] yj = yi[i]; for (int j = 0; j < size; j++) { int xj = MathEx.randomInt(ni); samples[yj[xj]] += 1; //classWeight[i]; } } } else { // Training samples draw without replacement. for (int i = 0; i < k; i++) { // We used to do up sampling. // But we switch to down sampling, which seems producing better AUC. int size = (int) Math.round(subsample * count[i] / weight[i]); int[] yj = yi[i]; int[] permutation = MathEx.permutate(count[i]); for (int j = 0; j < size; j++) { int xj = permutation[j]; samples[yj[xj]] += 1; //classWeight[i]; } } } long start = System.nanoTime(); DecisionTree tree = new DecisionTree(x, codec.y, y.field(), k, rule, maxDepth, maxNodes, nodeSize, mtryFinal, samples, order); double fitTime = (System.nanoTime() - start) / 1E6; // estimate OOB metrics start = System.nanoTime(); int noob = 0; for (int i = 0; i < n; i++) { if (samples[i] == 0) { noob++; } } int[] truth = new int[noob]; int[] oob = new int[noob]; double[][] posteriori = new double[noob][k]; for (int i = 0, j = 0; i < n; i++) { if (samples[i] == 0) { truth[j] = codec.y[i]; int p = tree.predict(x.get(i), posteriori[j]); oob[j] = p; prediction[i][p]++; j++; } } double scoreTime = (System.nanoTime() - start) / 1E6; // When data is very small, OOB samples may miss some classes. int oobk = MathEx.unique(truth).length; ClassificationMetrics metrics; if (oobk == 2) { double[] probability = Arrays.stream(posteriori).mapToDouble(p -> p[1]).toArray(); metrics = new ClassificationMetrics(fitTime, scoreTime, noob, Error.of(truth, oob), Accuracy.of(truth, oob), Sensitivity.of(truth, oob), Specificity.of(truth, oob), Precision.of(truth, oob), FScore.F1.score(truth, oob), MatthewsCorrelation.of(truth, oob), AUC.of(truth, probability), LogLoss.of(truth, probability) ); } else { metrics = new ClassificationMetrics(fitTime, scoreTime, noob, Error.of(truth, oob), Accuracy.of(truth, oob), CrossEntropy.of(truth, posteriori) ); } if (noob != 0) { logger.info("Decision tree OOB accuracy: {}", String.format("%.2f%%", 100*metrics.accuracy())); } else { logger.error("Decision tree trained without OOB samples."); } return new Model(tree, metrics); }).toArray(Model[]::new); double fitTime = 0.0, scoreTime = 0.0; for (Model model : models) { fitTime += model.metrics.fitTime(); scoreTime += model.metrics.scoreTime(); } int[] vote = new int[n]; for (int i = 0; i < n; i++) { vote[i] = MathEx.whichMax(prediction[i]); } ClassificationMetrics metrics = new ClassificationMetrics(fitTime, scoreTime, n, Error.of(codec.y, vote), Accuracy.of(codec.y, vote) ); return new RandomForest(formula, k, models, metrics, importance(models), codec.classes); } /** Calculate the importance of the whole forest. */ private static double[] importance(Model[] models) { int p = models[0].tree.importance().length; double[] importance = new double[p]; for (Model model : models) { double[] imp = model.tree.importance(); for (int i = 0; i < p; i++) { importance[i] += imp[i]; } } return importance; } @Override public Formula formula() { return formula; } @Override public StructType schema() { return models[0].tree.schema(); } /** * Returns the overall out-of-bag metric estimations. The OOB estimate is * quite accurate given that enough trees have been grown. Otherwise, the * OOB error estimate can bias upward. * * @return the out-of-bag metrics estimations. */ public ClassificationMetrics metrics() { return metrics; } /** * Returns the variable importance. Every time a split of a node is made * on variable the (GINI, information gain, etc.) impurity criterion for * the two descendent nodes is less than the parent node. Adding up the * decreases for each individual variable over all trees in the forest * gives a fast measure of variable importance that is often very * consistent with the permutation importance measure. * * @return the variable importance */ public double[] importance() { return importance; } /** * Returns the number of trees in the model. * * @return the number of trees in the model. */ public int size() { return models.length; } /** * Returns the base models. * @return the base models. */ public Model[] models() { return models; } @Override public DecisionTree[] trees() { return Arrays.stream(models).map(model -> model.tree).toArray(DecisionTree[]::new); } /** * Trims the tree model set to a smaller size in case of over-fitting. * Or if extra decision trees in the model don't improve the performance, * we may remove them to reduce the model size and also improve the speed of * prediction. * * @param ntrees the new (smaller) size of tree model set. * @return a new trimmed forest. */ public RandomForest trim(int ntrees) { if (ntrees > models.length) { throw new IllegalArgumentException("The new model size is larger than the current size."); } if (ntrees <= 0) { throw new IllegalArgumentException("Invalid new model size: " + ntrees); } Arrays.sort(models, Comparator.comparingDouble(model -> -model.weight)); // The OOB metrics are still the old one // as we don't access to the training data here. return new RandomForest(formula, k, Arrays.copyOf(models, ntrees), metrics, importance(models), classes); } /** * Merges two random forests. * @param other the other forest to merge with. * @return the merged forest. */ public RandomForest merge(RandomForest other) { if (!formula.equals(other.formula)) { throw new IllegalArgumentException("RandomForest have different model formula"); } Model[] forest = new Model[models.length + other.models.length]; System.arraycopy(models, 0, forest, 0, models.length); System.arraycopy(other.models, 0, forest, models.length, other.models.length); // rough estimation ClassificationMetrics mergedMetrics = new ClassificationMetrics( metrics.fitTime() + other.metrics.fitTime(), metrics.scoreTime() + other.metrics.scoreTime(), metrics.size(), (metrics.error() + other.metrics.error()) / 2, (metrics.accuracy() + other.metrics.accuracy()) / 2, (metrics.sensitivity() + other.metrics.sensitivity()) / 2, (metrics.specificity() + other.metrics.specificity()) / 2, (metrics.precision() + other.metrics.precision()) / 2, (metrics.f1() + other.metrics.f1()) / 2, (metrics.mcc() + other.metrics.mcc()) / 2, (metrics.auc() + other.metrics.auc()) / 2, (metrics.logloss() + other.metrics.logloss()) / 2, (metrics.crossentropy() + other.metrics.crossentropy()) / 2 ); double[] mergedImportance = importance.clone(); for (int i = 0; i < importance.length; i++) { mergedImportance[i] += other.importance[i]; } return new RandomForest(formula, k, forest, mergedMetrics, mergedImportance, classes); } @Override public int predict(Tuple x) { Tuple xt = formula.x(x); int[] y = new int[k]; for (Model model : models) { y[model.tree.predict(xt)]++; } return classes.valueOf(MathEx.whichMax(y)); } @Override public boolean soft() { return true; } @Override public int predict(Tuple x, double[] posteriori) { if (posteriori.length != k) { throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, k)); } Tuple xt = formula.x(x); double[] prob = new double[k]; Arrays.fill(posteriori, 0.0); for (Model model : models) { model.tree.predict(xt, prob); for (int i = 0; i < k; i++) { posteriori[i] += model.weight * prob[i]; } } MathEx.unitize1(posteriori); return classes.valueOf(MathEx.whichMax(posteriori)); } /** * Predict and estimate the probability by voting. * * @param x the instances to be classified. * @param posteriori a posteriori probabilities on output. * @return the predicted class labels. */ public int vote(Tuple x, double[] posteriori) { if (posteriori.length != k) { throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, k)); } Tuple xt = formula.x(x); Arrays.fill(posteriori, 0.0); for (Model model : models) { posteriori[model.tree.predict(xt)]++; } MathEx.unitize1(posteriori); return classes.valueOf(MathEx.whichMax(posteriori)); } /** * Test the model on a validation dataset. * * @param data the test data set. * @return the predictions with first 1, 2, ..., decision trees. */ public int[][] test(DataFrame data) { DataFrame x = formula.x(data); int n = x.size(); int ntrees = models.length; int[] p = new int[k]; int[][] prediction = new int[ntrees][n]; for (int j = 0; j < n; j++) { Tuple xj = x.get(j); Arrays.fill(p, 0); for (int i = 0; i < ntrees; i++) { p[models[i].tree.predict(xj)]++; prediction[i][j] = MathEx.whichMax(p); } } return prediction; } /** * Returns a new random forest by reduced error pruning. * @param test the test data set to evaluate the errors of nodes. * @return a new pruned random forest. */ public RandomForest prune(DataFrame test) { Model[] forest = Arrays.stream(models).parallel() .map(model -> new Model(model.tree.prune(test, formula, classes), model.metrics)) .toArray(Model[]::new); // The tree weight and OOB metrics are still the old one // as we don't access to the training data here. return new RandomForest(formula, k, forest, metrics, importance(forest), classes); } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy