All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.gui.visualize.ThresholdVisualizePanel Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This version represents the developer version, the "bleeding edge" of development, you could say. New functionality gets added to this version.

There is a newer version: 3.9.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    ThresholdVisualizePanel.java
 *    Copyright (C) 2003-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.gui.visualize;

import java.awt.BorderLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;

import javax.swing.BorderFactory;
import javax.swing.JFrame;
import javax.swing.border.TitledBorder;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.EvaluationUtils;
import weka.classifiers.evaluation.Prediction;
import weka.classifiers.evaluation.ThresholdCurve;
import weka.core.Instances;
import weka.core.SingleIndex;
import weka.core.Utils;

/**
 * This panel is a VisualizePanel, with the added ablility to display the area
 * under the ROC curve if an ROC curve is chosen.
 * 
 * @author Dale Fletcher ([email protected])
 * @author FracPete (fracpete at waikato dot ac dot nz)
 * @version $Revision: 10222 $
 */
public class ThresholdVisualizePanel extends VisualizePanel {

  /** for serialization */
  private static final long serialVersionUID = 3070002211779443890L;

  /** The string to add to the Plot Border. */
  private String m_ROCString = "";

  /** Original border text */
  private final String m_savePanelBorderText;

  /**
   * default constructor
   */
  public ThresholdVisualizePanel() {
    super();

    // Save the current border text
    TitledBorder tb = (TitledBorder) m_plotSurround.getBorder();
    m_savePanelBorderText = tb.getTitle();
  }

  /**
   * Set the string with ROC area
   * 
   * @param str ROC area string to add to border
   */
  public void setROCString(String str) {
    m_ROCString = str;
  }

  /**
   * This extracts the ROC area string
   * 
   * @return ROC area string
   */
  public String getROCString() {
    return m_ROCString;
  }

  /**
   * This overloads VisualizePanel's setUpComboBoxes to add ActionListeners to
   * watch for when the X/Y Axis comboboxes are changed.
   * 
   * @param inst a set of instances with data for plotting
   */
  @Override
  public void setUpComboBoxes(Instances inst) {
    super.setUpComboBoxes(inst);

    m_XCombo.addActionListener(new ActionListener() {
      @Override
      public void actionPerformed(ActionEvent e) {
        setBorderText();
      }
    });
    m_YCombo.addActionListener(new ActionListener() {
      @Override
      public void actionPerformed(ActionEvent e) {
        setBorderText();
      }
    });

    // Just in case the default is ROC
    setBorderText();
  }

  /**
   * This checks the current selected X/Y Axis comboBoxes to see if an ROC graph
   * is selected. If so, add the ROC area string to the plot border, otherwise
   * display the original border text.
   */
  private void setBorderText() {

    String xs = m_XCombo.getSelectedItem().toString();
    String ys = m_YCombo.getSelectedItem().toString();

    if (xs.equals("X: False Positive Rate (Num)")
      && ys.equals("Y: True Positive Rate (Num)")) {
      m_plotSurround.setBorder((BorderFactory
        .createTitledBorder(m_savePanelBorderText + " " + m_ROCString)));
    } else {
      m_plotSurround.setBorder((BorderFactory
        .createTitledBorder(m_savePanelBorderText)));
    }
  }

  /**
   * displays the previously saved instances
   * 
   * @param insts the instances to display
   * @throws Exception if display is not possible
   */
  @Override
  protected void openVisibleInstances(Instances insts) throws Exception {
    super.openVisibleInstances(insts);

    setROCString("(Area under ROC = "
      + Utils.doubleToString(ThresholdCurve.getROCArea(insts), 4) + ")");

    setBorderText();
  }

  /**
   * Starts the ThresholdVisualizationPanel with parameters from the command
   * line.
   * 

* * Valid options are: *

* -h
* lists all the commandline parameters *

* * -t file
* Dataset to process with given classifier. *

* * -W classname
* Full classname of classifier to run.
* Options after '--' are passed to the classifier.
* (default weka.classifiers.functions.Logistic) *

* * -r number
* The number of runs to perform (default 2). *

* * -x number
* The number of Cross-validation folds (default 10). *

* * -l file
* Previously saved threshold curve ARFF file. *

* * @param args optional commandline parameters */ public static void main(String[] args) { Instances inst; Classifier classifier; int runs; int folds; String tmpStr; boolean compute; Instances result; String[] options; SingleIndex classIndex; SingleIndex valueIndex; int seed; inst = null; classifier = null; runs = 2; folds = 10; compute = true; result = null; classIndex = null; valueIndex = null; seed = 1; try { // help? if (Utils.getFlag('h', args)) { System.out.println("\nOptions for " + ThresholdVisualizePanel.class.getName() + ":\n"); System.out.println("-h\n\tThis help."); System.out .println("-t \n\tDataset to process with given classifier."); System.out .println("-c \n\tThe class index. first and last are valid, too (default: last)."); System.out .println("-C \n\tThe index of the class value to get the the curve for (default: first)."); System.out .println("-W \n\tFull classname of classifier to run.\n\tOptions after '--' are passed to the classifier.\n\t(default: weka.classifiers.functions.Logistic)"); System.out .println("-r \n\tThe number of runs to perform (default: 1)."); System.out .println("-x \n\tThe number of Cross-validation folds (default: 10)."); System.out .println("-S \n\tThe seed value for randomizing the data (default: 1)."); System.out .println("-l \n\tPreviously saved threshold curve ARFF file."); return; } // regular options tmpStr = Utils.getOption('l', args); if (tmpStr.length() != 0) { result = new Instances(new BufferedReader(new FileReader(tmpStr))); compute = false; } if (compute) { tmpStr = Utils.getOption('r', args); if (tmpStr.length() != 0) { runs = Integer.parseInt(tmpStr); } else { runs = 1; } tmpStr = Utils.getOption('x', args); if (tmpStr.length() != 0) { folds = Integer.parseInt(tmpStr); } else { folds = 10; } tmpStr = Utils.getOption('S', args); if (tmpStr.length() != 0) { seed = Integer.parseInt(tmpStr); } else { seed = 1; } tmpStr = Utils.getOption('t', args); if (tmpStr.length() != 0) { inst = new Instances(new BufferedReader(new FileReader(tmpStr))); inst.setClassIndex(inst.numAttributes() - 1); } tmpStr = Utils.getOption('W', args); if (tmpStr.length() != 0) { options = Utils.partitionOptions(args); } else { tmpStr = weka.classifiers.functions.Logistic.class.getName(); options = new String[0]; } classifier = AbstractClassifier.forName(tmpStr, options); tmpStr = Utils.getOption('c', args); if (tmpStr.length() != 0) { classIndex = new SingleIndex(tmpStr); } else { classIndex = new SingleIndex("last"); } tmpStr = Utils.getOption('C', args); if (tmpStr.length() != 0) { valueIndex = new SingleIndex(tmpStr); } else { valueIndex = new SingleIndex("first"); } } // compute if necessary if (compute) { if (classIndex != null) { classIndex.setUpper(inst.numAttributes() - 1); inst.setClassIndex(classIndex.getIndex()); } else { inst.setClassIndex(inst.numAttributes() - 1); } if (valueIndex != null) { valueIndex.setUpper(inst.classAttribute().numValues() - 1); } ThresholdCurve tc = new ThresholdCurve(); EvaluationUtils eu = new EvaluationUtils(); ArrayList predictions = new ArrayList(); for (int i = 0; i < runs; i++) { eu.setSeed(seed + i); predictions.addAll(eu.getCVPredictions(classifier, inst, folds)); } if (valueIndex != null) { result = tc.getCurve(predictions, valueIndex.getIndex()); } else { result = tc.getCurve(predictions); } } // setup GUI ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(ThresholdCurve.getROCArea(result), 4) + ")"); if (compute) { vmc.setName(result.relationName() + ". (Class value " + inst.classAttribute().value(valueIndex.getIndex()) + ")"); } else { vmc.setName(result.relationName() + " (display only)"); } PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); vmc.addPlot(tempd); String plotName = vmc.getName(); final JFrame jf = new JFrame("Weka Classifier Visualize: " + plotName); jf.setSize(500, 400); jf.getContentPane().setLayout(new BorderLayout()); jf.getContentPane().add(vmc, BorderLayout.CENTER); jf.addWindowListener(new WindowAdapter() { @Override public void windowClosing(WindowEvent e) { jf.dispose(); } }); jf.setVisible(true); } catch (Exception e) { e.printStackTrace(); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy