weka.classifiers.rules.DecisionTable Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-stable Show documentation
Show all versions of weka-stable Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This is the stable version. Apart from bugfixes, this version
does not receive any other updates.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* DecisionTable.java
* Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.rules;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Hashtable;
import java.util.Random;
import java.util.Vector;
import weka.attributeSelection.ASEvaluation;
import weka.attributeSelection.ASSearch;
import weka.attributeSelection.BestFirst;
import weka.attributeSelection.SubsetEvaluator;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Evaluation;
import weka.classifiers.lazy.IBk;
import weka.core.AdditionalMeasureProducer;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;
/**
* Class for building and using a simple decision
* table majority classifier.
*
* For more information see:
*
* Ron Kohavi: The Power of Decision Tables. In: 8th European Conference on
* Machine Learning, 174-189, 1995.
*
*
*
* BibTeX:
*
*
* @inproceedings{Kohavi1995,
* author = {Ron Kohavi},
* booktitle = {8th European Conference on Machine Learning},
* pages = {174-189},
* publisher = {Springer},
* title = {The Power of Decision Tables},
* year = {1995}
* }
*
*
*
*
* Valid options are:
*
*
*
* -S <search method specification>
* Full class name of search method, followed
* by its options.
* eg: "weka.attributeSelection.BestFirst -D 1"
* (default weka.attributeSelection.BestFirst)
*
*
*
* -X <number of folds>
* Use cross validation to evaluate features.
* Use number of folds = 1 for leave one out CV.
* (Default = leave one out CV)
*
*
*
* -E <acc | rmse | mae | auc>
* Performance evaluation measure to use for selecting attributes.
* (Default = accuracy for discrete class and rmse for numeric class)
*
*
*
* -I
* Use nearest neighbour instead of global table majority.
*
*
*
* -R
* Display decision table rules.
*
*
*
* Options specific to search method weka.attributeSelection.BestFirst:
*
*
*
* -P <start set>
* Specify a starting set of attributes.
* Eg. 1,3,5-7.
*
*
*
* -D <0 = backward | 1 = forward | 2 = bi-directional>
* Direction of search. (default = 1).
*
*
*
* -N <num>
* Number of non-improving nodes to
* consider before terminating search.
*
*
*
* -S <num>
* Size of lookup cache for evaluated subsets.
* Expressed as a multiple of the number of
* attributes in the data set. (default = 1)
*
*
*
*
* @author Mark Hall ([email protected])
* @version $Revision: 15520 $
*/
public class DecisionTable extends AbstractClassifier implements OptionHandler,
WeightedInstancesHandler, AdditionalMeasureProducer,
TechnicalInformationHandler {
/** for serialization */
static final long serialVersionUID = 2888557078165701326L;
/** The hashtable used to hold training instances */
protected Hashtable m_entries;
/** The class priors to use when there is no match in the table */
protected double[] m_classPriorCounts;
protected double[] m_classPriors;
/** Holds the final feature set */
protected int[] m_decisionFeatures;
/** Discretization filter */
protected Filter m_disTransform;
/** Filter used to remove columns discarded by feature selection */
protected Remove m_delTransform;
/** IB1 used to classify non matching instances rather than majority class */
protected IBk m_ibk;
/** Holds the original training instances */
protected Instances m_theInstances;
/** Holds the final feature selected set of instances */
protected Instances m_dtInstances;
/** The number of attributes in the dataset */
protected int m_numAttributes;
/** The number of instances in the dataset */
private int m_numInstances;
/** Class is nominal */
protected boolean m_classIsNominal;
/** Use the IBk classifier rather than majority class */
protected boolean m_useIBk;
/** Display Rules */
protected boolean m_displayRules;
/** Number of folds for cross validating feature sets */
private int m_CVFolds;
/** Random numbers for use in cross validation */
private Random m_rr;
/** Holds the majority class */
protected double m_majority;
/** The search method to use */
protected ASSearch m_search = new BestFirst();
/** Our own internal evaluator */
protected ASEvaluation m_evaluator;
/** The evaluation object used to evaluate subsets */
protected Evaluation m_evaluation;
/** default is accuracy for discrete class and RMSE for numeric class */
public static final int EVAL_DEFAULT = 1;
public static final int EVAL_ACCURACY = 2;
public static final int EVAL_RMSE = 3;
public static final int EVAL_MAE = 4;
public static final int EVAL_AUC = 5;
public static final Tag[] TAGS_EVALUATION = {
new Tag(EVAL_DEFAULT,
"Default: accuracy (discrete class); RMSE (numeric class)"),
new Tag(EVAL_ACCURACY, "Accuracy (discrete class only"),
new Tag(EVAL_RMSE, "RMSE (of the class probabilities for discrete class)"),
new Tag(EVAL_MAE, "MAE (of the class probabilities for discrete class)"),
new Tag(EVAL_AUC, "AUC (area under the ROC curve - discrete class only)") };
protected int m_evaluationMeasure = EVAL_DEFAULT;
/**
* Returns a string describing classifier
*
* @return a description suitable for displaying in the explorer/experimenter
* gui
*/
public String globalInfo() {
return "Class for building and using a simple decision table majority "
+ "classifier.\n\n" + "For more information see: \n\n"
+ getTechnicalInformation().toString();
}
/**
* Returns an instance of a TechnicalInformation object, containing detailed
* information about the technical background of this class, e.g., paper
* reference or book this class is based on.
*
* @return the technical information about this class
*/
@Override
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation result;
result = new TechnicalInformation(Type.INPROCEEDINGS);
result.setValue(Field.AUTHOR, "Ron Kohavi");
result.setValue(Field.TITLE, "The Power of Decision Tables");
result.setValue(Field.BOOKTITLE,
"8th European Conference on Machine Learning");
result.setValue(Field.YEAR, "1995");
result.setValue(Field.PAGES, "174-189");
result.setValue(Field.PUBLISHER, "Springer");
return result;
}
/**
* Inserts an instance into the hash table
*
* @param inst instance to be inserted
* @param instA to create the hash key from
* @throws Exception if the instance can't be inserted
*/
private void insertIntoTable(Instance inst, double[] instA) throws Exception {
double[] tempClassDist2;
double[] newDist;
DecisionTableHashKey thekey;
if (instA != null) {
thekey = new DecisionTableHashKey(instA);
} else {
thekey = new DecisionTableHashKey(inst, inst.numAttributes(), false);
}
// see if this one is already in the table
tempClassDist2 = m_entries.get(thekey);
if (tempClassDist2 == null) {
if (m_classIsNominal) {
newDist = new double[m_theInstances.classAttribute().numValues()];
// Leplace estimation
for (int i = 0; i < m_theInstances.classAttribute().numValues(); i++) {
newDist[i] = 1.0;
}
newDist[(int) inst.classValue()] = inst.weight();
// add to the table
m_entries.put(thekey, newDist);
} else {
newDist = new double[2];
newDist[0] = inst.classValue() * inst.weight();
newDist[1] = inst.weight();
// add to the table
m_entries.put(thekey, newDist);
}
} else {
// update the distribution for this instance
if (m_classIsNominal) {
tempClassDist2[(int) inst.classValue()] += inst.weight();
// update the table
m_entries.put(thekey, tempClassDist2);
} else {
tempClassDist2[0] += (inst.classValue() * inst.weight());
tempClassDist2[1] += inst.weight();
// update the table
m_entries.put(thekey, tempClassDist2);
}
}
}
/**
* Classifies an instance for internal leave one out cross validation of
* feature sets
*
* @param instance instance to be "left out" and classified
* @param instA feature values of the selected features for the instance
* @return the classification of the instance
* @throws Exception if something goes wrong
*/
protected double evaluateInstanceLeaveOneOut(Instance instance, double[] instA)
throws Exception {
// System.err.println("---------------- superclass leave-one-out ------------");
DecisionTableHashKey thekey;
double[] tempDist;
double[] normDist;
thekey = new DecisionTableHashKey(instA);
if (m_classIsNominal) {
// if this one is not in the table
if ((tempDist = m_entries.get(thekey)) == null) {
throw new Error("This should never happen!");
} else {
normDist = new double[tempDist.length];
System.arraycopy(tempDist, 0, normDist, 0, tempDist.length);
normDist[(int) instance.classValue()] -= instance.weight();
// update the table
// first check to see if the class counts are all zero now
boolean ok = false;
for (double element : normDist) {
if (Utils.gr(element, 1.0)) {
ok = true;
break;
}
}
// downdate the class prior counts
m_classPriorCounts[(int) instance.classValue()] -= instance.weight();
double[] classPriors = m_classPriorCounts.clone();
Utils.normalize(classPriors);
if (!ok) { // majority class
normDist = classPriors;
}
m_classPriorCounts[(int) instance.classValue()] += instance.weight();
// if (ok) {
Utils.normalize(normDist);
if (m_evaluationMeasure == EVAL_AUC) {
m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, instance);
} else {
m_evaluation.evaluateModelOnce(normDist, instance);
}
return Utils.maxIndex(normDist);
/*
* } else { normDist = new double [normDist.length];
* normDist[(int)m_majority] = 1.0; if (m_evaluationMeasure == EVAL_AUC)
* { m_evaluation.evaluateModelOnceAndRecordPrediction(normDist,
* instance); } else { m_evaluation.evaluateModelOnce(normDist,
* instance); } return m_majority; }
*/
}
// return Utils.maxIndex(tempDist);
} else {
// see if this one is already in the table
if ((tempDist = m_entries.get(thekey)) != null) {
normDist = new double[tempDist.length];
System.arraycopy(tempDist, 0, normDist, 0, tempDist.length);
normDist[0] -= (instance.classValue() * instance.weight());
normDist[1] -= instance.weight();
if (Utils.eq(normDist[1], 0.0)) {
double[] temp = new double[1];
temp[0] = m_majority;
m_evaluation.evaluateModelOnce(temp, instance);
return m_majority;
} else {
double[] temp = new double[1];
temp[0] = normDist[0] / normDist[1];
m_evaluation.evaluateModelOnce(temp, instance);
return temp[0];
}
} else {
throw new Error("This should never happen!");
}
}
// shouldn't get here
// return 0.0;
}
/**
* Calculates the accuracy on a test fold for internal cross validation of
* feature sets
*
* @param fold set of instances to be "left out" and classified
* @param fs currently selected feature set
* @return the accuracy for the fold
* @throws Exception if something goes wrong
*/
protected double evaluateFoldCV(Instances fold, int[] fs) throws Exception {
int i;
int numFold = fold.numInstances();
int numCl = m_theInstances.classAttribute().numValues();
double[][] class_distribs = new double[numFold][numCl];
double[] instA = new double[fs.length];
double[] normDist;
DecisionTableHashKey thekey;
double acc = 0.0;
int classI = m_theInstances.classIndex();
Instance inst;
if (m_classIsNominal) {
normDist = new double[numCl];
} else {
normDist = new double[2];
}
// first *remove* instances
for (i = 0; i < numFold; i++) {
inst = fold.instance(i);
for (int j = 0; j < fs.length; j++) {
if (fs[j] == classI) {
instA[j] = Double.MAX_VALUE; // missing for the class
} else if (inst.isMissing(fs[j])) {
instA[j] = Double.MAX_VALUE;
} else {
instA[j] = inst.value(fs[j]);
}
}
thekey = new DecisionTableHashKey(instA);
if ((class_distribs[i] = m_entries.get(thekey)) == null) {
throw new Error("This should never happen!");
} else {
if (m_classIsNominal) {
class_distribs[i][(int) inst.classValue()] -= inst.weight();
} else {
class_distribs[i][0] -= (inst.classValue() * inst.weight());
class_distribs[i][1] -= inst.weight();
}
}
m_classPriorCounts[(int) inst.classValue()] -= inst.weight();
}
double[] classPriors = m_classPriorCounts.clone();
Utils.normalize(classPriors);
// now classify instances
for (i = 0; i < numFold; i++) {
inst = fold.instance(i);
System.arraycopy(class_distribs[i], 0, normDist, 0, normDist.length);
if (m_classIsNominal) {
boolean ok = false;
for (double element : normDist) {
if (Utils.gr(element, 1.0)) {
ok = true;
break;
}
}
if (!ok) { // majority class
normDist = classPriors.clone();
}
// if (ok) {
Utils.normalize(normDist);
if (m_evaluationMeasure == EVAL_AUC) {
m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst);
} else {
m_evaluation.evaluateModelOnce(normDist, inst);
}
/*
* } else { normDist[(int)m_majority] = 1.0; if (m_evaluationMeasure ==
* EVAL_AUC) {
* m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst); }
* else { m_evaluation.evaluateModelOnce(normDist, inst); } }
*/
} else {
if (Utils.eq(normDist[1], 0.0)) {
double[] temp = new double[1];
temp[0] = m_majority;
m_evaluation.evaluateModelOnce(temp, inst);
} else {
double[] temp = new double[1];
temp[0] = normDist[0] / normDist[1];
m_evaluation.evaluateModelOnce(temp, inst);
}
}
}
// now re-insert instances
for (i = 0; i < numFold; i++) {
inst = fold.instance(i);
m_classPriorCounts[(int) inst.classValue()] += inst.weight();
if (m_classIsNominal) {
class_distribs[i][(int) inst.classValue()] += inst.weight();
} else {
class_distribs[i][0] += (inst.classValue() * inst.weight());
class_distribs[i][1] += inst.weight();
}
}
return acc;
}
/**
* Evaluates a feature subset by cross validation
*
* @param feature_set the subset to be evaluated
* @param num_atts the number of attributes in the subset
* @return the estimated accuracy
* @throws Exception if subset can't be evaluated
*/
protected double estimatePerformance(BitSet feature_set, int num_atts)
throws Exception {
m_evaluation = new Evaluation(m_theInstances);
int i;
int[] fs = new int[num_atts];
double[] instA = new double[num_atts];
int classI = m_theInstances.classIndex();
int index = 0;
for (i = 0; i < m_numAttributes; i++) {
if (feature_set.get(i)) {
fs[index++] = i;
}
}
// create new hash table
m_entries = new Hashtable(
(int) (m_theInstances.numInstances() * 1.5));
// insert instances into the hash table
for (i = 0; i < m_numInstances; i++) {
Instance inst = m_theInstances.instance(i);
for (int j = 0; j < fs.length; j++) {
if (fs[j] == classI) {
instA[j] = Double.MAX_VALUE; // missing for the class
} else if (inst.isMissing(fs[j])) {
instA[j] = Double.MAX_VALUE;
} else {
instA[j] = inst.value(fs[j]);
}
}
insertIntoTable(inst, instA);
}
if (m_CVFolds == 1) {
// calculate leave one out error
for (i = 0; i < m_numInstances; i++) {
Instance inst = m_theInstances.instance(i);
for (int j = 0; j < fs.length; j++) {
if (fs[j] == classI) {
instA[j] = Double.MAX_VALUE; // missing for the class
} else if (inst.isMissing(fs[j])) {
instA[j] = Double.MAX_VALUE;
} else {
instA[j] = inst.value(fs[j]);
}
}
evaluateInstanceLeaveOneOut(inst, instA);
}
} else {
m_theInstances.randomize(m_rr);
m_theInstances.stratify(m_CVFolds);
// calculate 10 fold cross validation error
for (i = 0; i < m_CVFolds; i++) {
Instances insts = m_theInstances.testCV(m_CVFolds, i);
evaluateFoldCV(insts, fs);
}
}
switch (m_evaluationMeasure) {
case EVAL_DEFAULT:
if (m_classIsNominal) {
return m_evaluation.pctCorrect();
}
return -m_evaluation.rootMeanSquaredError();
case EVAL_ACCURACY:
return m_evaluation.pctCorrect();
case EVAL_RMSE:
return -m_evaluation.rootMeanSquaredError();
case EVAL_MAE:
return -m_evaluation.meanAbsoluteError();
case EVAL_AUC:
double[] classPriors = m_evaluation.getClassPriors();
Utils.normalize(classPriors);
double weightedAUC = 0;
for (i = 0; i < m_theInstances.classAttribute().numValues(); i++) {
double tempAUC = m_evaluation.areaUnderROC(i);
if (!Utils.isMissingValue(tempAUC)) {
weightedAUC += (classPriors[i] * tempAUC);
} else {
System.err.println("Undefined AUC!!");
}
}
return weightedAUC;
}
// shouldn't get here
return 0.0;
}
/**
* Resets the options.
*/
protected void resetOptions() {
m_entries = null;
m_decisionFeatures = null;
m_useIBk = false;
m_CVFolds = 1;
m_displayRules = false;
m_evaluationMeasure = EVAL_DEFAULT;
}
/**
* Constructor for a DecisionTable
*/
public DecisionTable() {
resetOptions();
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
@Override
public Enumeration
© 2015 - 2025 Weber Informatics LLC | Privacy Policy