All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.meta.MultiClassClassifier Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other breaking updates.

There is a newer version: 3.8.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    MultiClassClassifier.java
 *    Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.meta;

import java.io.Serializable;
import java.util.*;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.classifiers.rules.ZeroR;
import weka.core.*;
import weka.core.Capabilities.Capability;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.MakeIndicator;
import weka.filters.unsupervised.instance.RemoveWithValues;

/**
 
 * A metaclassifier for handling multi-class datasets with 2-class classifiers.
 * This classifier is also capable of applying error correcting output codes for increased accuracy.
 * If the base classifier cannot handle instance weights, and the instance weights are not uniform,
 * the data will be resampled with replacement based on the weights before being passed to the base classifier.
 * 

* * Valid options are:

* *

 -M <num>
 *  Sets the method to use. Valid values are 0 (1-against-all),
 *  1 (random codes), 2 (exhaustive code), and 3 (1-against-1). (default 0)
 * 
* *
 -R <num>
 *  Sets the multiplier when using random codes. (default 2.0)
* *
 -P
 *  Use pairwise coupling (only has an effect for 1-against1)
* *
 -L
 *  Use log loss decoding for random and exhaustive codes.
* *
 -S <num>
 *  Random number seed.
 *  (default 1)
* *
 -D
 *  If set, classifier is run in debug mode and
 *  may output additional info to the console
* *
 -W
 *  Full name of base classifier.
 *  (default: weka.classifiers.functions.Logistic)
* *
 
 * Options specific to classifier weka.classifiers.functions.Logistic:
 * 
* *
 -D
 *  Turn on debugging output.
* *
 -R <ridge>
 *  Set the ridge in the log-likelihood.
* *
 -M <number>
 *  Set the maximum number of iterations (default -1, until convergence).
* * * @author Eibe Frank ([email protected]) * @author Len Trigg ([email protected]) * @author Richard Kirkby ([email protected]) * @version $Revision: 15476 $ */ public class MultiClassClassifier extends RandomizableSingleClassifierEnhancer implements OptionHandler, WeightedInstancesHandler { /** for serialization */ static final long serialVersionUID = -3879602011542849141L; /** The classifiers. */ protected Classifier [] m_Classifiers; /** Use pairwise coupling with 1-vs-1 */ protected boolean m_pairwiseCoupling = false; /** Needed for pairwise coupling */ protected double [] m_SumOfWeights; /** The filters used to transform the class. */ protected Filter[] m_ClassFilters; /** ZeroR classifier for when all base classifier return zero probability. */ private ZeroR m_ZeroR; /** Internal copy of the class attribute for output purposes */ protected Attribute m_ClassAttribute; /** A transformed dataset header used by the 1-against-1 method */ protected Instances m_TwoClassDataset; /** * The multiplier when generating random codes. Will generate * numClasses * m_RandomWidthFactor codes */ private double m_RandomWidthFactor = 2.0; /** True if log loss decoding is to be used for random and exhaustive codes. */ protected boolean m_logLossDecoding = false; /** The multiclass method to use */ protected int m_Method = METHOD_1_AGAINST_ALL; /** 1-against-all */ public static final int METHOD_1_AGAINST_ALL = 0; /** random correction code */ public static final int METHOD_ERROR_RANDOM = 1; /** exhaustive correction code */ public static final int METHOD_ERROR_EXHAUSTIVE = 2; /** 1-against-1 */ public static final int METHOD_1_AGAINST_1 = 3; /** The error correction modes */ public static final Tag [] TAGS_METHOD = { new Tag(METHOD_1_AGAINST_ALL, "1-against-all"), new Tag(METHOD_ERROR_RANDOM, "Random correction code"), new Tag(METHOD_ERROR_EXHAUSTIVE, "Exhaustive correction code"), new Tag(METHOD_1_AGAINST_1, "1-against-1") }; /** * Constructor. */ public MultiClassClassifier() { m_Classifier = new weka.classifiers.functions.Logistic(); } /** * String describing default classifier. * * @return the default classifier classname */ protected String defaultClassifierString() { return "weka.classifiers.functions.Logistic"; } /** * Interface for the code constructors */ private abstract class Code implements Serializable, RevisionHandler { /** for serialization */ static final long serialVersionUID = 418095077487120846L; /** * Subclasses must allocate and fill these. * First dimension is number of codes. * Second dimension is number of classes. */ protected boolean [][] m_Codebits; /** * Returns the number of codes. * @return the number of codes */ public int size() { return m_Codebits.length; } /** * Returns the indices of the values set to true for this code, * using 1-based indexing (for input to Range). * * @param which the index * @return the 1-based indices */ public String getIndices(int which) { StringBuffer sb = new StringBuffer(); for (int i = 0; i < m_Codebits[which].length; i++) { if (m_Codebits[which][i]) { if (sb.length() != 0) { sb.append(','); } sb.append(i + 1); } } return sb.toString(); } /** * Returns a human-readable representation of the codes. * @return a string representation of the codes */ public String toString() { StringBuffer sb = new StringBuffer(); for(int i = 0; i < m_Codebits[0].length; i++) { for (int j = 0; j < m_Codebits.length; j++) { sb.append(m_Codebits[j][i] ? " 1" : " 0"); } sb.append('\n'); } return sb.toString(); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 15476 $"); } } /** * Constructs a code with no error correction */ private class StandardCode extends Code { /** for serialization */ static final long serialVersionUID = 3707829689461467358L; /** * constructor * * @param numClasses the number of classes */ public StandardCode(int numClasses) { m_Codebits = new boolean[numClasses][numClasses]; for (int i = 0; i < numClasses; i++) { m_Codebits[i][i] = true; } //System.err.println("Code:\n" + this); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 15476 $"); } } /** * Constructs a random code assignment */ private class RandomCode extends Code { /** for serialization */ static final long serialVersionUID = 4413410540703926563L; /** random number generator */ Random r = null; /** * constructor * * @param numClasses the number of classes * @param numCodes the number of codes * @param data the data to use */ public RandomCode(int numClasses, int numCodes, Instances data) { r = data.getRandomNumberGenerator(m_Seed); numCodes = Math.max(2, numCodes); // Need at least two classes m_Codebits = new boolean[numCodes][numClasses]; int i = 0; do { randomize(); //System.err.println(this); } while (!good() && (i++ < 100)); //System.err.println("Code:\n" + this); } private boolean good() { boolean [] ninClass = new boolean[m_Codebits[0].length]; boolean [] ainClass = new boolean[m_Codebits[0].length]; for (int i = 0; i < ainClass.length; i++) { ainClass[i] = true; } for (int i = 0; i < m_Codebits.length; i++) { boolean ninCode = false; boolean ainCode = true; for (int j = 0; j < m_Codebits[i].length; j++) { boolean current = m_Codebits[i][j]; ninCode = ninCode || current; ainCode = ainCode && current; ninClass[j] = ninClass[j] || current; ainClass[j] = ainClass[j] && current; } if (!ninCode || ainCode) { return false; } } for (int j = 0; j < ninClass.length; j++) { if (!ninClass[j] || ainClass[j]) { return false; } } return true; } /** * randomizes */ private void randomize() { for (int i = 0; i < m_Codebits.length; i++) { for (int j = 0; j < m_Codebits[i].length; j++) { double temp = r.nextDouble(); m_Codebits[i][j] = (temp < 0.5) ? false : true; } } } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 15476 $"); } } /* * TODO: Constructs codes as per: * Bose, R.C., Ray Chaudhuri (1960), On a class of error-correcting * binary group codes, Information and Control, 3, 68-79. * Hocquenghem, A. (1959) Codes corecteurs d'erreurs, Chiffres, 2, 147-156. */ //private class BCHCode extends Code {...} /** Constructs an exhaustive code assignment */ private class ExhaustiveCode extends Code { /** for serialization */ static final long serialVersionUID = 8090991039670804047L; /** * constructor * * @param numClasses the number of classes */ public ExhaustiveCode(int numClasses) { int width = (int)Math.pow(2, numClasses - 1) - 1; m_Codebits = new boolean[width][numClasses]; for (int j = 0; j < width; j++) { m_Codebits[j][0] = true; } for (int i = 1; i < numClasses; i++) { int skip = (int) Math.pow(2, numClasses - (i + 1)); for(int j = 0; j < width; j++) { m_Codebits[j][i] = ((j / skip) % 2 != 0); } } //System.err.println("Code:\n" + this); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 15476 $"); } } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // class result.disableAllClasses(); result.disableAllClassDependencies(); result.enable(Capability.NOMINAL_CLASS); return result; } /** * Builds the classifiers. * * @param insts the training data. * @throws Exception if a classifier can't be built */ public void buildClassifier(Instances insts) throws Exception { Instances newInsts; // can classifier handle the data? getCapabilities().testWithFail(insts); // zero training instances - could be incremental boolean zeroTrainingInstances = insts.numInstances() == 0; // remove instances with missing class insts = new Instances(insts); insts.deleteWithMissingClass(); if (m_Classifier == null) { throw new Exception("No base classifier has been set!"); } if (!insts.allInstanceWeightsIdentical() && !(m_Classifier instanceof WeightedInstancesHandler)) { Random r = (insts.numInstances() > 0) ? insts.getRandomNumberGenerator(getSeed()) : new Random(getSeed()); insts = insts.resampleWithWeights(r); } m_ZeroR = new ZeroR(); m_ZeroR.buildClassifier(insts); m_TwoClassDataset = null; int numClassifiers = insts.numClasses(); if (numClassifiers <= 2) { m_Classifiers = AbstractClassifier.makeCopies(m_Classifier, 1); m_Classifiers[0].buildClassifier(insts); m_ClassFilters = null; } else if (m_Method == METHOD_1_AGAINST_1) { // generate fastvector of pairs ArrayListpairs = new ArrayList(); for (int i=0; i 0 || zeroTrainingInstances) { newInsts.setClassIndex(insts.classIndex()); m_Classifiers[i].buildClassifier(newInsts); m_ClassFilters[i] = classFilter; m_SumOfWeights[i] = newInsts.sumOfWeights(); } else { m_Classifiers[i] = null; m_ClassFilters[i] = null; } } // construct a two-class header version of the dataset m_TwoClassDataset = new Instances(insts, 0); int classIndex = m_TwoClassDataset.classIndex(); m_TwoClassDataset.setClassIndex(-1); ArrayList classLabels = new ArrayList(); classLabels.add("class0"); classLabels.add("class1"); m_TwoClassDataset.replaceAttributeAt(new Attribute("class", classLabels), classIndex); m_TwoClassDataset.setClassIndex(classIndex); } else { // use error correcting code style methods Code code = null; switch (m_Method) { case METHOD_ERROR_EXHAUSTIVE: code = new ExhaustiveCode(numClassifiers); break; case METHOD_ERROR_RANDOM: code = new RandomCode(numClassifiers, (int)(numClassifiers * m_RandomWidthFactor), insts); break; case METHOD_1_AGAINST_ALL: code = new StandardCode(numClassifiers); break; default: throw new Exception("Unrecognized correction code type"); } numClassifiers = code.size(); m_Classifiers = AbstractClassifier.makeCopies(m_Classifier, numClassifiers); m_ClassFilters = new MakeIndicator[numClassifiers]; for (int i = 0; i < m_Classifiers.length; i++) { m_ClassFilters[i] = new MakeIndicator(); MakeIndicator classFilter = (MakeIndicator) m_ClassFilters[i]; classFilter.setAttributeIndex("" + (insts.classIndex() + 1)); classFilter.setValueIndices(code.getIndices(i)); classFilter.setNumeric(false); classFilter.setInputFormat(insts); newInsts = Filter.useFilter(insts, m_ClassFilters[i]); m_Classifiers[i].buildClassifier(newInsts); } } m_ClassAttribute = insts.classAttribute(); } /** * Returns the individual predictions of the base classifiers * for an instance. Used by StackedMultiClassClassifier. * Returns the probability for the second "class" predicted * by each base classifier. * * @param inst the instance to get the prediction for * @return the individual predictions * @throws Exception if the predictions can't be computed successfully */ public double[] individualPredictions(Instance inst) throws Exception { double[] result = null; if (m_Classifiers.length == 1) { result = new double[1]; result[0] = m_Classifiers[0].distributionForInstance(inst)[1]; } else { result = new double[m_ClassFilters.length]; for(int i = 0; i < m_ClassFilters.length; i++) { if (m_Classifiers[i] != null) { if (m_Method == METHOD_1_AGAINST_1) { Instance tempInst = (Instance)inst.copy(); tempInst.setDataset(m_TwoClassDataset); result[i] = m_Classifiers[i].distributionForInstance(tempInst)[1]; } else { m_ClassFilters[i].input(inst); m_ClassFilters[i].batchFinished(); result[i] = m_Classifiers[i]. distributionForInstance(m_ClassFilters[i].output())[1]; } } } } return result; } /** * Returns the distribution for an instance. * * @param inst the instance to get the distribution for * @return the distribution * @throws Exception if the distribution can't be computed successfully */ public double[] distributionForInstance(Instance inst) throws Exception { if (m_Classifiers.length == 1) { return m_Classifiers[0].distributionForInstance(inst); } double[] probs = new double[inst.numClasses()]; if (m_Method == METHOD_1_AGAINST_1) { double[][] r = new double[inst.numClasses()][inst.numClasses()]; double[][] n = new double[inst.numClasses()][inst.numClasses()]; for(int i = 0; i < m_ClassFilters.length; i++) { if (m_Classifiers[i] != null) { Instance tempInst = (Instance)inst.copy(); tempInst.setDataset(m_TwoClassDataset); double [] current = m_Classifiers[i].distributionForInstance(tempInst); Range range = new Range(((RemoveWithValues)m_ClassFilters[i]).getNominalIndices()); range.setUpper(m_ClassAttribute.numValues()); int[] pair = range.getSelection(); if (m_pairwiseCoupling && inst.numClasses() > 2) { r[pair[0]][pair[1]] = current[0]; n[pair[0]][pair[1]] = m_SumOfWeights[i]; } else { if (current[0] > current[1]) { probs[pair[0]] += 1.0; } else { probs[pair[1]] += 1.0; } } } } if (m_pairwiseCoupling && inst.numClasses() > 2) { return pairwiseCoupling(n, r); } } else if (m_Method == METHOD_1_AGAINST_ALL) { for(int i = 0; i < m_ClassFilters.length; i++) { m_ClassFilters[i].input(inst); m_ClassFilters[i].batchFinished(); probs[i] = m_Classifiers[i].distributionForInstance(m_ClassFilters[i].output())[1]; } } else { if (getLogLossDecoding()) { Arrays.fill(probs, 1.0); for (int i = 0; i < m_ClassFilters.length; i++) { m_ClassFilters[i].input(inst); m_ClassFilters[i].batchFinished(); double[] current = m_Classifiers[i].distributionForInstance(m_ClassFilters[i].output()); for (int j = 0; j < m_ClassAttribute.numValues(); j++) { if (((MakeIndicator) m_ClassFilters[i]).getValueRange().isInRange(j)) { probs[j] += Math.log(Utils.SMALL + (1.0 - 2 * Utils.SMALL) * current[1]); } else { probs[j] += Math.log(Utils.SMALL + (1.0 - 2 * Utils.SMALL) * current[0]); } } } probs = Utils.logs2probs(probs); } else { // Use old-style decoding for (int i = 0; i < m_ClassFilters.length; i++) { m_ClassFilters[i].input(inst); m_ClassFilters[i].batchFinished(); double[] current = m_Classifiers[i].distributionForInstance(m_ClassFilters[i].output()); for (int j = 0; j < m_ClassAttribute.numValues(); j++) { if (((MakeIndicator) m_ClassFilters[i]).getValueRange().isInRange(j)) { probs[j] += current[1]; } else { probs[j] += current[0]; } } } } } if (Utils.gr(Utils.sum(probs), 0)) { Utils.normalize(probs); return probs; } else { return m_ZeroR.distributionForInstance(inst); } } /** * Prints the classifiers. * * @return a string representation of the classifier */ public String toString() { if (m_Classifiers == null) { return "MultiClassClassifier: No model built yet."; } StringBuffer text = new StringBuffer(); text.append("MultiClassClassifier\n\n"); for (int i = 0; i < m_Classifiers.length; i++) { text.append("Classifier ").append(i + 1); if (m_Classifiers[i] != null) { if ((m_ClassFilters != null) && (m_ClassFilters[i] != null)) { if (m_ClassFilters[i] instanceof RemoveWithValues) { Range range = new Range(((RemoveWithValues)m_ClassFilters[i]) .getNominalIndices()); range.setUpper(m_ClassAttribute.numValues()); int[] pair = range.getSelection(); text.append(", " + (pair[0]+1) + " vs " + (pair[1]+1)); } else if (m_ClassFilters[i] instanceof MakeIndicator) { text.append(", using indicator values: "); text.append(((MakeIndicator)m_ClassFilters[i]).getValueRange()); } } text.append('\n'); text.append(m_Classifiers[i].toString() + "\n\n"); } else { text.append(" Skipped (no training examples)\n"); } } return text.toString(); } /** * Returns an enumeration describing the available options * * @return an enumeration of all the available options */ public Enumeration




© 2015 - 2024 Weber Informatics LLC | Privacy Policy