weka.classifiers.rules.DTNB Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-stable Show documentation
Show all versions of weka-stable Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This is the stable version. Apart from bugfixes, this version
does not receive any other updates.
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* DecisionTable.java
* Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.rules;
import weka.attributeSelection.ASEvaluation;
import weka.attributeSelection.ASSearch;
import weka.attributeSelection.SubsetEvaluator;
import weka.classifiers.bayes.NaiveBayes;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.TechnicalInformation;
import weka.core.Utils;
import weka.core.Capabilities.Capability;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import java.util.BitSet;
import java.util.Enumeration;
import java.util.Vector;
/**
*
* Class for building and using a decision table/naive bayes hybrid classifier. At each point in the search, the algorithm evaluates the merit of dividing the attributes into two disjoint subsets: one for the decision table, the other for naive Bayes. A forward selection search is used, where at each step, selected attributes are modeled by naive Bayes and the remainder by the decision table, and all attributes are modelled by the decision table initially. At each step, the algorithm also considers dropping an attribute entirely from the model.
*
* For more information, see:
*
* Mark Hall, Eibe Frank: Combining Naive Bayes and Decision Tables. In: Proceedings of the 21st Florida Artificial Intelligence Society Conference (FLAIRS), ???-???, 2008.
*
*
* BibTeX:
*
* @inproceedings{Hall2008,
* author = {Mark Hall and Eibe Frank},
* booktitle = {Proceedings of the 21st Florida Artificial Intelligence Society Conference (FLAIRS)},
* pages = {318-319},
* publisher = {AAAI press},
* title = {Combining Naive Bayes and Decision Tables},
* year = {2008}
* }
*
*
*
* Valid options are:
*
* -X <number of folds>
* Use cross validation to evaluate features.
* Use number of folds = 1 for leave one out CV.
* (Default = leave one out CV)
*
* -E <acc | rmse | mae | auc>
* Performance evaluation measure to use for selecting attributes.
* (Default = accuracy for discrete class and rmse for numeric class)
*
* -I
* Use nearest neighbour instead of global table majority.
*
* -R
* Display decision table rules.
*
*
*
* @author Mark Hall (mhall{[at]}pentaho{[dot]}org)
* @author Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz)
*
* @version $Revision: 6269 $
*
*/
public class DTNB extends DecisionTable {
/**
* The naive Bayes half of the hybrid
*/
protected NaiveBayes m_NB;
/**
* The features used by naive Bayes
*/
private int [] m_nbFeatures;
/**
* Percentage of the total number of features used by the decision table
*/
private double m_percentUsedByDT;
/**
* Percentage of the features features that were dropped entirely
*/
private double m_percentDeleted;
static final long serialVersionUID = 2999557077765701326L;
/**
* Returns a string describing classifier
* @return a description suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return
"Class for building and using a decision table/naive bayes hybrid classifier. At each point "
+ "in the search, the algorithm evaluates the merit of dividing the attributes into two disjoint "
+ "subsets: one for the decision table, the other for naive Bayes. A forward selection search is "
+ "used, where at each step, selected attributes are modeled by naive Bayes and the remainder "
+ "by the decision table, and all attributes are modelled by the decision table initially. At each "
+ "step, the algorithm also considers dropping an attribute entirely from the model.\n\n"
+ "For more information, see: \n\n"
+ getTechnicalInformation().toString();
}
/**
* Returns an instance of a TechnicalInformation object, containing
* detailed information about the technical background of this class,
* e.g., paper reference or book this class is based on.
*
* @return the technical information about this class
*/
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation result;
result = new TechnicalInformation(Type.INPROCEEDINGS);
result.setValue(Field.AUTHOR, "Mark Hall and Eibe Frank");
result.setValue(Field.TITLE, "Combining Naive Bayes and Decision Tables");
result.setValue(Field.BOOKTITLE, "Proceedings of the 21st Florida Artificial Intelligence "
+ "Society Conference (FLAIRS)");
result.setValue(Field.YEAR, "2008");
result.setValue(Field.PAGES, "318-319");
result.setValue(Field.PUBLISHER, "AAAI press");
return result;
}
/**
* Calculates the accuracy on a test fold for internal cross validation
* of feature sets
*
* @param fold set of instances to be "left out" and classified
* @param fs currently selected feature set
* @return the accuracy for the fold
* @throws Exception if something goes wrong
*/
double evaluateFoldCV(Instances fold, int [] fs) throws Exception {
int i;
int ruleCount = 0;
int numFold = fold.numInstances();
int numCl = m_theInstances.classAttribute().numValues();
double [][] class_distribs = new double [numFold][numCl];
double [] instA = new double [fs.length];
double [] normDist;
DecisionTableHashKey thekey;
double acc = 0.0;
int classI = m_theInstances.classIndex();
Instance inst;
if (m_classIsNominal) {
normDist = new double [numCl];
} else {
normDist = new double [2];
}
// first *remove* instances
for (i=0;i= temp_merit) {
temp_merit = temp_merit_delete;
deleteBetter = true;
}
z = (temp_merit >= temp_best);
if (z) {
temp_best = temp_merit;
temp_index = i;
addone = true;
done = false;
if (deleteBetter) {
deleted = true;
} else {
deleted = false;
}
}
// unset this addition/deletion
temp_group.set(i);
}
}
if (addone) {
best_group.clear(temp_index);
best_merit = temp_best;
if (deleted) {
// ((EvalWithDelete)m_evaluator).getDeletedList().set(temp_index);
((EvalWithDelete)eval).getDeletedList().set(temp_index);
}
//System.err.println("----------------------");
//System.err.println("Best subset: (dec table)" + best_group);
//System.err.println("Best subset: (deleted)" + ((EvalWithDelete)m_evaluator).getDeletedList());
//System.err.println(best_merit);
}
}
return attributeList(best_group);
}
/**
* converts a BitSet into a list of attribute indexes
* @param group the BitSet to convert
* @return an array of attribute indexes
**/
protected int[] attributeList (BitSet group) {
int count = 0;
BitSet copy = (BitSet)group.clone();
/* remove any that have been completely deleted from DTNB
BitSet deleted = ((EvalWithDelete)m_evaluator).getDeletedList();
for (int i = 0; i < m_numAttributes; i++) {
if (deleted.get(i)) {
copy.clear(i);
}
} */
// count how many were selected
for (int i = 0; i < m_numAttributes; i++) {
if (copy.get(i)) {
count++;
}
}
int[] list = new int[count];
count = 0;
for (int i = 0; i < m_numAttributes; i++) {
if (copy.get(i)) {
list[count++] = i;
}
}
return list;
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 6269 $");
}
}
private void setUpSearch() {
m_backwardWithDelete = new BackwardsWithDelete();
}
/**
* Generates the classifier.
*
* @param data set of instances serving as training data
* @throws Exception if the classifier has not been generated successfully
*/
public void buildClassifier(Instances data) throws Exception {
m_saveMemory = false;
if (data.classAttribute().isNumeric()) {
throw new Exception("Can only handle nominal class!");
}
if (m_backwardWithDelete == null) {
setUpSearch();
m_search = m_backwardWithDelete;
}
/* if (m_search != m_backwardWithDelete) {
m_search = m_backwardWithDelete;
} */
super.buildClassifier(data);
// new NB stuff
// delete the features used by the decision table (not the class!!)
for (int i = 0; i < m_theInstances.numAttributes(); i++) {
m_theInstances.attribute(i).setWeight(1.0); // reset all weights
}
// m_nbFeatures = new int [m_decisionFeatures.length - 1];
int count = 0;
for (int i = 0; i < m_decisionFeatures.length; i++) {
if (m_decisionFeatures[i] != m_theInstances.classIndex()) {
count++;
// m_nbFeatures[count++] = m_decisionFeatures[i];
m_theInstances.attribute(m_decisionFeatures[i]).setWeight(0.0); // No influence for NB
}
}
double numDeleted = 0;
// remove any attributes that have been deleted completely from the DTNB
BitSet deleted = ((EvalWithDelete)m_evaluator).getDeletedList();
for (int i = 0; i < m_theInstances.numAttributes(); i++) {
if (deleted.get(i)) {
m_theInstances.attribute(i).setWeight(0.0);
// count--;
numDeleted++;
// System.err.println("Attribute "+i+" was eliminated completely");
}
}
m_percentUsedByDT = (double)count / (m_theInstances.numAttributes() - 1);
m_percentDeleted = numDeleted / (m_theInstances.numAttributes() -1);
m_NB = new NaiveBayes();
m_NB.buildClassifier(m_theInstances);
m_dtInstances = new Instances(m_dtInstances, 0);
m_theInstances = new Instances(m_theInstances, 0);
}
/**
* Calculates the class membership probabilities for the given
* test instance.
*
* @param instance the instance to be classified
* @return predicted class probability distribution
* @exception Exception if distribution can't be computed
*/
public double [] distributionForInstance(Instance instance)
throws Exception {
DecisionTableHashKey thekey;
double [] tempDist;
double [] normDist;
m_disTransform.input(instance);
m_disTransform.batchFinished();
instance = m_disTransform.output();
m_delTransform.input(instance);
m_delTransform.batchFinished();
Instance dtInstance = m_delTransform.output();
thekey = new DecisionTableHashKey(dtInstance, dtInstance.numAttributes(), false);
// if this one is not in the table
if ((tempDist = (double [])m_entries.get(thekey)) == null) {
if (m_useIBk) {
tempDist = m_ibk.distributionForInstance(dtInstance);
} else {
// tempDist = new double [m_theInstances.classAttribute().numValues()];
// tempDist[(int)m_majority] = 1.0;
tempDist = m_classPriors.clone();
// return tempDist; ??????
}
} else {
// normalise distribution
normDist = new double [tempDist.length];
System.arraycopy(tempDist,0,normDist,0,tempDist.length);
Utils.normalize(normDist);
tempDist = normDist;
}
double [] nbDist = m_NB.distributionForInstance(instance);
for (int i = 0; i < nbDist.length; i++) {
tempDist[i] = (Math.log(tempDist[i]) - Math.log(m_classPriors[i]));
tempDist[i] += Math.log(nbDist[i]);
/*tempDist[i] *= nbDist[i];
tempDist[i] /= m_classPriors[i];*/
}
tempDist = Utils.logs2probs(tempDist);
Utils.normalize(tempDist);
return tempDist;
}
public String toString() {
String sS = super.toString();
if (m_displayRules && m_NB != null) {
sS += m_NB.toString();
}
return sS;
}
/**
* Returns the number of rules
* @return the number of rules
*/
public double measurePercentAttsUsedByDT() {
return m_percentUsedByDT;
}
/**
* Returns an enumeration of the additional measure names
* @return an enumeration of the measure names
*/
public Enumeration enumerateMeasures() {
Vector newVector = new Vector(2);
newVector.addElement("measureNumRules");
newVector.addElement("measurePercentAttsUsedByDT");
return newVector.elements();
}
/**
* Returns the value of the named measure
* @param additionalMeasureName the name of the measure to query for its value
* @return the value of the named measure
* @throws IllegalArgumentException if the named measure is not supported
*/
public double getMeasure(String additionalMeasureName) {
if (additionalMeasureName.compareToIgnoreCase("measureNumRules") == 0) {
return measureNumRules();
} else if (additionalMeasureName.compareToIgnoreCase("measurePercentAttsUsedByDT") == 0) {
return measurePercentAttsUsedByDT();
} else {
throw new IllegalArgumentException(additionalMeasureName
+ " not supported (DecisionTable)");
}
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disable(Capability.NUMERIC_CLASS);
result.disable(Capability.DATE_CLASS);
return result;
}
/**
* Sets the search method to use
*
* @param search
*/
public void setSearch(ASSearch search) {
// Search method cannot be changed.
// Must be BackwardsWithDelete
return;
}
/**
* Gets the current search method
*
* @return the search method used
*/
public ASSearch getSearch() {
if (m_backwardWithDelete == null) {
setUpSearch();
// setSearch(m_backwardWithDelete);
m_search = m_backwardWithDelete;
}
return m_search;
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(7);
newVector.addElement(new Option(
"\tUse cross validation to evaluate features.\n" +
"\tUse number of folds = 1 for leave one out CV.\n" +
"\t(Default = leave one out CV)",
"X", 1, "-X "));
newVector.addElement(new Option(
"\tPerformance evaluation measure to use for selecting attributes.\n" +
"\t(Default = accuracy for discrete class and rmse for numeric class)",
"E", 1, "-E "));
newVector.addElement(new Option(
"\tUse nearest neighbour instead of global table majority.",
"I", 0, "-I"));
newVector.addElement(new Option(
"\tDisplay decision table rules.\n",
"R", 0, "-R"));
return newVector.elements();
}
/**
* Parses the options for this object.
*
* Valid options are:
*
* -X <number of folds>
* Use cross validation to evaluate features.
* Use number of folds = 1 for leave one out CV.
* (Default = leave one out CV)
*
* -E <acc | rmse | mae | auc>
* Performance evaluation measure to use for selecting attributes.
* (Default = accuracy for discrete class and rmse for numeric class)
*
* -I
* Use nearest neighbour instead of global table majority.
*
* -R
* Display decision table rules.
*
*
*
* @param options the list of options as an array of strings
* @throws Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String optionString;
resetOptions();
optionString = Utils.getOption('X',options);
if (optionString.length() != 0) {
setCrossVal(Integer.parseInt(optionString));
}
m_useIBk = Utils.getFlag('I',options);
m_displayRules = Utils.getFlag('R',options);
optionString = Utils.getOption('E', options);
if (optionString.length() != 0) {
if (optionString.equals("acc")) {
setEvaluationMeasure(new SelectedTag(EVAL_ACCURACY, TAGS_EVALUATION));
} else if (optionString.equals("rmse")) {
setEvaluationMeasure(new SelectedTag(EVAL_RMSE, TAGS_EVALUATION));
} else if (optionString.equals("mae")) {
setEvaluationMeasure(new SelectedTag(EVAL_MAE, TAGS_EVALUATION));
} else if (optionString.equals("auc")) {
setEvaluationMeasure(new SelectedTag(EVAL_AUC, TAGS_EVALUATION));
} else {
throw new IllegalArgumentException("Invalid evaluation measure");
}
}
}
/**
* Gets the current settings of the classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String [] getOptions() {
String [] options = new String [9];
int current = 0;
options[current++] = "-X"; options[current++] = "" + getCrossVal();
if (m_evaluationMeasure != EVAL_DEFAULT) {
options[current++] = "-E";
switch (m_evaluationMeasure) {
case EVAL_ACCURACY:
options[current++] = "acc";
break;
case EVAL_RMSE:
options[current++] = "rmse";
break;
case EVAL_MAE:
options[current++] = "mae";
break;
case EVAL_AUC:
options[current++] = "auc";
break;
}
}
if (m_useIBk) {
options[current++] = "-I";
}
if (m_displayRules) {
options[current++] = "-R";
}
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 6269 $");
}
/**
* Main method for testing this class.
*
* @param argv the command-line options
*/
public static void main(String [] argv) {
runClassifier(new DTNB(), argv);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy