All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.rules.OneR Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This version represents the developer version, the "bleeding edge" of development, you could say. New functionality gets added to this version.

There is a newer version: 3.9.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    OneR.java
 *    Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.rules;

import java.io.Serializable;
import java.util.Collections;
import java.util.Enumeration;
import java.util.LinkedList;
import java.util.ListIterator;
import java.util.Vector;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Sourcable;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WekaException;

/**
 *  Class for building and using a 1R classifier; in
 * other words, uses the minimum-error attribute for prediction, discretizing
 * numeric attributes. For more information, see:
*
* R.C. Holte (1993). Very simple classification rules perform well on most * commonly used datasets. Machine Learning. 11:63-91. *

* * * BibTeX: * *

 * @article{Holte1993,
 *    author = {R.C. Holte},
 *    journal = {Machine Learning},
 *    pages = {63-91},
 *    title = {Very simple classification rules perform well on most commonly used datasets},
 *    volume = {11},
 *    year = {1993}
 * }
 * 
*

* * * Valid options are: *

* *

 * -B <minimum bucket size>
 *  The minimum number of objects in a bucket (default: 6).
 * 
* * * * @author Ian H. Witten ([email protected]) * @version $Revision: 10153 $ */ public class OneR extends AbstractClassifier implements TechnicalInformationHandler, Sourcable { /** for serialization */ static final long serialVersionUID = -3459427003147861443L; /** * Returns a string describing classifier * * @return a description suitable for displaying in the explorer/experimenter * gui */ public String globalInfo() { return "Class for building and using a 1R classifier; in other words, uses " + "the minimum-error attribute for prediction, discretizing numeric " + "attributes. For more information, see:\n\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing detailed * information about the technical background of this class, e.g., paper * reference or book this class is based on. * * @return the technical information about this class */ @Override public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(Type.ARTICLE); result.setValue(Field.AUTHOR, "R.C. Holte"); result.setValue(Field.YEAR, "1993"); result .setValue(Field.TITLE, "Very simple classification rules perform well on most commonly used datasets"); result.setValue(Field.JOURNAL, "Machine Learning"); result.setValue(Field.VOLUME, "11"); result.setValue(Field.PAGES, "63-91"); return result; } /** * Class for storing store a 1R rule. */ private class OneRRule implements Serializable, RevisionHandler { /** for serialization */ static final long serialVersionUID = 2252814630957092281L; /** The class attribute. */ private final Attribute m_class; /** The number of instances used for building the rule. */ private final int m_numInst; /** Attribute to test */ private final Attribute m_attr; /** Training set examples this rule gets right */ private int m_correct; /** Predicted class for each value of attr */ private final int[] m_classifications; /** Predicted class for missing values */ private int m_missingValueClass = -1; /** Breakpoints (numeric attributes only) */ private double[] m_breakpoints; /** * Constructor for nominal attribute. * * @param data the data to work with * @param attribute the attribute to use * @throws Exception if something goes wrong */ public OneRRule(Instances data, Attribute attribute) throws Exception { m_class = data.classAttribute(); m_numInst = data.numInstances(); m_attr = attribute; m_correct = 0; m_classifications = new int[m_attr.numValues()]; } /** * Constructor for numeric attribute. * * @param data the data to work with * @param attribute the attribute to use * @param nBreaks the break point * @throws Exception if something goes wrong */ public OneRRule(Instances data, Attribute attribute, int nBreaks) throws Exception { m_class = data.classAttribute(); m_numInst = data.numInstances(); m_attr = attribute; m_correct = 0; m_classifications = new int[nBreaks]; m_breakpoints = new double[nBreaks - 1]; // last breakpoint is infinity } /** * Returns a description of the rule. * * @return a string representation of the rule */ @Override public String toString() { try { StringBuffer text = new StringBuffer(); text.append(m_attr.name() + ":\n"); for (int v = 0; v < m_classifications.length; v++) { text.append("\t"); if (m_attr.isNominal()) { text.append(m_attr.value(v)); } else if (v < m_breakpoints.length) { text.append("< " + m_breakpoints[v]); } else if (v > 0) { text.append(">= " + m_breakpoints[v - 1]); } else { text.append("not ?"); } text.append("\t-> " + m_class.value(m_classifications[v]) + "\n"); } if (m_missingValueClass != -1) { text.append("\t?\t-> " + m_class.value(m_missingValueClass) + "\n"); } text .append("(" + m_correct + "/" + m_numInst + " instances correct)\n"); return text.toString(); } catch (Exception e) { return "Can't print OneR classifier!"; } } /** * Returns the revision string. * * @return the revision */ @Override public String getRevision() { return RevisionUtils.extract("$Revision: 10153 $"); } } /** A 1-R rule */ private OneRRule m_rule; /** The minimum bucket size */ private int m_minBucketSize = 6; /** a ZeroR model in case no model can be built from the data */ private Classifier m_ZeroR; /** * Classifies a given instance. * * @param inst the instance to be classified * @return the classification of the instance */ @Override public double classifyInstance(Instance inst) throws Exception { // default model? if (m_ZeroR != null) { return m_ZeroR.classifyInstance(inst); } int v = 0; if (inst.isMissing(m_rule.m_attr)) { if (m_rule.m_missingValueClass != -1) { return m_rule.m_missingValueClass; } else { return 0; // missing values occur in test but not training set } } if (m_rule.m_attr.isNominal()) { v = (int) inst.value(m_rule.m_attr); } else { while (v < m_rule.m_breakpoints.length && inst.value(m_rule.m_attr) >= m_rule.m_breakpoints[v]) { v++; } } return m_rule.m_classifications[v]; } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ @Override public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); return result; } /** * Generates the classifier. * * @param instances the instances to be used for building the classifier * @throws Exception if the classifier can't be built successfully */ @Override public void buildClassifier(Instances instances) throws Exception { boolean noRule = true; // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class Instances data = new Instances(instances); data.deleteWithMissingClass(); // only class? -> build ZeroR model if (data.numAttributes() == 1) { System.err .println("Cannot build model (only class attribute present in data!), " + "using ZeroR model instead!"); m_ZeroR = new weka.classifiers.rules.ZeroR(); m_ZeroR.buildClassifier(data); return; } else { m_ZeroR = null; } // for each attribute ... Enumeration enu = instances.enumerateAttributes(); while (enu.hasMoreElements()) { try { OneRRule r = newRule(enu.nextElement(), data); // if this attribute is the best so far, replace the rule if (noRule || r.m_correct > m_rule.m_correct) { m_rule = r; } noRule = false; } catch (Exception ex) { } } if (noRule) { throw new WekaException("No attributes found to work with!"); } } /** * Create a rule branching on this attribute. * * @param attr the attribute to branch on * @param data the data to be used for creating the rule * @return the generated rule * @throws Exception if the rule can't be built successfully */ public OneRRule newRule(Attribute attr, Instances data) throws Exception { OneRRule r; // ... create array to hold the missing value counts int[] missingValueCounts = new int[data.classAttribute().numValues()]; if (attr.isNominal()) { r = newNominalRule(attr, data, missingValueCounts); } else { r = newNumericRule(attr, data, missingValueCounts); } r.m_missingValueClass = Utils.maxIndex(missingValueCounts); if (missingValueCounts[r.m_missingValueClass] == 0) { r.m_missingValueClass = -1; // signal for no missing value class } else { r.m_correct += missingValueCounts[r.m_missingValueClass]; } return r; } /** * Create a rule branching on this nominal attribute. * * @param attr the attribute to branch on * @param data the data to be used for creating the rule * @param missingValueCounts to be filled in * @return the generated rule * @throws Exception if the rule can't be built successfully */ public OneRRule newNominalRule(Attribute attr, Instances data, int[] missingValueCounts) throws Exception { // ... create arrays to hold the counts int[][] counts = new int[attr.numValues()][data.classAttribute() .numValues()]; // ... calculate the counts Enumeration enu = data.enumerateInstances(); while (enu.hasMoreElements()) { Instance i = enu.nextElement(); if (i.isMissing(attr)) { missingValueCounts[(int) i.classValue()]++; } else { counts[(int) i.value(attr)][(int) i.classValue()]++; } } OneRRule r = new OneRRule(data, attr); // create a new rule for (int value = 0; value < attr.numValues(); value++) { int best = Utils.maxIndex(counts[value]); r.m_classifications[value] = best; r.m_correct += counts[value][best]; } return r; } /** * Create a rule branching on this numeric attribute * * @param attr the attribute to branch on * @param data the data to be used for creating the rule * @param missingValueCounts to be filled in * @return the generated rule * @throws Exception if the rule can't be built successfully */ public OneRRule newNumericRule(Attribute attr, Instances data, int[] missingValueCounts) throws Exception { // make a copy before sorting so that ties are treated consistently // and aren't affected by sorting performed for any numeric // attributes processed before this one data = new Instances(data); int lastInstance = data.numInstances(); // missing values get sorted to the end of the instances data.sort(attr); while (lastInstance > 0 && data.instance(lastInstance - 1).isMissing(attr)) { lastInstance--; missingValueCounts[(int) data.instance(lastInstance).classValue()]++; } if (lastInstance == 0) { throw new Exception("Only missing values in the training data!"); } // gather class distributions for all values double lastValue = 0; LinkedList distributions = new LinkedList(); LinkedList values = new LinkedList(); int[] distribution = null; for (int i = 0; i < lastInstance; i++) { // new value? if ((i == 0) || (data.instance(i).value(attr) > lastValue)) { if (i != 0) { values.add((lastValue + data.instance(i).value(attr)) / 2.0); } lastValue = data.instance(i).value(attr); distribution = new int[data.numClasses()]; distributions.add(distribution); } distribution[(int) data.instance(i).classValue()]++; } values.add(Double.MAX_VALUE); // create iterator to go through list ListIterator it = distributions.listIterator(); ListIterator itVals = values.listIterator(); int[] oldDist = null; while (it.hasNext()) { // grab next trivial bucket and iterate to next value as well int[] newDist = it.next(); itVals.next(); // should we merge the two buckets? if ((oldDist != null) && // classes the same? ((Utils.maxIndex(newDist) == Utils.maxIndex(oldDist)) || // bucket not large enough? (oldDist[Utils.maxIndex(oldDist)] < m_minBucketSize))) { // add counts for (int j = 0; j < oldDist.length; j++) { newDist[j] += oldDist[j]; } // remove distribution it.previous(); // element just visited it.previous(); // previous element we want to remove it.remove(); it.next(); // back to element just visited // remove value itVals.previous(); // element just visited itVals.previous(); // previous element we want to remove itVals.remove(); itVals.next(); // back to element just visited } // make progress oldDist = newDist; } // last scan, merge adjacent intervals with same class and calculate correct // classifications int numCorrect = 0; it = distributions.listIterator(); itVals = values.listIterator(); oldDist = null; while (it.hasNext()) { // grab next trivial bucket and iterate to next value as well int[] newDist = it.next(); itVals.next(); // number of correct classifications does not change by merging numCorrect += newDist[Utils.maxIndex(newDist)]; // should we merge the two buckets? if ((oldDist != null) && // classes the same? (Utils.maxIndex(newDist) == Utils.maxIndex(oldDist))) { // add counts for (int j = 0; j < oldDist.length; j++) { newDist[j] += oldDist[j]; } // remove distribution it.previous(); // element just visited it.previous(); // previous element we want to remove it.remove(); it.next(); // back to element just visited // remove value itVals.previous(); // element just visited itVals.previous(); // previous element we want to remove itVals.remove(); itVals.next(); // back to element just visited } // make progress oldDist = newDist; } OneRRule r = new OneRRule(data, attr, distributions.size()); // new rule // with cl // branches r.m_correct = numCorrect; it = distributions.listIterator(); itVals = values.listIterator(); int v = 0; while (it.hasNext()) { r.m_classifications[v] = Utils.maxIndex(it.next()); double splitPoint = itVals.next(); if (itVals.hasNext()) { r.m_breakpoints[v] = splitPoint; } v++; } return r; } /** * Returns an enumeration describing the available options.. * * @return an enumeration of all the available options. */ @Override public Enumeration




© 2015 - 2024 Weber Informatics LLC | Privacy Policy