All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.classification.GradientTreeBoost Maven / Gradle / Ivy

The newest version!
/*******************************************************************************
 * Copyright (c) 2010 Haifeng Li
 *   
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *  
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/
package smile.classification;

import java.util.Arrays;

import smile.data.Attribute;
import smile.data.NumericAttribute;
import smile.math.Math;
import smile.regression.RegressionTree;
import smile.util.SmileUtils;
import smile.validation.Accuracy;
import smile.validation.ClassificationMeasure;

/**
 * Gradient boosting for classification. Gradient boosting is typically used
 * with decision trees (especially CART regression trees) of a fixed size as
 * base learners. For this special case Friedman proposes a modification to
 * gradient boosting method which improves the quality of fit of each base
 * learner.
 * 

* Generic gradient boosting at the t-th step would fit a regression tree to * pseudo-residuals. Let J be the number of its leaves. The tree partitions * the input space into J disjoint regions and predicts a constant value in * each region. The parameter J controls the maximum allowed * level of interaction between variables in the model. With J = 2 (decision * stumps), no interaction between variables is allowed. With J = 3 the model * may include effects of the interaction between up to two variables, and * so on. Hastie et al. comment that typically 4 ≤ J ≤ 8 work well * for boosting and results are fairly insensitive to the choice of in * this range, J = 2 is insufficient for many applications, and J > 10 is * unlikely to be required. *

* Fitting the training set too closely can lead to degradation of the model's * generalization ability. Several so-called regularization techniques reduce * this over-fitting effect by constraining the fitting procedure. * One natural regularization parameter is the number of gradient boosting * iterations T (i.e. the number of trees in the model when the base learner * is a decision tree). Increasing T reduces the error on training set, * but setting it too high may lead to over-fitting. An optimal value of T * is often selected by monitoring prediction error on a separate validation * data set. *

* Another regularization approach is the shrinkage which times a parameter * η (called the "learning rate") to update term. * Empirically it has been found that using small learning rates (such as * η < 0.1) yields dramatic improvements in model's generalization ability * over gradient boosting without shrinking (η = 1). However, it comes at * the price of increasing computational time both during training and * prediction: lower learning rate requires more iterations. *

* Soon after the introduction of gradient boosting Friedman proposed a * minor modification to the algorithm, motivated by Breiman's bagging method. * Specifically, he proposed that at each iteration of the algorithm, a base * learner should be fit on a subsample of the training set drawn at random * without replacement. Friedman observed a substantional improvement in * gradient boosting's accuracy with this modification. *

* Subsample size is some constant fraction f of the size of the training set. * When f = 1, the algorithm is deterministic and identical to the one * described above. Smaller values of f introduce randomness into the * algorithm and help prevent over-fitting, acting as a kind of regularization. * The algorithm also becomes faster, because regression trees have to be fit * to smaller datasets at each iteration. Typically, f is set to 0.5, meaning * that one half of the training set is used to build each base learner. *

* Also, like in bagging, sub-sampling allows one to define an out-of-bag * estimate of the prediction performance improvement by evaluating predictions * on those observations which were not used in the building of the next * base learner. Out-of-bag estimates help avoid the need for an independent * validation dataset, but often underestimate actual performance improvement * and the optimal number of iterations. *

* Gradient tree boosting implementations often also use regularization by * limiting the minimum number of observations in trees' terminal nodes. * It's used in the tree building process by ignoring any splits that lead * to nodes containing fewer than this number of training set instances. * Imposing this limit helps to reduce variance in predictions at leaves. * *

References

*
    *
  1. J. H. Friedman. Greedy Function Approximation: A Gradient Boosting Machine, 1999.
  2. *
  3. J. H. Friedman. Stochastic Gradient Boosting, 1999.
  4. *
* * @author Haifeng Li */ public class GradientTreeBoost implements Classifier { /** * The number of classes. */ private int k = 2; /** * Forest of regression trees for binary classification. */ private RegressionTree[] trees; /** * Forest of regression trees for multi-class classification. */ private RegressionTree[][] forest; /** * Variable importance. Every time a split of a node is made on variable * the impurity criterion for the two descendent nodes is less than the * parent node. Adding up the decreases for each individual variable over * all trees in the forest gives a simple variable importance. */ private double[] importance; /** * The intercept. */ private double b = 0.0; /** * The shrinkage parameter in (0, 1] controls the learning rate of procedure. */ private double shrinkage = 0.005; /** * The number of leaves in each tree. */ private int J = 6; /** * The number of trees. */ private int T = 500; /** * The sampling rate for stochastic tree boosting. */ private double f = 0.7; /** * Trainer for GradientTreeBoost classifiers. */ public static class Trainer extends ClassifierTrainer { /** * The number of trees. */ private int T = 500; /** * The shrinkage parameter in (0, 1] controls the learning rate of procedure. */ private double shrinkage = 0.005; /** * The number of leaves in each tree. */ private int J = 6; /** * The sampling rate for stochastic tree boosting. */ private double f = 0.7; /** * Default constructor of 500 trees. */ public Trainer() { } /** * Constructor. * * @param T the number of trees. */ public Trainer(int T) { if (T < 1) { throw new IllegalArgumentException("Invalid number of trees: " + T); } this.T = T; } /** * Constructor. * * @param attributes the attributes of independent variable. * @param T the number of trees. */ public Trainer(Attribute[] attributes, int T) { super(attributes); if (T < 1) { throw new IllegalArgumentException("Invalid number of trees: " + T); } this.T = T; } /** * Sets the number of trees in the random forest. * @param T the number of trees. */ public Trainer setNumTrees(int T) { if (T < 1) { throw new IllegalArgumentException("Invalid number of trees: " + T); } this.T = T; return this; } /** * Sets the maximum number of leaf nodes in the tree. * @param J the maximum number of leaf nodes in the tree. */ public Trainer setMaximumLeafNodes(int J) { if (J < 2) { throw new IllegalArgumentException("Invalid number of leaf nodes: " + J); } this.J = J; return this; } /** * Sets the shrinkage parameter in (0, 1] controls the learning rate of procedure. * @param shrinkage the learning rate. */ public Trainer setShrinkage(double shrinkage) { if (shrinkage <= 0 || shrinkage > 1) { throw new IllegalArgumentException("Invalid shrinkage: " + shrinkage); } this.shrinkage = shrinkage; return this; } /** * Sets the sampling rate for stochastic tree boosting. * @param f the sampling rate for stochastic tree boosting. */ public Trainer setSamplingRates(double f) { if (f <= 0 || f > 1) { throw new IllegalArgumentException("Invalid sampling fraction: " + f); } this.f = f; return this; } @Override public GradientTreeBoost train(double[][] x, int[] y) { return new GradientTreeBoost(attributes, x, y, T, J, shrinkage, f); } } /** * Constructor. Learns a gradient tree boosting for classification. * @param x the training instances. * @param y the class labels. * @param T the number of iterations (trees). */ public GradientTreeBoost(double[][] x, int[] y, int T) { this(null, x, y, T); } /** * Constructor. Learns a gradient tree boosting for classification. * * @param x the training instances. * @param y the class labels. * @param T the number of iterations (trees). * @param J the number of leaves in each tree. * @param shrinkage the shrinkage parameter in (0, 1] controls the learning rate of procedure. * @param f the sampling rate for stochastic tree boosting. */ public GradientTreeBoost(double[][] x, int[] y, int T, int J, double shrinkage, double f) { this(null, x, y, T, J, shrinkage, f); } /** * Constructor. Learns a gradient tree boosting for classification. * * @param attributes the attribute properties. * @param x the training instances. * @param y the class labels. * @param T the number of iterations (trees). */ public GradientTreeBoost(Attribute[] attributes, double[][] x, int[] y, int T) { this(attributes, x, y, T, 6, x.length < 2000 ? 0.005 : 0.05, 0.7); } /** * Constructor. Learns a gradient tree boosting for classification. * * @param attributes the attribute properties. * @param x the training instances. * @param y the class labels. * @param T the number of iterations (trees). * @param J the number of leaves in each tree. * @param shrinkage the shrinkage parameter in (0, 1] controls the learning rate of procedure. * @param f the sampling fraction for stochastic tree boosting. */ public GradientTreeBoost(Attribute[] attributes, double[][] x, int[] y, int T, int J, double shrinkage, double f) { if (x.length != y.length) { throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length)); } if (shrinkage <= 0 || shrinkage > 1) { throw new IllegalArgumentException("Invalid shrinkage: " + shrinkage); } if (f <= 0 || f > 1) { throw new IllegalArgumentException("Invalid sampling fraction: " + f); } if (attributes == null) { int p = x[0].length; attributes = new Attribute[p]; for (int i = 0; i < p; i++) { attributes[i] = new NumericAttribute("V" + (i + 1)); } } this.T = T; this.J = J; this.shrinkage = shrinkage; this.f = f; this.k = Math.max(y) + 1; if (k < 2) { throw new IllegalArgumentException("Only one class or negative class labels."); } importance = new double[attributes.length]; if (k == 2) { train2(attributes, x, y); for (RegressionTree tree : trees) { double[] imp = tree.importance(); for (int i = 0; i < imp.length; i++) { importance[i] += imp[i]; } } } else { traink(attributes, x, y); for (RegressionTree[] grove : forest) { for (RegressionTree tree : grove) { double[] imp = tree.importance(); for (int i = 0; i < imp.length; i++) { importance[i] += imp[i]; } } } } } /** * Returns the variable importance. Every time a split of a node is made * on variable the impurity criterion for the two descendent nodes is less * than the parent node. Adding up the decreases for each individual * variable over all trees in the forest gives a simple measure of variable * importance. * * @return the variable importance */ public double[] importance() { return importance; } /** * Train L2 tree boost. */ private void train2(Attribute[] attributes, double[][] x, int[] y) { int n = x.length; int N = (int) Math.round(n * f); int[] y2 = new int[n]; for (int i = 0; i < n; i++) { if (y[i] == 1) { y2[i] = 1; } else { y2[i] = -1; } } y = y2; double[] h = new double[n]; // current F(x_i) double[] response = new double[n]; // response variable for regression tree. double mu = Math.mean(y); b = 0.5 * Math.log((1 + mu) / (1 - mu)); for (int i = 0; i < n; i++) { h[i] = b; } int[][] order = SmileUtils.sort(attributes, x); RegressionTree.NodeOutput output = new L2NodeOutput(response); trees = new RegressionTree[T]; int[] perm = new int[n]; int[] samples = new int[n]; for (int i = 0; i < n; i++) { perm[i] = i; } for (int m = 0; m < T; m++) { Arrays.fill(samples, 0); Math.permutate(perm); for (int i = 0; i < N; i++) { samples[perm[i]] = 1; } for (int i = 0; i < n; i++) { response[i] = 2.0 * y[i] / (1 + Math.exp(2 * y[i] * h[i])); } trees[m] = new RegressionTree(attributes, x, response, J, order, samples, output); for (int i = 0; i < n; i++) { h[i] += shrinkage * trees[m].predict(x[i]); } } } /** * Train L-k tree boost. */ private void traink(Attribute[] attributes, double[][] x, int[] y) { int n = x.length; int N = (int) Math.round(n * f); double[][] h = new double[k][n]; // boost tree output. double[][] p = new double[k][n]; // posteriori probabilities. double[][] response = new double[k][n]; // pseudo response. int[][] order = SmileUtils.sort(attributes, x); forest = new RegressionTree[k][T]; RegressionTree.NodeOutput[] output = new LKNodeOutput[k]; for (int i = 0; i < k; i++) { output[i] = new LKNodeOutput(response[i]); } int[] perm = new int[n]; int[] samples = new int[n]; for (int i = 0; i < n; i++) { perm[i] = i; } for (int m = 0; m < T; m++) { for (int i = 0; i < n; i++) { double max = Double.NEGATIVE_INFINITY; for (int j = 0; j < k; j++) { if (max < h[j][i]) { max = h[j][i]; } } double Z = 0.0; for (int j = 0; j < k; j++) { p[j][i] = Math.exp(h[j][i] - max); Z += p[j][i]; } for (int j = 0; j < k; j++) { p[j][i] /= Z; } } for (int j = 0; j < k; j++) { for (int i = 0; i < n; i++) { if (y[i] == j) { response[j][i] = 1.0; } else { response[j][i] = 0.0; } response[j][i] -= p[j][i]; } Arrays.fill(samples, 0); Math.permutate(perm); for (int i = 0; i < N; i++) { samples[perm[i]] = 1; } forest[j][m] = new RegressionTree(attributes, x, response[j], J, order, samples, output[j]); for (int i = 0; i < n; i++) { h[j][i] += shrinkage * forest[j][m].predict(x[i]); } } } } /** * Class to calculate node output for two-class logistic regression. */ class L2NodeOutput implements RegressionTree.NodeOutput { /** * Pseudo response to fit. */ double[] y; /** * Constructor. * @param y pseudo response to fit. */ public L2NodeOutput(double[] y) { this.y = y; } @Override public double calculate(int[] samples) { double nu = 0.0; double de = 0.0; for (int i = 0; i < samples.length; i++) { if (samples[i] > 0) { double abs = Math.abs(y[i]); nu += y[i]; de += abs * (2.0 - abs); } } return nu / de; } } /** * Class to calculate node output for multi-class logistic regression. */ class LKNodeOutput implements RegressionTree.NodeOutput { /** * Responses to fit. */ double[] y; /** * Constructor. * @param response response to fit. */ public LKNodeOutput(double[] response) { this.y = response; } @Override public double calculate(int[] samples) { int n = 0; double nu = 0.0; double de = 0.0; for (int i = 0; i < samples.length; i++) { if (samples[i] > 0) { n++; double abs = Math.abs(y[i]); nu += y[i]; de += abs * (1.0 - abs); } } if (de < 1E-10) { return nu / n; } return ((k-1.0) / k) * (nu / de); } } /** * Returns the number of trees in the model. * * @return the number of trees in the model */ public int size() { return trees.length; } /** * Trims the tree model set to a smaller size in case of over-fitting. * Or if extra decision trees in the model don't improve the performance, * we may remove them to reduce the model size and also improve the speed of * prediction. * * @param T the new (smaller) size of tree model set. */ public void trim(int T) { if (k == 2) { if (T > trees.length) { throw new IllegalArgumentException("The new model size is larger than the current size."); } if (T <= 0) { throw new IllegalArgumentException("Invalid new model size: " + T); } if (T < trees.length) { trees = Arrays.copyOf(trees, T); } } else { if (T > forest[0].length) { throw new IllegalArgumentException("The new model size is larger than the current one."); } if (T <= 0) { throw new IllegalArgumentException("Invalid new model size: " + T); } if (T < forest[0].length) { for (int i = 0; i < forest.length; i++) { forest[i] = Arrays.copyOf(forest[i], T); } } } } @Override public int predict(double[] x) { if (k == 2) { double y = b; for (int i = 0; i < T; i++) { y += shrinkage * trees[i].predict(x); } return y > 0 ? 1 : 0; } else { double max = Double.NEGATIVE_INFINITY; int y = -1; for (int j = 0; j < k; j++) { double yj = 0.0; for (int i = 0; i < T; i++) { yj += shrinkage * forest[j][i].predict(x); } if (yj > max) { max = yj; y = j; } } return y; } } @Override public int predict(double[] x, double[] posteriori) { if (posteriori.length != k) { throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, k)); } if (k == 2) { double y = b; for (int i = 0; i < T; i++) { y += shrinkage * trees[i].predict(x); } posteriori[0] = 1.0 / (1.0 + Math.exp(2*y)); posteriori[1] = 1.0 - posteriori[0]; if (y > 0) { return 1; } else { return 0; } } else { double max = Double.NEGATIVE_INFINITY; int y = -1; for (int j = 0; j < k; j++) { posteriori[j] = 0.0; for (int i = 0; i < T; i++) { posteriori[j] += shrinkage * forest[j][i].predict(x); } if (posteriori[j] > max) { max = posteriori[j]; y = j; } } double Z = 0.0; for (int i = 0; i < k; i++) { posteriori[i] = Math.exp(posteriori[i] - max); Z += posteriori[i]; } for (int i = 0; i < k; i++) { posteriori[i] /= Z; } return y; } } /** * Test the model on a validation dataset. * * @param x the test data set. * @param y the test data response values. * @return accuracies with first 1, 2, ..., decision trees. */ public double[] test(double[][] x, int[] y) { double[] accuracy = new double[T]; int n = x.length; int[] label = new int[n]; Accuracy measure = new Accuracy(); if (k == 2) { double[] prediction = new double[n]; Arrays.fill(prediction, b); for (int i = 0; i < T; i++) { for (int j = 0; j < n; j++) { prediction[j] += shrinkage * trees[i].predict(x[j]); label[j] = prediction[j] > 0 ? 1 : 0; } accuracy[i] = measure.measure(y, label); } } else { double[][] prediction = new double[n][k]; for (int i = 0; i < T; i++) { for (int j = 0; j < n; j++) { for (int l = 0; l < k; l++) { prediction[j][l] += shrinkage * forest[l][i].predict(x[j]); } label[j] = Math.whichMax(prediction[j]); } accuracy[i] = measure.measure(y, label); } } return accuracy; } /** * Test the model on a validation dataset. * * @param x the test data set. * @param y the test data labels. * @param measures the performance measures of classification. * @return performance measures with first 1, 2, ..., decision trees. */ public double[][] test(double[][] x, int[] y, ClassificationMeasure[] measures) { int m = measures.length; double[][] results = new double[T][m]; int n = x.length; int[] label = new int[n]; if (k == 2) { double[] prediction = new double[n]; Arrays.fill(prediction, b); for (int i = 0; i < T; i++) { for (int j = 0; j < n; j++) { prediction[j] += shrinkage * trees[i].predict(x[j]); label[j] = prediction[j] > 0 ? 1 : 0; } for (int j = 0; j < m; j++) { results[i][j] = measures[j].measure(y, label); } } } else { double[][] prediction = new double[n][k]; for (int i = 0; i < T; i++) { for (int j = 0; j < n; j++) { for (int l = 0; l < k; l++) { prediction[j][l] += shrinkage * forest[l][i].predict(x[j]); } label[j] = Math.whichMax(prediction[j]); } for (int j = 0; j < m; j++) { results[i][j] = measures[j].measure(y, label); } } } return results; } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy