All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.meta.Vote Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    Vote.java
 *    Copyright (C) 2000-2012 University of Waikato
 *
 */

package weka.classifiers.meta;

import weka.classifiers.Classifier;
import weka.classifiers.RandomizableMultipleClassifiersCombiner;
import weka.core.Aggregateable;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Environment;
import weka.core.EnvironmentHandler;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Vector;

/**
 *  Class for combining classifiers. Different
 * combinations of probability estimates for classification are available.
*
* For more information see:
*
* Ludmila I. Kuncheva (2004). Combining Pattern Classifiers: Methods and * Algorithms. John Wiley and Sons, Inc..
*
* J. Kittler, M. Hatef, Robert P.W. Duin, J. Matas (1998). On combining * classifiers. IEEE Transactions on Pattern Analysis and Machine Intelligence. * 20(3):226-239. *

* * * Valid options are: *

* *

 * -P <path to serialized classifier>
 *  Full path to serialized classifier to include.
 *  May be specified multiple times to include
 *  multiple serialized classifiers. Note: it does
 *  not make sense to use pre-built classifiers in
 *  a cross-validation.
 * 
* *
 * -R <AVG|PROD|MAJ|MIN|MAX|MED>
 *  The combination rule to use
 *  (default: AVG)
 * 
* *
 * -print
 *  Print the individual models in the output
 * 
* *
 * -S <num>
 *  Random number seed.
 *  (default 1)
 * 
* *
 * -B <classifier specification>
 *  Full class name of classifier to include, followed
 *  by scheme options. May be specified multiple times.
 *  (default: "weka.classifiers.rules.ZeroR")
 * 
* *
 * -output-debug-info
 *  If set, classifier is run in debug mode and
 *  may output additional info to the console
 * 
* *
 * -do-not-check-capabilities
 *  If set, classifier capabilities are not checked before classifier is built
 *  (use with caution).
 * 
* *
 * Options specific to classifier weka.classifiers.rules.ZeroR:
 * 
* *
 * -output-debug-info
 *  If set, classifier is run in debug mode and
 *  may output additional info to the console
 * 
* *
 * -do-not-check-capabilities
 *  If set, classifier capabilities are not checked before classifier is built
 *  (use with caution).
 * 
* * * * BibTeX: * *
 * @book{Kuncheva2004,
 *    author = {Ludmila I. Kuncheva},
 *    publisher = {John Wiley and Sons, Inc.},
 *    title = {Combining Pattern Classifiers: Methods and Algorithms},
 *    year = {2004}
 * }
 * 
 * @article{Kittler1998,
 *    author = {J. Kittler and M. Hatef and Robert P.W. Duin and J. Matas},
 *    journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence},
 *    number = {3},
 *    pages = {226-239},
 *    title = {On combining classifiers},
 *    volume = {20},
 *    year = {1998}
 * }
 * 
*

* * * @author Alexander K. Seewald ([email protected]) * @author Eibe Frank ([email protected]) * @author Roberto Perdisci ([email protected]) * @version $Revision: 12424 $ */ public class Vote extends RandomizableMultipleClassifiersCombiner implements TechnicalInformationHandler, EnvironmentHandler, Aggregateable { /** for serialization */ static final long serialVersionUID = -637891196294399624L; /** combination rule: Average of Probabilities */ public static final int AVERAGE_RULE = 1; /** combination rule: Product of Probabilities (only nominal classes) */ public static final int PRODUCT_RULE = 2; /** combination rule: Majority Voting (only nominal classes) */ public static final int MAJORITY_VOTING_RULE = 3; /** combination rule: Minimum Probability */ public static final int MIN_RULE = 4; /** combination rule: Maximum Probability */ public static final int MAX_RULE = 5; /** combination rule: Median Probability (only numeric class) */ public static final int MEDIAN_RULE = 6; /** combination rules */ public static final Tag[] TAGS_RULES = { new Tag(AVERAGE_RULE, "AVG", "Average of Probabilities"), new Tag(PRODUCT_RULE, "PROD", "Product of Probabilities"), new Tag(MAJORITY_VOTING_RULE, "MAJ", "Majority Voting"), new Tag(MIN_RULE, "MIN", "Minimum Probability"), new Tag(MAX_RULE, "MAX", "Maximum Probability"), new Tag(MEDIAN_RULE, "MED", "Median") }; /** Combination Rule variable */ protected int m_CombinationRule = AVERAGE_RULE; /** List of file paths to serialized models to load */ protected List m_classifiersToLoad = new ArrayList(); /** List of de-serialized pre-built classifiers to include in the ensemble */ protected List m_preBuiltClassifiers = new ArrayList(); /** Environment variables */ protected transient Environment m_env = Environment.getSystemWide(); /** Structure of the training data */ protected Instances m_structure; /** Print the individual models in the output */ protected boolean m_dontPrintModels; /** * Returns a string describing classifier * * @return a description suitable for displaying in the explorer/experimenter * gui */ public String globalInfo() { return "Class for combining classifiers. Different combinations of " + "probability estimates for classification are available.\n\n" + "For more information see:\n\n" + getTechnicalInformation().toString(); } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ @Override public Enumeration

* * Valid options are: *

* *

   * -P <path to serialized classifier>
   *  Full path to serialized classifier to include.
   *  May be specified multiple times to include
   *  multiple serialized classifiers. Note: it does
   *  not make sense to use pre-built classifiers in
   *  a cross-validation.
   * 
* *
   * -R <AVG|PROD|MAJ|MIN|MAX|MED>
   *  The combination rule to use
   *  (default: AVG)
   * 
* *
   * -print
   *  Print the individual models in the output
   * 
* *
   * -S <num>
   *  Random number seed.
   *  (default 1)
   * 
* *
   * -B <classifier specification>
   *  Full class name of classifier to include, followed
   *  by scheme options. May be specified multiple times.
   *  (default: "weka.classifiers.rules.ZeroR")
   * 
* *
   * -output-debug-info
   *  If set, classifier is run in debug mode and
   *  may output additional info to the console
   * 
* *
   * -do-not-check-capabilities
   *  If set, classifier capabilities are not checked before classifier is built
   *  (use with caution).
   * 
* *
   * Options specific to classifier weka.classifiers.rules.ZeroR:
   * 
* *
   * -output-debug-info
   *  If set, classifier is run in debug mode and
   *  may output additional info to the console
   * 
* *
   * -do-not-check-capabilities
   *  If set, classifier capabilities are not checked before classifier is built
   *  (use with caution).
   * 
* * * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ @Override public void setOptions(String[] options) throws Exception { String tmpStr; tmpStr = Utils.getOption('R', options); if (tmpStr.length() != 0) { setCombinationRule(new SelectedTag(tmpStr, TAGS_RULES)); } else { setCombinationRule(new SelectedTag(AVERAGE_RULE, TAGS_RULES)); } m_classifiersToLoad.clear(); while (true) { String loadString = Utils.getOption('P', options); if (loadString.length() == 0) { break; } m_classifiersToLoad.add(loadString); } setDoNotPrintModels(Utils.getFlag("-do-not-print", options)); super.setOptions(options); } /** * Returns an instance of a TechnicalInformation object, containing detailed * information about the technical background of this class, e.g., paper * reference or book this class is based on. * * @return the technical information about this class */ @Override public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; TechnicalInformation additional; result = new TechnicalInformation(Type.BOOK); result.setValue(Field.AUTHOR, "Ludmila I. Kuncheva"); result.setValue(Field.TITLE, "Combining Pattern Classifiers: Methods and Algorithms"); result.setValue(Field.YEAR, "2004"); result.setValue(Field.PUBLISHER, "John Wiley and Sons, Inc."); additional = result.add(Type.ARTICLE); additional.setValue(Field.AUTHOR, "J. Kittler and M. Hatef and Robert P.W. Duin and J. Matas"); additional.setValue(Field.YEAR, "1998"); additional.setValue(Field.TITLE, "On combining classifiers"); additional.setValue(Field.JOURNAL, "IEEE Transactions on Pattern Analysis and Machine Intelligence"); additional.setValue(Field.VOLUME, "20"); additional.setValue(Field.NUMBER, "3"); additional.setValue(Field.PAGES, "226-239"); return result; } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ @Override public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); if (m_preBuiltClassifiers.size() == 0 && m_classifiersToLoad.size() > 0) { try { loadClassifiers(null); } catch (Exception e) { e.printStackTrace(); } } if (m_preBuiltClassifiers.size() > 0) { if (m_Classifiers.length == 0) { result = (Capabilities) m_preBuiltClassifiers.get(0).getCapabilities().clone(); } for (int i = 1; i < m_preBuiltClassifiers.size(); i++) { result.and(m_preBuiltClassifiers.get(i).getCapabilities()); } for (Capability cap : Capability.values()) { result.enableDependency(cap); } } // class if ((m_CombinationRule == PRODUCT_RULE) || (m_CombinationRule == MAJORITY_VOTING_RULE)) { result.disableAllClasses(); result.disableAllClassDependencies(); result.enable(Capability.NOMINAL_CLASS); result.enableDependency(Capability.NOMINAL_CLASS); } else if (m_CombinationRule == MEDIAN_RULE) { result.disableAllClasses(); result.disableAllClassDependencies(); result.enable(Capability.NUMERIC_CLASS); result.enableDependency(Capability.NUMERIC_CLASS); } return result; } /** * Buildclassifier selects a classifier from the set of classifiers by * minimising error on the training data. * * @param data the training data to be used for generating the boosted * classifier. * @throws Exception if the classifier could not be built successfully */ @Override public void buildClassifier(Instances data) throws Exception { // remove instances with missing class Instances newData = new Instances(data); newData.deleteWithMissingClass(); m_structure = new Instances(newData, 0); if (m_classifiersToLoad.size() > 0) { m_preBuiltClassifiers.clear(); loadClassifiers(data); if (m_Classifiers.length == 1 && m_Classifiers[0] instanceof weka.classifiers.rules.ZeroR) { // remove the single ZeroR m_Classifiers = new Classifier[0]; } } // can classifier handle the data? getCapabilities().testWithFail(data); for (int i = 0; i < m_Classifiers.length; i++) { getClassifier(i).buildClassifier(newData); } } /** * Load serialized models to include in the ensemble * * @param data training instances (used in a header compatibility check with * each of the loaded models) * * @throws Exception if there is a problem de-serializing a model */ private void loadClassifiers(Instances data) throws Exception { for (String path : m_classifiersToLoad) { if (Environment.containsEnvVariables(path)) { try { path = m_env.substitute(path); } catch (Exception ex) { } } File toLoad = new File(path); if (!toLoad.isFile()) { throw new Exception("\"" + path + "\" does not seem to be a valid file!"); } ObjectInputStream is = new ObjectInputStream(new BufferedInputStream(new FileInputStream( toLoad))); Object c = is.readObject(); if (!(c instanceof Classifier)) { is.close(); throw new Exception("\"" + path + "\" does not contain a classifier!"); } Object header = null; header = is.readObject(); if (header instanceof Instances) { if (data != null && !data.equalHeaders((Instances) header)) { is.close(); throw new Exception("\"" + path + "\" was trained with data that is " + "of a differnet structure than the incoming training data"); } } if (header == null) { System.out.println("[Vote] warning: no header instances for \"" + path + "\""); } is.close(); addPreBuiltClassifier((Classifier) c); } } /** * Add a prebuilt classifier to the list for use in the ensemble * * @param c a prebuilt Classifier to add. */ public void addPreBuiltClassifier(Classifier c) { m_preBuiltClassifiers.add(c); } /** * Remove a prebuilt classifier from the list to use in the ensemble * * @param c the classifier to remove */ public void removePreBuiltClassifier(Classifier c) { m_preBuiltClassifiers.remove(c); } /** * Classifies the given test instance. * * @param instance the instance to be classified * @return the predicted most likely class for the instance or * Utils.missingValue() if no prediction is made * @throws Exception if an error occurred during the prediction */ @Override public double classifyInstance(Instance instance) throws Exception { double result; double[] dist; int index; switch (m_CombinationRule) { case AVERAGE_RULE: case PRODUCT_RULE: case MAJORITY_VOTING_RULE: case MIN_RULE: case MAX_RULE: dist = distributionForInstance(instance); if (instance.classAttribute().isNominal()) { index = Utils.maxIndex(dist); if (dist[index] == 0) { result = Utils.missingValue(); } else { result = index; } } else if (instance.classAttribute().isNumeric()) { result = dist[0]; } else { result = Utils.missingValue(); } break; case MEDIAN_RULE: result = classifyInstanceMedian(instance); break; default: throw new IllegalStateException("Unknown combination rule '" + m_CombinationRule + "'!"); } return result; } /** * Classifies the given test instance, returning the median from all * classifiers. Can assume that class is numeric. * * @param instance the instance to be classified * @return the predicted most likely class for the instance or * Utils.missingValue() if no prediction is made * @throws Exception if an error occurred during the prediction */ protected double classifyInstanceMedian(Instance instance) throws Exception { double[] results = new double[m_Classifiers.length + m_preBuiltClassifiers.size()]; int numResults = 0; for (Classifier m_Classifier : m_Classifiers) { double pred = m_Classifier.classifyInstance(instance); if (!Utils.isMissingValue(pred)) { results[numResults++] = pred; } } for (int i = 0; i < m_preBuiltClassifiers.size(); i++) { double pred = m_preBuiltClassifiers.get(i).classifyInstance(instance); if (!Utils.isMissingValue(pred)) { results[numResults++] = pred; } } if (numResults == 0) { return Utils.missingValue(); } else if (numResults == 1) { return results[0]; } else { double[] actualResults = new double[numResults]; System.arraycopy(results, 0, actualResults, 0, numResults); return Utils.kthSmallestValue(actualResults, actualResults.length / 2); } } /** * Classifies a given instance using the selected combination rule. * * @param instance the instance to be classified * @return the distribution * @throws Exception if instance could not be classified successfully */ @Override public double[] distributionForInstance(Instance instance) throws Exception { double[] result = new double[instance.numClasses()]; switch (m_CombinationRule) { case AVERAGE_RULE: result = distributionForInstanceAverage(instance); break; case PRODUCT_RULE: result = distributionForInstanceProduct(instance); break; case MAJORITY_VOTING_RULE: result = distributionForInstanceMajorityVoting(instance); break; case MIN_RULE: result = distributionForInstanceMin(instance); break; case MAX_RULE: result = distributionForInstanceMax(instance); break; case MEDIAN_RULE: result[0] = classifyInstance(instance); break; default: throw new IllegalStateException("Unknown combination rule '" + m_CombinationRule + "'!"); } if (!instance.classAttribute().isNumeric() && (Utils.sum(result) > 0)) { Utils.normalize(result); } return result; } /** * Classifies a given instance using the Average of Probabilities combination * rule. * * @param instance the instance to be classified * @return the distribution * @throws Exception if instance could not be classified successfully */ protected double[] distributionForInstanceAverage(Instance instance) throws Exception { double[] probs = new double[instance.numClasses()]; double numPredictions = 0; for (int i = 0; i < m_Classifiers.length; i++) { double[] dist = getClassifier(i).distributionForInstance(instance); if (!instance.classAttribute().isNumeric() || !Utils.isMissingValue(dist[0])) { for (int j = 0; j < dist.length; j++) { probs[j] += dist[j]; } numPredictions++; } } for (int i = 0; i < m_preBuiltClassifiers.size(); i++) { double[] dist = m_preBuiltClassifiers.get(i).distributionForInstance(instance); if (!instance.classAttribute().isNumeric() || !Utils.isMissingValue(dist[0])) { for (int j = 0; j < dist.length; j++) { probs[j] += dist[j]; } numPredictions++; } } if (instance.classAttribute().isNumeric()) { if (numPredictions == 0) { probs[0] = Utils.missingValue(); } else { for (int j = 0; j < probs.length; j++) { probs[j] /= numPredictions; } } } else { // Should normalize "probability" distribution if (Utils.sum(probs) > 0) { Utils.normalize(probs); } } return probs; } /** * Classifies a given instance using the Product of Probabilities combination * rule. Can assume that class is nominal. * * @param instance the instance to be classified * @return the distribution * @throws Exception if instance could not be classified successfully */ protected double[] distributionForInstanceProduct(Instance instance) throws Exception { double[] probs = new double[instance.numClasses()]; for (int i = 0; i < probs.length; i++) { probs[i] = 1.0; } int numPredictions = 0; for (int i = 0; i < m_Classifiers.length; i++) { double[] dist = getClassifier(i).distributionForInstance(instance); if (Utils.sum(dist) > 0) { for (int j = 0; j < dist.length; j++) { probs[j] *= dist[j]; } numPredictions++; } } for (int i = 0; i < m_preBuiltClassifiers.size(); i++) { double[] dist = m_preBuiltClassifiers.get(i).distributionForInstance(instance); if (Utils.sum(dist) > 0) { for (int j = 0; j < dist.length; j++) { probs[j] *= dist[j]; } numPredictions++; } } // No predictions? if (numPredictions == 0) { return new double[instance.numClasses()]; } // Should normalize to get "probabilities" if (Utils.sum(probs) > 0) { Utils.normalize(probs); } return probs; } /** * Classifies a given instance using the Majority Voting combination rule. Can * assume that class is nominal. * * @param instance the instance to be classified * @return the distribution * @throws Exception if instance could not be classified successfully */ protected double[] distributionForInstanceMajorityVoting(Instance instance) throws Exception { double[] probs = new double[instance.classAttribute().numValues()]; double[] votes = new double[probs.length]; for (int i = 0; i < m_Classifiers.length; i++) { probs = getClassifier(i).distributionForInstance(instance); int maxIndex = 0; for (int j = 0; j < probs.length; j++) { if (probs[j] > probs[maxIndex]) { maxIndex = j; } } // Consider the cases when multiple classes happen to have the same // probability if (probs[maxIndex] > 0) { for (int j = 0; j < probs.length; j++) { if (probs[j] == probs[maxIndex]) { votes[j]++; } } } } for (int i = 0; i < m_preBuiltClassifiers.size(); i++) { probs = m_preBuiltClassifiers.get(i).distributionForInstance(instance); int maxIndex = 0; for (int j = 0; j < probs.length; j++) { if (probs[j] > probs[maxIndex]) { maxIndex = j; } } // Consider the cases when multiple classes happen to have the same // probability if (probs[maxIndex] > 0) { for (int j = 0; j < probs.length; j++) { if (probs[j] == probs[maxIndex]) { votes[j]++; } } } } int tmpMajorityIndex = 0; for (int k = 1; k < votes.length; k++) { if (votes[k] > votes[tmpMajorityIndex]) { tmpMajorityIndex = k; } } // No votes received if (votes[tmpMajorityIndex] == 0) { return new double[instance.numClasses()]; } // Consider the cases when multiple classes receive the same amount of votes Vector majorityIndexes = new Vector(); for (int k = 0; k < votes.length; k++) { if (votes[k] == votes[tmpMajorityIndex]) { majorityIndexes.add(k); } } int majorityIndex = tmpMajorityIndex; if (majorityIndexes.size() > 1) { // resolve ties by looking at the predicted distribution double[] distPreds = distributionForInstanceAverage(instance); majorityIndex = Utils.maxIndex(distPreds); // Resolve the ties according to a uniform random distribution // majorityIndex = majorityIndexes.get(m_Random.nextInt(majorityIndexes.size())); } // set probs to 0 probs = new double[probs.length]; probs[majorityIndex] = 1; // the class that have been voted the most // receives 1 return probs; } /** * Classifies a given instance using the Maximum Probability combination rule. * * @param instance the instance to be classified * @return the distribution * @throws Exception if instance could not be classified successfully */ protected double[] distributionForInstanceMax(Instance instance) throws Exception { double[] probs = new double[instance.numClasses()]; double numPredictions = 0; for (int i = 0; i < m_Classifiers.length; i++) { double[] dist = getClassifier(i).distributionForInstance(instance); if (!instance.classAttribute().isNumeric() || !Utils.isMissingValue(dist[0])) { for (int j = 0; j < dist.length; j++) { if ((probs[j] < dist[j]) || (numPredictions == 0)) { probs[j] = dist[j]; } } numPredictions++; } } for (int i = 0; i < m_preBuiltClassifiers.size(); i++) { double[] dist = m_preBuiltClassifiers.get(i).distributionForInstance(instance); if (!instance.classAttribute().isNumeric() || !Utils.isMissingValue(dist[0])) { for (int j = 0; j < dist.length; j++) { if ((probs[j] < dist[j]) || (numPredictions == 0)) { probs[j] = dist[j]; } } numPredictions++; } } if (instance.classAttribute().isNumeric()) { if (numPredictions == 0) { probs[0] = Utils.missingValue(); } } else { // Should normalize "probability" distribution if (Utils.sum(probs) > 0) { Utils.normalize(probs); } } return probs; } /** * Classifies a given instance using the Minimum Probability combination rule. * * @param instance the instance to be classified * @return the distribution * @throws Exception if instance could not be classified successfully */ protected double[] distributionForInstanceMin(Instance instance) throws Exception { double[] probs = new double[instance.numClasses()]; double numPredictions = 0; for (int i = 0; i < m_Classifiers.length; i++) { double[] dist = getClassifier(i).distributionForInstance(instance); if (!instance.classAttribute().isNumeric() || !Utils.isMissingValue(dist[0])) { for (int j = 0; j < dist.length; j++) { if ((probs[j] > dist[j]) || (numPredictions == 0)) { probs[j] = dist[j]; } } numPredictions++; } } for (int i = 0; i < m_preBuiltClassifiers.size(); i++) { double[] dist = m_preBuiltClassifiers.get(i).distributionForInstance(instance); if (!instance.classAttribute().isNumeric() || !Utils.isMissingValue(dist[0])) { for (int j = 0; j < dist.length; j++) { if ((probs[j] > dist[j]) || (numPredictions == 0)) { probs[j] = dist[j]; } } numPredictions++; } } if (instance.classAttribute().isNumeric()) { if (numPredictions == 0) { probs[0] = Utils.missingValue(); } } else { // Should normalize "probability" distribution if (Utils.sum(probs) > 0) { Utils.normalize(probs); } } return probs; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String combinationRuleTipText() { return "The combination rule used."; } /** * Gets the combination rule used * * @return the combination rule used */ public SelectedTag getCombinationRule() { return new SelectedTag(m_CombinationRule, TAGS_RULES); } /** * Sets the combination rule to use. Values other than * * @param newRule the combination rule method to use */ public void setCombinationRule(SelectedTag newRule) { if (newRule.getTags() == TAGS_RULES) { m_CombinationRule = newRule.getSelectedTag().getID(); } } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String preBuiltClassifiersTipText() { return "The pre-built serialized classifiers to include. Multiple " + "serialized classifiers can be included alongside those " + "that are built from scratch when this classifier runs. " + "Note that it does not make sense to include pre-built " + "classifiers in a cross-validation since they are static " + "and their models do not change from fold to fold."; } /** * Set the paths to pre-built serialized classifiers to load and include in * the ensemble * * @param preBuilt an array of File paths to serialized models */ public void setPreBuiltClassifiers(File[] preBuilt) { m_classifiersToLoad.clear(); if (preBuilt != null && preBuilt.length > 0) { for (File element : preBuilt) { String path = element.toString(); m_classifiersToLoad.add(path); } } } /** * Get the paths to pre-built serialized classifiers to load and include in * the ensemble * * @return an array of File paths to serialized models */ public File[] getPreBuiltClassifiers() { File[] result = new File[m_classifiersToLoad.size()]; for (int i = 0; i < m_classifiersToLoad.size(); i++) { result[i] = new File(m_classifiersToLoad.get(i)); } return result; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String doNotPrintModelsTipText() { return "Do not print the individual trees in the output"; } /** * Set whether to print the individual ensemble models in the output * * @param print true if the individual models are to be printed */ public void setDoNotPrintModels(boolean print) { m_dontPrintModels = print; } /** * Get whether to print the individual ensemble models in the output * * @return true if the individual models are to be printed */ public boolean getDoNotPrintModels() { return m_dontPrintModels; } /** * Output a representation of this classifier * * @return a string representation of the classifier */ @Override public String toString() { if (m_Classifiers == null) { return "Vote: No model built yet."; } String result = "Vote combines"; result += " the probability distributions of these base learners:\n"; for (int i = 0; i < m_Classifiers.length; i++) { result += '\t' + getClassifierSpec(i) + '\n'; } for (Classifier c : m_preBuiltClassifiers) { result += "\t" + c.getClass().getName() + Utils.joinOptions(((OptionHandler) c).getOptions()) + "\n"; } result += "using the '"; switch (m_CombinationRule) { case AVERAGE_RULE: result += "Average"; break; case PRODUCT_RULE: result += "Product"; break; case MAJORITY_VOTING_RULE: result += "Majority Voting"; break; case MIN_RULE: result += "Minimum"; break; case MAX_RULE: result += "Maximum"; break; case MEDIAN_RULE: result += "Median"; break; default: throw new IllegalStateException("Unknown combination rule '" + m_CombinationRule + "'!"); } result += "' combination rule \n"; StringBuilder resultBuilder = null; if (!m_dontPrintModels) { resultBuilder = new StringBuilder(); resultBuilder.append(result).append("\nAll the models:\n\n"); for (Classifier c : m_Classifiers) { resultBuilder.append(c).append("\n"); } for (Classifier c : m_preBuiltClassifiers) { resultBuilder.append(c).append("\n"); } } return resultBuilder == null ? result : resultBuilder.toString(); } /** * Returns the revision string. * * @return the revision */ @Override public String getRevision() { return RevisionUtils.extract("$Revision: 12424 $"); } /** * Set environment variable values to substitute in the paths of serialized * models to load * * @param env the environment variables to use */ @Override public void setEnvironment(Environment env) { m_env = env; } /** * Aggregate an object with this one * * @param toAggregate the object to aggregate * @return the result of aggregation * @throws Exception if the supplied object can't be aggregated for some * reason */ @Override public Classifier aggregate(Classifier toAggregate) throws Exception { if (m_structure == null && m_Classifiers.length == 1 && (m_Classifiers[0] instanceof weka.classifiers.rules.ZeroR)) { // remove the single untrained ZeroR setClassifiers(new Classifier[0]); } // Can't do any training data compatibility checks unfortunately addPreBuiltClassifier(toAggregate); return this; } /** * Call to complete the aggregation process. Allows implementers to do any * final processing based on how many objects were aggregated. * * @throws Exception if the aggregation can't be finalized for some reason */ @Override public void finalizeAggregation() throws Exception { // nothing to do } /** * Main method for testing this class. * * @param argv should contain the following arguments: -t training file [-T * test file] [-c class index] */ public static void main(String[] argv) { runClassifier(new Vote(), argv); } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy