weka.classifiers.trees.LADTree Maven / Gradle / Ivy
Show all versions of weka-stable Show documentation
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* LADTree.java
* Copyright (C) 2001 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.trees;
import weka.classifiers.*;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.*;
import weka.classifiers.trees.adtree.ReferenceInstances;
import java.util.*;
import java.io.*;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
/**
* Class for generating a multi-class alternating decision tree using the LogitBoost strategy. For more info, see
*
* Geoffrey Holmes, Bernhard Pfahringer, Richard Kirkby, Eibe Frank, Mark Hall: Multiclass alternating decision trees. In: ECML, 161-172, 2001.
*
*
* BibTeX:
*
* @inproceedings{Holmes2001,
* author = {Geoffrey Holmes and Bernhard Pfahringer and Richard Kirkby and Eibe Frank and Mark Hall},
* booktitle = {ECML},
* pages = {161-172},
* publisher = {Springer},
* title = {Multiclass alternating decision trees},
* year = {2001}
* }
*
*
*
* Valid options are:
*
* -B <number of boosting iterations>
* Number of boosting iterations.
* (Default = 10)
*
* -D
* If set, classifier is run in debug mode and
* may output additional info to the console
*
*
* @author Richard Kirkby
* @version $Revision: 10279 $
*/
public class LADTree
extends Classifier implements Drawable,
AdditionalMeasureProducer,
TechnicalInformationHandler {
/**
* For serialization
*/
private static final long serialVersionUID = -4940716114518300302L;
// Constant from LogitBoost
protected double Z_MAX = 4;
// Number of classes
protected int m_numOfClasses;
// Instances as reference instances
protected ReferenceInstances m_trainInstances;
// Root of the tree
protected PredictionNode m_root = null;
// To keep track of the order in which splits are added
protected int m_lastAddedSplitNum = 0;
// Indices for numeric attributes
protected int[] m_numericAttIndices;
// Variables to keep track of best options
protected double m_search_smallestLeastSquares;
protected PredictionNode m_search_bestInsertionNode;
protected Splitter m_search_bestSplitter;
protected Instances m_search_bestPathInstances;
// A collection of splitter nodes
protected FastVector m_staticPotentialSplitters2way;
// statistics
protected int m_nodesExpanded = 0;
protected int m_examplesCounted = 0;
// options
protected int m_boostingIterations = 10;
/**
* Returns a string describing classifier
* @return a description suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Class for generating a multi-class alternating decision tree using " +
"the LogitBoost strategy. For more info, see\n\n"
+ getTechnicalInformation().toString();
}
/**
* Returns an instance of a TechnicalInformation object, containing
* detailed information about the technical background of this class,
* e.g., paper reference or book this class is based on.
*
* @return the technical information about this class
*/
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation result;
result = new TechnicalInformation(Type.INPROCEEDINGS);
result.setValue(Field.AUTHOR, "Geoffrey Holmes and Bernhard Pfahringer and Richard Kirkby and Eibe Frank and Mark Hall");
result.setValue(Field.TITLE, "Multiclass alternating decision trees");
result.setValue(Field.BOOKTITLE, "ECML");
result.setValue(Field.YEAR, "2001");
result.setValue(Field.PAGES, "161-172");
result.setValue(Field.PUBLISHER, "Springer");
return result;
}
/** helper classes ********************************************************************/
protected class LADInstance extends Instance {
public double[] fVector;
public double[] wVector;
public double[] pVector;
public double[] zVector;
public LADInstance(Instance instance) {
super(instance); // copy the instance
setDataset(instance.dataset()); // preserve dataset
// set up vectors
fVector = new double[m_numOfClasses];
wVector = new double[m_numOfClasses];
pVector = new double[m_numOfClasses];
zVector = new double[m_numOfClasses];
// set initial probabilities
double initProb = 1.0 / ((double) m_numOfClasses);
for (int i=0; i Z_MAX) { // threshold
zVector[i] = Z_MAX;
}
} else {
zVector[i] = -1.0 / (1.0 - pVector[i]);
if (zVector[i] < -Z_MAX) { // threshold
zVector[i] = -Z_MAX;
}
}
}
}
public double yVector(int index) {
return (index == (int) classValue() ? 1.0 : 0.0);
}
public Object copy() {
LADInstance copy = new LADInstance((Instance) super.copy());
System.arraycopy(fVector, 0, copy.fVector, 0, fVector.length);
System.arraycopy(wVector, 0, copy.wVector, 0, wVector.length);
System.arraycopy(pVector, 0, copy.pVector, 0, pVector.length);
System.arraycopy(zVector, 0, copy.zVector, 0, zVector.length);
return copy;
}
public String toString() {
StringBuffer text = new StringBuffer();
text.append(" * F(");
for (int i=0; i= splitPoint)
filteredInstances.addReference(inst);
}
}
return filteredInstances;
}
public String attributeString() {
return m_trainInstances.attribute(attIndex).name();
}
public String comparisonString(int branchNum) {
return ((branchNum == 0 ? "< " : ">= ") + Utils.doubleToString(splitPoint, 3));
}
public boolean equalTo(Splitter compare) {
if (compare instanceof TwoWayNumericSplit) { // test object type
TwoWayNumericSplit compareSame = (TwoWayNumericSplit) compare;
return (attIndex == compareSame.attIndex &&
splitPoint == compareSame.splitPoint);
} else return false;
}
public void setChildForBranch(int branchNum, PredictionNode childPredictor) {
children[branchNum] = childPredictor;
}
public PredictionNode getChildForBranch(int branchNum) {
return children[branchNum];
}
public Object clone() { // deep copy
TwoWayNumericSplit clone = new TwoWayNumericSplit(attIndex, splitPoint);
if (children[0] != null)
clone.setChildForBranch(0, (PredictionNode) children[0].clone());
if (children[1] != null)
clone.setChildForBranch(1, (PredictionNode) children[1].clone());
return clone;
}
private double findSplit(Instances instances, int index) throws Exception {
double splitPoint = 0;
double bestVal = Double.MAX_VALUE, currVal, currCutPoint;
int numMissing = 0;
double[][] distribution = new double[3][instances.numClasses()];
// Compute counts for all the values
for (int i = 0; i < instances.numInstances(); i++) {
Instance inst = instances.instance(i);
if (!inst.isMissing(index)) {
distribution[1][(int)inst.classValue()] ++;
} else {
distribution[2][(int)inst.classValue()] ++;
numMissing++;
}
}
// Sort instances
instances.sort(index);
// Make split counts for each possible split and evaluate
for (int i = 0; i < instances.numInstances() - (numMissing + 1); i++) {
Instance inst = instances.instance(i);
Instance instPlusOne = instances.instance(i + 1);
distribution[0][(int)inst.classValue()] += inst.weight();
distribution[1][(int)inst.classValue()] -= inst.weight();
if (Utils.sm(inst.value(index), instPlusOne.value(index))) {
currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0;
currVal = ContingencyTables.entropyConditionedOnRows(distribution);
if (Utils.sm(currVal, bestVal)) {
splitPoint = currCutPoint;
bestVal = currVal;
}
}
}
return splitPoint;
}
}
/**
* Sets up the tree ready to be trained.
*
* @param instances the instances to train the tree with
* @exception Exception if training data is unsuitable
*/
public void initClassifier(Instances instances) throws Exception {
// clear stats
m_nodesExpanded = 0;
m_examplesCounted = 0;
m_lastAddedSplitNum = 0;
m_numOfClasses = instances.numClasses();
// make sure training data is suitable
if (instances.checkForStringAttributes()) {
throw new Exception("Can't handle string attributes!");
}
if (!instances.classAttribute().isNominal()) {
throw new Exception("Class must be nominal!");
}
// create training set (use LADInstance class)
m_trainInstances =
new ReferenceInstances(instances, instances.numInstances());
for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) {
Instance inst = (Instance) e.nextElement();
if (!inst.classIsMissing()) {
LADInstance adtInst = new LADInstance(inst);
m_trainInstances.addReference(adtInst);
adtInst.setDataset(m_trainInstances);
}
}
// create the root prediction node
m_root = new PredictionNode(new double[m_numOfClasses]);
// pre-calculate what we can
generateStaticPotentialSplittersAndNumericIndices();
}
public void next(int iteration) throws Exception {
boost();
}
public void done() throws Exception {}
/**
* Performs a single boosting iteration.
* Will add a new splitter node and two prediction nodes to the tree
* (unless merging takes place).
*
* @exception Exception if try to boost without setting up tree first
*/
private void boost() throws Exception {
if (m_trainInstances == null)
throw new Exception("Trying to boost with no training data");
// perform the search
searchForBestTest();
if (m_Debug) {
System.out.println("Best split found: "
+ m_search_bestSplitter.getNumOfBranches() + "-way split on "
+ m_search_bestSplitter.attributeString()
//+ "\nsmallestLeastSquares = " + m_search_smallestLeastSquares);
+ "\nBestGain = " + m_search_smallestLeastSquares);
}
if (m_search_bestSplitter == null) return; // handle empty instances
// create the new nodes for the tree, updating the weights
for (int i=0; i m_search_smallestLeastSquares) {
if (m_Debug) {
System.out.print(" (best so far)");
}
m_search_smallestLeastSquares = leastSquares;
m_search_bestInsertionNode = currentNode;
m_search_bestSplitter = split;
m_search_bestPathInstances = instances;
}
if (m_Debug) {
System.out.print("\n");
}
}
private void evaluateNumericSplit(PredictionNode currentNode,
Instances instances, int attIndex)
{
double[] splitAndLS = findNumericSplitpointAndLS(instances, attIndex);
double gain = leastSquaresNonMissing(instances,attIndex) - splitAndLS[1];
if (m_Debug) {
//System.out.println("Instances considered are: " + instances);
System.out.print("Numeric split on " + instances.attribute(attIndex).name()
+ " has leastSquares value of "
//+ Utils.doubleToString(splitAndLS[1],3));
+ Utils.doubleToString(gain,3));
}
if (gain > m_search_smallestLeastSquares) {
if (m_Debug) {
System.out.print(" (best so far)");
}
m_search_smallestLeastSquares = gain; //splitAndLS[1];
m_search_bestInsertionNode = currentNode;
m_search_bestSplitter = new TwoWayNumericSplit(attIndex, splitAndLS[0]);;
m_search_bestPathInstances = instances;
}
if (m_Debug) {
System.out.print("\n");
}
}
private double[] findNumericSplitpointAndLS(Instances instances, int attIndex) {
double allLS = leastSquares(instances);
// all instances in right subset
double[] term1L = new double[m_numOfClasses];
double[] term2L = new double[m_numOfClasses];
double[] term3L = new double[m_numOfClasses];
double[] meanNumL = new double[m_numOfClasses];
double[] meanDenL = new double[m_numOfClasses];
double[] term1R = new double[m_numOfClasses];
double[] term2R = new double[m_numOfClasses];
double[] term3R = new double[m_numOfClasses];
double[] meanNumR = new double[m_numOfClasses];
double[] meanDenR = new double[m_numOfClasses];
double temp1, temp2, temp3;
double[] classMeans = new double[m_numOfClasses];
double[] classTotals = new double[m_numOfClasses];
// fill up RHS
for (int j=0; j instances.instance(i).value(attIndex))
newSplit = true;
else newSplit = false;
LADInstance inst = (LADInstance) instances.instance(i);
leastSquares = 0.0;
for (int j=0; j 0 ? smallestLeastSquares : 0;
return result;
}
private double leastSquares(Instances instances) {
double numerator=0, denominator=0, w, t;
double[] classMeans = new double[m_numOfClasses];
double[] classTotals = new double[m_numOfClasses];
for (int i=0; i 0 ? numerator : 0;// / denominator;
}
private double leastSquaresNonMissing(Instances instances, int attIndex) {
double numerator=0, denominator=0, w, t;
double[] classMeans = new double[m_numOfClasses];
double[] classTotals = new double[m_numOfClasses];
for (int i=0; i 0 ? numerator : 0;// / denominator;
}
private double[] calcPredictionValues(Instances instances) {
double[] classMeans = new double[m_numOfClasses];
double meansSum = 0;
double multiplier = ((double) (m_numOfClasses-1)) / ((double) (m_numOfClasses));
double[] classTotals = new double[m_numOfClasses];
for (int i=0; i 0.0) Utils.normalize(distribution, sum);
return distribution;
}
/**
* Returns the class prediction values (votes) for an instance.
*
* @param inst the instance
* @param currentNode the root of the tree to get the values from
* @param currentValues the current values before adding the values contained in the
* subtree
* @return the class prediction values (votes)
*/
private double[] predictionValuesForInstance(Instance inst, PredictionNode currentNode,
double[] currentValues) {
double[] predValues = currentNode.getValues();
for (int i=0; i= 0)
currentValues = predictionValuesForInstance(inst, split.getChildForBranch(branch),
currentValues);
}
return currentValues;
}
/** model output functions ************************************************************/
/**
* Returns a description of the classifier.
*
* @return a string containing a description of the classifier
*/
public String toString() {
String className = getClass().getName();
if (m_root == null)
return (className +" not built yet");
else {
return (className + ":\n\n" + toString(m_root, 1) +
"\nLegend: " + legend() +
"\n#Tree size (total): " +
numOfAllNodes(m_root) +
"\n#Tree size (number of predictor nodes): " +
numOfPredictionNodes(m_root) +
"\n#Leaves (number of predictor nodes): " +
numOfLeafNodes(m_root) +
"\n#Expanded nodes: " +
m_nodesExpanded +
"\n#Processed examples: " +
m_examplesCounted +
"\n#Ratio e/n: " +
((double)m_examplesCounted/(double)m_nodesExpanded)
);
}
}
/**
* Traverses the tree, forming a string that describes it.
*
* @param currentNode the current node under investigation
* @param level the current level in the tree
* @return the string describing the subtree
*/
private String toString(PredictionNode currentNode, int level) {
StringBuffer text = new StringBuffer();
text.append(": ");
double[] predValues = currentNode.getValues();
for (int i=0; i" + "S" + split.orderAdded +
" [style=dotted]\n");
text.append("S" + split.orderAdded + " [label=\"" + split.orderAdded + ": " +
split.attributeString() + "\"]\n");
for (int i=0; i" + "S" + split.orderAdded + "P" + i +
" [label=\"" + Utils.backQuoteChars(split.comparisonString(i)) + "\"]\n");
graphTraverse(child, text, split.orderAdded, i);
}
}
}
}
/**
* Returns the legend of the tree, describing how results are to be interpreted.
*
* @return a string containing the legend of the classifier
*/
public String legend() {
Attribute classAttribute = null;
if (m_trainInstances == null) return "";
try {classAttribute = m_trainInstances.classAttribute();} catch (Exception x){};
if (m_numOfClasses == 1) {
return ("-ve = " + classAttribute.value(0)
+ ", +ve = " + classAttribute.value(1));
} else {
StringBuffer text = new StringBuffer();
for (int i=0; i0) text.append(", ");
text.append(classAttribute.value(i));
}
return text.toString();
}
}
/** option handling ******************************************************************/
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String numOfBoostingIterationsTipText() {
return "The number of boosting iterations to use, which determines the size of the tree.";
}
/**
* Gets the number of boosting iterations.
*
* @return the number of boosting iterations
*/
public int getNumOfBoostingIterations() {
return m_boostingIterations;
}
/**
* Sets the number of boosting iterations.
*
* @param b the number of boosting iterations to use
*/
public void setNumOfBoostingIterations(int b) {
m_boostingIterations = b;
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options
*/
public Enumeration listOptions() {
Vector newVector = new Vector(1);
newVector.addElement(new Option(
"\tNumber of boosting iterations.\n"
+"\t(Default = 10)",
"B", 1,"-B "));
Enumeration enu = super.listOptions();
while (enu.hasMoreElements()) {
newVector.addElement(enu.nextElement());
}
return newVector.elements();
}
/**
* Parses a given list of options. Valid options are:
*
* -B num
* Set the number of boosting iterations
* (default 10)
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String bString = Utils.getOption('B', options);
if (bString.length() != 0) setNumOfBoostingIterations(Integer.parseInt(bString));
super.setOptions(options);
Utils.checkForRemainingOptions(options);
}
/**
* Gets the current settings of ADTree.
*
* @return an array of strings suitable for passing to setOptions()
*/
public String[] getOptions() {
String[] options = new String[2 + super.getOptions().length];
int current = 0;
options[current++] = "-B"; options[current++] = "" + getNumOfBoostingIterations();
System.arraycopy(super.getOptions(), 0, options, current, super.getOptions().length);
while (current < options.length) options[current++] = "";
return options;
}
/** additional measures ***************************************************************/
/**
* Calls measure function for tree size.
*
* @return the tree size
*/
public double measureTreeSize() {
return numOfAllNodes(m_root);
}
/**
* Calls measure function for leaf size.
*
* @return the leaf size
*/
public double measureNumLeaves() {
return numOfPredictionNodes(m_root);
}
/**
* Calls measure function for leaf size.
*
* @return the leaf size
*/
public double measureNumPredictionLeaves() {
return numOfLeafNodes(m_root);
}
/**
* Returns the number of nodes expanded.
*
* @return the number of nodes expanded during search
*/
public double measureNodesExpanded() {
return m_nodesExpanded;
}
/**
* Returns the number of examples "counted".
*
* @return the number of nodes processed during search
*/
public double measureExamplesCounted() {
return m_examplesCounted;
}
/**
* Returns an enumeration of the additional measure names.
*
* @return an enumeration of the measure names
*/
public Enumeration enumerateMeasures() {
Vector newVector = new Vector(5);
newVector.addElement("measureTreeSize");
newVector.addElement("measureNumLeaves");
newVector.addElement("measureNumPredictionLeaves");
newVector.addElement("measureNodesExpanded");
newVector.addElement("measureExamplesCounted");
return newVector.elements();
}
/**
* Returns the value of the named measure.
*
* @param measureName the name of the measure to query for its value
* @return the value of the named measure
* @exception IllegalArgumentException if the named measure is not supported
*/
public double getMeasure(String additionalMeasureName) {
if (additionalMeasureName.equalsIgnoreCase("measureTreeSize")) {
return measureTreeSize();
}
else if (additionalMeasureName.equalsIgnoreCase("measureNodesExpanded")) {
return measureNodesExpanded();
}
else if (additionalMeasureName.equalsIgnoreCase("measureNumLeaves")) {
return measureNumLeaves();
}
else if (additionalMeasureName.equalsIgnoreCase("measureNumPredictionLeaves")) {
return measureNumPredictionLeaves();
}
else if (additionalMeasureName.equalsIgnoreCase("measureExamplesCounted")) {
return measureExamplesCounted();
}
else {throw new IllegalArgumentException(additionalMeasureName
+ " not supported (ADTree)");
}
}
/**
* Returns the number of prediction nodes in a tree.
*
* @param root the root of the tree being measured
* @return tree size in number of prediction nodes
*/
protected int numOfPredictionNodes(PredictionNode root) {
int numSoFar = 0;
if (root != null) {
numSoFar++;
for (Enumeration e = root.children(); e.hasMoreElements(); ) {
Splitter split = (Splitter) e.nextElement();
for (int i=0; i 0) {
for (Enumeration e = root.children(); e.hasMoreElements(); ) {
Splitter split = (Splitter) e.nextElement();
for (int i=0; i=0; i--) {
Instance inst = test.instance(i);
try {
if (classifyInstance(inst) != inst.classValue())
error++;
} catch (Exception e) { error++;}
}
return error;
}
/**
* Merges two trees together. Modifies the tree being acted on, leaving tree passed
* as a parameter untouched (cloned). Does not check to see whether training instances
* are compatible - strange things could occur if they are not.
*
* @param mergeWith the tree to merge with
* @exception Exception if merge could not be performed
*/
public void merge(LADTree mergeWith) throws Exception {
if (m_root == null || mergeWith.m_root == null)
throw new Exception("Trying to merge an uninitialized tree");
if (m_numOfClasses != mergeWith.m_numOfClasses)
throw new Exception("Trees not suitable for merge - "
+ "different sized prediction nodes");
m_root.merge(mergeWith.m_root);
}
/**
* Returns the type of graph this classifier
* represents.
* @return Drawable.TREE
*/
public int graphType() {
return Drawable.TREE;
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 10279 $");
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// attributes
result.enable(Capability.NOMINAL_ATTRIBUTES);
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.DATE_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
// class
result.enable(Capability.NOMINAL_CLASS);
result.enable(Capability.MISSING_CLASS_VALUES);
return result;
}
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
runClassifier(new LADTree(), argv);
}
}