weka.classifiers.trees.DecisionStump Maven / Gradle / Ivy
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* DecisionStump.java
* Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.trees;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Sourcable;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.ContingencyTables;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
/**
* Class for building and using a decision stump. Usually used in conjunction with a boosting algorithm. Does regression (based on mean-squared error) or classification (based on entropy). Missing is treated as a separate value.
*
*
* Typical usage:
* java weka.classifiers.meta.LogitBoost -I 100 -W weka.classifiers.trees.DecisionStump
* -t training_data
*
* Valid options are:
*
* -D
* If set, classifier is run in debug mode and
* may output additional info to the console
*
*
* @author Eibe Frank ([email protected])
* @version $Revision: 9171 $
*/
public class DecisionStump
extends AbstractClassifier
implements WeightedInstancesHandler, Sourcable {
/** for serialization */
static final long serialVersionUID = 1618384535950391L;
/** The attribute used for classification. */
protected int m_AttIndex;
/** The split point (index respectively). */
protected double m_SplitPoint;
/** The distribution of class values or the means in each subset. */
protected double[][] m_Distribution;
/** The instances used for training. */
protected Instances m_Instances;
/** a ZeroR model in case no model can be built from the data */
protected Classifier m_ZeroR;
/**
* Returns a string describing classifier
* @return a description suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Class for building and using a decision stump. Usually used in "
+ "conjunction with a boosting algorithm. Does regression (based on "
+ "mean-squared error) or classification (based on entropy). Missing "
+ "is treated as a separate value.";
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// attributes
result.enable(Capability.NOMINAL_ATTRIBUTES);
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.DATE_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
// class
result.enable(Capability.NOMINAL_CLASS);
result.enable(Capability.NUMERIC_CLASS);
result.enable(Capability.DATE_CLASS);
result.enable(Capability.MISSING_CLASS_VALUES);
return result;
}
/**
* Generates the classifier.
*
* @param instances set of instances serving as training data
* @throws Exception if the classifier has not been generated successfully
*/
public void buildClassifier(Instances instances) throws Exception {
double bestVal = Double.MAX_VALUE, currVal;
double bestPoint = -Double.MAX_VALUE;
int bestAtt = -1, numClasses;
// can classifier handle the data?
getCapabilities().testWithFail(instances);
// remove instances with missing class
instances = new Instances(instances);
instances.deleteWithMissingClass();
// only class? -> build ZeroR model
if (instances.numAttributes() == 1) {
System.err.println(
"Cannot build model (only class attribute present in data!), "
+ "using ZeroR model instead!");
m_ZeroR = new weka.classifiers.rules.ZeroR();
m_ZeroR.buildClassifier(instances);
return;
}
else {
m_ZeroR = null;
}
double[][] bestDist = new double[3][instances.numClasses()];
m_Instances = new Instances(instances);
if (m_Instances.classAttribute().isNominal()) {
numClasses = m_Instances.numClasses();
} else {
numClasses = 1;
}
// For each attribute
boolean first = true;
for (int i = 0; i < m_Instances.numAttributes(); i++) {
if (i != m_Instances.classIndex()) {
// Reserve space for distribution.
m_Distribution = new double[3][numClasses];
// Compute value of criterion for best split on attribute
if (m_Instances.attribute(i).isNominal()) {
currVal = findSplitNominal(i);
} else {
currVal = findSplitNumeric(i);
}
if ((first) || (currVal < bestVal)) {
bestVal = currVal;
bestAtt = i;
bestPoint = m_SplitPoint;
for (int j = 0; j < 3; j++) {
System.arraycopy(m_Distribution[j], 0, bestDist[j], 0,
numClasses);
}
}
// First attribute has been investigated
first = false;
}
}
// Set attribute, split point and distribution.
m_AttIndex = bestAtt;
m_SplitPoint = bestPoint;
m_Distribution = bestDist;
if (m_Instances.classAttribute().isNominal()) {
for (int i = 0; i < m_Distribution.length; i++) {
double sumCounts = Utils.sum(m_Distribution[i]);
if (sumCounts == 0) { // This means there were only missing attribute values
System.arraycopy(m_Distribution[2], 0, m_Distribution[i], 0,
m_Distribution[2].length);
Utils.normalize(m_Distribution[i]);
} else {
Utils.normalize(m_Distribution[i], sumCounts);
}
}
}
// Save memory
m_Instances = new Instances(m_Instances, 0);
}
/**
* Calculates the class membership probabilities for the given test instance.
*
* @param instance the instance to be classified
* @return predicted class probability distribution
* @throws Exception if distribution can't be computed
*/
public double[] distributionForInstance(Instance instance) throws Exception {
// default model?
if (m_ZeroR != null) {
return m_ZeroR.distributionForInstance(instance);
}
return m_Distribution[whichSubset(instance)];
}
/**
* Returns the decision tree as Java source code.
*
* @param className the classname of the generated code
* @return the tree as Java source code
* @throws Exception if something goes wrong
*/
public String toSource(String className) throws Exception {
StringBuffer text = new StringBuffer("class ");
Attribute c = m_Instances.classAttribute();
text.append(className)
.append(" {\n"
+" public static double classify(Object[] i) {\n");
text.append(" /* " + m_Instances.attribute(m_AttIndex).name() + " */\n");
text.append(" if (i[").append(m_AttIndex);
text.append("] == null) { return ");
text.append(sourceClass(c, m_Distribution[2])).append(";");
if (m_Instances.attribute(m_AttIndex).isNominal()) {
text.append(" } else if (((String)i[").append(m_AttIndex);
text.append("]).equals(\"");
text.append(m_Instances.attribute(m_AttIndex).value((int)m_SplitPoint));
text.append("\")");
} else {
text.append(" } else if (((Double)i[").append(m_AttIndex);
text.append("]).doubleValue() <= ").append(m_SplitPoint);
}
text.append(") { return ");
text.append(sourceClass(c, m_Distribution[0])).append(";");
text.append(" } else { return ");
text.append(sourceClass(c, m_Distribution[1])).append(";");
text.append(" }\n }\n}\n");
return text.toString();
}
/**
* Returns the value as string out of the given distribution
*
* @param c the attribute to get the value for
* @param dist the distribution to extract the value
* @return the value
*/
protected String sourceClass(Attribute c, double []dist) {
if (c.isNominal()) {
return Integer.toString(Utils.maxIndex(dist));
} else {
return Double.toString(dist[0]);
}
}
/**
* Returns a description of the classifier.
*
* @return a description of the classifier as a string.
*/
public String toString(){
// only ZeroR model?
if (m_ZeroR != null) {
StringBuffer buf = new StringBuffer();
buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n");
buf.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n");
buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
buf.append(m_ZeroR.toString());
return buf.toString();
}
if (m_Instances == null) {
return "Decision Stump: No model built yet.";
}
try {
StringBuffer text = new StringBuffer();
text.append("Decision Stump\n\n");
text.append("Classifications\n\n");
Attribute att = m_Instances.attribute(m_AttIndex);
if (att.isNominal()) {
text.append(att.name() + " = " + att.value((int)m_SplitPoint) +
" : ");
text.append(printClass(m_Distribution[0]));
text.append(att.name() + " != " + att.value((int)m_SplitPoint) +
" : ");
text.append(printClass(m_Distribution[1]));
} else {
text.append(att.name() + " <= " + m_SplitPoint + " : ");
text.append(printClass(m_Distribution[0]));
text.append(att.name() + " > " + m_SplitPoint + " : ");
text.append(printClass(m_Distribution[1]));
}
text.append(att.name() + " is missing : ");
text.append(printClass(m_Distribution[2]));
if (m_Instances.classAttribute().isNominal()) {
text.append("\nClass distributions\n\n");
if (att.isNominal()) {
text.append(att.name() + " = " + att.value((int)m_SplitPoint) +
"\n");
text.append(printDist(m_Distribution[0]));
text.append(att.name() + " != " + att.value((int)m_SplitPoint) +
"\n");
text.append(printDist(m_Distribution[1]));
} else {
text.append(att.name() + " <= " + m_SplitPoint + "\n");
text.append(printDist(m_Distribution[0]));
text.append(att.name() + " > " + m_SplitPoint + "\n");
text.append(printDist(m_Distribution[1]));
}
text.append(att.name() + " is missing\n");
text.append(printDist(m_Distribution[2]));
}
return text.toString();
} catch (Exception e) {
return "Can't print decision stump classifier!";
}
}
/**
* Prints a class distribution.
*
* @param dist the class distribution to print
* @return the distribution as a string
* @throws Exception if distribution can't be printed
*/
protected String printDist(double[] dist) throws Exception {
StringBuffer text = new StringBuffer();
if (m_Instances.classAttribute().isNominal()) {
for (int i = 0; i < m_Instances.numClasses(); i++) {
text.append(m_Instances.classAttribute().value(i) + "\t");
}
text.append("\n");
for (int i = 0; i < m_Instances.numClasses(); i++) {
text.append(dist[i] + "\t");
}
text.append("\n");
}
return text.toString();
}
/**
* Prints a classification.
*
* @param dist the class distribution
* @return the classificationn as a string
* @throws Exception if the classification can't be printed
*/
protected String printClass(double[] dist) throws Exception {
StringBuffer text = new StringBuffer();
if (m_Instances.classAttribute().isNominal()) {
text.append(m_Instances.classAttribute().value(Utils.maxIndex(dist)));
} else {
text.append(dist[0]);
}
return text.toString() + "\n";
}
/**
* Finds best split for nominal attribute and returns value.
*
* @param index attribute index
* @return value of criterion for the best split
* @throws Exception if something goes wrong
*/
protected double findSplitNominal(int index) throws Exception {
if (m_Instances.classAttribute().isNominal()) {
return findSplitNominalNominal(index);
} else {
return findSplitNominalNumeric(index);
}
}
/**
* Finds best split for nominal attribute and nominal class
* and returns value.
*
* @param index attribute index
* @return value of criterion for the best split
* @throws Exception if something goes wrong
*/
protected double findSplitNominalNominal(int index) throws Exception {
double bestVal = Double.MAX_VALUE, currVal;
double[][] counts = new double[m_Instances.attribute(index).numValues()
+ 1][m_Instances.numClasses()];
double[] sumCounts = new double[m_Instances.numClasses()];
double[][] bestDist = new double[3][m_Instances.numClasses()];
int numMissing = 0;
// Compute counts for all the values
for (int i = 0; i < m_Instances.numInstances(); i++) {
Instance inst = m_Instances.instance(i);
if (inst.isMissing(index)) {
numMissing++;
counts[m_Instances.attribute(index).numValues()]
[(int)inst.classValue()] += inst.weight();
} else {
counts[(int)inst.value(index)][(int)inst.classValue()] += inst
.weight();
}
}
// Compute sum of counts
for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
for (int j = 0; j < m_Instances.numClasses(); j++) {
sumCounts[j] += counts[i][j];
}
}
// Make split counts for each possible split and evaluate
System.arraycopy(counts[m_Instances.attribute(index).numValues()], 0,
m_Distribution[2], 0, m_Instances.numClasses());
for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
for (int j = 0; j < m_Instances.numClasses(); j++) {
m_Distribution[0][j] = counts[i][j];
m_Distribution[1][j] = sumCounts[j] - counts[i][j];
}
currVal = ContingencyTables.entropyConditionedOnRows(m_Distribution);
if (currVal < bestVal) {
bestVal = currVal;
m_SplitPoint = (double)i;
for (int j = 0; j < 3; j++) {
System.arraycopy(m_Distribution[j], 0, bestDist[j], 0,
m_Instances.numClasses());
}
}
}
// No missing values in training data.
if (numMissing == 0) {
System.arraycopy(sumCounts, 0, bestDist[2], 0,
m_Instances.numClasses());
}
m_Distribution = bestDist;
return bestVal;
}
/**
* Finds best split for nominal attribute and numeric class
* and returns value.
*
* @param index attribute index
* @return value of criterion for the best split
* @throws Exception if something goes wrong
*/
protected double findSplitNominalNumeric(int index) throws Exception {
double bestVal = Double.MAX_VALUE, currVal;
double[] sumsSquaresPerValue =
new double[m_Instances.attribute(index).numValues()],
sumsPerValue = new double[m_Instances.attribute(index).numValues()],
weightsPerValue = new double[m_Instances.attribute(index).numValues()];
double totalSumSquaresW = 0, totalSumW = 0, totalSumOfWeightsW = 0,
totalSumOfWeights = 0, totalSum = 0;
double[] sumsSquares = new double[3], sumOfWeights = new double[3];
double[][] bestDist = new double[3][1];
// Compute counts for all the values
for (int i = 0; i < m_Instances.numInstances(); i++) {
Instance inst = m_Instances.instance(i);
if (inst.isMissing(index)) {
m_Distribution[2][0] += inst.classValue() * inst.weight();
sumsSquares[2] += inst.classValue() * inst.classValue()
* inst.weight();
sumOfWeights[2] += inst.weight();
} else {
weightsPerValue[(int)inst.value(index)] += inst.weight();
sumsPerValue[(int)inst.value(index)] += inst.classValue()
* inst.weight();
sumsSquaresPerValue[(int)inst.value(index)] +=
inst.classValue() * inst.classValue() * inst.weight();
}
totalSumOfWeights += inst.weight();
totalSum += inst.classValue() * inst.weight();
}
// Check if the total weight is zero
if (totalSumOfWeights <= 0) {
return bestVal;
}
// Compute sum of counts without missing ones
for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
totalSumOfWeightsW += weightsPerValue[i];
totalSumSquaresW += sumsSquaresPerValue[i];
totalSumW += sumsPerValue[i];
}
// Make split counts for each possible split and evaluate
for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
m_Distribution[0][0] = sumsPerValue[i];
sumsSquares[0] = sumsSquaresPerValue[i];
sumOfWeights[0] = weightsPerValue[i];
m_Distribution[1][0] = totalSumW - sumsPerValue[i];
sumsSquares[1] = totalSumSquaresW - sumsSquaresPerValue[i];
sumOfWeights[1] = totalSumOfWeightsW - weightsPerValue[i];
currVal = variance(m_Distribution, sumsSquares, sumOfWeights);
if (currVal < bestVal) {
bestVal = currVal;
m_SplitPoint = (double)i;
for (int j = 0; j < 3; j++) {
if (sumOfWeights[j] > 0) {
bestDist[j][0] = m_Distribution[j][0] / sumOfWeights[j];
} else {
bestDist[j][0] = totalSum / totalSumOfWeights;
}
}
}
}
m_Distribution = bestDist;
return bestVal;
}
/**
* Finds best split for numeric attribute and returns value.
*
* @param index attribute index
* @return value of criterion for the best split
* @throws Exception if something goes wrong
*/
protected double findSplitNumeric(int index) throws Exception {
if (m_Instances.classAttribute().isNominal()) {
return findSplitNumericNominal(index);
} else {
return findSplitNumericNumeric(index);
}
}
/**
* Finds best split for numeric attribute and nominal class
* and returns value.
*
* @param index attribute index
* @return value of criterion for the best split
* @throws Exception if something goes wrong
*/
protected double findSplitNumericNominal(int index) throws Exception {
double bestVal = Double.MAX_VALUE, currVal, currCutPoint;
int numMissing = 0;
double[] sum = new double[m_Instances.numClasses()];
double[][] bestDist = new double[3][m_Instances.numClasses()];
// Compute counts for all the values
for (int i = 0; i < m_Instances.numInstances(); i++) {
Instance inst = m_Instances.instance(i);
if (!inst.isMissing(index)) {
m_Distribution[1][(int)inst.classValue()] += inst.weight();
} else {
m_Distribution[2][(int)inst.classValue()] += inst.weight();
numMissing++;
}
}
System.arraycopy(m_Distribution[1], 0, sum, 0, m_Instances.numClasses());
// Save current distribution as best distribution
for (int j = 0; j < 3; j++) {
System.arraycopy(m_Distribution[j], 0, bestDist[j], 0,
m_Instances.numClasses());
}
// Sort instances
m_Instances.sort(index);
// Make split counts for each possible split and evaluate
for (int i = 0; i < m_Instances.numInstances() - (numMissing + 1); i++) {
Instance inst = m_Instances.instance(i);
Instance instPlusOne = m_Instances.instance(i + 1);
m_Distribution[0][(int)inst.classValue()] += inst.weight();
m_Distribution[1][(int)inst.classValue()] -= inst.weight();
if (inst.value(index) < instPlusOne.value(index)) {
currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0;
currVal = ContingencyTables.entropyConditionedOnRows(m_Distribution);
if (currVal < bestVal) {
m_SplitPoint = currCutPoint;
bestVal = currVal;
for (int j = 0; j < 3; j++) {
System.arraycopy(m_Distribution[j], 0, bestDist[j], 0,
m_Instances.numClasses());
}
}
}
}
// No missing values in training data.
if (numMissing == 0) {
System.arraycopy(sum, 0, bestDist[2], 0, m_Instances.numClasses());
}
m_Distribution = bestDist;
return bestVal;
}
/**
* Finds best split for numeric attribute and numeric class
* and returns value.
*
* @param index attribute index
* @return value of criterion for the best split
* @throws Exception if something goes wrong
*/
protected double findSplitNumericNumeric(int index) throws Exception {
double bestVal = Double.MAX_VALUE, currVal, currCutPoint;
int numMissing = 0;
double[] sumsSquares = new double[3], sumOfWeights = new double[3];
double[][] bestDist = new double[3][1];
double totalSum = 0, totalSumOfWeights = 0;
// Compute counts for all the values
for (int i = 0; i < m_Instances.numInstances(); i++) {
Instance inst = m_Instances.instance(i);
if (!inst.isMissing(index)) {
m_Distribution[1][0] += inst.classValue() * inst.weight();
sumsSquares[1] += inst.classValue() * inst.classValue()
* inst.weight();
sumOfWeights[1] += inst.weight();
} else {
m_Distribution[2][0] += inst.classValue() * inst.weight();
sumsSquares[2] += inst.classValue() * inst.classValue()
* inst.weight();
sumOfWeights[2] += inst.weight();
numMissing++;
}
totalSumOfWeights += inst.weight();
totalSum += inst.classValue() * inst.weight();
}
// Check if the total weight is zero
if (totalSumOfWeights <= 0) {
return bestVal;
}
// Sort instances
m_Instances.sort(index);
// Make split counts for each possible split and evaluate
for (int i = 0; i < m_Instances.numInstances() - (numMissing + 1); i++) {
Instance inst = m_Instances.instance(i);
Instance instPlusOne = m_Instances.instance(i + 1);
m_Distribution[0][0] += inst.classValue() * inst.weight();
sumsSquares[0] += inst.classValue() * inst.classValue() * inst.weight();
sumOfWeights[0] += inst.weight();
m_Distribution[1][0] -= inst.classValue() * inst.weight();
sumsSquares[1] -= inst.classValue() * inst.classValue() * inst.weight();
sumOfWeights[1] -= inst.weight();
if (inst.value(index) < instPlusOne.value(index)) {
currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0;
currVal = variance(m_Distribution, sumsSquares, sumOfWeights);
if (currVal < bestVal) {
m_SplitPoint = currCutPoint;
bestVal = currVal;
for (int j = 0; j < 3; j++) {
if (sumOfWeights[j] > 0) {
bestDist[j][0] = m_Distribution[j][0] / sumOfWeights[j];
} else {
bestDist[j][0] = totalSum / totalSumOfWeights;
}
}
}
}
}
m_Distribution = bestDist;
return bestVal;
}
/**
* Computes variance for subsets.
*
* @param s
* @param sS
* @param sumOfWeights
* @return the variance
*/
protected double variance(double[][] s,double[] sS,double[] sumOfWeights) {
double var = 0;
for (int i = 0; i < s.length; i++) {
if (sumOfWeights[i] > 0) {
var += sS[i] - ((s[i][0] * s[i][0]) / (double) sumOfWeights[i]);
}
}
return var;
}
/**
* Returns the subset an instance falls into.
*
* @param instance the instance to check
* @return the subset the instance falls into
* @throws Exception if something goes wrong
*/
protected int whichSubset(Instance instance) throws Exception {
if (instance.isMissing(m_AttIndex)) {
return 2;
} else if (instance.attribute(m_AttIndex).isNominal()) {
if ((int)instance.value(m_AttIndex) == m_SplitPoint) {
return 0;
} else {
return 1;
}
} else {
if (instance.value(m_AttIndex) <= m_SplitPoint) {
return 0;
} else {
return 1;
}
}
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 9171 $");
}
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
runClassifier(new DecisionStump(), argv);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy