All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.gui.visualize.ThresholdVisualizePanel Maven / Gradle / Ivy

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    ThresholdVisualizePanel.java
 *    Copyright (C) 2003 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.gui.visualize;

import weka.classifiers.Classifier;
import weka.classifiers.evaluation.EvaluationUtils;
import weka.classifiers.evaluation.ThresholdCurve;
import weka.core.FastVector;
import weka.core.Instances;
import weka.core.SingleIndex;
import weka.core.Utils;

import java.awt.BorderLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;
import java.io.BufferedReader;
import java.io.FileReader;

import javax.swing.BorderFactory;
import javax.swing.JFrame;
import javax.swing.border.TitledBorder;

/** 
 * This panel is a VisualizePanel, with the added ablility to display the
 * area under the ROC curve if an ROC curve is chosen.
 *
 * @author Dale Fletcher ([email protected])
 * @author FracPete (fracpete at waikato dot ac dot nz)
 * @version $Revision: 7059 $
 */
public class ThresholdVisualizePanel 
  extends VisualizePanel {

  /** for serialization */
  private static final long serialVersionUID = 3070002211779443890L;

  /** The string to add to the Plot Border. */
  private String m_ROCString="";
 
  /** Original border text */
  private String m_savePanelBorderText;

  /**
   * default constructor
   */
  public ThresholdVisualizePanel() {
    super();

    // Save the current border text
    TitledBorder tb=(TitledBorder) m_plotSurround.getBorder();
    m_savePanelBorderText = tb.getTitle();
  }
  
  /**
   * Set the string with ROC area
   * @param str ROC area string to add to border
   */  
  public void setROCString(String str) {
    m_ROCString=str;
  }

  /**
   * This extracts the ROC area string 
   * @return ROC area string 
   */
  public String getROCString() {
    return m_ROCString;
  }

  /**
   * This overloads VisualizePanel's setUpComboBoxes to add 
   * ActionListeners to watch for when the X/Y Axis comboboxes
   * are changed. 
   * @param inst a set of instances with data for plotting
   */
  public void setUpComboBoxes(Instances inst) {
    super.setUpComboBoxes(inst);

    m_XCombo.addActionListener(new ActionListener() {
	public void actionPerformed(ActionEvent e) {
	  setBorderText();
	}
    });
    m_YCombo.addActionListener(new ActionListener() {
	public void actionPerformed(ActionEvent e) {
	  setBorderText();
	}
    });

    // Just in case the default is ROC
    setBorderText();
  }

  /**
   * This checks the current selected X/Y Axis comboBoxes to see if 
   * an ROC graph is selected. If so, add the ROC area string to the
   * plot border, otherwise display the original border text.
   */
  private void setBorderText() {

    String xs = m_XCombo.getSelectedItem().toString();
    String ys = m_YCombo.getSelectedItem().toString();

    if (xs.equals(Messages.getInstance().getString("ThresholdVisualizePanel_SetBorderText_Text_First")) && ys.equals(Messages.getInstance().getString("ThresholdVisualizePanel_SetBorderText_Text_Second")))   {
        m_plotSurround.setBorder((BorderFactory.createTitledBorder(m_savePanelBorderText+" "+m_ROCString)));
    } else
        m_plotSurround.setBorder((BorderFactory.createTitledBorder(m_savePanelBorderText))); 
  }

  /**
   * displays the previously saved instances
   * 
   * @param insts	the instances to display
   * @throws Exception	if display is not possible
   */
  protected void openVisibleInstances(Instances insts) throws Exception {
    super.openVisibleInstances(insts);

    setROCString(
    		Messages.getInstance().getString("ThresholdVisualizePanel_OpenVisibleInstances_Text_First") 
	+ Utils.doubleToString(ThresholdCurve.getROCArea(insts), 4) + Messages.getInstance().getString("ThresholdVisualizePanel_OpenVisibleInstances_Text_Second"));
    
    setBorderText();
  }
  
  /**
   * Starts the ThresholdVisualizationPanel with parameters from the command line. 

* * Valid options are:

* -h
* lists all the commandline parameters

* * -t file
* Dataset to process with given classifier.

* * -W classname
* Full classname of classifier to run.
* Options after '--' are passed to the classifier.
* (default weka.classifiers.functions.Logistic)

* * -r number
* The number of runs to perform (default 2).

* * -x number
* The number of Cross-validation folds (default 10).

* * -l file
* Previously saved threshold curve ARFF file.

* * @param args optional commandline parameters */ public static void main(String [] args) { Instances inst; Classifier classifier; int runs; int folds; String tmpStr; boolean compute; Instances result; String[] options; SingleIndex classIndex; SingleIndex valueIndex; int seed; inst = null; classifier = null; runs = 2; folds = 10; compute = true; result = null; classIndex = null; valueIndex = null; seed = 1; try { // help? if (Utils.getFlag('h', args)) { System.out.println(Messages.getInstance().getString("ThresholdVisualizePanel_Main_Text_First") + ThresholdVisualizePanel.class.getName() + Messages.getInstance().getString("ThresholdVisualizePanel_Main_Text_Second")); System.out.println(Messages.getInstance().getString("ThresholdVisualizePanel_Main_Text_Third")); System.out.println(Messages.getInstance().getString("ThresholdVisualizePanel_Main_Text_Fourth")); System.out.println(Messages.getInstance().getString("ThresholdVisualizePanel_Main_Text_Fifth")); System.out.println(Messages.getInstance().getString("ThresholdVisualizePanel_Main_Text_Sixth")); System.out.println(Messages.getInstance().getString("ThresholdVisualizePanel_Main_Text_Seventh")); System.out.println(Messages.getInstance().getString("ThresholdVisualizePanel_Main_Text_Eighth")); System.out.println(Messages.getInstance().getString("ThresholdVisualizePanel_Main_Text_Nineth")); System.out.println(Messages.getInstance().getString("ThresholdVisualizePanel_Main_Text_Tenth")); System.out.println(Messages.getInstance().getString("ThresholdVisualizePanel_Main_Text_Eleventh")); return; } // regular options tmpStr = Utils.getOption('l', args); if (tmpStr.length() != 0) { result = new Instances(new BufferedReader(new FileReader(tmpStr))); compute = false; } if (compute) { tmpStr = Utils.getOption('r', args); if (tmpStr.length() != 0) runs = Integer.parseInt(tmpStr); else runs = 1; tmpStr = Utils.getOption('x', args); if (tmpStr.length() != 0) folds = Integer.parseInt(tmpStr); else folds = 10; tmpStr = Utils.getOption('S', args); if (tmpStr.length() != 0) seed = Integer.parseInt(tmpStr); else seed = 1; tmpStr = Utils.getOption('t', args); if (tmpStr.length() != 0) { inst = new Instances(new BufferedReader(new FileReader(tmpStr))); inst.setClassIndex(inst.numAttributes() - 1); } tmpStr = Utils.getOption('W', args); if (tmpStr.length() != 0) { options = Utils.partitionOptions(args); } else { tmpStr = weka.classifiers.functions.Logistic.class.getName(); options = new String[0]; } classifier = Classifier.forName(tmpStr, options); tmpStr = Utils.getOption('c', args); if (tmpStr.length() != 0) classIndex = new SingleIndex(tmpStr); else classIndex = new SingleIndex("last"); tmpStr = Utils.getOption('C', args); if (tmpStr.length() != 0) valueIndex = new SingleIndex(tmpStr); else valueIndex = new SingleIndex("first"); } // compute if necessary if (compute) { if (classIndex != null) { classIndex.setUpper(inst.numAttributes() - 1); inst.setClassIndex(classIndex.getIndex()); } else { inst.setClassIndex(inst.numAttributes() - 1); } if (valueIndex != null) { valueIndex.setUpper(inst.classAttribute().numValues() - 1); } ThresholdCurve tc = new ThresholdCurve(); EvaluationUtils eu = new EvaluationUtils(); FastVector predictions = new FastVector(); for (int i = 0; i < runs; i++) { eu.setSeed(seed + i); predictions.appendElements(eu.getCVPredictions(classifier, inst, folds)); } if (valueIndex != null) result = tc.getCurve(predictions, valueIndex.getIndex()); else result = tc.getCurve(predictions); } // setup GUI ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString(Messages.getInstance().getString("ThresholdVisualizePanel_OpenVisibleInstances_Text_Third") + Utils.doubleToString(ThresholdCurve.getROCArea(result), 4) + Messages.getInstance().getString("ThresholdVisualizePanel_OpenVisibleInstances_Text_Fourth")); if (compute) vmc.setName( result.relationName() + Messages.getInstance().getString("ThresholdVisualizePanel_Main_Text_Twelveth") + inst.classAttribute().value(valueIndex.getIndex()) + Messages.getInstance().getString("ThresholdVisualizePanel_Main_Text_Thirteenth")); else vmc.setName( result.relationName() + Messages.getInstance().getString("ThresholdVisualizePanel_Main_Text_Fourteenth")); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); vmc.addPlot(tempd); String plotName = vmc.getName(); final JFrame jf = new JFrame(Messages.getInstance().getString("ThresholdVisualizePanel_Main_JFrame_Text") + plotName); jf.setSize(500,400); jf.getContentPane().setLayout(new BorderLayout()); jf.getContentPane().add(vmc, BorderLayout.CENTER); jf.addWindowListener(new WindowAdapter() { public void windowClosing(WindowEvent e) { jf.dispose(); } }); jf.setVisible(true); } catch (Exception e) { e.printStackTrace(); } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy