weka.classifiers.trees.RandomForest Maven / Gradle / Ivy
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* RandomForest.java
* Copyright (C) 2001-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.trees;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.meta.Bagging;
import weka.core.Capabilities;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.Utils;
import weka.core.WekaException;
import weka.gui.ProgrammaticProperty;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Vector;
/**
* Class for constructing a forest of random trees.
*
* For more information see:
*
* Leo Breiman (2001). Random Forests. Machine Learning. 45(1):5-32.
*
*
*
* BibTeX:
*
*
* @article{Breiman2001,
* author = {Leo Breiman},
* journal = {Machine Learning},
* number = {1},
* pages = {5-32},
* title = {Random Forests},
* volume = {45},
* year = {2001}
* }
*
*
*
*
*
*
* Valid options are:
*
*
*
* -P
* Size of each bag, as a percentage of the
* training set size. (default 100)
*
*
*
* -O
* Calculate the out of bag error.
*
*
*
* -store-out-of-bag-predictions
* Whether to store out of bag predictions in internal evaluation object.
*
*
*
* -output-out-of-bag-complexity-statistics
* Whether to output complexity-based statistics when out-of-bag evaluation is performed.
*
*
*
* -print
* Print the individual classifiers in the output
*
*
*
* -attribute-importance
* Compute and output attribute importance (mean impurity decrease method)
*
*
*
* -I <num>
* Number of iterations.
* (current value 100)
*
*
*
* -num-slots <num>
* Number of execution slots.
* (default 1 - i.e. no parallelism)
* (use 0 to auto-detect number of cores)
*
*
*
* -K <number of attributes>
* Number of attributes to randomly investigate. (default 0)
* (<1 = int(log_2(#predictors)+1)).
*
*
*
* -M <minimum number of instances>
* Set minimum number of instances per leaf.
* (default 1)
*
*
*
* -V <minimum variance for split>
* Set minimum numeric class variance proportion
* of train variance for split (default 1e-3).
*
*
*
* -S <num>
* Seed for random number generator.
* (default 1)
*
*
*
* -depth <num>
* The maximum depth of the tree, 0 for unlimited.
* (default 0)
*
*
*
* -N <num>
* Number of folds for backfitting (default 0, no backfitting).
*
*
*
* -U
* Allow unclassified instances.
*
*
*
* -B
* Break ties randomly when several attributes look equally good.
*
*
*
* -output-debug-info
* If set, classifier is run in debug mode and
* may output additional info to the console
*
*
*
* -do-not-check-capabilities
* If set, classifier capabilities are not checked before classifier is built
* (use with caution).
*
*
*
* -num-decimal-places
* The number of decimal places for the output of numbers in the model (default 2).
*
*
*
* -batch-size
* The desired batch size for batch prediction (default 100).
*
*
*
*
* @author Richard Kirkby ([email protected])
* @version $Revision: 13295 $
*/
public class RandomForest extends Bagging {
/** for serialization */
static final long serialVersionUID = 1116839470751428698L;
/** True to compute attribute importance */
protected boolean m_computeAttributeImportance;
/**
* The default number of iterations to perform.
*/
@Override
protected int defaultNumberOfIterations() {
return 100;
}
/**
* Constructor that sets base classifier for bagging to RandomTre and default
* number of iterations to 100.
*/
public RandomForest() {
RandomTree rTree = new RandomTree();
rTree.setDoNotCheckCapabilities(true);
super.setClassifier(rTree);
super.setRepresentCopiesUsingWeights(true);
setNumIterations(defaultNumberOfIterations());
}
/**
* Returns default capabilities of the base classifier.
*
* @return the capabilities of the base classifier
*/
public Capabilities getCapabilities() {
// Cannot use the main RandomTree object because capabilities checking has
// been turned off
// for that object.
return (new RandomTree()).getCapabilities();
}
/**
* String describing default classifier.
*
* @return the default classifier classname
*/
@Override
protected String defaultClassifierString() {
return "weka.classifiers.trees.RandomTree";
}
/**
* String describing default classifier options.
*
* @return the default classifier options
*/
@Override
protected String[] defaultClassifierOptions() {
String[] args = { "-do-not-check-capabilities" };
return args;
}
/**
* Returns a string describing classifier
*
* @return a description suitable for displaying in the explorer/experimenter
* gui
*/
public String globalInfo() {
return "Class for constructing a forest of random trees.\n\n"
+ "For more information see: \n\n" + getTechnicalInformation().toString();
}
/**
* Returns an instance of a TechnicalInformation object, containing detailed
* information about the technical background of this class, e.g., paper
* reference or book this class is based on.
*
* @return the technical information about this class
*/
@Override
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation result;
result = new TechnicalInformation(Type.ARTICLE);
result.setValue(Field.AUTHOR, "Leo Breiman");
result.setValue(Field.YEAR, "2001");
result.setValue(Field.TITLE, "Random Forests");
result.setValue(Field.JOURNAL, "Machine Learning");
result.setValue(Field.VOLUME, "45");
result.setValue(Field.NUMBER, "1");
result.setValue(Field.PAGES, "5-32");
return result;
}
/**
* This method only accepts RandomTree arguments.
*
* @param newClassifier the RandomTree to use.
* @exception if argument is not a RandomTree
*/
@Override
@ProgrammaticProperty
public void setClassifier(Classifier newClassifier) {
if (!(newClassifier instanceof RandomTree)) {
throw new IllegalArgumentException(
"RandomForest: Argument of setClassifier() must be a RandomTree.");
}
super.setClassifier(newClassifier);
}
/**
* This method only accepts true as its argument
*
* @param representUsingWeights must be set to true.
* @exception if argument is not true
*/
@Override
@ProgrammaticProperty
public void setRepresentCopiesUsingWeights(boolean representUsingWeights) {
if (!representUsingWeights) {
throw new IllegalArgumentException(
"RandomForest: Argument of setRepresentCopiesUsingWeights() must be true.");
}
super.setRepresentCopiesUsingWeights(representUsingWeights);
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String numFeaturesTipText() {
return ((RandomTree) getClassifier()).KValueTipText();
}
/**
* Get the number of features used in random selection.
*
* @return Value of numFeatures.
*/
public int getNumFeatures() {
return ((RandomTree) getClassifier()).getKValue();
}
/**
* Set the number of features to use in random selection.
*
* @param newNumFeatures Value to assign to numFeatures.
*/
public void setNumFeatures(int newNumFeatures) {
((RandomTree) getClassifier()).setKValue(newNumFeatures);
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String computeAttributeImportanceTipText() {
return "Compute attribute importance via mean impurity decrease";
}
/**
* Set whether to compute and output attribute importance scores
*
* @param computeAttributeImportance true to compute attribute importance
* scores
*/
public void setComputeAttributeImportance(boolean computeAttributeImportance) {
m_computeAttributeImportance = computeAttributeImportance;
((RandomTree)m_Classifier).setComputeImpurityDecreases(computeAttributeImportance);
}
/**
* Get whether to compute and output attribute importance scores
*
* @return true if computing attribute importance scores
*/
public boolean getComputeAttributeImportance() {
return m_computeAttributeImportance;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String maxDepthTipText() {
return ((RandomTree) getClassifier()).maxDepthTipText();
}
/**
* Get the maximum depth of trh tree, 0 for unlimited.
*
* @return the maximum depth.
*/
public int getMaxDepth() {
return ((RandomTree) getClassifier()).getMaxDepth();
}
/**
* Set the maximum depth of the tree, 0 for unlimited.
*
* @param value the maximum depth.
*/
public void setMaxDepth(int value) {
((RandomTree) getClassifier()).setMaxDepth(value);
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String breakTiesRandomlyTipText() {
return ((RandomTree) getClassifier()).breakTiesRandomlyTipText();
}
/**
* Get whether to break ties randomly.
*
* @return true if ties are to be broken randomly.
*/
public boolean getBreakTiesRandomly() {
return ((RandomTree) getClassifier()).getBreakTiesRandomly();
}
/**
* Set whether to break ties randomly.
*
* @param newBreakTiesRandomly true if ties are to be broken randomly
*/
public void setBreakTiesRandomly(boolean newBreakTiesRandomly) {
((RandomTree) getClassifier()).setBreakTiesRandomly(newBreakTiesRandomly);
}
/**
* Set debugging mode.
*
* @param debug true if debug output should be printed
*/
public void setDebug(boolean debug) {
super.setDebug(debug);
((RandomTree) getClassifier()).setDebug(debug);
}
/**
* Set the number of decimal places.
*/
public void setNumDecimalPlaces(int num) {
super.setNumDecimalPlaces(num);
((RandomTree) getClassifier()).setNumDecimalPlaces(num);
}
/**
* Set the preferred batch size for batch prediction.
*
* @param size the batch size to use
*/
@Override
public void setBatchSize(String size) {
super.setBatchSize(size);
((RandomTree) getClassifier()).setBatchSize(size);
}
/**
* Sets the seed for the random number generator.
*
* @param s the seed to be used
*/
public void setSeed(int s) {
super.setSeed(s);
((RandomTree) getClassifier()).setSeed(s);
}
/**
* Returns description of the bagged classifier.
*
* @return description of the bagged classifier as a string
*/
@Override
public String toString() {
if (m_Classifiers == null) {
return "RandomForest: No model built yet.";
}
StringBuilder buffer = new StringBuilder("RandomForest\n\n");
buffer.append(super.toString());
if (getComputeAttributeImportance()) {
try {
double[] nodeCounts = new double[m_data.numAttributes()];
double[] impurityScores =
computeAverageImpurityDecreasePerAttribute(nodeCounts);
int[] sortedIndices = Utils.sort(impurityScores);
buffer
.append("\n\nAttribute importance based on average impurity decrease "
+ "(and number of nodes using that attribute)\n\n");
for (int i = sortedIndices.length - 1; i >= 0; i--) {
int index = sortedIndices[i];
if (index != m_data.classIndex()) {
buffer
.append(
Utils.doubleToString(impurityScores[index], 10,
getNumDecimalPlaces())).append(" (")
.append(Utils.doubleToString(nodeCounts[index], 6, 0))
.append(") ").append(m_data.attribute(index).name())
.append("\n");
}
}
} catch (WekaException ex) {
// ignore
}
}
return buffer.toString();
}
/**
* Computes the average impurity decrease per attribute over the trees
*
* @param nodeCounts an optional array that, if non-null, will hold the count
* of the number of nodes at which each attribute was used for
* splitting
* @return the average impurity decrease per attribute over the trees
*/
public double[] computeAverageImpurityDecreasePerAttribute(
double[] nodeCounts) throws WekaException {
if (m_Classifiers == null) {
throw new WekaException("Classifier has not been built yet!");
}
if (!getComputeAttributeImportance()) {
throw new WekaException("Stats for attribute importance have not "
+ "been collected!");
}
double[] impurityDecreases = new double[m_data.numAttributes()];
if (nodeCounts == null) {
nodeCounts = new double[m_data.numAttributes()];
}
for (Classifier c : m_Classifiers) {
double[][] forClassifier = ((RandomTree) c).getImpurityDecreases();
for (int i = 0; i < m_data.numAttributes(); i++) {
impurityDecreases[i] += forClassifier[i][0];
nodeCounts[i] += forClassifier[i][1];
}
}
for (int i = 0; i < m_data.numAttributes(); i++) {
if (nodeCounts[i] > 0) {
impurityDecreases[i] /= nodeCounts[i];
}
}
return impurityDecreases;
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options
*/
@Override
public Enumeration
© 2015 - 2025 Weber Informatics LLC | Privacy Policy