weka.gui.visualize.ThresholdVisualizePanel Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-dev Show documentation
Show all versions of weka-dev Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This version represents the developer version, the
"bleeding edge" of development, you could say. New functionality gets added
to this version.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* ThresholdVisualizePanel.java
* Copyright (C) 2003-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.gui.visualize;
import java.awt.BorderLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import javax.swing.BorderFactory;
import javax.swing.JFrame;
import javax.swing.border.TitledBorder;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.EvaluationUtils;
import weka.classifiers.evaluation.Prediction;
import weka.classifiers.evaluation.ThresholdCurve;
import weka.core.Instances;
import weka.core.SingleIndex;
import weka.core.Utils;
/**
* This panel is a VisualizePanel, with the added ablility to display the area
* under the ROC curve if an ROC curve is chosen.
*
* @author Dale Fletcher ([email protected])
* @author FracPete (fracpete at waikato dot ac dot nz)
* @version $Revision: 10222 $
*/
public class ThresholdVisualizePanel extends VisualizePanel {
/** for serialization */
private static final long serialVersionUID = 3070002211779443890L;
/** The string to add to the Plot Border. */
private String m_ROCString = "";
/** Original border text */
private final String m_savePanelBorderText;
/**
* default constructor
*/
public ThresholdVisualizePanel() {
super();
// Save the current border text
TitledBorder tb = (TitledBorder) m_plotSurround.getBorder();
m_savePanelBorderText = tb.getTitle();
}
/**
* Set the string with ROC area
*
* @param str ROC area string to add to border
*/
public void setROCString(String str) {
m_ROCString = str;
}
/**
* This extracts the ROC area string
*
* @return ROC area string
*/
public String getROCString() {
return m_ROCString;
}
/**
* This overloads VisualizePanel's setUpComboBoxes to add ActionListeners to
* watch for when the X/Y Axis comboboxes are changed.
*
* @param inst a set of instances with data for plotting
*/
@Override
public void setUpComboBoxes(Instances inst) {
super.setUpComboBoxes(inst);
m_XCombo.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
setBorderText();
}
});
m_YCombo.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
setBorderText();
}
});
// Just in case the default is ROC
setBorderText();
}
/**
* This checks the current selected X/Y Axis comboBoxes to see if an ROC graph
* is selected. If so, add the ROC area string to the plot border, otherwise
* display the original border text.
*/
private void setBorderText() {
String xs = m_XCombo.getSelectedItem().toString();
String ys = m_YCombo.getSelectedItem().toString();
if (xs.equals("X: False Positive Rate (Num)")
&& ys.equals("Y: True Positive Rate (Num)")) {
m_plotSurround.setBorder((BorderFactory
.createTitledBorder(m_savePanelBorderText + " " + m_ROCString)));
} else {
m_plotSurround.setBorder((BorderFactory
.createTitledBorder(m_savePanelBorderText)));
}
}
/**
* displays the previously saved instances
*
* @param insts the instances to display
* @throws Exception if display is not possible
*/
@Override
protected void openVisibleInstances(Instances insts) throws Exception {
super.openVisibleInstances(insts);
setROCString("(Area under ROC = "
+ Utils.doubleToString(ThresholdCurve.getROCArea(insts), 4) + ")");
setBorderText();
}
/**
* Starts the ThresholdVisualizationPanel with parameters from the command
* line.
*
*
* Valid options are:
*
* -h
* lists all the commandline parameters
*
*
* -t file
* Dataset to process with given classifier.
*
*
* -W classname
* Full classname of classifier to run.
* Options after '--' are passed to the classifier.
* (default weka.classifiers.functions.Logistic)
*
*
* -r number
* The number of runs to perform (default 2).
*
*
* -x number
* The number of Cross-validation folds (default 10).
*
*
* -l file
* Previously saved threshold curve ARFF file.
*
*
* @param args optional commandline parameters
*/
public static void main(String[] args) {
Instances inst;
Classifier classifier;
int runs;
int folds;
String tmpStr;
boolean compute;
Instances result;
String[] options;
SingleIndex classIndex;
SingleIndex valueIndex;
int seed;
inst = null;
classifier = null;
runs = 2;
folds = 10;
compute = true;
result = null;
classIndex = null;
valueIndex = null;
seed = 1;
try {
// help?
if (Utils.getFlag('h', args)) {
System.out.println("\nOptions for "
+ ThresholdVisualizePanel.class.getName() + ":\n");
System.out.println("-h\n\tThis help.");
System.out
.println("-t \n\tDataset to process with given classifier.");
System.out
.println("-c \n\tThe class index. first and last are valid, too (default: last).");
System.out
.println("-C \n\tThe index of the class value to get the the curve for (default: first).");
System.out
.println("-W \n\tFull classname of classifier to run.\n\tOptions after '--' are passed to the classifier.\n\t(default: weka.classifiers.functions.Logistic)");
System.out
.println("-r \n\tThe number of runs to perform (default: 1).");
System.out
.println("-x \n\tThe number of Cross-validation folds (default: 10).");
System.out
.println("-S \n\tThe seed value for randomizing the data (default: 1).");
System.out
.println("-l \n\tPreviously saved threshold curve ARFF file.");
return;
}
// regular options
tmpStr = Utils.getOption('l', args);
if (tmpStr.length() != 0) {
result = new Instances(new BufferedReader(new FileReader(tmpStr)));
compute = false;
}
if (compute) {
tmpStr = Utils.getOption('r', args);
if (tmpStr.length() != 0) {
runs = Integer.parseInt(tmpStr);
} else {
runs = 1;
}
tmpStr = Utils.getOption('x', args);
if (tmpStr.length() != 0) {
folds = Integer.parseInt(tmpStr);
} else {
folds = 10;
}
tmpStr = Utils.getOption('S', args);
if (tmpStr.length() != 0) {
seed = Integer.parseInt(tmpStr);
} else {
seed = 1;
}
tmpStr = Utils.getOption('t', args);
if (tmpStr.length() != 0) {
inst = new Instances(new BufferedReader(new FileReader(tmpStr)));
inst.setClassIndex(inst.numAttributes() - 1);
}
tmpStr = Utils.getOption('W', args);
if (tmpStr.length() != 0) {
options = Utils.partitionOptions(args);
} else {
tmpStr = weka.classifiers.functions.Logistic.class.getName();
options = new String[0];
}
classifier = AbstractClassifier.forName(tmpStr, options);
tmpStr = Utils.getOption('c', args);
if (tmpStr.length() != 0) {
classIndex = new SingleIndex(tmpStr);
} else {
classIndex = new SingleIndex("last");
}
tmpStr = Utils.getOption('C', args);
if (tmpStr.length() != 0) {
valueIndex = new SingleIndex(tmpStr);
} else {
valueIndex = new SingleIndex("first");
}
}
// compute if necessary
if (compute) {
if (classIndex != null) {
classIndex.setUpper(inst.numAttributes() - 1);
inst.setClassIndex(classIndex.getIndex());
} else {
inst.setClassIndex(inst.numAttributes() - 1);
}
if (valueIndex != null) {
valueIndex.setUpper(inst.classAttribute().numValues() - 1);
}
ThresholdCurve tc = new ThresholdCurve();
EvaluationUtils eu = new EvaluationUtils();
ArrayList predictions = new ArrayList();
for (int i = 0; i < runs; i++) {
eu.setSeed(seed + i);
predictions.addAll(eu.getCVPredictions(classifier, inst, folds));
}
if (valueIndex != null) {
result = tc.getCurve(predictions, valueIndex.getIndex());
} else {
result = tc.getCurve(predictions);
}
}
// setup GUI
ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
vmc.setROCString("(Area under ROC = "
+ Utils.doubleToString(ThresholdCurve.getROCArea(result), 4) + ")");
if (compute) {
vmc.setName(result.relationName() + ". (Class value "
+ inst.classAttribute().value(valueIndex.getIndex()) + ")");
} else {
vmc.setName(result.relationName() + " (display only)");
}
PlotData2D tempd = new PlotData2D(result);
tempd.setPlotName(result.relationName());
tempd.addInstanceNumberAttribute();
vmc.addPlot(tempd);
String plotName = vmc.getName();
final JFrame jf = new JFrame("Weka Classifier Visualize: " + plotName);
jf.setSize(500, 400);
jf.getContentPane().setLayout(new BorderLayout());
jf.getContentPane().add(vmc, BorderLayout.CENTER);
jf.addWindowListener(new WindowAdapter() {
@Override
public void windowClosing(WindowEvent e) {
jf.dispose();
}
});
jf.setVisible(true);
} catch (Exception e) {
e.printStackTrace();
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy