smile.regression.RandomForest Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of openchemlib Show documentation
Show all versions of openchemlib Show documentation
Open Source Chemistry Library
/*******************************************************************************
* Copyright (c) 2010 Haifeng Li
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
package smile.regression;
import smile.data.Attribute;
import smile.data.AttributeDataset;
import smile.data.NumericAttribute;
import smile.math.Math;
import smile.util.MulticoreExecutor;
import smile.util.SmileUtils;
import smile.validation.RMSE;
import smile.validation.RegressionMeasure;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
/**
* Random forest for regression. Random forest is an ensemble method that
* consists of many regression trees and outputs the average of individual
* trees. The method combines bagging idea and the random selection of features.
*
* Each tree is constructed using the following algorithm:
*
* - If the number of cases in the training set is N, randomly sample N cases
* with replacement from the original data. This sample will
* be the training set for growing the tree.
*
- If there are M input variables, a number m << M is specified such
* that at each node, m variables are selected at random out of the M and
* the best split on these m is used to split the node. The value of m is
* held constant during the forest growing.
*
- Each tree is grown to the largest extent possible. There is no pruning.
*
* The advantages of random forest are:
*
* - For many data sets, it produces a highly accurate model.
*
- It runs efficiently on large data sets.
*
- It can handle thousands of input variables without variable deletion.
*
- It gives estimates of what variables are important in the classification.
*
- It generates an internal unbiased estimate of the generalization error
* as the forest building progresses.
*
- It has an effective method for estimating missing data and maintains
* accuracy when a large proportion of the data are missing.
*
* The disadvantages are
*
* - Random forests are prone to over-fitting for some datasets. This is
* even more pronounced in noisy classification/regression tasks.
*
- For data including categorical variables with different number of
* levels, random forests are biased in favor of those attributes with more
* levels. Therefore, the variable importance scores from random forest are
* not reliable for this type of data.
*
*
* @author Haifeng Li
*/
public class RandomForest implements Regression {
private static final long serialVersionUID = 1L;
/**
* Forest of regression trees.
*/
private List trees;
/**
* Out-of-bag estimation of RMSE, which is quite accurate given that
* enough trees have been grown (otherwise the OOB estimate can
* bias upward).
*/
private double error;
/**
* Variable importance. Every time a split of a node is made on variable
* the impurity criterion for the two descendent nodes is less than the
* parent node. Adding up the decreases for each individual variable over
* all trees in the forest gives a fast variable importance that is often
* very consistent with the permutation importance measure.
*/
private double[] importance;
/**
* Values between [-1, 1] that represents monotonic regression coefficient for each attribute.
*
* It can be used to enforce model to keep monotonic relationship between target and the attribute.
* Positive value enforce target to be positively correlated with this feature.
* Positive value enforce target to be negatively correlated with this feature.
* Zero value turns off monotonic regression.
*/
private double[] monotonicRegression;
/**
* Trainer for random forest.
*/
public static class Trainer extends RegressionTrainer {
/**
* The number of trees.
*/
private int ntrees = 500;
/**
* The number of random selected features to be used to determine the decision
* at a node of the tree. floor(sqrt(dim)) seems to give generally good performance,
* where dim is the number of variables.
*/
private int mtry = -1;
/**
* The minimum number of instances in leaf nodes.
*/
private int nodeSize = 5;
/**
* The maximum number of leaf nodes in the tree.
*/
private int maxNodes = 100;
/**
* The sampling rate.
*/
private double subsample = 1.0;
/**
* Constructor.
*
* @param ntrees the number of trees.
*/
public Trainer(int ntrees) {
if (ntrees < 1) {
throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
}
this.ntrees = ntrees;
}
/**
* Constructor.
*
* @param attributes the attributes of independent variable.
* @param ntrees the number of trees.
*/
public Trainer(Attribute[] attributes, int ntrees) {
super(attributes);
if (ntrees < 1) {
throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
}
this.ntrees = ntrees;
}
/**
* Sets the number of trees in the random forest.
* @param ntrees the number of trees.
*/
public Trainer setNumTrees(int ntrees) {
if (ntrees < 1) {
throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
}
this.ntrees = ntrees;
return this;
}
/**
* Sets the number of random selected features for splitting.
* @param mtry the number of random selected features to be used to determine
* the decision at a node of the tree. p/3 seems to give
* generally good performance, where p is the number of variables.
*/
public Trainer setNumRandomFeatures(int mtry) {
if (mtry < 1) {
throw new IllegalArgumentException("Invalid number of random selected features for splitting: " + mtry);
}
this.mtry = mtry;
return this;
}
/**
* Sets the maximum number of leaf nodes.
* @param maxNodes the maximum number of leaf nodes.
*/
public Trainer setMaxNodes(int maxNodes) {
if (maxNodes < 2) {
throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + maxNodes);
}
this.maxNodes = maxNodes;
return this;
}
/**
* Sets the minimum size of leaf nodes.
* @param nodeSize the number of instances in a node below which the tree will not split.
*/
public Trainer setNodeSize(int nodeSize) {
if (nodeSize < 1) {
throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + nodeSize);
}
this.nodeSize = nodeSize;
return this;
}
/**
* Sets the sampling rate.
* @param subsample the sampling rate.
*/
public Trainer setSamplingRates(double subsample) {
if (subsample <= 0 || subsample > 1) {
throw new IllegalArgumentException("Invalid sampling rate: " + subsample);
}
this.subsample = subsample;
return this;
}
@Override
public RandomForest train(double[][] x, double[] y) {
return new RandomForest(attributes, x, y, ntrees, maxNodes, nodeSize, mtry, subsample);
}
}
/**
* Trains a regression tree.
*/
static class TrainingTask implements Callable {
/**
* Attribute properties.
*/
Attribute[] attributes;
/**
* Training instances.
*/
double[][] x;
/**
* Training response variable.
*/
double[] y;
/**
* The index of training values in ascending order. Note that only
* numeric attributes will be sorted.
*/
int[][] order;
/**
* The number of variables to pick up in each node.
*/
int mtry;
/**
* The minimum number of instances in leaf nodes.
*/
int nodeSize = 5;
/**
* The maximum number of leaf nodes in the tree.
*/
int maxNodes = 100;
/**
* The sampling rate.
*/
double subsample = 1.0;
final double[] monotonicRegression;
/**
* Predictions of of out-of-bag samples.
*/
double[] prediction;
/**
* Out-of-bag sample
*/
int[] oob;
/**
* Constructor.
*/
TrainingTask(Attribute[] attributes, double[][] x, double[] y, int maxNodes, int nodeSize, int mtry, double subsample, int[][] order, double[] prediction, int[] oob, double[] monotonicRegression) {
this.attributes = attributes;
this.monotonicRegression = monotonicRegression;
this.x = x;
this.y = y;
this.order = order;
this.mtry = mtry;
this.nodeSize = nodeSize;
this.maxNodes = maxNodes;
this.subsample = subsample;
this.prediction = prediction;
this.oob = oob;
}
@Override
public RegressionTree call() {
int n = x.length;
int[] samples = new int[n];
if (subsample == 1.0) {
// Training samples draw with replacement.
for (int i = 0; i < n; i++) {
int xi = Math.randomInt(n);
samples[xi]++;
}
} else {
// Training samples draw without replacement.
int[] perm = new int[n];
for (int i = 0; i < n; i++) {
perm[i] = i;
}
Math.permutate(perm);
int m = (int) Math.round(n * subsample);
for (int i = 0; i < m; i++) {
samples[perm[i]]++;
}
}
RegressionTree tree = new RegressionTree(attributes, x, y, maxNodes, nodeSize, mtry, order, samples, null, monotonicRegression);
for (int i = 0; i < n; i++) {
if (samples[i] == 0) {
double pred = tree.predict(x[i]);
synchronized (x[i]) {
prediction[i] += pred;
oob[i]++;
}
}
}
return tree;
}
}
/**
* Constructor. Learns a random forest for regression.
*
* @param x the training instances.
* @param y the response variable.
* @param ntrees the number of trees.
*/
public RandomForest(double[][] x, double[] y, int ntrees) {
this(null, x, y, ntrees);
}
/**
* Constructor. Learns a random forest for regression.
*
* @param x the training instances.
* @param y the response variable.
* @param ntrees the number of trees.
* @param mtry the number of input variables to be used to determine the decision
* at a node of the tree. p/3 seems to give generally good performance,
* where p is the number of variables.
* @param nodeSize the number of instances in a node below which the tree will
* not split, setting nodeSize = 5 generally gives good results.
* @param maxNodes the maximum number of leaf nodes in the tree.
*/
public RandomForest(double[][] x, double[] y, int ntrees, int maxNodes, int nodeSize, int mtry) {
this(null, x, y, ntrees, maxNodes, nodeSize, mtry);
}
/**
* Constructor. Learns a random forest for regression.
*
* @param attributes the attribute properties.
* @param x the training instances.
* @param y the response variable.
* @param ntrees the number of trees.
*/
public RandomForest(Attribute[] attributes, double[][] x, double[] y, int ntrees) {
this(attributes, x, y, ntrees, 100);
}
/**
* Constructor. Learns a random forest for regression.
*
* @param data the dataset
* @param ntrees the number of trees.
*/
public RandomForest(AttributeDataset data, int ntrees) {
this(data.attributes(), data.x(), data.y(), ntrees);
}
/**
* Constructor. Learns a random forest for regression.
*
* @param attributes the attribute properties.
* @param x the training instances.
* @param y the response variable.
* @param ntrees the number of trees.
* @param maxNodes the maximum number of leaf nodes in the tree.
*/
public RandomForest(Attribute[] attributes, double[][] x, double[] y, int ntrees, int maxNodes) {
this(attributes, x, y, ntrees, maxNodes, 5);
}
/**
* Constructor. Learns a random forest for regression.
*
* @param data the dataset
* @param ntrees the number of trees.
* @param maxNodes the maximum number of leaf nodes in the tree.
*/
public RandomForest(AttributeDataset data, int ntrees, int maxNodes) {
this(data.attributes(), data.x(), data.y(), ntrees, maxNodes);
}
/**
* Constructor. Learns a random forest for regression.
*
* @param attributes the attribute properties.
* @param x the training instances.
* @param y the response variable.
* @param ntrees the number of trees.
* @param nodeSize the number of instances in a node below which the tree will
* not split, setting nodeSize = 5 generally gives good results.
* @param maxNodes the maximum number of leaf nodes in the tree.
*/
public RandomForest(Attribute[] attributes, double[][] x, double[] y, int ntrees, int maxNodes, int nodeSize) {
this(attributes, x, y, ntrees, maxNodes, nodeSize, x[0].length / 3);
}
/**
* Constructor. Learns a random forest for regression.
*
* @param data the dataset
* @param ntrees the number of trees.
* @param maxNodes the maximum number of leaf nodes in the tree.
* @param nodeSize the number of instances in a node below which the tree will
* not split, setting nodeSize = 5 generally gives good results.
*/
public RandomForest(AttributeDataset data, int ntrees, int maxNodes, int nodeSize) {
this(data.attributes(), data.x(), data.y(), ntrees, maxNodes, nodeSize);
}
/**
* Constructor. Learns a random forest for regression.
*
* @param attributes the attribute properties.
* @param x the training instances.
* @param y the response variable.
* @param ntrees the number of trees.
* @param mtry the number of input variables to be used to determine the decision
* at a node of the tree. p/3 seems to give generally good performance,
* where dim is the number of variables.
* @param nodeSize the number of instances in a node below which the tree will
* not split, setting nodeSize = 5 generally gives good results.
* @param maxNodes the maximum number of leaf nodes in the tree.
*/
public RandomForest(Attribute[] attributes, double[][] x, double[] y, int ntrees, int maxNodes, int nodeSize, int mtry) {
this(attributes, x, y, ntrees, maxNodes, nodeSize, mtry, 1.0);
}
/**
* Constructor. Learns a random forest for regression.
*
* @param attributes the attribute properties.
* @param x the training instances.
* @param y the response variable.
* @param ntrees the number of trees.
* @param mtry the number of input variables to be used to determine the decision
* at a node of the tree. p/3 seems to give generally good performance,
* where dim is the number of variables.
* @param nodeSize the number of instances in a node below which the tree will
* not split, setting nodeSize = 5 generally gives good results.
* @param maxNodes the maximum number of leaf nodes in the tree.
* @param subsample the sampling rate for training tree. 1.0 means sampling with replacement. < 1.0 means
* sampling without replacement.
*/
public RandomForest(Attribute[] attributes, double[][] x, double[] y, int ntrees, int maxNodes, int nodeSize, int mtry, double subsample) {
this(attributes, x, y, ntrees, maxNodes, nodeSize, mtry, subsample, null);
}
/**
* Constructor. Learns a random forest for regression.
*
* @param data the dataset
* @param ntrees the number of trees.
* @param mtry the number of input variables to be used to determine the decision
* at a node of the tree. p/3 seems to give generally good performance,
* where dim is the number of variables.
* @param nodeSize the number of instances in a node below which the tree will
* not split, setting nodeSize = 5 generally gives good results.
* @param maxNodes the maximum number of leaf nodes in the tree.
* @param subsample the sampling rate for training tree. 1.0 means sampling with replacement. < 1.0 means
* sampling without replacement.
*/
public RandomForest(AttributeDataset data, int ntrees, int maxNodes, int nodeSize, int mtry, double subsample, double[] monotonicRegression) {
this(data.attributes(), data.x(), data.y(), ntrees, maxNodes, nodeSize, mtry, subsample, monotonicRegression);
}
/**
* Constructor. Learns a random forest for regression.
*
* @param attributes the attribute properties.
* @param x the training instances.
* @param y the response variable.
* @param ntrees the number of trees.
* @param mtry the number of input variables to be used to determine the decision
* at a node of the tree. p/3 seems to give generally good performance,
* where dim is the number of variables.
* @param nodeSize the number of instances in a node below which the tree will
* not split, setting nodeSize = 5 generally gives good results.
* @param maxNodes the maximum number of leaf nodes in the tree.
* @param subsample the sampling rate for training tree. 1.0 means sampling with replacement. < 1.0 means
* sampling without replacement.
*/
public RandomForest(Attribute[] attributes, double[][] x, double[] y, int ntrees, int maxNodes, int nodeSize, int mtry, double subsample, double[] monotonicRegression) {
if (x.length != y.length) {
throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
}
if (ntrees < 1) {
throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
}
if (mtry < 1 || mtry > x[0].length) {
throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + mtry);
}
if (nodeSize < 2) {
throw new IllegalArgumentException("Invalid minimum size of leaves: " + nodeSize);
}
if (maxNodes < 2) {
throw new IllegalArgumentException("Invalid maximum number of leaves: " + maxNodes);
}
if (subsample <= 0 || subsample > 1) {
throw new IllegalArgumentException("Invalid sampling rate: " + subsample);
}
if (attributes == null) {
int p = x[0].length;
attributes = new Attribute[p];
for (int i = 0; i < p; i++) {
attributes[i] = new NumericAttribute("V" + (i + 1));
}
}
int n = x.length;
double[] prediction = new double[n];
int[] oob = new int[n];
int[][] order = SmileUtils.sort(attributes, x);
List tasks = new ArrayList<>();
for (int i = 0; i < ntrees; i++) {
tasks.add(new TrainingTask(attributes, x, y, maxNodes, nodeSize, mtry, subsample, order, prediction, oob, monotonicRegression));
}
try {
trees = MulticoreExecutor.run(tasks);
} catch (Exception ex) {
ex.printStackTrace();
trees = new ArrayList<>(ntrees);
for (int i = 0; i < ntrees; i++) {
trees.add(tasks.get(i).call());
}
}
int m = 0;
for (int i = 0; i < n; i++) {
if (oob[i] > 0) {
m++;
double pred = prediction[i] / oob[i];
error += Math.sqr(pred - y[i]);
}
}
if (m > 0) {
error = Math.sqrt(error / m);
}
importance = calculateImportance(trees, attributes.length);
}
/**
* Merges together two random forests and returns a new forest consisting of trees from both input forests.
*/
public RandomForest merge(RandomForest other) {
if (this.importance.length != other.importance.length) {
throw new IllegalArgumentException("RandomForest have different sizes of feature vectors");
}
ArrayList mergedTrees = new ArrayList<>();
mergedTrees.addAll(this.trees);
mergedTrees.addAll(other.trees);
double weightedMergedError = ((this.error * this.trees.size()) + (other.error * other.trees.size())) / (this.trees.size() + other.trees.size());
double[] mergedImportance = calculateImportance(mergedTrees, this.importance.length);
return new RandomForest(mergedTrees, weightedMergedError, mergedImportance);
}
private RandomForest(List trees, double error, double[] importance) {
this.trees = trees;
this.error = error;
this.importance = importance;
}
private static double[] calculateImportance(List trees, int featuresCount) {
double[] importance = new double[featuresCount];
for (RegressionTree tree : trees) {
double[] imp = tree.importance();
for (int i = 0; i < imp.length; i++) {
importance[i] += imp[i];
}
}
return importance;
}
/**
* Returns the out-of-bag estimation of RMSE. The OOB estimate is
* quite accurate given that enough trees have been grown. Otherwise the
* OOB estimate can bias upward.
*
* @return the out-of-bag estimation of RMSE
*/
public double error() {
return error;
}
/**
* Returns the variable importance. Every time a split of a node is made
* on variable the impurity criterion for the two descendent nodes is less
* than the parent node. Adding up the decreases for each individual
* variable over all trees in the forest gives a fast measure of variable
* importance that is often very consistent with the permutation importance
* measure.
*
* @return the variable importance
*/
public double[] importance() {
return importance;
}
/**
* Returns the number of trees in the model.
*
* @return the number of trees in the model
*/
public int size() {
return trees.size();
}
/**
* Trims the tree model set to a smaller size in case of over-fitting.
* Or if extra decision trees in the model don't improve the performance,
* we may remove them to reduce the model size and also improve the speed of
* prediction.
*
* @param ntrees the new (smaller) size of tree model set.
*/
public void trim(int ntrees) {
if (ntrees > trees.size()) {
throw new IllegalArgumentException("The new model size is larger than the current size.");
}
if (ntrees <= 0) {
throw new IllegalArgumentException("Invalid new model size: " + ntrees);
}
List model = new ArrayList<>(ntrees);
for (int i = 0; i < ntrees; i++) {
model.add(trees.get(i));
}
trees = model;
}
@Override
public double predict(double[] x) {
double y = 0;
for (RegressionTree tree : trees) {
y += tree.predict(x);
}
return y / trees.size();
}
/**
* Test the model on a validation dataset.
*
* @param x the test data set.
* @param y the test data response values.
* @return RMSEs with first 1, 2, ..., regression trees.
*/
public double[] test(double[][] x, double[] y) {
int T = trees.size();
double[] rmse = new double[T];
int n = x.length;
double[] sum = new double[n];
double[] prediction = new double[n];
RMSE measure = new RMSE();
for (int i = 0, nt = 1; i < T; i++, nt++) {
for (int j = 0; j < n; j++) {
sum[j] += trees.get(i).predict(x[j]);
prediction[j] = sum[j] / nt;
}
rmse[i] = measure.measure(y, prediction);
}
return rmse;
}
/**
* Test the model on a validation dataset.
*
* @param x the test data set.
* @param y the test data output values.
* @param measures the performance measures of regression.
* @return performance measures with first 1, 2, ..., regression trees.
*/
public double[][] test(double[][] x, double[] y, RegressionMeasure[] measures) {
int T = trees.size();
int m = measures.length;
double[][] results = new double[T][m];
int n = x.length;
double[] sum = new double[n];
double[] prediction = new double[n];
for (int i = 0, nt = 1; i < T; i++, nt++) {
for (int j = 0; j < n; j++) {
sum[j] += trees.get(i).predict(x[j]);
prediction[j] = sum[j] / nt;
}
for (int j = 0; j < m; j++) {
results[i][j] = measures[j].measure(y, prediction);
}
}
return results;
}
/**
* Returns the regression trees.
*/
public RegressionTree[] getTrees() {
return trees.toArray(new RegressionTree[trees.size()]);
}
}