All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.rules.DecisionTable Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    DecisionTable.java
 *    Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.rules;

import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Hashtable;
import java.util.Random;
import java.util.Vector;

import weka.attributeSelection.ASEvaluation;
import weka.attributeSelection.ASSearch;
import weka.attributeSelection.BestFirst;
import weka.attributeSelection.SubsetEvaluator;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Evaluation;
import weka.classifiers.lazy.IBk;
import weka.core.AdditionalMeasureProducer;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

/**
 *  Class for building and using a simple decision
 * table majority classifier.
*
* For more information see:
*
* Ron Kohavi: The Power of Decision Tables. In: 8th European Conference on * Machine Learning, 174-189, 1995. *

* * * BibTeX: * *

 * @inproceedings{Kohavi1995,
 *    author = {Ron Kohavi},
 *    booktitle = {8th European Conference on Machine Learning},
 *    pages = {174-189},
 *    publisher = {Springer},
 *    title = {The Power of Decision Tables},
 *    year = {1995}
 * }
 * 
*

* * * Valid options are: *

* *

 * -S <search method specification>
 *  Full class name of search method, followed
 *  by its options.
 *  eg: "weka.attributeSelection.BestFirst -D 1"
 *  (default weka.attributeSelection.BestFirst)
 * 
* *
 * -X <number of folds>
 *  Use cross validation to evaluate features.
 *  Use number of folds = 1 for leave one out CV.
 *  (Default = leave one out CV)
 * 
* *
 * -E <acc | rmse | mae | auc>
 *  Performance evaluation measure to use for selecting attributes.
 *  (Default = accuracy for discrete class and rmse for numeric class)
 * 
* *
 * -I
 *  Use nearest neighbour instead of global table majority.
 * 
* *
 * -R
 *  Display decision table rules.
 * 
* *
 * Options specific to search method weka.attributeSelection.BestFirst:
 * 
* *
 * -P <start set>
 *  Specify a starting set of attributes.
 *  Eg. 1,3,5-7.
 * 
* *
 * -D <0 = backward | 1 = forward | 2 = bi-directional>
 *  Direction of search. (default = 1).
 * 
* *
 * -N <num>
 *  Number of non-improving nodes to
 *  consider before terminating search.
 * 
* *
 * -S <num>
 *  Size of lookup cache for evaluated subsets.
 *  Expressed as a multiple of the number of
 *  attributes in the data set. (default = 1)
 * 
* * * * @author Mark Hall ([email protected]) * @version $Revision: 15520 $ */ public class DecisionTable extends AbstractClassifier implements OptionHandler, WeightedInstancesHandler, AdditionalMeasureProducer, TechnicalInformationHandler { /** for serialization */ static final long serialVersionUID = 2888557078165701326L; /** The hashtable used to hold training instances */ protected Hashtable m_entries; /** The class priors to use when there is no match in the table */ protected double[] m_classPriorCounts; protected double[] m_classPriors; /** Holds the final feature set */ protected int[] m_decisionFeatures; /** Discretization filter */ protected Filter m_disTransform; /** Filter used to remove columns discarded by feature selection */ protected Remove m_delTransform; /** IB1 used to classify non matching instances rather than majority class */ protected IBk m_ibk; /** Holds the original training instances */ protected Instances m_theInstances; /** Holds the final feature selected set of instances */ protected Instances m_dtInstances; /** The number of attributes in the dataset */ protected int m_numAttributes; /** The number of instances in the dataset */ private int m_numInstances; /** Class is nominal */ protected boolean m_classIsNominal; /** Use the IBk classifier rather than majority class */ protected boolean m_useIBk; /** Display Rules */ protected boolean m_displayRules; /** Number of folds for cross validating feature sets */ private int m_CVFolds; /** Random numbers for use in cross validation */ private Random m_rr; /** Holds the majority class */ protected double m_majority; /** The search method to use */ protected ASSearch m_search = new BestFirst(); /** Our own internal evaluator */ protected ASEvaluation m_evaluator; /** The evaluation object used to evaluate subsets */ protected Evaluation m_evaluation; /** default is accuracy for discrete class and RMSE for numeric class */ public static final int EVAL_DEFAULT = 1; public static final int EVAL_ACCURACY = 2; public static final int EVAL_RMSE = 3; public static final int EVAL_MAE = 4; public static final int EVAL_AUC = 5; public static final Tag[] TAGS_EVALUATION = { new Tag(EVAL_DEFAULT, "Default: accuracy (discrete class); RMSE (numeric class)"), new Tag(EVAL_ACCURACY, "Accuracy (discrete class only"), new Tag(EVAL_RMSE, "RMSE (of the class probabilities for discrete class)"), new Tag(EVAL_MAE, "MAE (of the class probabilities for discrete class)"), new Tag(EVAL_AUC, "AUC (area under the ROC curve - discrete class only)") }; protected int m_evaluationMeasure = EVAL_DEFAULT; /** * Returns a string describing classifier * * @return a description suitable for displaying in the explorer/experimenter * gui */ public String globalInfo() { return "Class for building and using a simple decision table majority " + "classifier.\n\n" + "For more information see: \n\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing detailed * information about the technical background of this class, e.g., paper * reference or book this class is based on. * * @return the technical information about this class */ @Override public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(Type.INPROCEEDINGS); result.setValue(Field.AUTHOR, "Ron Kohavi"); result.setValue(Field.TITLE, "The Power of Decision Tables"); result.setValue(Field.BOOKTITLE, "8th European Conference on Machine Learning"); result.setValue(Field.YEAR, "1995"); result.setValue(Field.PAGES, "174-189"); result.setValue(Field.PUBLISHER, "Springer"); return result; } /** * Inserts an instance into the hash table * * @param inst instance to be inserted * @param instA to create the hash key from * @throws Exception if the instance can't be inserted */ private void insertIntoTable(Instance inst, double[] instA) throws Exception { double[] tempClassDist2; double[] newDist; DecisionTableHashKey thekey; if (instA != null) { thekey = new DecisionTableHashKey(instA); } else { thekey = new DecisionTableHashKey(inst, inst.numAttributes(), false); } // see if this one is already in the table tempClassDist2 = m_entries.get(thekey); if (tempClassDist2 == null) { if (m_classIsNominal) { newDist = new double[m_theInstances.classAttribute().numValues()]; // Leplace estimation for (int i = 0; i < m_theInstances.classAttribute().numValues(); i++) { newDist[i] = 1.0; } newDist[(int) inst.classValue()] = inst.weight(); // add to the table m_entries.put(thekey, newDist); } else { newDist = new double[2]; newDist[0] = inst.classValue() * inst.weight(); newDist[1] = inst.weight(); // add to the table m_entries.put(thekey, newDist); } } else { // update the distribution for this instance if (m_classIsNominal) { tempClassDist2[(int) inst.classValue()] += inst.weight(); // update the table m_entries.put(thekey, tempClassDist2); } else { tempClassDist2[0] += (inst.classValue() * inst.weight()); tempClassDist2[1] += inst.weight(); // update the table m_entries.put(thekey, tempClassDist2); } } } /** * Classifies an instance for internal leave one out cross validation of * feature sets * * @param instance instance to be "left out" and classified * @param instA feature values of the selected features for the instance * @return the classification of the instance * @throws Exception if something goes wrong */ protected double evaluateInstanceLeaveOneOut(Instance instance, double[] instA) throws Exception { // System.err.println("---------------- superclass leave-one-out ------------"); DecisionTableHashKey thekey; double[] tempDist; double[] normDist; thekey = new DecisionTableHashKey(instA); if (m_classIsNominal) { // if this one is not in the table if ((tempDist = m_entries.get(thekey)) == null) { throw new Error("This should never happen!"); } else { normDist = new double[tempDist.length]; System.arraycopy(tempDist, 0, normDist, 0, tempDist.length); normDist[(int) instance.classValue()] -= instance.weight(); // update the table // first check to see if the class counts are all zero now boolean ok = false; for (double element : normDist) { if (Utils.gr(element, 1.0)) { ok = true; break; } } // downdate the class prior counts m_classPriorCounts[(int) instance.classValue()] -= instance.weight(); double[] classPriors = m_classPriorCounts.clone(); Utils.normalize(classPriors); if (!ok) { // majority class normDist = classPriors; } m_classPriorCounts[(int) instance.classValue()] += instance.weight(); // if (ok) { Utils.normalize(normDist); if (m_evaluationMeasure == EVAL_AUC) { m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, instance); } else { m_evaluation.evaluateModelOnce(normDist, instance); } return Utils.maxIndex(normDist); /* * } else { normDist = new double [normDist.length]; * normDist[(int)m_majority] = 1.0; if (m_evaluationMeasure == EVAL_AUC) * { m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, * instance); } else { m_evaluation.evaluateModelOnce(normDist, * instance); } return m_majority; } */ } // return Utils.maxIndex(tempDist); } else { // see if this one is already in the table if ((tempDist = m_entries.get(thekey)) != null) { normDist = new double[tempDist.length]; System.arraycopy(tempDist, 0, normDist, 0, tempDist.length); normDist[0] -= (instance.classValue() * instance.weight()); normDist[1] -= instance.weight(); if (Utils.eq(normDist[1], 0.0)) { double[] temp = new double[1]; temp[0] = m_majority; m_evaluation.evaluateModelOnce(temp, instance); return m_majority; } else { double[] temp = new double[1]; temp[0] = normDist[0] / normDist[1]; m_evaluation.evaluateModelOnce(temp, instance); return temp[0]; } } else { throw new Error("This should never happen!"); } } // shouldn't get here // return 0.0; } /** * Calculates the accuracy on a test fold for internal cross validation of * feature sets * * @param fold set of instances to be "left out" and classified * @param fs currently selected feature set * @return the accuracy for the fold * @throws Exception if something goes wrong */ protected double evaluateFoldCV(Instances fold, int[] fs) throws Exception { int i; int numFold = fold.numInstances(); int numCl = m_theInstances.classAttribute().numValues(); double[][] class_distribs = new double[numFold][numCl]; double[] instA = new double[fs.length]; double[] normDist; DecisionTableHashKey thekey; double acc = 0.0; int classI = m_theInstances.classIndex(); Instance inst; if (m_classIsNominal) { normDist = new double[numCl]; } else { normDist = new double[2]; } // first *remove* instances for (i = 0; i < numFold; i++) { inst = fold.instance(i); for (int j = 0; j < fs.length; j++) { if (fs[j] == classI) { instA[j] = Double.MAX_VALUE; // missing for the class } else if (inst.isMissing(fs[j])) { instA[j] = Double.MAX_VALUE; } else { instA[j] = inst.value(fs[j]); } } thekey = new DecisionTableHashKey(instA); if ((class_distribs[i] = m_entries.get(thekey)) == null) { throw new Error("This should never happen!"); } else { if (m_classIsNominal) { class_distribs[i][(int) inst.classValue()] -= inst.weight(); } else { class_distribs[i][0] -= (inst.classValue() * inst.weight()); class_distribs[i][1] -= inst.weight(); } } m_classPriorCounts[(int) inst.classValue()] -= inst.weight(); } double[] classPriors = m_classPriorCounts.clone(); Utils.normalize(classPriors); // now classify instances for (i = 0; i < numFold; i++) { inst = fold.instance(i); System.arraycopy(class_distribs[i], 0, normDist, 0, normDist.length); if (m_classIsNominal) { boolean ok = false; for (double element : normDist) { if (Utils.gr(element, 1.0)) { ok = true; break; } } if (!ok) { // majority class normDist = classPriors.clone(); } // if (ok) { Utils.normalize(normDist); if (m_evaluationMeasure == EVAL_AUC) { m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst); } else { m_evaluation.evaluateModelOnce(normDist, inst); } /* * } else { normDist[(int)m_majority] = 1.0; if (m_evaluationMeasure == * EVAL_AUC) { * m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst); } * else { m_evaluation.evaluateModelOnce(normDist, inst); } } */ } else { if (Utils.eq(normDist[1], 0.0)) { double[] temp = new double[1]; temp[0] = m_majority; m_evaluation.evaluateModelOnce(temp, inst); } else { double[] temp = new double[1]; temp[0] = normDist[0] / normDist[1]; m_evaluation.evaluateModelOnce(temp, inst); } } } // now re-insert instances for (i = 0; i < numFold; i++) { inst = fold.instance(i); m_classPriorCounts[(int) inst.classValue()] += inst.weight(); if (m_classIsNominal) { class_distribs[i][(int) inst.classValue()] += inst.weight(); } else { class_distribs[i][0] += (inst.classValue() * inst.weight()); class_distribs[i][1] += inst.weight(); } } return acc; } /** * Evaluates a feature subset by cross validation * * @param feature_set the subset to be evaluated * @param num_atts the number of attributes in the subset * @return the estimated accuracy * @throws Exception if subset can't be evaluated */ protected double estimatePerformance(BitSet feature_set, int num_atts) throws Exception { m_evaluation = new Evaluation(m_theInstances); int i; int[] fs = new int[num_atts]; double[] instA = new double[num_atts]; int classI = m_theInstances.classIndex(); int index = 0; for (i = 0; i < m_numAttributes; i++) { if (feature_set.get(i)) { fs[index++] = i; } } // create new hash table m_entries = new Hashtable( (int) (m_theInstances.numInstances() * 1.5)); // insert instances into the hash table for (i = 0; i < m_numInstances; i++) { Instance inst = m_theInstances.instance(i); for (int j = 0; j < fs.length; j++) { if (fs[j] == classI) { instA[j] = Double.MAX_VALUE; // missing for the class } else if (inst.isMissing(fs[j])) { instA[j] = Double.MAX_VALUE; } else { instA[j] = inst.value(fs[j]); } } insertIntoTable(inst, instA); } if (m_CVFolds == 1) { // calculate leave one out error for (i = 0; i < m_numInstances; i++) { Instance inst = m_theInstances.instance(i); for (int j = 0; j < fs.length; j++) { if (fs[j] == classI) { instA[j] = Double.MAX_VALUE; // missing for the class } else if (inst.isMissing(fs[j])) { instA[j] = Double.MAX_VALUE; } else { instA[j] = inst.value(fs[j]); } } evaluateInstanceLeaveOneOut(inst, instA); } } else { m_theInstances.randomize(m_rr); m_theInstances.stratify(m_CVFolds); // calculate 10 fold cross validation error for (i = 0; i < m_CVFolds; i++) { Instances insts = m_theInstances.testCV(m_CVFolds, i); evaluateFoldCV(insts, fs); } } switch (m_evaluationMeasure) { case EVAL_DEFAULT: if (m_classIsNominal) { return m_evaluation.pctCorrect(); } return -m_evaluation.rootMeanSquaredError(); case EVAL_ACCURACY: return m_evaluation.pctCorrect(); case EVAL_RMSE: return -m_evaluation.rootMeanSquaredError(); case EVAL_MAE: return -m_evaluation.meanAbsoluteError(); case EVAL_AUC: double[] classPriors = m_evaluation.getClassPriors(); Utils.normalize(classPriors); double weightedAUC = 0; for (i = 0; i < m_theInstances.classAttribute().numValues(); i++) { double tempAUC = m_evaluation.areaUnderROC(i); if (!Utils.isMissingValue(tempAUC)) { weightedAUC += (classPriors[i] * tempAUC); } else { System.err.println("Undefined AUC!!"); } } return weightedAUC; } // shouldn't get here return 0.0; } /** * Resets the options. */ protected void resetOptions() { m_entries = null; m_decisionFeatures = null; m_useIBk = false; m_CVFolds = 1; m_displayRules = false; m_evaluationMeasure = EVAL_DEFAULT; } /** * Constructor for a DecisionTable */ public DecisionTable() { resetOptions(); } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ @Override public Enumeration