weka.classifiers.meta.CVParameterSelection Maven / Gradle / Ivy
Show all versions of weka-stable Show documentation
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* CVParameterSelection.java
* Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.meta;
import java.io.Serializable;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.core.Capabilities;
import weka.core.Drawable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Summarizable;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
/**
* Class for performing parameter selection by cross-validation for any classifier.
*
* For more information, see:
*
* R. Kohavi (1995). Wrappers for Performance Enhancement and Oblivious Decision Graphs. Department of Computer Science, Stanford University.
*
*
* BibTeX:
*
* @phdthesis{Kohavi1995,
* address = {Department of Computer Science, Stanford University},
* author = {R. Kohavi},
* school = {Stanford University},
* title = {Wrappers for Performance Enhancement and Oblivious Decision Graphs},
* year = {1995}
* }
*
*
*
* Valid options are:
*
* -X <number of folds>
* Number of folds used for cross validation (default 10).
*
* -P <classifier parameter>
* Classifier parameter options.
* eg: "N 1 5 10" Sets an optimisation parameter for the
* classifier with name -N, with lower bound 1, upper bound
* 5, and 10 optimisation steps. The upper bound may be the
* character 'A' or 'I' to substitute the number of
* attributes or instances in the training data,
* respectively. This parameter may be supplied more than
* once to optimise over several classifier options
* simultaneously.
*
* -S <num>
* Random number seed.
* (default 1)
*
* -D
* If set, classifier is run in debug mode and
* may output additional info to the console
*
* -W
* Full name of base classifier.
* (default: weka.classifiers.rules.ZeroR)
*
*
* Options specific to classifier weka.classifiers.rules.ZeroR:
*
*
* -D
* If set, classifier is run in debug mode and
* may output additional info to the console
*
*
* Options after -- are passed to the designated sub-classifier.
*
* @author Len Trigg ([email protected])
* @version $Revision: 13370 $
*/
public class CVParameterSelection
extends RandomizableSingleClassifierEnhancer
implements Drawable, Summarizable, TechnicalInformationHandler {
/** for serialization */
static final long serialVersionUID = -6529603380876641265L;
/**
* A data structure to hold values associated with a single
* cross-validation search parameter
*/
protected class CVParameter
implements Serializable, RevisionHandler {
/** for serialization */
static final long serialVersionUID = -4668812017709421953L;
/** Char used to identify the option of interest */
private String m_ParamChar;
/** Lower bound for the CV search */
private double m_Lower;
/** Upper bound for the CV search */
private double m_Upper;
/** Number of steps during the search */
private double m_Steps;
/** The parameter value with the best performance */
private double m_ParamValue;
/** True if the parameter should be added at the end of the argument list */
private boolean m_AddAtEnd;
/** True if the parameter should be rounded to an integer */
private boolean m_RoundParam;
/**
* Constructs a CVParameter.
*
* @param param the parameter definition
* @throws Exception if construction of CVParameter fails
*/
public CVParameter(String param) throws Exception {
String[] parts = param.split(" ");
if (parts.length < 4 || parts.length > 5) {
throw new Exception("CVParameter " + param
+ ": four or five components expected!");
}
try {
Double.parseDouble(parts[0]);
throw new Exception("CVParameter " + param
+ ": Character parameter identifier expected");
} catch (NumberFormatException n) {
m_ParamChar = parts[0];
}
try {
m_Lower = Double.parseDouble(parts[1]);
} catch (NumberFormatException n) {
throw new Exception("CVParameter " + param
+ ": Numeric lower bound expected");
}
if (parts[2].equals("A")) {
m_Upper = m_Lower - 1;
} else if (parts[2].equals("I")) {
m_Upper = m_Lower - 2;
} else {
try {
m_Upper = Double.parseDouble(parts[2]);
if (m_Upper < m_Lower) {
throw new Exception("CVParameter " + param
+ ": Upper bound is less than lower bound");
}
} catch (NumberFormatException n) {
throw new Exception("CVParameter " + param
+ ": Upper bound must be numeric, or 'A' or 'N'");
}
}
try {
m_Steps = Double.parseDouble(parts[3]);
} catch (NumberFormatException n) {
throw new Exception("CVParameter " + param
+ ": Numeric number of steps expected");
}
if (parts.length == 5 && parts[4].equals("R")) {
m_RoundParam = true;
}
}
/**
* Returns a CVParameter as a string.
*
* @return the CVParameter as string
*/
public String toString() {
String result = m_ParamChar + " " + m_Lower + " ";
switch ((int)(m_Lower - m_Upper + 0.5)) {
case 1:
result += "A";
break;
case 2:
result += "I";
break;
default:
result += m_Upper;
break;
}
result += " " + m_Steps;
if (m_RoundParam) {
result += " R";
}
return result;
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 13370 $");
}
}
/**
* The base classifier options (not including those being set
* by cross-validation)
*/
protected String [] m_ClassifierOptions;
/** The set of all classifier options as determined by cross-validation */
protected String [] m_BestClassifierOptions;
/** The set of all options at initialization time. So that getOptions
can return this. */
protected String [] m_InitOptions;
/** The cross-validated performance of the best options */
protected double m_BestPerformance;
/** The set of parameters to cross-validate over */
protected Vector m_CVParams = new Vector();
/** The number of attributes in the data */
protected int m_NumAttributes;
/** The number of instances in a training fold */
protected int m_TrainFoldSize;
/** The number of folds used in cross-validation */
protected int m_NumFolds = 10;
/**
* Create the options array to pass to the classifier. The parameter
* values and positions are taken from m_ClassifierOptions and
* m_CVParams.
*
* @return the options array
*/
protected String [] createOptions() {
String [] options = new String [m_ClassifierOptions.length
+ 2 * m_CVParams.size()];
int start = 0, end = options.length;
// Add the cross-validation parameters and their values
for (int i = 0; i < m_CVParams.size(); i++) {
CVParameter cvParam = (CVParameter)m_CVParams.elementAt(i);
double paramValue = cvParam.m_ParamValue;
if (cvParam.m_RoundParam) {
// paramValue = (double)((int) (paramValue + 0.5));
paramValue = Math.rint(paramValue);
}
boolean isInt = ((paramValue - (int)paramValue) == 0);
if (cvParam.m_AddAtEnd) {
options[--end] = "" + ((cvParam.m_RoundParam || isInt) ?
Utils.doubleToString(paramValue,4) : cvParam.m_ParamValue);
//Utils.doubleToString(paramValue,4);
options[--end] = "-" + cvParam.m_ParamChar;
} else {
options[start++] = "-" + cvParam.m_ParamChar;
options[start++] = "" + ((cvParam.m_RoundParam || isInt) ?
Utils.doubleToString(paramValue,4) : cvParam.m_ParamValue);
//+ Utils.doubleToString(paramValue,4);
}
}
// Add the static parameters
System.arraycopy(m_ClassifierOptions, 0,
options, start,
m_ClassifierOptions.length);
return options;
}
/**
* Finds the best parameter combination. (recursive for each parameter
* being optimised).
*
* @param depth the index of the parameter to be optimised at this level
* @param trainData the data the search is based on
* @param random a random number generator
* @throws Exception if an error occurs
*/
protected void findParamsByCrossValidation(int depth, Instances trainData,
Random random)
throws Exception {
if (depth < m_CVParams.size()) {
CVParameter cvParam = (CVParameter) m_CVParams.elementAt(depth);
double upper;
switch ((int) (cvParam.m_Lower - cvParam.m_Upper + 0.5)) {
case 1:
upper = m_NumAttributes;
break;
case 2:
upper = m_TrainFoldSize;
break;
default:
upper = cvParam.m_Upper;
break;
}
double increment = (upper - cvParam.m_Lower) / (cvParam.m_Steps - 1);
for (cvParam.m_ParamValue = cvParam.m_Lower;
cvParam.m_ParamValue <= upper;
cvParam.m_ParamValue += increment) {
findParamsByCrossValidation(depth + 1, trainData, random);
}
} else {
Evaluation evaluation = new Evaluation(trainData);
// Work with a copy of the base classifier in case the base classifier does not initialize itself properly
Classifier copiedClassifier = AbstractClassifier.makeCopy(m_Classifier);
// Set the classifier options
String[] options = createOptions();
if (m_Debug) {
System.err.print("Setting options for "
+ copiedClassifier.getClass().getName() + ":");
for (int i = 0; i < options.length; i++) {
System.err.print(" " + options[i]);
}
System.err.println("");
}
((OptionHandler) copiedClassifier).setOptions(options);
for (int j = 0; j < m_NumFolds; j++) {
// We want to randomize the data the same way for every
// learning scheme.
Instances train = trainData.trainCV(m_NumFolds, j, new Random(1));
Instances test = trainData.testCV(m_NumFolds, j);
copiedClassifier.buildClassifier(train);
evaluation.setPriors(train);
evaluation.evaluateModel(copiedClassifier, test);
}
double error = evaluation.errorRate();
if (m_Debug) {
System.err.println("Cross-validated error rate: "
+ Utils.doubleToString(error, 6, 4));
}
if ((m_BestPerformance == -99) || (error < m_BestPerformance)) {
m_BestPerformance = error;
m_BestClassifierOptions = createOptions();
}
}
}
/**
* Returns a string describing this classifier
* @return a description of the classifier suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Class for performing parameter selection by cross-validation "
+ "for any classifier.\n\n"
+ "For more information, see:\n\n"
+ getTechnicalInformation().toString();
}
/**
* Returns an instance of a TechnicalInformation object, containing
* detailed information about the technical background of this class,
* e.g., paper reference or book this class is based on.
*
* @return the technical information about this class
*/
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation result;
result = new TechnicalInformation(Type.PHDTHESIS);
result.setValue(Field.AUTHOR, "R. Kohavi");
result.setValue(Field.YEAR, "1995");
result.setValue(Field.TITLE, "Wrappers for Performance Enhancement and Oblivious Decision Graphs");
result.setValue(Field.SCHOOL, "Stanford University");
result.setValue(Field.ADDRESS, "Department of Computer Science, Stanford University");
return result;
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration