weka.classifiers.trees.LMT Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-stable Show documentation
Show all versions of weka-stable Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This is the stable version. Apart from bugfixes, this version
does not receive any other updates.
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* LMT.java
* Copyright (C) 2003 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.trees;
import weka.classifiers.Classifier;
import weka.classifiers.trees.j48.C45ModelSelection;
import weka.classifiers.trees.j48.ModelSelection;
import weka.classifiers.trees.lmt.LMTNode;
import weka.classifiers.trees.lmt.ResidualModelSelection;
import weka.core.AdditionalMeasureProducer;
import weka.core.Capabilities;
import weka.core.Drawable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.Capabilities.Capability;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.filters.Filter;
import weka.filters.supervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import java.util.Enumeration;
import java.util.Vector;
/**
* Classifier for building 'logistic model trees', which are classification trees with logistic regression functions at the leaves. The algorithm can deal with binary and multi-class target variables, numeric and nominal attributes and missing values.
*
* For more information see:
*
* Niels Landwehr, Mark Hall, Eibe Frank (2005). Logistic Model Trees. Machine Learning. 95(1-2):161-205.
*
* Marc Sumner, Eibe Frank, Mark Hall: Speeding up Logistic Model Tree Induction. In: 9th European Conference on Principles and Practice of Knowledge Discovery in Databases, 675-683, 2005.
*
*
* BibTeX:
*
* @article{Landwehr2005,
* author = {Niels Landwehr and Mark Hall and Eibe Frank},
* journal = {Machine Learning},
* number = {1-2},
* pages = {161-205},
* title = {Logistic Model Trees},
* volume = {95},
* year = {2005}
* }
*
* @inproceedings{Sumner2005,
* author = {Marc Sumner and Eibe Frank and Mark Hall},
* booktitle = {9th European Conference on Principles and Practice of Knowledge Discovery in Databases},
* pages = {675-683},
* publisher = {Springer},
* title = {Speeding up Logistic Model Tree Induction},
* year = {2005}
* }
*
*
*
* Valid options are:
*
* -B
* Binary splits (convert nominal attributes to binary ones)
*
* -R
* Split on residuals instead of class values
*
* -C
* Use cross-validation for boosting at all nodes (i.e., disable heuristic)
*
* -P
* Use error on probabilities instead of misclassification error for stopping criterion of LogitBoost.
*
* -I <numIterations>
* Set fixed number of iterations for LogitBoost (instead of using cross-validation)
*
* -M <numInstances>
* Set minimum number of instances at which a node can be split (default 15)
*
* -W <beta>
* Set beta for weight trimming for LogitBoost. Set to 0 (default) for no weight trimming.
*
* -A
* The AIC is used to choose the best iteration.
*
*
* @author Niels Landwehr
* @author Marc Sumner
* @version $Revision: 5535 $
*/
public class LMT
extends Classifier
implements OptionHandler, AdditionalMeasureProducer, Drawable,
TechnicalInformationHandler {
/** for serialization */
static final long serialVersionUID = -1113212459618104943L;
/** Filter to replace missing values*/
protected ReplaceMissingValues m_replaceMissing;
/** Filter to replace nominal attributes*/
protected NominalToBinary m_nominalToBinary;
/** root of the logistic model tree*/
protected LMTNode m_tree;
/** use heuristic that determines the number of LogitBoost iterations only once in the beginning?*/
protected boolean m_fastRegression;
/** convert nominal attributes to binary ?*/
protected boolean m_convertNominal;
/** split on residuals?*/
protected boolean m_splitOnResiduals;
/**use error on probabilties instead of misclassification for stopping criterion of LogitBoost?*/
protected boolean m_errorOnProbabilities;
/**minimum number of instances at which a node is considered for splitting*/
protected int m_minNumInstances;
/**if non-zero, use fixed number of iterations for LogitBoost*/
protected int m_numBoostingIterations;
/**Threshold for trimming weights. Instances with a weight lower than this (as a percentage
* of total weights) are not included in the regression fit.
**/
protected double m_weightTrimBeta;
/** If true, the AIC is used to choose the best LogitBoost iteration*/
private boolean m_useAIC = false;
/**
* Creates an instance of LMT with standard options
*/
public LMT() {
m_fastRegression = true;
m_numBoostingIterations = -1;
m_minNumInstances = 15;
m_weightTrimBeta = 0;
m_useAIC = false;
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// attributes
result.enable(Capability.NOMINAL_ATTRIBUTES);
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.DATE_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
// class
result.enable(Capability.NOMINAL_CLASS);
result.enable(Capability.MISSING_CLASS_VALUES);
return result;
}
/**
* Builds the classifier.
*
* @param data the data to train with
* @throws Exception if classifier can't be built successfully
*/
public void buildClassifier(Instances data) throws Exception{
// can classifier handle the data?
getCapabilities().testWithFail(data);
// remove instances with missing class
Instances filteredData = new Instances(data);
filteredData.deleteWithMissingClass();
//replace missing values
m_replaceMissing = new ReplaceMissingValues();
m_replaceMissing.setInputFormat(filteredData);
filteredData = Filter.useFilter(filteredData, m_replaceMissing);
//possibly convert nominal attributes globally
if (m_convertNominal) {
m_nominalToBinary = new NominalToBinary();
m_nominalToBinary.setInputFormat(filteredData);
filteredData = Filter.useFilter(filteredData, m_nominalToBinary);
}
int minNumInstances = 2;
//create ModelSelection object, either for splits on the residuals or for splits on the class value
ModelSelection modSelection;
if (m_splitOnResiduals) {
modSelection = new ResidualModelSelection(minNumInstances);
} else {
modSelection = new C45ModelSelection(minNumInstances, filteredData);
}
//create tree root
m_tree = new LMTNode(modSelection, m_numBoostingIterations, m_fastRegression,
m_errorOnProbabilities, m_minNumInstances, m_weightTrimBeta, m_useAIC);
//build tree
m_tree.buildClassifier(filteredData);
if (modSelection instanceof C45ModelSelection) ((C45ModelSelection)modSelection).cleanup();
}
/**
* Returns class probabilities for an instance.
*
* @param instance the instance to compute the distribution for
* @return the class probabilities
* @throws Exception if distribution can't be computed successfully
*/
public double [] distributionForInstance(Instance instance) throws Exception {
//replace missing values
m_replaceMissing.input(instance);
instance = m_replaceMissing.output();
//possibly convert nominal attributes
if (m_convertNominal) {
m_nominalToBinary.input(instance);
instance = m_nominalToBinary.output();
}
return m_tree.distributionForInstance(instance);
}
/**
* Classifies an instance.
*
* @param instance the instance to classify
* @return the classification
* @throws Exception if instance can't be classified successfully
*/
public double classifyInstance(Instance instance) throws Exception {
double maxProb = -1;
int maxIndex = 0;
//classify by maximum probability
double[] probs = distributionForInstance(instance);
for (int j = 0; j < instance.numClasses(); j++) {
if (Utils.gr(probs[j], maxProb)) {
maxIndex = j;
maxProb = probs[j];
}
}
return (double)maxIndex;
}
/**
* Returns a description of the classifier.
*
* @return a string representation of the classifier
*/
public String toString() {
if (m_tree!=null) {
return "Logistic model tree \n------------------\n" + m_tree.toString();
} else {
return "No tree build";
}
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(8);
newVector.addElement(new Option("\tBinary splits (convert nominal attributes to binary ones)",
"B", 0, "-B"));
newVector.addElement(new Option("\tSplit on residuals instead of class values",
"R", 0, "-R"));
newVector.addElement(new Option("\tUse cross-validation for boosting at all nodes (i.e., disable heuristic)",
"C", 0, "-C"));
newVector.addElement(new Option("\tUse error on probabilities instead of misclassification error "+
"for stopping criterion of LogitBoost.",
"P", 0, "-P"));
newVector.addElement(new Option("\tSet fixed number of iterations for LogitBoost (instead of using "+
"cross-validation)",
"I",1,"-I "));
newVector.addElement(new Option("\tSet minimum number of instances at which a node can be split (default 15)",
"M",1,"-M "));
newVector.addElement(new Option("\tSet beta for weight trimming for LogitBoost. Set to 0 (default) for no weight trimming.",
"W",1,"-W "));
newVector.addElement(new Option("\tThe AIC is used to choose the best iteration.",
"A", 0, "-A"));
return newVector.elements();
}
/**
* Parses a given list of options.
*
* Valid options are:
*
* -B
* Binary splits (convert nominal attributes to binary ones)
*
* -R
* Split on residuals instead of class values
*
* -C
* Use cross-validation for boosting at all nodes (i.e., disable heuristic)
*
* -P
* Use error on probabilities instead of misclassification error for stopping criterion of LogitBoost.
*
* -I <numIterations>
* Set fixed number of iterations for LogitBoost (instead of using cross-validation)
*
* -M <numInstances>
* Set minimum number of instances at which a node can be split (default 15)
*
* -W <beta>
* Set beta for weight trimming for LogitBoost. Set to 0 (default) for no weight trimming.
*
* -A
* The AIC is used to choose the best iteration.
*
*
* @param options the list of options as an array of strings
* @throws Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
setConvertNominal(Utils.getFlag('B', options));
setSplitOnResiduals(Utils.getFlag('R', options));
setFastRegression(!Utils.getFlag('C', options));
setErrorOnProbabilities(Utils.getFlag('P', options));
String optionString = Utils.getOption('I', options);
if (optionString.length() != 0) {
setNumBoostingIterations((new Integer(optionString)).intValue());
}
optionString = Utils.getOption('M', options);
if (optionString.length() != 0) {
setMinNumInstances((new Integer(optionString)).intValue());
}
optionString = Utils.getOption('W', options);
if (optionString.length() != 0) {
setWeightTrimBeta((new Double(optionString)).doubleValue());
}
setUseAIC(Utils.getFlag('A', options));
Utils.checkForRemainingOptions(options);
}
/**
* Gets the current settings of the Classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String[] getOptions() {
String[] options = new String[11];
int current = 0;
if (getConvertNominal()) {
options[current++] = "-B";
}
if (getSplitOnResiduals()) {
options[current++] = "-R";
}
if (!getFastRegression()) {
options[current++] = "-C";
}
if (getErrorOnProbabilities()) {
options[current++] = "-P";
}
options[current++] = "-I";
options[current++] = ""+getNumBoostingIterations();
options[current++] = "-M";
options[current++] = ""+getMinNumInstances();
options[current++] = "-W";
options[current++] = ""+getWeightTrimBeta();
if (getUseAIC()) {
options[current++] = "-A";
}
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Get the value of weightTrimBeta.
*/
public double getWeightTrimBeta(){
return m_weightTrimBeta;
}
/**
* Get the value of useAIC.
*
* @return Value of useAIC.
*/
public boolean getUseAIC(){
return m_useAIC;
}
/**
* Set the value of weightTrimBeta.
*/
public void setWeightTrimBeta(double n){
m_weightTrimBeta = n;
}
/**
* Set the value of useAIC.
*
* @param c Value to assign to useAIC.
*/
public void setUseAIC(boolean c){
m_useAIC = c;
}
/**
* Get the value of convertNominal.
*
* @return Value of convertNominal.
*/
public boolean getConvertNominal(){
return m_convertNominal;
}
/**
* Get the value of splitOnResiduals.
*
* @return Value of splitOnResiduals.
*/
public boolean getSplitOnResiduals(){
return m_splitOnResiduals;
}
/**
* Get the value of fastRegression.
*
* @return Value of fastRegression.
*/
public boolean getFastRegression(){
return m_fastRegression;
}
/**
* Get the value of errorOnProbabilities.
*
* @return Value of errorOnProbabilities.
*/
public boolean getErrorOnProbabilities(){
return m_errorOnProbabilities;
}
/**
* Get the value of numBoostingIterations.
*
* @return Value of numBoostingIterations.
*/
public int getNumBoostingIterations(){
return m_numBoostingIterations;
}
/**
* Get the value of minNumInstances.
*
* @return Value of minNumInstances.
*/
public int getMinNumInstances(){
return m_minNumInstances;
}
/**
* Set the value of convertNominal.
*
* @param c Value to assign to convertNominal.
*/
public void setConvertNominal(boolean c){
m_convertNominal = c;
}
/**
* Set the value of splitOnResiduals.
*
* @param c Value to assign to splitOnResiduals.
*/
public void setSplitOnResiduals(boolean c){
m_splitOnResiduals = c;
}
/**
* Set the value of fastRegression.
*
* @param c Value to assign to fastRegression.
*/
public void setFastRegression(boolean c){
m_fastRegression = c;
}
/**
* Set the value of errorOnProbabilities.
*
* @param c Value to assign to errorOnProbabilities.
*/
public void setErrorOnProbabilities(boolean c){
m_errorOnProbabilities = c;
}
/**
* Set the value of numBoostingIterations.
*
* @param c Value to assign to numBoostingIterations.
*/
public void setNumBoostingIterations(int c){
m_numBoostingIterations = c;
}
/**
* Set the value of minNumInstances.
*
* @param c Value to assign to minNumInstances.
*/
public void setMinNumInstances(int c){
m_minNumInstances = c;
}
/**
* Returns the type of graph this classifier
* represents.
* @return Drawable.TREE
*/
public int graphType() {
return Drawable.TREE;
}
/**
* Returns graph describing the tree.
*
* @return the graph describing the tree
* @throws Exception if graph can't be computed
*/
public String graph() throws Exception {
return m_tree.graph();
}
/**
* Returns the size of the tree
* @return the size of the tree
*/
public int measureTreeSize(){
return m_tree.numNodes();
}
/**
* Returns the number of leaves in the tree
* @return the number of leaves in the tree
*/
public int measureNumLeaves(){
return m_tree.numLeaves();
}
/**
* Returns an enumeration of the additional measure names
* @return an enumeration of the measure names
*/
public Enumeration enumerateMeasures() {
Vector newVector = new Vector(2);
newVector.addElement("measureTreeSize");
newVector.addElement("measureNumLeaves");
return newVector.elements();
}
/**
* Returns the value of the named measure
* @param additionalMeasureName the name of the measure to query for its value
* @return the value of the named measure
* @throws IllegalArgumentException if the named measure is not supported
*/
public double getMeasure(String additionalMeasureName) {
if (additionalMeasureName.compareToIgnoreCase("measureTreeSize") == 0) {
return measureTreeSize();
} else if (additionalMeasureName.compareToIgnoreCase("measureNumLeaves") == 0) {
return measureNumLeaves();
} else {
throw new IllegalArgumentException(additionalMeasureName
+ " not supported (LMT)");
}
}
/**
* Returns a string describing classifier
* @return a description suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Classifier for building 'logistic model trees', which are classification trees with "
+"logistic regression functions at the leaves. The algorithm can deal with binary and multi-class "
+"target variables, numeric and nominal attributes and missing values.\n\n"
+"For more information see: \n\n"
+ getTechnicalInformation().toString();
}
/**
* Returns an instance of a TechnicalInformation object, containing
* detailed information about the technical background of this class,
* e.g., paper reference or book this class is based on.
*
* @return the technical information about this class
*/
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation result;
TechnicalInformation additional;
result = new TechnicalInformation(Type.ARTICLE);
result.setValue(Field.AUTHOR, "Niels Landwehr and Mark Hall and Eibe Frank");
result.setValue(Field.TITLE, "Logistic Model Trees");
result.setValue(Field.JOURNAL, "Machine Learning");
result.setValue(Field.YEAR, "2005");
result.setValue(Field.VOLUME, "95");
result.setValue(Field.PAGES, "161-205");
result.setValue(Field.NUMBER, "1-2");
additional = result.add(Type.INPROCEEDINGS);
additional.setValue(Field.AUTHOR, "Marc Sumner and Eibe Frank and Mark Hall");
additional.setValue(Field.TITLE, "Speeding up Logistic Model Tree Induction");
additional.setValue(Field.BOOKTITLE, "9th European Conference on Principles and Practice of Knowledge Discovery in Databases");
additional.setValue(Field.YEAR, "2005");
additional.setValue(Field.PAGES, "675-683");
additional.setValue(Field.PUBLISHER, "Springer");
return result;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String convertNominalTipText() {
return "Convert all nominal attributes to binary ones before building the tree. "
+"This means that all splits in the final tree will be binary.";
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String splitOnResidualsTipText() {
return "Set splitting criterion based on the residuals of LogitBoost. "
+"There are two possible splitting criteria for LMT: the default is to use the C4.5 "
+"splitting criterion that uses information gain on the class variable. The other splitting "
+"criterion tries to improve the purity in the residuals produces when fitting the logistic "
+"regression functions. The choice of the splitting criterion does not usually affect classification "
+"accuracy much, but can produce different trees.";
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String fastRegressionTipText() {
return "Use heuristic that avoids cross-validating the number of Logit-Boost iterations at every node. "
+"When fitting the logistic regression functions at a node, LMT has to determine the number of LogitBoost "
+"iterations to run. Originally, this number was cross-validated at every node in the tree. "
+"To save time, this heuristic cross-validates the number only once and then uses that number at every "
+"node in the tree. Usually this does not decrease accuracy but improves runtime considerably.";
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String errorOnProbabilitiesTipText() {
return "Minimize error on probabilities instead of misclassification error when cross-validating the number "
+"of LogitBoost iterations. When set, the number of LogitBoost iterations is chosen that minimizes "
+"the root mean squared error instead of the misclassification error.";
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String numBoostingIterationsTipText() {
return "Set a fixed number of iterations for LogitBoost. If >= 0, this sets a fixed number of LogitBoost "
+"iterations that is used everywhere in the tree. If < 0, the number is cross-validated.";
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String minNumInstancesTipText() {
return "Set the minimum number of instances at which a node is considered for splitting. "
+"The default value is 15.";
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String weightTrimBetaTipText() {
return "Set the beta value used for weight trimming in LogitBoost. "
+"Only instances carrying (1 - beta)% of the weight from previous iteration "
+"are used in the next iteration. Set to 0 for no weight trimming. "
+"The default value is 0.";
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String useAICTipText() {
return "The AIC is used to determine when to stop LogitBoost iterations. "
+"The default is not to use AIC.";
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 5535 $");
}
/**
* Main method for testing this class
*
* @param argv the commandline options
*/
public static void main (String [] argv) {
runClassifier(new LMT(), argv);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy