All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.meta.CVParameterSelection Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    CVParameterSelection.java
 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.meta;

import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.core.Capabilities;
import weka.core.Drawable;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Summarizable;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;

import java.io.Serializable;
import java.io.StreamTokenizer;
import java.io.StringReader;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

/**
 
 * Class for performing parameter selection by cross-validation for any classifier.
*
* For more information, see:
*
* R. Kohavi (1995). Wrappers for Performance Enhancement and Oblivious Decision Graphs. Department of Computer Science, Stanford University. *

* * BibTeX: *

 * @phdthesis{Kohavi1995,
 *    address = {Department of Computer Science, Stanford University},
 *    author = {R. Kohavi},
 *    school = {Stanford University},
 *    title = {Wrappers for Performance Enhancement and Oblivious Decision Graphs},
 *    year = {1995}
 * }
 * 
*

* * Valid options are:

* *

 -X <number of folds>
 *  Number of folds used for cross validation (default 10).
* *
 -P <classifier parameter>
 *  Classifier parameter options.
 *  eg: "N 1 5 10" Sets an optimisation parameter for the
 *  classifier with name -N, with lower bound 1, upper bound
 *  5, and 10 optimisation steps. The upper bound may be the
 *  character 'A' or 'I' to substitute the number of
 *  attributes or instances in the training data,
 *  respectively. This parameter may be supplied more than
 *  once to optimise over several classifier options
 *  simultaneously.
* *
 -S <num>
 *  Random number seed.
 *  (default 1)
* *
 -D
 *  If set, classifier is run in debug mode and
 *  may output additional info to the console
* *
 -W
 *  Full name of base classifier.
 *  (default: weka.classifiers.rules.ZeroR)
* *
 
 * Options specific to classifier weka.classifiers.rules.ZeroR:
 * 
* *
 -D
 *  If set, classifier is run in debug mode and
 *  may output additional info to the console
* * * Options after -- are passed to the designated sub-classifier.

* * @author Len Trigg ([email protected]) * @version $Revision: 8180 $ */ public class CVParameterSelection extends RandomizableSingleClassifierEnhancer implements Drawable, Summarizable, TechnicalInformationHandler { /** for serialization */ static final long serialVersionUID = -6529603380876641265L; /** * A data structure to hold values associated with a single * cross-validation search parameter */ protected class CVParameter implements Serializable, RevisionHandler { /** for serialization */ static final long serialVersionUID = -4668812017709421953L; /** Char used to identify the option of interest */ private String m_ParamChar; /** Lower bound for the CV search */ private double m_Lower; /** Upper bound for the CV search */ private double m_Upper; /** Number of steps during the search */ private double m_Steps; /** The parameter value with the best performance */ private double m_ParamValue; /** True if the parameter should be added at the end of the argument list */ private boolean m_AddAtEnd; /** True if the parameter should be rounded to an integer */ private boolean m_RoundParam; /** * Constructs a CVParameter. * * @param param the parameter definition * @throws Exception if construction of CVParameter fails */ public CVParameter(String param) throws Exception { String[] parts = param.split(" "); if (parts.length < 4 || parts.length > 5) { throw new Exception("CVParameter " + param + ": four or five components expected!"); } try { Double.parseDouble(parts[0]); throw new Exception("CVParameter " + param + ": Character parameter identifier expected"); } catch (NumberFormatException n) { m_ParamChar = parts[0]; } try { m_Lower = Double.parseDouble(parts[1]); } catch (NumberFormatException n) { throw new Exception("CVParameter " + param + ": Numeric lower bound expected"); } if (parts[2].equals("A")) { m_Upper = m_Lower - 1; } else if (parts[2].equals("I")) { m_Upper = m_Lower - 2; } else { try { m_Upper = Double.parseDouble(parts[2]); if (m_Upper < m_Lower) { throw new Exception("CVParameter " + param + ": Upper bound is less than lower bound"); } } catch (NumberFormatException n) { throw new Exception("CVParameter " + param + ": Upper bound must be numeric, or 'A' or 'N'"); } } try { m_Steps = Double.parseDouble(parts[3]); } catch (NumberFormatException n) { throw new Exception("CVParameter " + param + ": Numeric number of steps expected"); } if (parts.length == 5 && parts[4].equals("R")) { m_RoundParam = true; } } /** * Returns a CVParameter as a string. * * @return the CVParameter as string */ public String toString() { String result = m_ParamChar + " " + m_Lower + " "; switch ((int)(m_Lower - m_Upper + 0.5)) { case 1: result += "A"; break; case 2: result += "I"; break; default: result += m_Upper; break; } result += " " + m_Steps; if (m_RoundParam) { result += " R"; } return result; } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 8180 $"); } } /** * The base classifier options (not including those being set * by cross-validation) */ protected String [] m_ClassifierOptions; /** The set of all classifier options as determined by cross-validation */ protected String [] m_BestClassifierOptions; /** The set of all options at initialization time. So that getOptions can return this. */ protected String [] m_InitOptions; /** The cross-validated performance of the best options */ protected double m_BestPerformance; /** The set of parameters to cross-validate over */ protected FastVector m_CVParams = new FastVector(); /** The number of attributes in the data */ protected int m_NumAttributes; /** The number of instances in a training fold */ protected int m_TrainFoldSize; /** The number of folds used in cross-validation */ protected int m_NumFolds = 10; /** * Create the options array to pass to the classifier. The parameter * values and positions are taken from m_ClassifierOptions and * m_CVParams. * * @return the options array */ protected String [] createOptions() { String [] options = new String [m_ClassifierOptions.length + 2 * m_CVParams.size()]; int start = 0, end = options.length; // Add the cross-validation parameters and their values for (int i = 0; i < m_CVParams.size(); i++) { CVParameter cvParam = (CVParameter)m_CVParams.elementAt(i); double paramValue = cvParam.m_ParamValue; if (cvParam.m_RoundParam) { // paramValue = (double)((int) (paramValue + 0.5)); paramValue = Math.rint(paramValue); } boolean isInt = ((paramValue - (int)paramValue) == 0); if (cvParam.m_AddAtEnd) { options[--end] = "" + ((cvParam.m_RoundParam || isInt) ? Utils.doubleToString(paramValue,4) : cvParam.m_ParamValue); //Utils.doubleToString(paramValue,4); options[--end] = "-" + cvParam.m_ParamChar; } else { options[start++] = "-" + cvParam.m_ParamChar; options[start++] = "" + ((cvParam.m_RoundParam || isInt) ? Utils.doubleToString(paramValue,4) : cvParam.m_ParamValue); //+ Utils.doubleToString(paramValue,4); } } // Add the static parameters System.arraycopy(m_ClassifierOptions, 0, options, start, m_ClassifierOptions.length); return options; } /** * Finds the best parameter combination. (recursive for each parameter * being optimised). * * @param depth the index of the parameter to be optimised at this level * @param trainData the data the search is based on * @param random a random number generator * @throws Exception if an error occurs */ protected void findParamsByCrossValidation(int depth, Instances trainData, Random random) throws Exception { if (depth < m_CVParams.size()) { CVParameter cvParam = (CVParameter)m_CVParams.elementAt(depth); double upper; switch ((int)(cvParam.m_Lower - cvParam.m_Upper + 0.5)) { case 1: upper = m_NumAttributes; break; case 2: upper = m_TrainFoldSize; break; default: upper = cvParam.m_Upper; break; } double increment = (upper - cvParam.m_Lower) / (cvParam.m_Steps - 1); for(cvParam.m_ParamValue = cvParam.m_Lower; cvParam.m_ParamValue <= upper; cvParam.m_ParamValue += increment) { findParamsByCrossValidation(depth + 1, trainData, random); } } else { Evaluation evaluation = new Evaluation(trainData); // Set the classifier options String [] options = createOptions(); if (m_Debug) { System.err.print("Setting options for " + m_Classifier.getClass().getName() + ":"); for (int i = 0; i < options.length; i++) { System.err.print(" " + options[i]); } System.err.println(""); } ((OptionHandler)m_Classifier).setOptions(options); for (int j = 0; j < m_NumFolds; j++) { // We want to randomize the data the same way for every // learning scheme. Instances train = trainData.trainCV(m_NumFolds, j, new Random(1)); Instances test = trainData.testCV(m_NumFolds, j); m_Classifier.buildClassifier(train); evaluation.setPriors(train); evaluation.evaluateModel(m_Classifier, test); } double error = evaluation.errorRate(); if (m_Debug) { System.err.println("Cross-validated error rate: " + Utils.doubleToString(error, 6, 4)); } if ((m_BestPerformance == -99) || (error < m_BestPerformance)) { m_BestPerformance = error; m_BestClassifierOptions = createOptions(); } } } /** * Returns a string describing this classifier * @return a description of the classifier suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Class for performing parameter selection by cross-validation " + "for any classifier.\n\n" + "For more information, see:\n\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing * detailed information about the technical background of this class, * e.g., paper reference or book this class is based on. * * @return the technical information about this class */ public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(Type.PHDTHESIS); result.setValue(Field.AUTHOR, "R. Kohavi"); result.setValue(Field.YEAR, "1995"); result.setValue(Field.TITLE, "Wrappers for Performance Enhancement and Oblivious Decision Graphs"); result.setValue(Field.SCHOOL, "Stanford University"); result.setValue(Field.ADDRESS, "Department of Computer Science, Stanford University"); return result; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(2); newVector.addElement(new Option( "\tNumber of folds used for cross validation (default 10).", "X", 1, "-X ")); newVector.addElement(new Option( "\tClassifier parameter options.\n" + "\teg: \"N 1 5 10\" Sets an optimisation parameter for the\n" + "\tclassifier with name -N, with lower bound 1, upper bound\n" + "\t5, and 10 optimisation steps. The upper bound may be the\n" + "\tcharacter 'A' or 'I' to substitute the number of\n" + "\tattributes or instances in the training data,\n" + "\trespectively. This parameter may be supplied more than\n" + "\tonce to optimise over several classifier options\n" + "\tsimultaneously.", "P", 1, "-P ")); Enumeration enu = super.listOptions(); while (enu.hasMoreElements()) { newVector.addElement(enu.nextElement()); } return newVector.elements(); } /** * Parses a given list of options.

* * Valid options are:

* *

 -X <number of folds>
   *  Number of folds used for cross validation (default 10).
* *
 -P <classifier parameter>
   *  Classifier parameter options.
   *  eg: "N 1 5 10" Sets an optimisation parameter for the
   *  classifier with name -N, with lower bound 1, upper bound
   *  5, and 10 optimisation steps. The upper bound may be the
   *  character 'A' or 'I' to substitute the number of
   *  attributes or instances in the training data,
   *  respectively. This parameter may be supplied more than
   *  once to optimise over several classifier options
   *  simultaneously.
* *
 -S <num>
   *  Random number seed.
   *  (default 1)
* *
 -D
   *  If set, classifier is run in debug mode and
   *  may output additional info to the console
* *
 -W
   *  Full name of base classifier.
   *  (default: weka.classifiers.rules.ZeroR)
* *
 
   * Options specific to classifier weka.classifiers.rules.ZeroR:
   * 
* *
 -D
   *  If set, classifier is run in debug mode and
   *  may output additional info to the console
* * * Options after -- are passed to the designated sub-classifier.

* * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String foldsString = Utils.getOption('X', options); if (foldsString.length() != 0) { setNumFolds(Integer.parseInt(foldsString)); } else { setNumFolds(10); } String cvParam; m_CVParams = new FastVector(); do { cvParam = Utils.getOption('P', options); if (cvParam.length() != 0) { addCVParameter(cvParam); } } while (cvParam.length() != 0); super.setOptions(options); } /** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ public String [] getOptions() { String[] superOptions; if (m_InitOptions != null) { try { m_Classifier.setOptions((String[])m_InitOptions.clone()); superOptions = super.getOptions(); m_Classifier.setOptions((String[])m_BestClassifierOptions.clone()); } catch (Exception e) { throw new RuntimeException("CVParameterSelection: could not set options " + "in getOptions()."); } } else { superOptions = super.getOptions(); } String [] options = new String [superOptions.length + m_CVParams.size() * 2 + 2]; int current = 0; for (int i = 0; i < m_CVParams.size(); i++) { options[current++] = "-P"; options[current++] = "" + getCVParameter(i); } options[current++] = "-X"; options[current++] = "" + getNumFolds(); System.arraycopy(superOptions, 0, options, current, superOptions.length); return options; } /** * Returns (a copy of) the best options found for the classifier. * * @return the best options */ public String[] getBestClassifierOptions() { return (String[]) m_BestClassifierOptions.clone(); } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.setMinimumNumberInstances(m_NumFolds); return result; } /** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class Instances trainData = new Instances(instances); trainData.deleteWithMissingClass(); if (!(m_Classifier instanceof OptionHandler)) { throw new IllegalArgumentException("Base classifier should be OptionHandler."); } m_InitOptions = ((OptionHandler)m_Classifier).getOptions(); m_BestPerformance = -99; m_NumAttributes = trainData.numAttributes(); Random random = new Random(m_Seed); trainData.randomize(random); m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0).numInstances(); // Check whether there are any parameters to optimize if (m_CVParams.size() == 0) { m_Classifier.buildClassifier(trainData); m_BestClassifierOptions = m_InitOptions; return; } if (trainData.classAttribute().isNominal()) { trainData.stratify(m_NumFolds); } m_BestClassifierOptions = null; // Set up m_ClassifierOptions -- take getOptions() and remove // those being optimised. m_ClassifierOptions = ((OptionHandler)m_Classifier).getOptions(); for (int i = 0; i < m_CVParams.size(); i++) { Utils.getOption(((CVParameter)m_CVParams.elementAt(i)).m_ParamChar, m_ClassifierOptions); } findParamsByCrossValidation(0, trainData, random); String [] options = (String [])m_BestClassifierOptions.clone(); ((OptionHandler)m_Classifier).setOptions(options); m_Classifier.buildClassifier(trainData); } /** * Predicts the class distribution for the given test instance. * * @param instance the instance to be classified * @return the predicted class value * @throws Exception if an error occurred during the prediction */ public double[] distributionForInstance(Instance instance) throws Exception { return m_Classifier.distributionForInstance(instance); } /** * Adds a scheme parameter to the list of parameters to be set * by cross-validation * * @param cvParam the string representation of a scheme parameter. The * format is:
* param_char lower_bound upper_bound number_of_steps
* eg to search a parameter -P from 1 to 10 by increments of 1:
* P 1 10 11
* @throws Exception if the parameter specifier is of the wrong format */ public void addCVParameter(String cvParam) throws Exception { CVParameter newCV = new CVParameter(cvParam); m_CVParams.addElement(newCV); } /** * Gets the scheme paramter with the given index. * * @param index the index for the parameter * @return the scheme parameter */ public String getCVParameter(int index) { if (m_CVParams.size() <= index) { return ""; } return ((CVParameter)m_CVParams.elementAt(index)).toString(); } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String CVParametersTipText() { return "Sets the scheme parameters which are to be set "+ "by cross-validation.\n"+ "The format for each string should be:\n"+ "param_char lower_bound upper_bound number_of_steps\n"+ "eg to search a parameter -P from 1 to 10 by increments of 1:\n"+ " \"P 1 10 10\" "; } /** * Get method for CVParameters. * * @return the CVParameters */ public Object[] getCVParameters() { Object[] CVParams = m_CVParams.toArray(); String params[] = new String[CVParams.length]; for(int i=0; i





© 2015 - 2025 Weber Informatics LLC | Privacy Policy