weka.classifiers.trees.ADTree Maven / Gradle / Ivy
Show all versions of weka-stable Show documentation
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* ADTree.java
* Copyright (C) 2001 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.trees;
import weka.classifiers.Classifier;
import weka.classifiers.IterativeClassifier;
import weka.classifiers.trees.adtree.PredictionNode;
import weka.classifiers.trees.adtree.ReferenceInstances;
import weka.classifiers.trees.adtree.Splitter;
import weka.classifiers.trees.adtree.TwoWayNominalSplit;
import weka.classifiers.trees.adtree.TwoWayNumericSplit;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Drawable;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.SerializedObject;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.Capabilities.Capability;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
/**
* Class for generating an alternating decision tree. The basic algorithm is based on:
*
* Freund, Y., Mason, L.: The alternating decision tree learning algorithm. In: Proceeding of the Sixteenth International Conference on Machine Learning, Bled, Slovenia, 124-133, 1999.
*
* This version currently only supports two-class problems. The number of boosting iterations needs to be manually tuned to suit the dataset and the desired complexity/accuracy tradeoff. Induction of the trees has been optimized, and heuristic search methods have been introduced to speed learning.
*
*
* BibTeX:
*
* @inproceedings{Freund1999,
* address = {Bled, Slovenia},
* author = {Freund, Y. and Mason, L.},
* booktitle = {Proceeding of the Sixteenth International Conference on Machine Learning},
* pages = {124-133},
* title = {The alternating decision tree learning algorithm},
* year = {1999}
* }
*
*
*
* Valid options are:
*
* -B <number of boosting iterations>
* Number of boosting iterations.
* (Default = 10)
*
* -E <-3|-2|-1|>=0>
* Expand nodes: -3(all), -2(weight), -1(z_pure), >=0 seed for random walk
* (Default = -3)
*
* -D
* Save the instance data with the model
*
*
* @author Richard Kirkby ([email protected])
* @author Bernhard Pfahringer ([email protected])
* @version $Revision: 10290 $
*/
public class ADTree
extends Classifier
implements OptionHandler, Drawable, AdditionalMeasureProducer,
WeightedInstancesHandler, IterativeClassifier,
TechnicalInformationHandler {
/** for serialization */
static final long serialVersionUID = -1532264837167690683L;
/**
* Returns a string describing classifier
* @return a description suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Class for generating an alternating decision tree. The basic "
+ "algorithm is based on:\n\n"
+ getTechnicalInformation().toString() + "\n\n"
+ "This version currently only supports two-class problems. The number of boosting "
+ "iterations needs to be manually tuned to suit the dataset and the desired "
+ "complexity/accuracy tradeoff. Induction of the trees has been optimized, and heuristic "
+ "search methods have been introduced to speed learning.";
}
/** search mode: Expand all paths */
public static final int SEARCHPATH_ALL = 0;
/** search mode: Expand the heaviest path */
public static final int SEARCHPATH_HEAVIEST = 1;
/** search mode: Expand the best z-pure path */
public static final int SEARCHPATH_ZPURE = 2;
/** search mode: Expand a random path */
public static final int SEARCHPATH_RANDOM = 3;
/** The search modes */
public static final Tag [] TAGS_SEARCHPATH = {
new Tag(SEARCHPATH_ALL, "Expand all paths"),
new Tag(SEARCHPATH_HEAVIEST, "Expand the heaviest path"),
new Tag(SEARCHPATH_ZPURE, "Expand the best z-pure path"),
new Tag(SEARCHPATH_RANDOM, "Expand a random path")
};
/** The instances used to train the tree */
protected Instances m_trainInstances;
/** The root of the tree */
protected PredictionNode m_root = null;
/** The random number generator - used for the random search heuristic */
protected Random m_random = null;
/** The number of the last splitter added to the tree */
protected int m_lastAddedSplitNum = 0;
/** An array containing the inidices to the numeric attributes in the data */
protected int[] m_numericAttIndices;
/** An array containing the inidices to the nominal attributes in the data */
protected int[] m_nominalAttIndices;
/** The total weight of the instances - used to speed Z calculations */
protected double m_trainTotalWeight;
/** The training instances with positive class - referencing the training dataset */
protected ReferenceInstances m_posTrainInstances;
/** The training instances with negative class - referencing the training dataset */
protected ReferenceInstances m_negTrainInstances;
/** The best node to insert under, as found so far by the latest search */
protected PredictionNode m_search_bestInsertionNode;
/** The best splitter to insert, as found so far by the latest search */
protected Splitter m_search_bestSplitter;
/** The smallest Z value found so far by the latest search */
protected double m_search_smallestZ;
/** The positive instances that apply to the best path found so far */
protected Instances m_search_bestPathPosInstances;
/** The negative instances that apply to the best path found so far */
protected Instances m_search_bestPathNegInstances;
/** Statistics - the number of prediction nodes investigated during search */
protected int m_nodesExpanded = 0;
/** Statistics - the number of instances processed during search */
protected int m_examplesCounted = 0;
/** Option - the number of boosting iterations o perform */
protected int m_boostingIterations = 10;
/** Option - the search mode */
protected int m_searchPath = 0;
/** Option - the seed to use for a random search */
protected int m_randomSeed = 0;
/** Option - whether the tree should remember the instance data */
protected boolean m_saveInstanceData = false;
/**
* Returns an instance of a TechnicalInformation object, containing
* detailed information about the technical background of this class,
* e.g., paper reference or book this class is based on.
*
* @return the technical information about this class
*/
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation result;
result = new TechnicalInformation(Type.INPROCEEDINGS);
result.setValue(Field.AUTHOR, "Freund, Y. and Mason, L.");
result.setValue(Field.YEAR, "1999");
result.setValue(Field.TITLE, "The alternating decision tree learning algorithm");
result.setValue(Field.BOOKTITLE, "Proceeding of the Sixteenth International Conference on Machine Learning");
result.setValue(Field.ADDRESS, "Bled, Slovenia");
result.setValue(Field.PAGES, "124-133");
return result;
}
/**
* Sets up the tree ready to be trained, using two-class optimized method.
*
* @param instances the instances to train the tree with
* @exception Exception if training data is unsuitable
*/
public void initClassifier(Instances instances) throws Exception {
// clear stats
m_nodesExpanded = 0;
m_examplesCounted = 0;
m_lastAddedSplitNum = 0;
// prepare the random generator
m_random = new Random(m_randomSeed);
// create training set
m_trainInstances = new Instances(instances);
// create positive/negative subsets
m_posTrainInstances = new ReferenceInstances(m_trainInstances,
m_trainInstances.numInstances());
m_negTrainInstances = new ReferenceInstances(m_trainInstances,
m_trainInstances.numInstances());
for (Enumeration e = m_trainInstances.enumerateInstances(); e.hasMoreElements(); ) {
Instance inst = (Instance) e.nextElement();
if ((int) inst.classValue() == 0)
m_negTrainInstances.addReference(inst); // belongs in negative class
else
m_posTrainInstances.addReference(inst); // belongs in positive class
}
m_posTrainInstances.compactify();
m_negTrainInstances.compactify();
// create the root prediction node
double rootPredictionValue = calcPredictionValue(m_posTrainInstances,
m_negTrainInstances);
m_root = new PredictionNode(rootPredictionValue);
// pre-adjust weights
updateWeights(m_posTrainInstances, m_negTrainInstances, rootPredictionValue);
// pre-calculate what we can
generateAttributeIndicesSingle();
}
/**
* Performs one iteration.
*
* @param iteration the index of the current iteration (0-based)
* @exception Exception if this iteration fails
*/
public void next(int iteration) throws Exception {
boost();
}
/**
* Performs a single boosting iteration, using two-class optimized method.
* Will add a new splitter node and two prediction nodes to the tree
* (unless merging takes place).
*
* @exception Exception if try to boost without setting up tree first or there are no
* instances to train with
*/
public void boost() throws Exception {
if (m_trainInstances == null || m_trainInstances.numInstances() == 0)
throw new Exception("Trying to boost with no training data");
// perform the search
searchForBestTestSingle();
if (m_search_bestSplitter == null) return; // handle empty instances
// create the new nodes for the tree, updating the weights
for (int i=0; i<2; i++) {
Instances posInstances =
m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathPosInstances);
Instances negInstances =
m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathNegInstances);
double predictionValue = calcPredictionValue(posInstances, negInstances);
PredictionNode newPredictor = new PredictionNode(predictionValue);
updateWeights(posInstances, negInstances, predictionValue);
m_search_bestSplitter.setChildForBranch(i, newPredictor);
}
// insert the new nodes
m_search_bestInsertionNode.addChild((Splitter) m_search_bestSplitter, this);
// free memory
m_search_bestPathPosInstances = null;
m_search_bestPathNegInstances = null;
m_search_bestSplitter = null;
}
/**
* Generates the m_nominalAttIndices and m_numericAttIndices arrays to index
* the respective attribute types in the training data.
*
*/
private void generateAttributeIndicesSingle() {
// insert indices into vectors
FastVector nominalIndices = new FastVector();
FastVector numericIndices = new FastVector();
for (int i=0; i= m_search_smallestZ) return;
// keep stats
m_nodesExpanded++;
m_examplesCounted += posInstances.numInstances() + negInstances.numInstances();
// evaluate static splitters (nominal)
for (int i=0; i 0) {
// merge the two sets of instances into one
Instances allInstances = new Instances(posInstances);
for (Enumeration e = negInstances.enumerateInstances(); e.hasMoreElements(); )
allInstances.add((Instance) e.nextElement());
// use method of finding the optimal Z split-point
for (int i=0; i largestWeight) {
heaviestSplit = split;
heaviestBranch = i;
largestWeight = weight;
}
}
}
if (heaviestSplit != null)
searchForBestTestSingle(heaviestSplit.getChildForBranch(heaviestBranch),
heaviestSplit.instancesDownBranch(heaviestBranch,
posInstances),
heaviestSplit.instancesDownBranch(heaviestBranch,
negInstances));
}
/**
* Continues single (two-class optimized) search by investigating only the path
* with the best Z-pure value at each branch.
*
* @param currentNode the root of the subtree to be searched
* @param posInstances the positive-class instances that apply at this node
* @param negInstances the negative-class instances that apply at this node
* @exception Exception if search fails
*/
private void goDownZpurePathSingle(PredictionNode currentNode,
Instances posInstances, Instances negInstances)
throws Exception {
double lowestZpure = m_search_smallestZ; // do z-pure cutoff
PredictionNode bestPath = null;
Instances bestPosSplit = null, bestNegSplit = null;
// search for branch with lowest Z-pure
for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
Splitter split = (Splitter) e.nextElement();
for (int i=0; i= 0)
currentValue = predictionValueForInstance(inst, split.getChildForBranch(branch),
currentValue);
}
return currentValue;
}
/**
* Returns a description of the classifier.
*
* @return a string containing a description of the classifier
*/
public String toString() {
if (m_root == null)
return ("ADTree not built yet");
else {
return ("Alternating decision tree:\n\n" + toString(m_root, 1) +
"\nLegend: " + legend() +
"\nTree size (total number of nodes): " + numOfAllNodes(m_root) +
"\nLeaves (number of predictor nodes): " + numOfPredictionNodes(m_root)
);
}
}
/**
* Traverses the tree, forming a string that describes it.
*
* @param currentNode the current node under investigation
* @param level the current level in the tree
* @return the string describing the subtree
*/
protected String toString(PredictionNode currentNode, int level) {
StringBuffer text = new StringBuffer();
text.append(": " + Utils.doubleToString(currentNode.getValue(),3));
for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
Splitter split = (Splitter) e.nextElement();
for (int j=0; j 0) text.append(" data=\n" + instances + "\n,\n");
text.append("]\n");
for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
Splitter split = (Splitter) e.nextElement();
text.append("S" + splitOrder + "P" + predOrder + "->" + "S" + split.orderAdded +
" [style=dotted]\n");
text.append("S" + split.orderAdded + " [label=\"" + split.orderAdded + ": " +
Utils.backQuoteChars(split.attributeString(m_trainInstances)) + "\"]\n");
for (int i=0; i" + "S" + split.orderAdded + "P" + i +
" [label=\"" + Utils.backQuoteChars(split.comparisonString(i, m_trainInstances)) + "\"]\n");
graphTraverse(child, text, split.orderAdded, i,
split.instancesDownBranch(i, instances));
}
}
}
}
/**
* Returns the legend of the tree, describing how results are to be interpreted.
*
* @return a string containing the legend of the classifier
*/
public String legend() {
Attribute classAttribute = null;
if (m_trainInstances == null) return "";
try {classAttribute = m_trainInstances.classAttribute();} catch (Exception x){};
return ("-ve = " + classAttribute.value(0) +
", +ve = " + classAttribute.value(1));
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String numOfBoostingIterationsTipText() {
return "Sets the number of boosting iterations to perform. You will need to manually "
+ "tune this parameter to suit the dataset and the desired complexity/accuracy "
+ "tradeoff. More boosting iterations will result in larger (potentially more "
+ " accurate) trees, but will make learning slower. Each iteration will add 3 nodes "
+ "(1 split + 2 prediction) to the tree unless merging occurs.";
}
/**
* Gets the number of boosting iterations.
*
* @return the number of boosting iterations
*/
public int getNumOfBoostingIterations() {
return m_boostingIterations;
}
/**
* Sets the number of boosting iterations.
*
* @param b the number of boosting iterations to use
*/
public void setNumOfBoostingIterations(int b) {
m_boostingIterations = b;
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String searchPathTipText() {
return "Sets the type of search to perform when building the tree. The default option"
+ " (Expand all paths) will do an exhaustive search. The other search methods are"
+ " heuristic, so they are not guaranteed to find an optimal solution but they are"
+ " much faster. Expand the heaviest path: searches the path with the most heavily"
+ " weighted instances. Expand the best z-pure path: searches the path determined"
+ " by the best z-pure estimate. Expand a random path: the fastest method, simply"
+ " searches down a single random path on each iteration.";
}
/**
* Gets the method of searching the tree for a new insertion. Will be one of
* SEARCHPATH_ALL, SEARCHPATH_HEAVIEST, SEARCHPATH_ZPURE, SEARCHPATH_RANDOM.
*
* @return the tree searching mode
*/
public SelectedTag getSearchPath() {
return new SelectedTag(m_searchPath, TAGS_SEARCHPATH);
}
/**
* Sets the method of searching the tree for a new insertion. Will be one of
* SEARCHPATH_ALL, SEARCHPATH_HEAVIEST, SEARCHPATH_ZPURE, SEARCHPATH_RANDOM.
*
* @param newMethod the new tree searching mode
*/
public void setSearchPath(SelectedTag newMethod) {
if (newMethod.getTags() == TAGS_SEARCHPATH) {
m_searchPath = newMethod.getSelectedTag().getID();
}
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String randomSeedTipText() {
return "Sets the random seed to use for a random search.";
}
/**
* Gets random seed for a random walk.
*
* @return the random seed
*/
public int getRandomSeed() {
return m_randomSeed;
}
/**
* Sets random seed for a random walk.
*
* @param seed the random seed
*/
public void setRandomSeed(int seed) {
// the actual random object is created when the tree is initialized
m_randomSeed = seed;
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String saveInstanceDataTipText() {
return "Sets whether the tree is to save instance data - the model will take up more"
+ " memory if it does. If enabled you will be able to visualize the instances at"
+ " the prediction nodes when visualizing the tree.";
}
/**
* Gets whether the tree is to save instance data.
*
* @return the random seed
*/
public boolean getSaveInstanceData() {
return m_saveInstanceData;
}
/**
* Sets whether the tree is to save instance data.
*
* @param v true then the tree saves instance data
*/
public void setSaveInstanceData(boolean v) {
m_saveInstanceData = v;
}
/**
* Returns an enumeration describing the available options..
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(3);
newVector.addElement(new Option(
"\tNumber of boosting iterations.\n"
+"\t(Default = 10)",
"B", 1,"-B "));
newVector.addElement(new Option(
"\tExpand nodes: -3(all), -2(weight), -1(z_pure), "
+">=0 seed for random walk\n"
+"\t(Default = -3)",
"E", 1,"-E <-3|-2|-1|>=0>"));
newVector.addElement(new Option(
"\tSave the instance data with the model",
"D", 0,"-D"));
return newVector.elements();
}
/**
* Parses a given list of options. Valid options are:
*
* -B num
* Set the number of boosting iterations
* (default 10)
*
* -E num
* Set the nodes to expand: -3(all), -2(weight), -1(z_pure), >=0 seed for random walk
* (default -3)
*
* -D
* Save the instance data with the model
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String bString = Utils.getOption('B', options);
if (bString.length() != 0) setNumOfBoostingIterations(Integer.parseInt(bString));
String eString = Utils.getOption('E', options);
if (eString.length() != 0) {
int value = Integer.parseInt(eString);
if (value >= 0) {
setSearchPath(new SelectedTag(SEARCHPATH_RANDOM, TAGS_SEARCHPATH));
setRandomSeed(value);
} else setSearchPath(new SelectedTag(value + 3, TAGS_SEARCHPATH));
}
setSaveInstanceData(Utils.getFlag('D', options));
Utils.checkForRemainingOptions(options);
}
/**
* Gets the current settings of ADTree.
*
* @return an array of strings suitable for passing to setOptions()
*/
public String[] getOptions() {
String[] options = new String[6];
int current = 0;
options[current++] = "-B"; options[current++] = "" + getNumOfBoostingIterations();
options[current++] = "-E"; options[current++] = "" +
(m_searchPath == SEARCHPATH_RANDOM ?
m_randomSeed : m_searchPath - 3);
if (getSaveInstanceData()) options[current++] = "-D";
while (current < options.length) options[current++] = "";
return options;
}
/**
* Calls measure function for tree size - the total number of nodes.
*
* @return the tree size
*/
public double measureTreeSize() {
return numOfAllNodes(m_root);
}
/**
* Calls measure function for leaf size - the number of prediction nodes.
*
* @return the leaf size
*/
public double measureNumLeaves() {
return numOfPredictionNodes(m_root);
}
/**
* Calls measure function for prediction leaf size - the number of
* prediction nodes without children.
*
* @return the leaf size
*/
public double measureNumPredictionLeaves() {
return numOfPredictionLeafNodes(m_root);
}
/**
* Returns the number of nodes expanded.
*
* @return the number of nodes expanded during search
*/
public double measureNodesExpanded() {
return m_nodesExpanded;
}
/**
* Returns the number of examples "counted".
*
* @return the number of nodes processed during search
*/
public double measureExamplesProcessed() {
return m_examplesCounted;
}
/**
* Returns an enumeration of the additional measure names.
*
* @return an enumeration of the measure names
*/
public Enumeration enumerateMeasures() {
Vector newVector = new Vector(4);
newVector.addElement("measureTreeSize");
newVector.addElement("measureNumLeaves");
newVector.addElement("measureNumPredictionLeaves");
newVector.addElement("measureNodesExpanded");
newVector.addElement("measureExamplesProcessed");
return newVector.elements();
}
/**
* Returns the value of the named measure.
*
* @param additionalMeasureName the name of the measure to query for its value
* @return the value of the named measure
* @exception IllegalArgumentException if the named measure is not supported
*/
public double getMeasure(String additionalMeasureName) {
if (additionalMeasureName.equalsIgnoreCase("measureTreeSize")) {
return measureTreeSize();
}
else if (additionalMeasureName.equalsIgnoreCase("measureNumLeaves")) {
return measureNumLeaves();
}
else if (additionalMeasureName.equalsIgnoreCase("measureNumPredictionLeaves")) {
return measureNumPredictionLeaves();
}
else if (additionalMeasureName.equalsIgnoreCase("measureNodesExpanded")) {
return measureNodesExpanded();
}
else if (additionalMeasureName.equalsIgnoreCase("measureExamplesProcessed")) {
return measureExamplesProcessed();
}
else {throw new IllegalArgumentException(additionalMeasureName
+ " not supported (ADTree)");
}
}
/**
* Returns the total number of nodes in a tree.
*
* @param root the root of the tree being measured
* @return tree size in number of splitter + prediction nodes
*/
protected int numOfAllNodes(PredictionNode root) {
int numSoFar = 0;
if (root != null) {
numSoFar++;
for (Enumeration e = root.children(); e.hasMoreElements(); ) {
numSoFar++;
Splitter split = (Splitter) e.nextElement();
for (int i=0; i 0) {
for (Enumeration e = root.children(); e.hasMoreElements(); ) {
Splitter split = (Splitter) e.nextElement();
for (int i=0; i