All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.trees.LADTree Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    LADTree.java
 *    Copyright (C) 2001 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.trees;

import weka.classifiers.*;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.*;
import weka.classifiers.trees.adtree.ReferenceInstances;
import java.util.*;
import java.io.*;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;

/**
 
 * Class for generating a multi-class alternating decision tree using the LogitBoost strategy. For more info, see
*
* Geoffrey Holmes, Bernhard Pfahringer, Richard Kirkby, Eibe Frank, Mark Hall: Multiclass alternating decision trees. In: ECML, 161-172, 2001. *

* * BibTeX: *

 * @inproceedings{Holmes2001,
 *    author = {Geoffrey Holmes and Bernhard Pfahringer and Richard Kirkby and Eibe Frank and Mark Hall},
 *    booktitle = {ECML},
 *    pages = {161-172},
 *    publisher = {Springer},
 *    title = {Multiclass alternating decision trees},
 *    year = {2001}
 * }
 * 
*

* * Valid options are:

* *

 -B <number of boosting iterations>
 *  Number of boosting iterations.
 *  (Default = 10)
* *
 -D
 *  If set, classifier is run in debug mode and
 *  may output additional info to the console
* * * @author Richard Kirkby * @version $Revision: 10279 $ */ public class LADTree extends Classifier implements Drawable, AdditionalMeasureProducer, TechnicalInformationHandler { /** * For serialization */ private static final long serialVersionUID = -4940716114518300302L; // Constant from LogitBoost protected double Z_MAX = 4; // Number of classes protected int m_numOfClasses; // Instances as reference instances protected ReferenceInstances m_trainInstances; // Root of the tree protected PredictionNode m_root = null; // To keep track of the order in which splits are added protected int m_lastAddedSplitNum = 0; // Indices for numeric attributes protected int[] m_numericAttIndices; // Variables to keep track of best options protected double m_search_smallestLeastSquares; protected PredictionNode m_search_bestInsertionNode; protected Splitter m_search_bestSplitter; protected Instances m_search_bestPathInstances; // A collection of splitter nodes protected FastVector m_staticPotentialSplitters2way; // statistics protected int m_nodesExpanded = 0; protected int m_examplesCounted = 0; // options protected int m_boostingIterations = 10; /** * Returns a string describing classifier * @return a description suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Class for generating a multi-class alternating decision tree using " + "the LogitBoost strategy. For more info, see\n\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing * detailed information about the technical background of this class, * e.g., paper reference or book this class is based on. * * @return the technical information about this class */ public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(Type.INPROCEEDINGS); result.setValue(Field.AUTHOR, "Geoffrey Holmes and Bernhard Pfahringer and Richard Kirkby and Eibe Frank and Mark Hall"); result.setValue(Field.TITLE, "Multiclass alternating decision trees"); result.setValue(Field.BOOKTITLE, "ECML"); result.setValue(Field.YEAR, "2001"); result.setValue(Field.PAGES, "161-172"); result.setValue(Field.PUBLISHER, "Springer"); return result; } /** helper classes ********************************************************************/ protected class LADInstance extends Instance { public double[] fVector; public double[] wVector; public double[] pVector; public double[] zVector; public LADInstance(Instance instance) { super(instance); // copy the instance setDataset(instance.dataset()); // preserve dataset // set up vectors fVector = new double[m_numOfClasses]; wVector = new double[m_numOfClasses]; pVector = new double[m_numOfClasses]; zVector = new double[m_numOfClasses]; // set initial probabilities double initProb = 1.0 / ((double) m_numOfClasses); for (int i=0; i Z_MAX) { // threshold zVector[i] = Z_MAX; } } else { zVector[i] = -1.0 / (1.0 - pVector[i]); if (zVector[i] < -Z_MAX) { // threshold zVector[i] = -Z_MAX; } } } } public double yVector(int index) { return (index == (int) classValue() ? 1.0 : 0.0); } public Object copy() { LADInstance copy = new LADInstance((Instance) super.copy()); System.arraycopy(fVector, 0, copy.fVector, 0, fVector.length); System.arraycopy(wVector, 0, copy.wVector, 0, wVector.length); System.arraycopy(pVector, 0, copy.pVector, 0, pVector.length); System.arraycopy(zVector, 0, copy.zVector, 0, zVector.length); return copy; } public String toString() { StringBuffer text = new StringBuffer(); text.append(" * F("); for (int i=0; i= splitPoint) filteredInstances.addReference(inst); } } return filteredInstances; } public String attributeString() { return m_trainInstances.attribute(attIndex).name(); } public String comparisonString(int branchNum) { return ((branchNum == 0 ? "< " : ">= ") + Utils.doubleToString(splitPoint, 3)); } public boolean equalTo(Splitter compare) { if (compare instanceof TwoWayNumericSplit) { // test object type TwoWayNumericSplit compareSame = (TwoWayNumericSplit) compare; return (attIndex == compareSame.attIndex && splitPoint == compareSame.splitPoint); } else return false; } public void setChildForBranch(int branchNum, PredictionNode childPredictor) { children[branchNum] = childPredictor; } public PredictionNode getChildForBranch(int branchNum) { return children[branchNum]; } public Object clone() { // deep copy TwoWayNumericSplit clone = new TwoWayNumericSplit(attIndex, splitPoint); if (children[0] != null) clone.setChildForBranch(0, (PredictionNode) children[0].clone()); if (children[1] != null) clone.setChildForBranch(1, (PredictionNode) children[1].clone()); return clone; } private double findSplit(Instances instances, int index) throws Exception { double splitPoint = 0; double bestVal = Double.MAX_VALUE, currVal, currCutPoint; int numMissing = 0; double[][] distribution = new double[3][instances.numClasses()]; // Compute counts for all the values for (int i = 0; i < instances.numInstances(); i++) { Instance inst = instances.instance(i); if (!inst.isMissing(index)) { distribution[1][(int)inst.classValue()] ++; } else { distribution[2][(int)inst.classValue()] ++; numMissing++; } } // Sort instances instances.sort(index); // Make split counts for each possible split and evaluate for (int i = 0; i < instances.numInstances() - (numMissing + 1); i++) { Instance inst = instances.instance(i); Instance instPlusOne = instances.instance(i + 1); distribution[0][(int)inst.classValue()] += inst.weight(); distribution[1][(int)inst.classValue()] -= inst.weight(); if (Utils.sm(inst.value(index), instPlusOne.value(index))) { currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0; currVal = ContingencyTables.entropyConditionedOnRows(distribution); if (Utils.sm(currVal, bestVal)) { splitPoint = currCutPoint; bestVal = currVal; } } } return splitPoint; } } /** * Sets up the tree ready to be trained. * * @param instances the instances to train the tree with * @exception Exception if training data is unsuitable */ public void initClassifier(Instances instances) throws Exception { // clear stats m_nodesExpanded = 0; m_examplesCounted = 0; m_lastAddedSplitNum = 0; m_numOfClasses = instances.numClasses(); // make sure training data is suitable if (instances.checkForStringAttributes()) { throw new Exception("Can't handle string attributes!"); } if (!instances.classAttribute().isNominal()) { throw new Exception("Class must be nominal!"); } // create training set (use LADInstance class) m_trainInstances = new ReferenceInstances(instances, instances.numInstances()); for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { Instance inst = (Instance) e.nextElement(); if (!inst.classIsMissing()) { LADInstance adtInst = new LADInstance(inst); m_trainInstances.addReference(adtInst); adtInst.setDataset(m_trainInstances); } } // create the root prediction node m_root = new PredictionNode(new double[m_numOfClasses]); // pre-calculate what we can generateStaticPotentialSplittersAndNumericIndices(); } public void next(int iteration) throws Exception { boost(); } public void done() throws Exception {} /** * Performs a single boosting iteration. * Will add a new splitter node and two prediction nodes to the tree * (unless merging takes place). * * @exception Exception if try to boost without setting up tree first */ private void boost() throws Exception { if (m_trainInstances == null) throw new Exception("Trying to boost with no training data"); // perform the search searchForBestTest(); if (m_Debug) { System.out.println("Best split found: " + m_search_bestSplitter.getNumOfBranches() + "-way split on " + m_search_bestSplitter.attributeString() //+ "\nsmallestLeastSquares = " + m_search_smallestLeastSquares); + "\nBestGain = " + m_search_smallestLeastSquares); } if (m_search_bestSplitter == null) return; // handle empty instances // create the new nodes for the tree, updating the weights for (int i=0; i m_search_smallestLeastSquares) { if (m_Debug) { System.out.print(" (best so far)"); } m_search_smallestLeastSquares = leastSquares; m_search_bestInsertionNode = currentNode; m_search_bestSplitter = split; m_search_bestPathInstances = instances; } if (m_Debug) { System.out.print("\n"); } } private void evaluateNumericSplit(PredictionNode currentNode, Instances instances, int attIndex) { double[] splitAndLS = findNumericSplitpointAndLS(instances, attIndex); double gain = leastSquaresNonMissing(instances,attIndex) - splitAndLS[1]; if (m_Debug) { //System.out.println("Instances considered are: " + instances); System.out.print("Numeric split on " + instances.attribute(attIndex).name() + " has leastSquares value of " //+ Utils.doubleToString(splitAndLS[1],3)); + Utils.doubleToString(gain,3)); } if (gain > m_search_smallestLeastSquares) { if (m_Debug) { System.out.print(" (best so far)"); } m_search_smallestLeastSquares = gain; //splitAndLS[1]; m_search_bestInsertionNode = currentNode; m_search_bestSplitter = new TwoWayNumericSplit(attIndex, splitAndLS[0]);; m_search_bestPathInstances = instances; } if (m_Debug) { System.out.print("\n"); } } private double[] findNumericSplitpointAndLS(Instances instances, int attIndex) { double allLS = leastSquares(instances); // all instances in right subset double[] term1L = new double[m_numOfClasses]; double[] term2L = new double[m_numOfClasses]; double[] term3L = new double[m_numOfClasses]; double[] meanNumL = new double[m_numOfClasses]; double[] meanDenL = new double[m_numOfClasses]; double[] term1R = new double[m_numOfClasses]; double[] term2R = new double[m_numOfClasses]; double[] term3R = new double[m_numOfClasses]; double[] meanNumR = new double[m_numOfClasses]; double[] meanDenR = new double[m_numOfClasses]; double temp1, temp2, temp3; double[] classMeans = new double[m_numOfClasses]; double[] classTotals = new double[m_numOfClasses]; // fill up RHS for (int j=0; j instances.instance(i).value(attIndex)) newSplit = true; else newSplit = false; LADInstance inst = (LADInstance) instances.instance(i); leastSquares = 0.0; for (int j=0; j 0 ? numerator : 0;// / denominator; } private double leastSquaresNonMissing(Instances instances, int attIndex) { double numerator=0, denominator=0, w, t; double[] classMeans = new double[m_numOfClasses]; double[] classTotals = new double[m_numOfClasses]; for (int i=0; i 0 ? numerator : 0;// / denominator; } private double[] calcPredictionValues(Instances instances) { double[] classMeans = new double[m_numOfClasses]; double meansSum = 0; double multiplier = ((double) (m_numOfClasses-1)) / ((double) (m_numOfClasses)); double[] classTotals = new double[m_numOfClasses]; for (int i=0; i 0.0) Utils.normalize(distribution, sum); return distribution; } /** * Returns the class prediction values (votes) for an instance. * * @param inst the instance * @param currentNode the root of the tree to get the values from * @param currentValues the current values before adding the values contained in the * subtree * @return the class prediction values (votes) */ private double[] predictionValuesForInstance(Instance inst, PredictionNode currentNode, double[] currentValues) { double[] predValues = currentNode.getValues(); for (int i=0; i= 0) currentValues = predictionValuesForInstance(inst, split.getChildForBranch(branch), currentValues); } return currentValues; } /** model output functions ************************************************************/ /** * Returns a description of the classifier. * * @return a string containing a description of the classifier */ public String toString() { String className = getClass().getName(); if (m_root == null) return (className +" not built yet"); else { return (className + ":\n\n" + toString(m_root, 1) + "\nLegend: " + legend() + "\n#Tree size (total): " + numOfAllNodes(m_root) + "\n#Tree size (number of predictor nodes): " + numOfPredictionNodes(m_root) + "\n#Leaves (number of predictor nodes): " + numOfLeafNodes(m_root) + "\n#Expanded nodes: " + m_nodesExpanded + "\n#Processed examples: " + m_examplesCounted + "\n#Ratio e/n: " + ((double)m_examplesCounted/(double)m_nodesExpanded) ); } } /** * Traverses the tree, forming a string that describes it. * * @param currentNode the current node under investigation * @param level the current level in the tree * @return the string describing the subtree */ private String toString(PredictionNode currentNode, int level) { StringBuffer text = new StringBuffer(); text.append(": "); double[] predValues = currentNode.getValues(); for (int i=0; i" + "S" + split.orderAdded + " [style=dotted]\n"); text.append("S" + split.orderAdded + " [label=\"" + split.orderAdded + ": " + split.attributeString() + "\"]\n"); for (int i=0; i" + "S" + split.orderAdded + "P" + i + " [label=\"" + Utils.backQuoteChars(split.comparisonString(i)) + "\"]\n"); graphTraverse(child, text, split.orderAdded, i); } } } } /** * Returns the legend of the tree, describing how results are to be interpreted. * * @return a string containing the legend of the classifier */ public String legend() { Attribute classAttribute = null; if (m_trainInstances == null) return ""; try {classAttribute = m_trainInstances.classAttribute();} catch (Exception x){}; if (m_numOfClasses == 1) { return ("-ve = " + classAttribute.value(0) + ", +ve = " + classAttribute.value(1)); } else { StringBuffer text = new StringBuffer(); for (int i=0; i0) text.append(", "); text.append(classAttribute.value(i)); } return text.toString(); } } /** option handling ******************************************************************/ /** * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numOfBoostingIterationsTipText() { return "The number of boosting iterations to use, which determines the size of the tree."; } /** * Gets the number of boosting iterations. * * @return the number of boosting iterations */ public int getNumOfBoostingIterations() { return m_boostingIterations; } /** * Sets the number of boosting iterations. * * @param b the number of boosting iterations to use */ public void setNumOfBoostingIterations(int b) { m_boostingIterations = b; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options */ public Enumeration listOptions() { Vector newVector = new Vector(1); newVector.addElement(new Option( "\tNumber of boosting iterations.\n" +"\t(Default = 10)", "B", 1,"-B ")); Enumeration enu = super.listOptions(); while (enu.hasMoreElements()) { newVector.addElement(enu.nextElement()); } return newVector.elements(); } /** * Parses a given list of options. Valid options are:

* * -B num
* Set the number of boosting iterations * (default 10)

* * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String bString = Utils.getOption('B', options); if (bString.length() != 0) setNumOfBoostingIterations(Integer.parseInt(bString)); super.setOptions(options); Utils.checkForRemainingOptions(options); } /** * Gets the current settings of ADTree. * * @return an array of strings suitable for passing to setOptions() */ public String[] getOptions() { String[] options = new String[2 + super.getOptions().length]; int current = 0; options[current++] = "-B"; options[current++] = "" + getNumOfBoostingIterations(); System.arraycopy(super.getOptions(), 0, options, current, super.getOptions().length); while (current < options.length) options[current++] = ""; return options; } /** additional measures ***************************************************************/ /** * Calls measure function for tree size. * * @return the tree size */ public double measureTreeSize() { return numOfAllNodes(m_root); } /** * Calls measure function for leaf size. * * @return the leaf size */ public double measureNumLeaves() { return numOfPredictionNodes(m_root); } /** * Calls measure function for leaf size. * * @return the leaf size */ public double measureNumPredictionLeaves() { return numOfLeafNodes(m_root); } /** * Returns the number of nodes expanded. * * @return the number of nodes expanded during search */ public double measureNodesExpanded() { return m_nodesExpanded; } /** * Returns the number of examples "counted". * * @return the number of nodes processed during search */ public double measureExamplesCounted() { return m_examplesCounted; } /** * Returns an enumeration of the additional measure names. * * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { Vector newVector = new Vector(5); newVector.addElement("measureTreeSize"); newVector.addElement("measureNumLeaves"); newVector.addElement("measureNumPredictionLeaves"); newVector.addElement("measureNodesExpanded"); newVector.addElement("measureExamplesCounted"); return newVector.elements(); } /** * Returns the value of the named measure. * * @param measureName the name of the measure to query for its value * @return the value of the named measure * @exception IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (additionalMeasureName.equalsIgnoreCase("measureTreeSize")) { return measureTreeSize(); } else if (additionalMeasureName.equalsIgnoreCase("measureNodesExpanded")) { return measureNodesExpanded(); } else if (additionalMeasureName.equalsIgnoreCase("measureNumLeaves")) { return measureNumLeaves(); } else if (additionalMeasureName.equalsIgnoreCase("measureNumPredictionLeaves")) { return measureNumPredictionLeaves(); } else if (additionalMeasureName.equalsIgnoreCase("measureExamplesCounted")) { return measureExamplesCounted(); } else {throw new IllegalArgumentException(additionalMeasureName + " not supported (ADTree)"); } } /** * Returns the number of prediction nodes in a tree. * * @param root the root of the tree being measured * @return tree size in number of prediction nodes */ protected int numOfPredictionNodes(PredictionNode root) { int numSoFar = 0; if (root != null) { numSoFar++; for (Enumeration e = root.children(); e.hasMoreElements(); ) { Splitter split = (Splitter) e.nextElement(); for (int i=0; i 0) { for (Enumeration e = root.children(); e.hasMoreElements(); ) { Splitter split = (Splitter) e.nextElement(); for (int i=0; i=0; i--) { Instance inst = test.instance(i); try { if (classifyInstance(inst) != inst.classValue()) error++; } catch (Exception e) { error++;} } return error; } /** * Merges two trees together. Modifies the tree being acted on, leaving tree passed * as a parameter untouched (cloned). Does not check to see whether training instances * are compatible - strange things could occur if they are not. * * @param mergeWith the tree to merge with * @exception Exception if merge could not be performed */ public void merge(LADTree mergeWith) throws Exception { if (m_root == null || mergeWith.m_root == null) throw new Exception("Trying to merge an uninitialized tree"); if (m_numOfClasses != mergeWith.m_numOfClasses) throw new Exception("Trees not suitable for merge - " + "different sized prediction nodes"); m_root.merge(mergeWith.m_root); } /** * Returns the type of graph this classifier * represents. * @return Drawable.TREE */ public int graphType() { return Drawable.TREE; } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 10279 $"); } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); return result; } /** * Main method for testing this class. * * @param argv the options */ public static void main(String [] argv) { runClassifier(new LADTree(), argv); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy