weka.classifiers.meta.ThresholdSelector Maven / Gradle / Ivy
Show all versions of weka-stable Show documentation
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* ThresholdSelector.java
* Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.meta;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.classifiers.evaluation.EvaluationUtils;
import weka.classifiers.evaluation.ThresholdCurve;
import weka.core.Attribute;
import weka.core.AttributeStats;
import weka.core.Capabilities;
import weka.core.Drawable;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.core.Capabilities.Capability;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
/**
* A metaclassifier that selecting a mid-point threshold on the probability output by a Classifier. The midpoint threshold is set so that a given performance measure is optimized. Currently this is the F-measure. Performance is measured either on the training data, a hold-out set or using cross-validation. In addition, the probabilities returned by the base learner can have their range expanded so that the output probabilities will reside between 0 and 1 (this is useful if the scheme normally produces probabilities in a very narrow range).
*
*
* Valid options are:
*
* -C <integer>
* The class for which threshold is determined. Valid values are:
* 1, 2 (for first and second classes, respectively), 3 (for whichever
* class is least frequent), and 4 (for whichever class value is most
* frequent), and 5 (for the first class named any of "yes","pos(itive)"
* "1", or method 3 if no matches). (default 5).
*
* -X <number of folds>
* Number of folds used for cross validation. If just a
* hold-out set is used, this determines the size of the hold-out set
* (default 3).
*
* -R <integer>
* Sets whether confidence range correction is applied. This
* can be used to ensure the confidences range from 0 to 1.
* Use 0 for no range correction, 1 for correction based on
* the min/max values seen during threshold selection
* (default 0).
*
* -E <integer>
* Sets the evaluation mode. Use 0 for
* evaluation using cross-validation,
* 1 for evaluation using hold-out set,
* and 2 for evaluation on the
* training data (default 1).
*
* -M [FMEASURE|ACCURACY|TRUE_POS|TRUE_NEG|TP_RATE|PRECISION|RECALL]
* Measure used for evaluation (default is FMEASURE).
*
*
* -manual <real>
* Set a manual threshold to use. This option overrides
* automatic selection and options pertaining to
* automatic selection will be ignored.
* (default -1, i.e. do not use a manual threshold).
*
* -S <num>
* Random number seed.
* (default 1)
*
* -D
* If set, classifier is run in debug mode and
* may output additional info to the console
*
* -W
* Full name of base classifier.
* (default: weka.classifiers.functions.Logistic)
*
*
* Options specific to classifier weka.classifiers.functions.Logistic:
*
*
* -D
* Turn on debugging output.
*
* -R <ridge>
* Set the ridge in the log-likelihood.
*
* -M <number>
* Set the maximum number of iterations (default -1, until convergence).
*
*
* Options after -- are passed to the designated sub-classifier.
*
* @author Eibe Frank ([email protected])
* @version $Revision: 1.43 $
*/
public class ThresholdSelector
extends RandomizableSingleClassifierEnhancer
implements OptionHandler, Drawable {
/** for serialization */
static final long serialVersionUID = -1795038053239867444L;
/** no range correction */
public static final int RANGE_NONE = 0;
/** Correct based on min/max observed */
public static final int RANGE_BOUNDS = 1;
/** Type of correction applied to threshold range */
public static final Tag [] TAGS_RANGE = {
new Tag(RANGE_NONE, "No range correction"),
new Tag(RANGE_BOUNDS, "Correct based on min/max observed")
};
/** entire training set */
public static final int EVAL_TRAINING_SET = 2;
/** single tuned fold */
public static final int EVAL_TUNED_SPLIT = 1;
/** n-fold cross-validation */
public static final int EVAL_CROSS_VALIDATION = 0;
/** The evaluation modes */
public static final Tag [] TAGS_EVAL = {
new Tag(EVAL_TRAINING_SET, "Entire training set"),
new Tag(EVAL_TUNED_SPLIT, "Single tuned fold"),
new Tag(EVAL_CROSS_VALIDATION, "N-Fold cross validation")
};
/** first class value */
public static final int OPTIMIZE_0 = 0;
/** second class value */
public static final int OPTIMIZE_1 = 1;
/** least frequent class value */
public static final int OPTIMIZE_LFREQ = 2;
/** most frequent class value */
public static final int OPTIMIZE_MFREQ = 3;
/** class value name, either 'yes' or 'pos(itive)' */
public static final int OPTIMIZE_POS_NAME = 4;
/** How to determine which class value to optimize for */
public static final Tag [] TAGS_OPTIMIZE = {
new Tag(OPTIMIZE_0, "First class value"),
new Tag(OPTIMIZE_1, "Second class value"),
new Tag(OPTIMIZE_LFREQ, "Least frequent class value"),
new Tag(OPTIMIZE_MFREQ, "Most frequent class value"),
new Tag(OPTIMIZE_POS_NAME, "Class value named: \"yes\", \"pos(itive)\",\"1\"")
};
/** F-measure */
public static final int FMEASURE = 1;
/** accuracy */
public static final int ACCURACY = 2;
/** true-positive */
public static final int TRUE_POS = 3;
/** true-negative */
public static final int TRUE_NEG = 4;
/** true-positive rate */
public static final int TP_RATE = 5;
/** precision */
public static final int PRECISION = 6;
/** recall */
public static final int RECALL = 7;
/** the measure to use */
public static final Tag[] TAGS_MEASURE = {
new Tag(FMEASURE, "FMEASURE"),
new Tag(ACCURACY, "ACCURACY"),
new Tag(TRUE_POS, "TRUE_POS"),
new Tag(TRUE_NEG, "TRUE_NEG"),
new Tag(TP_RATE, "TP_RATE"),
new Tag(PRECISION, "PRECISION"),
new Tag(RECALL, "RECALL")
};
/** The upper threshold used as the basis of correction */
protected double m_HighThreshold = 1;
/** The lower threshold used as the basis of correction */
protected double m_LowThreshold = 0;
/** The threshold that lead to the best performance */
protected double m_BestThreshold = -Double.MAX_VALUE;
/** The best value that has been observed */
protected double m_BestValue = - Double.MAX_VALUE;
/** The number of folds used in cross-validation */
protected int m_NumXValFolds = 3;
/** Designated class value, determined during building */
protected int m_DesignatedClass = 0;
/** Method to determine which class to optimize for */
protected int m_ClassMode = OPTIMIZE_POS_NAME;
/** The evaluation mode */
protected int m_EvalMode = EVAL_TUNED_SPLIT;
/** The range correction mode */
protected int m_RangeMode = RANGE_NONE;
/** evaluation measure used for determining threshold **/
int m_nMeasure = FMEASURE;
/** True if a manually set threshold is being used */
protected boolean m_manualThreshold = false;
/** -1 = not used by default */
protected double m_manualThresholdValue = -1;
/** The minimum value for the criterion. If threshold adjustment
yields less than that, the default threshold of 0.5 is used. */
protected static final double MIN_VALUE = 0.05;
/**
* Constructor.
*/
public ThresholdSelector() {
m_Classifier = new weka.classifiers.functions.Logistic();
}
/**
* String describing default classifier.
*
* @return the default classifier classname
*/
protected String defaultClassifierString() {
return "weka.classifiers.functions.Logistic";
}
/**
* Collects the classifier predictions using the specified evaluation method.
*
* @param instances the set of Instances
to generate
* predictions for.
* @param mode the evaluation mode.
* @param numFolds the number of folds to use if not evaluating on the
* full training set.
* @return a FastVector
containing the predictions.
* @throws Exception if an error occurs generating the predictions.
*/
protected FastVector getPredictions(Instances instances, int mode, int numFolds)
throws Exception {
EvaluationUtils eu = new EvaluationUtils();
eu.setSeed(m_Seed);
switch (mode) {
case EVAL_TUNED_SPLIT:
Instances trainData = null, evalData = null;
Instances data = new Instances(instances);
Random random = new Random(m_Seed);
data.randomize(random);
data.stratify(numFolds);
// Make sure that both subsets contain at least one positive instance
for (int subsetIndex = 0; subsetIndex < numFolds; subsetIndex++) {
trainData = data.trainCV(numFolds, subsetIndex, random);
evalData = data.testCV(numFolds, subsetIndex);
if (checkForInstance(trainData) && checkForInstance(evalData)) {
break;
}
}
return eu.getTrainTestPredictions(m_Classifier, trainData, evalData);
case EVAL_TRAINING_SET:
return eu.getTrainTestPredictions(m_Classifier, instances, instances);
case EVAL_CROSS_VALIDATION:
return eu.getCVPredictions(m_Classifier, instances, numFolds);
default:
throw new RuntimeException("Unrecognized evaluation mode");
}
}
/**
* Tooltip for this property.
*
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String measureTipText() {
return "Sets the measure for determining the threshold.";
}
/**
* set measure used for determining threshold
*
* @param newMeasure Tag representing measure to be used
*/
public void setMeasure(SelectedTag newMeasure) {
if (newMeasure.getTags() == TAGS_MEASURE) {
m_nMeasure = newMeasure.getSelectedTag().getID();
}
}
/**
* get measure used for determining threshold
*
* @return Tag representing measure used
*/
public SelectedTag getMeasure() {
return new SelectedTag(m_nMeasure, TAGS_MEASURE);
}
/**
* Finds the best threshold, this implementation searches for the
* highest FMeasure. If no FMeasure higher than MIN_VALUE is found,
* the default threshold of 0.5 is used.
*
* @param predictions a FastVector
containing the predictions.
*/
protected void findThreshold(FastVector predictions) {
Instances curve = (new ThresholdCurve()).getCurve(predictions, m_DesignatedClass);
double low = 1.0;
double high = 0.0;
//System.err.println(curve);
if (curve.numInstances() > 0) {
Instance maxInst = curve.instance(0);
double maxValue = 0;
int index1 = 0;
int index2 = 0;
switch (m_nMeasure) {
case FMEASURE:
index1 = curve.attribute(ThresholdCurve.FMEASURE_NAME).index();
maxValue = maxInst.value(index1);
break;
case TRUE_POS:
index1 = curve.attribute(ThresholdCurve.TRUE_POS_NAME).index();
maxValue = maxInst.value(index1);
break;
case TRUE_NEG:
index1 = curve.attribute(ThresholdCurve.TRUE_NEG_NAME).index();
maxValue = maxInst.value(index1);
break;
case TP_RATE:
index1 = curve.attribute(ThresholdCurve.TP_RATE_NAME).index();
maxValue = maxInst.value(index1);
break;
case PRECISION:
index1 = curve.attribute(ThresholdCurve.PRECISION_NAME).index();
maxValue = maxInst.value(index1);
break;
case RECALL:
index1 = curve.attribute(ThresholdCurve.RECALL_NAME).index();
maxValue = maxInst.value(index1);
break;
case ACCURACY:
index1 = curve.attribute(ThresholdCurve.TRUE_POS_NAME).index();
index2 = curve.attribute(ThresholdCurve.TRUE_NEG_NAME).index();
maxValue = maxInst.value(index1) + maxInst.value(index2);
break;
}
int indexThreshold = curve.attribute(ThresholdCurve.THRESHOLD_NAME).index();
for (int i = 1; i < curve.numInstances(); i++) {
Instance current = curve.instance(i);
double currentValue = 0;
if (m_nMeasure == ACCURACY) {
currentValue= current.value(index1) + current.value(index2);
} else {
currentValue= current.value(index1);
}
if (currentValue> maxValue) {
maxInst = current;
maxValue = currentValue;
}
if (m_RangeMode == RANGE_BOUNDS) {
double thresh = current.value(indexThreshold);
if (thresh < low) {
low = thresh;
}
if (thresh > high) {
high = thresh;
}
}
}
if (maxValue > MIN_VALUE) {
m_BestThreshold = maxInst.value(indexThreshold);
m_BestValue = maxValue;
//System.err.println("maxFM: " + maxFM);
}
if (m_RangeMode == RANGE_BOUNDS) {
m_LowThreshold = low;
m_HighThreshold = high;
//System.err.println("Threshold range: " + low + " - " + high);
}
}
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(5);
newVector.addElement(new Option(
"\tThe class for which threshold is determined. Valid values are:\n" +
"\t1, 2 (for first and second classes, respectively), 3 (for whichever\n" +
"\tclass is least frequent), and 4 (for whichever class value is most\n" +
"\tfrequent), and 5 (for the first class named any of \"yes\",\"pos(itive)\"\n" +
"\t\"1\", or method 3 if no matches). (default 5).",
"C", 1, "-C "));
newVector.addElement(new Option(
"\tNumber of folds used for cross validation. If just a\n" +
"\thold-out set is used, this determines the size of the hold-out set\n" +
"\t(default 3).",
"X", 1, "-X "));
newVector.addElement(new Option(
"\tSets whether confidence range correction is applied. This\n" +
"\tcan be used to ensure the confidences range from 0 to 1.\n" +
"\tUse 0 for no range correction, 1 for correction based on\n" +
"\tthe min/max values seen during threshold selection\n"+
"\t(default 0).",
"R", 1, "-R "));
newVector.addElement(new Option(
"\tSets the evaluation mode. Use 0 for\n" +
"\tevaluation using cross-validation,\n" +
"\t1 for evaluation using hold-out set,\n" +
"\tand 2 for evaluation on the\n" +
"\ttraining data (default 1).",
"E", 1, "-E "));
newVector.addElement(new Option(
"\tMeasure used for evaluation (default is FMEASURE).\n",
"M", 1, "-M [FMEASURE|ACCURACY|TRUE_POS|TRUE_NEG|TP_RATE|PRECISION|RECALL]"));
newVector.addElement(new Option(
"\tSet a manual threshold to use. This option overrides\n"
+ "\tautomatic selection and options pertaining to\n"
+ "\tautomatic selection will be ignored.\n"
+ "\t(default -1, i.e. do not use a manual threshold).",
"manual", 1, "-manual "));
Enumeration enu = super.listOptions();
while (enu.hasMoreElements()) {
newVector.addElement(enu.nextElement());
}
return newVector.elements();
}
/**
* Parses a given list of options.
*
* Valid options are:
*
* -C <integer>
* The class for which threshold is determined. Valid values are:
* 1, 2 (for first and second classes, respectively), 3 (for whichever
* class is least frequent), and 4 (for whichever class value is most
* frequent), and 5 (for the first class named any of "yes","pos(itive)"
* "1", or method 3 if no matches). (default 5).
*
* -X <number of folds>
* Number of folds used for cross validation. If just a
* hold-out set is used, this determines the size of the hold-out set
* (default 3).
*
* -R <integer>
* Sets whether confidence range correction is applied. This
* can be used to ensure the confidences range from 0 to 1.
* Use 0 for no range correction, 1 for correction based on
* the min/max values seen during threshold selection
* (default 0).
*
* -E <integer>
* Sets the evaluation mode. Use 0 for
* evaluation using cross-validation,
* 1 for evaluation using hold-out set,
* and 2 for evaluation on the
* training data (default 1).
*
* -M [FMEASURE|ACCURACY|TRUE_POS|TRUE_NEG|TP_RATE|PRECISION|RECALL]
* Measure used for evaluation (default is FMEASURE).
*
*
* -manual <real>
* Set a manual threshold to use. This option overrides
* automatic selection and options pertaining to
* automatic selection will be ignored.
* (default -1, i.e. do not use a manual threshold).
*
* -S <num>
* Random number seed.
* (default 1)
*
* -D
* If set, classifier is run in debug mode and
* may output additional info to the console
*
* -W
* Full name of base classifier.
* (default: weka.classifiers.functions.Logistic)
*
*
* Options specific to classifier weka.classifiers.functions.Logistic:
*
*
* -D
* Turn on debugging output.
*
* -R <ridge>
* Set the ridge in the log-likelihood.
*
* -M <number>
* Set the maximum number of iterations (default -1, until convergence).
*
*
* Options after -- are passed to the designated sub-classifier.
*
* @param options the list of options as an array of strings
* @throws Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String manualS = Utils.getOption("manual", options);
if (manualS.length() > 0) {
double val = Double.parseDouble(manualS);
if (val >= 0.0) {
setManualThresholdValue(val);
}
}
String classString = Utils.getOption('C', options);
if (classString.length() != 0) {
setDesignatedClass(new SelectedTag(Integer.parseInt(classString) - 1,
TAGS_OPTIMIZE));
} else {
setDesignatedClass(new SelectedTag(OPTIMIZE_POS_NAME, TAGS_OPTIMIZE));
}
String modeString = Utils.getOption('E', options);
if (modeString.length() != 0) {
setEvaluationMode(new SelectedTag(Integer.parseInt(modeString),
TAGS_EVAL));
} else {
setEvaluationMode(new SelectedTag(EVAL_TUNED_SPLIT, TAGS_EVAL));
}
String rangeString = Utils.getOption('R', options);
if (rangeString.length() != 0) {
setRangeCorrection(new SelectedTag(Integer.parseInt(rangeString),
TAGS_RANGE));
} else {
setRangeCorrection(new SelectedTag(RANGE_NONE, TAGS_RANGE));
}
String measureString = Utils.getOption('M', options);
if (measureString.length() != 0) {
setMeasure(new SelectedTag(measureString, TAGS_MEASURE));
} else {
setMeasure(new SelectedTag(FMEASURE, TAGS_MEASURE));
}
String foldsString = Utils.getOption('X', options);
if (foldsString.length() != 0) {
setNumXValFolds(Integer.parseInt(foldsString));
} else {
setNumXValFolds(3);
}
super.setOptions(options);
}
/**
* Gets the current settings of the Classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String [] getOptions() {
String [] superOptions = super.getOptions();
String [] options = new String [superOptions.length + 12];
int current = 0;
if (m_manualThreshold) {
options[current++] = "-manual"; options[current++] = "" + getManualThresholdValue();
}
options[current++] = "-C"; options[current++] = "" + (m_ClassMode + 1);
options[current++] = "-X"; options[current++] = "" + getNumXValFolds();
options[current++] = "-E"; options[current++] = "" + m_EvalMode;
options[current++] = "-R"; options[current++] = "" + m_RangeMode;
options[current++] = "-M"; options[current++] = "" + getMeasure().getSelectedTag().getReadable();
System.arraycopy(superOptions, 0, options, current,
superOptions.length);
current += superOptions.length;
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
// class
result.disableAllClasses();
result.disableAllClassDependencies();
result.enable(Capability.BINARY_CLASS);
return result;
}
/**
* Generates the classifier.
*
* @param instances set of instances serving as training data
* @throws Exception if the classifier has not been generated successfully
*/
public void buildClassifier(Instances instances)
throws Exception {
// can classifier handle the data?
getCapabilities().testWithFail(instances);
// remove instances with missing class
instances = new Instances(instances);
instances.deleteWithMissingClass();
AttributeStats stats = instances.attributeStats(instances.classIndex());
if (m_manualThreshold) {
m_BestThreshold = m_manualThresholdValue;
} else {
m_BestThreshold = 0.5;
}
m_BestValue = MIN_VALUE;
m_HighThreshold = 1;
m_LowThreshold = 0;
// If data contains only one instance of positive data
// optimize on training data
if (stats.distinctCount != 2) {
System.err.println("Couldn't find examples of both classes. No adjustment.");
m_Classifier.buildClassifier(instances);
} else {
// Determine which class value to look for
switch (m_ClassMode) {
case OPTIMIZE_0:
m_DesignatedClass = 0;
break;
case OPTIMIZE_1:
m_DesignatedClass = 1;
break;
case OPTIMIZE_POS_NAME:
Attribute cAtt = instances.classAttribute();
boolean found = false;
for (int i = 0; i < cAtt.numValues() && !found; i++) {
String name = cAtt.value(i).toLowerCase();
if (name.startsWith("yes") || name.equals("1") ||
name.startsWith("pos")) {
found = true;
m_DesignatedClass = i;
}
}
if (found) {
break;
}
// No named class found, so fall through to default of least frequent
case OPTIMIZE_LFREQ:
m_DesignatedClass = (stats.nominalCounts[0] > stats.nominalCounts[1]) ? 1 : 0;
break;
case OPTIMIZE_MFREQ:
m_DesignatedClass = (stats.nominalCounts[0] > stats.nominalCounts[1]) ? 0 : 1;
break;
default:
throw new Exception("Unrecognized class value selection mode");
}
/*
System.err.println("ThresholdSelector: Using mode="
+ TAGS_OPTIMIZE[m_ClassMode].getReadable());
System.err.println("ThresholdSelector: Optimizing using class "
+ m_DesignatedClass + "/"
+ instances.classAttribute().value(m_DesignatedClass));
*/
if (m_manualThreshold) {
m_Classifier.buildClassifier(instances);
return;
}
if (stats.nominalCounts[m_DesignatedClass] == 1) {
System.err.println("Only 1 positive found: optimizing on training data");
findThreshold(getPredictions(instances, EVAL_TRAINING_SET, 0));
} else {
int numFolds = Math.min(m_NumXValFolds, stats.nominalCounts[m_DesignatedClass]);
//System.err.println("Number of folds for threshold selector: " + numFolds);
findThreshold(getPredictions(instances, m_EvalMode, numFolds));
if (m_EvalMode != EVAL_TRAINING_SET) {
m_Classifier.buildClassifier(instances);
}
}
}
}
/**
* Checks whether instance of designated class is in subset.
*
* @param data the data to check for instance
* @return true if the instance is in the subset
* @throws Exception if checking fails
*/
private boolean checkForInstance(Instances data) throws Exception {
for (int i = 0; i < data.numInstances(); i++) {
if (((int)data.instance(i).classValue()) == m_DesignatedClass) {
return true;
}
}
return false;
}
/**
* Calculates the class membership probabilities for the given test instance.
*
* @param instance the instance to be classified
* @return predicted class probability distribution
* @throws Exception if instance could not be classified
* successfully
*/
public double [] distributionForInstance(Instance instance)
throws Exception {
double [] pred = m_Classifier.distributionForInstance(instance);
double prob = pred[m_DesignatedClass];
// Warp probability
if (prob > m_BestThreshold) {
prob = 0.5 + (prob - m_BestThreshold) /
((m_HighThreshold - m_BestThreshold) * 2);
} else {
prob = (prob - m_LowThreshold) /
((m_BestThreshold - m_LowThreshold) * 2);
}
if (prob < 0) {
prob = 0.0;
} else if (prob > 1) {
prob = 1.0;
}
// Alter the distribution
pred[m_DesignatedClass] = prob;
if (pred.length == 2) { // Handle case when there's only one class
pred[(m_DesignatedClass + 1) % 2] = 1.0 - prob;
}
return pred;
}
/**
* @return a description of the classifier suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "A metaclassifier that selecting a mid-point threshold on the "
+ "probability output by a Classifier. The midpoint "
+ "threshold is set so that a given performance measure is optimized. "
+ "Currently this is the F-measure. Performance is measured either on "
+ "the training data, a hold-out set or using cross-validation. In "
+ "addition, the probabilities returned by the base learner can "
+ "have their range expanded so that the output probabilities will "
+ "reside between 0 and 1 (this is useful if the scheme normally "
+ "produces probabilities in a very narrow range).";
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String designatedClassTipText() {
return "Sets the class value for which the optimization is performed. "
+ "The options are: pick the first class value; pick the second "
+ "class value; pick whichever class is least frequent; pick whichever "
+ "class value is most frequent; pick the first class named any of "
+ "\"yes\",\"pos(itive)\", \"1\", or the least frequent if no matches).";
}
/**
* Gets the method to determine which class value to optimize. Will
* be one of OPTIMIZE_0, OPTIMIZE_1, OPTIMIZE_LFREQ, OPTIMIZE_MFREQ,
* OPTIMIZE_POS_NAME.
*
* @return the class selection mode.
*/
public SelectedTag getDesignatedClass() {
return new SelectedTag(m_ClassMode, TAGS_OPTIMIZE);
}
/**
* Sets the method to determine which class value to optimize. Will
* be one of OPTIMIZE_0, OPTIMIZE_1, OPTIMIZE_LFREQ, OPTIMIZE_MFREQ,
* OPTIMIZE_POS_NAME.
*
* @param newMethod the new class selection mode.
*/
public void setDesignatedClass(SelectedTag newMethod) {
if (newMethod.getTags() == TAGS_OPTIMIZE) {
m_ClassMode = newMethod.getSelectedTag().getID();
}
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String evaluationModeTipText() {
return "Sets the method used to determine the threshold/performance "
+ "curve. The options are: perform optimization based on the entire "
+ "training set (may result in overfitting); perform an n-fold "
+ "cross-validation (may be time consuming); perform one fold of "
+ "an n-fold cross-validation (faster but likely less accurate).";
}
/**
* Sets the evaluation mode used. Will be one of
* EVAL_TRAINING, EVAL_TUNED_SPLIT, or EVAL_CROSS_VALIDATION
*
* @param newMethod the new evaluation mode.
*/
public void setEvaluationMode(SelectedTag newMethod) {
if (newMethod.getTags() == TAGS_EVAL) {
m_EvalMode = newMethod.getSelectedTag().getID();
}
}
/**
* Gets the evaluation mode used. Will be one of
* EVAL_TRAINING, EVAL_TUNED_SPLIT, or EVAL_CROSS_VALIDATION
*
* @return the evaluation mode.
*/
public SelectedTag getEvaluationMode() {
return new SelectedTag(m_EvalMode, TAGS_EVAL);
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String rangeCorrectionTipText() {
return "Sets the type of prediction range correction performed. "
+ "The options are: do not do any range correction; "
+ "expand predicted probabilities so that the minimum probability "
+ "observed during the optimization maps to 0, and the maximum "
+ "maps to 1 (values outside this range are clipped to 0 and 1).";
}
/**
* Sets the confidence range correction mode used. Will be one of
* RANGE_NONE, or RANGE_BOUNDS
*
* @param newMethod the new correciton mode.
*/
public void setRangeCorrection(SelectedTag newMethod) {
if (newMethod.getTags() == TAGS_RANGE) {
m_RangeMode = newMethod.getSelectedTag().getID();
}
}
/**
* Gets the confidence range correction mode used. Will be one of
* RANGE_NONE, or RANGE_BOUNDS
*
* @return the confidence correction mode.
*/
public SelectedTag getRangeCorrection() {
return new SelectedTag(m_RangeMode, TAGS_RANGE);
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String numXValFoldsTipText() {
return "Sets the number of folds used during full cross-validation "
+ "and tuned fold evaluation. This number will be automatically "
+ "reduced if there are insufficient positive examples.";
}
/**
* Get the number of folds used for cross-validation.
*
* @return the number of folds used for cross-validation.
*/
public int getNumXValFolds() {
return m_NumXValFolds;
}
/**
* Set the number of folds used for cross-validation.
*
* @param newNumFolds the number of folds used for cross-validation.
*/
public void setNumXValFolds(int newNumFolds) {
if (newNumFolds < 2) {
throw new IllegalArgumentException("Number of folds must be greater than 1");
}
m_NumXValFolds = newNumFolds;
}
/**
* Returns the type of graph this classifier
* represents.
*
* @return the type of graph this classifier represents
*/
public int graphType() {
if (m_Classifier instanceof Drawable)
return ((Drawable)m_Classifier).graphType();
else
return Drawable.NOT_DRAWABLE;
}
/**
* Returns graph describing the classifier (if possible).
*
* @return the graph of the classifier in dotty format
* @throws Exception if the classifier cannot be graphed
*/
public String graph() throws Exception {
if (m_Classifier instanceof Drawable)
return ((Drawable)m_Classifier).graph();
else throw new Exception("Classifier: " + getClassifierSpec()
+ " cannot be graphed");
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String manualThresholdValueTipText() {
return "Sets a manual threshold value to use. "
+ "If this is set (non-negative value between 0 and 1), then "
+ "all options pertaining to automatic threshold selection are "
+ "ignored. ";
}
/**
* Sets the value for a manual threshold. If this option
* is set (non-negative value between 0 and 1), then options
* pertaining to automatic threshold selection are ignored.
*
* @param threshold the manual threshold to use
*/
public void setManualThresholdValue(double threshold) throws Exception {
m_manualThresholdValue = threshold;
if (threshold >= 0.0 && threshold <= 1.0) {
m_manualThreshold = true;
} else {
m_manualThreshold = false;
if (threshold >= 0) {
throw new IllegalArgumentException("Threshold must be in the "
+ "range 0..1.");
}
}
}
/**
* Returns the value of the manual threshold. (a negative
* value indicates that no manual threshold is being used.
*
* @return the value of the manual threshold.
*/
public double getManualThresholdValue() {
return m_manualThresholdValue;
}
/**
* Returns description of the cross-validated classifier.
*
* @return description of the cross-validated classifier as a string
*/
public String toString() {
if (m_BestValue == -Double.MAX_VALUE)
return "ThresholdSelector: No model built yet.";
String result = "Threshold Selector.\n"
+ "Classifier: " + m_Classifier.getClass().getName() + "\n";
result += "Index of designated class: " + m_DesignatedClass + "\n";
if (m_manualThreshold) {
result += "User supplied threshold: " + m_BestThreshold + "\n";
} else {
result += "Evaluation mode: ";
switch (m_EvalMode) {
case EVAL_CROSS_VALIDATION:
result += m_NumXValFolds + "-fold cross-validation";
break;
case EVAL_TUNED_SPLIT:
result += "tuning on 1/" + m_NumXValFolds + " of the data";
break;
case EVAL_TRAINING_SET:
default:
result += "tuning on the training data";
}
result += "\n";
result += "Threshold: " + m_BestThreshold + "\n";
result += "Best value: " + m_BestValue + "\n";
if (m_RangeMode == RANGE_BOUNDS) {
result += "Expanding range [" + m_LowThreshold + "," + m_HighThreshold
+ "] to [0, 1]\n";
}
result += "Measure: " + getMeasure().getSelectedTag().getReadable() + "\n";
}
result += m_Classifier.toString();
return result;
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 1.43 $");
}
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
runClassifier(new ThresholdSelector(), argv);
}
}