All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.gui.beans.ClassifierPerformanceEvaluator Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    ClassifierPerformanceEvaluator.java
 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.gui.beans;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.ThresholdCurve;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionHandler;
import weka.core.Utils;
import weka.gui.visualize.PlotData2D;

import java.io.Serializable;
import java.util.Enumeration;
import java.util.Vector;

/**
 * A bean that evaluates the performance of batch trained classifiers
 *
 * @author Mark Hall
 * @version $Revision: 7059 $
 */
public class ClassifierPerformanceEvaluator 
  extends AbstractEvaluator
  implements BatchClassifierListener, 
	     Serializable, UserRequestAcceptor, EventConstraints {

  /** for serialization */
  private static final long serialVersionUID = -3511801418192148690L;

  /**
   * Evaluation object used for evaluating a classifier
   */
  private transient Evaluation m_eval;

  private transient Thread m_evaluateThread = null;
  
  private transient long m_currentBatchIdentifier;
  private transient int m_setsComplete;
  
  private Vector m_textListeners = new Vector();
  private Vector m_thresholdListeners = new Vector();
  private Vector m_visualizableErrorListeners = new Vector();

  public ClassifierPerformanceEvaluator() {
    m_visual.loadIcons(BeanVisual.ICON_PATH
		       +"ClassifierPerformanceEvaluator.gif",
		       BeanVisual.ICON_PATH
		       +"ClassifierPerformanceEvaluator_animated.gif");
    m_visual.setText("ClassifierPerformanceEvaluator");
  }

  /**
   * Set a custom (descriptive) name for this bean
   * 
   * @param name the name to use
   */
  public void setCustomName(String name) {
    m_visual.setText(name);
  }

  /**
   * Get the custom (descriptive) name for this bean (if one has been set)
   * 
   * @return the custom name (or the default name)
   */
  public String getCustomName() {
    return m_visual.getText();
  }
  
  /**
   * Global info for this bean
   *
   * @return a String value
   */
  public String globalInfo() {
    return Messages.getInstance().getString("ClassifierPerformanceEvaluator_GlobalInfo_Text");
  }

  // ----- Stuff for ROC curves
  private boolean m_rocListenersConnected = false;
  // Plottable Instances with predictions appended
  private transient Instances m_predInstances = null;
  // Actual predictions
  private transient FastVector m_plotShape = null;
  private transient FastVector m_plotSize = null;

  /**
   * Accept a classifier to be evaluated
   *
   * @param ce a BatchClassifierEvent value
   */
  public void acceptClassifier(final BatchClassifierEvent ce) {
    if (ce.getTestSet() == null || ce.getTestSet().isStructureOnly()) {
      return; // cant evaluate empty/non-existent test instances
    }
    try {
      if (m_evaluateThread == null) {
	m_evaluateThread = new Thread() {
	    public void run() {
	      boolean errorOccurred = false;
//	      final String oldText = m_visual.getText();
	      Classifier classifier = ce.getClassifier();
	      try {
		//if (ce.getSetNumber() == 1) {
	        if (ce.getGroupIdentifier() != m_currentBatchIdentifier) {
                  if (ce.getTrainSet().getDataSet() == null ||
                      ce.getTrainSet().getDataSet().numInstances() == 0) {
                    // we have no training set to estimate majority class
                    // or mean of target from
                    m_eval = new Evaluation(ce.getTestSet().getDataSet());
                    m_eval.useNoPriors();
                  } else {
                    m_eval = new Evaluation(ce.getTrainSet().getDataSet());
                  }

//		  m_classifier = ce.getClassifier();
		  if (m_visualizableErrorListeners.size() > 0) {
		    m_predInstances = 
		      weka.gui.explorer.ClassifierPanel.
		      setUpVisualizableInstances(new Instances(ce.getTestSet().getDataSet()));
		    m_plotShape = new FastVector();
		    m_plotSize = new FastVector();
		  }
		  
		  m_currentBatchIdentifier = ce.getGroupIdentifier();
                  m_setsComplete = 0;
		}
//		if (ce.getSetNumber() <= ce.getMaxSetNumber()) {
	        if (m_setsComplete < ce.getMaxSetNumber()) {
		  if (ce.getTrainSet().getDataSet() != null &&
                      ce.getTrainSet().getDataSet().numInstances() > 0) {
                    // set the priors
                    m_eval.setPriors(ce.getTrainSet().getDataSet());
                  }
		  
//		  m_visual.setText("Evaluating ("+ce.getSetNumber()+")...");
		  if (m_logger != null) {
		    m_logger.statusMessage(statusMessagePrefix()
					   + Messages.getInstance().getString("ClassifierPerformanceEvaluator_AcceptClassifier_Visual_SetText_Text_First") + ce.getSetNumber()
					   + Messages.getInstance().getString("ClassifierPerformanceEvaluator_AcceptClassifier_Visual_SetText_Text_Second"));
		  }
		  m_visual.setAnimated();
		  /*
		  m_eval.evaluateModel(ce.getClassifier(), 
		  ce.getTestSet().getDataSet()); */
		  for (int i = 0; i < ce.getTestSet().getDataSet().numInstances(); i++) {
		    Instance temp = ce.getTestSet().getDataSet().instance(i);
		    weka.gui.explorer.ClassifierPanel.
		    processClassifierPrediction(temp, ce.getClassifier(),
						m_eval, m_predInstances, m_plotShape,
						m_plotSize);
		  }
		  
		  m_setsComplete++;
		}
		
		if (ce.getSetNumber() == ce.getMaxSetNumber()) {
                  //		  System.err.println(m_eval.toSummaryString());
		  // m_resultsString.append(m_eval.toSummaryString());
		  // m_outText.setText(m_resultsString.toString());
		  String textTitle = classifier.getClass().getName();
		  String textOptions = "";
		  if (classifier instanceof OptionHandler) {
	             textOptions = 
	               Utils.joinOptions(((OptionHandler)classifier).getOptions()); 
		  }
		  textTitle = 
		    textTitle.substring(textTitle.lastIndexOf('.')+1,
					textTitle.length());
		  String resultT = Messages.getInstance().getString("ClassifierPerformanceEvaluator_AcceptClassifier_ResultT_Text_First") + textTitle + "\n"
		    + ((textOptions.length() > 0) ? Messages.getInstance().getString("ClassifierPerformanceEvaluator_AcceptClassifier_ResultT_Text_Second") + textOptions + "\n": "")
		    + Messages.getInstance().getString("ClassifierPerformanceEvaluator_AcceptClassifier_ResultT_Text_Third") + ce.getTestSet().getDataSet().relationName()
		    + "\n\n" + m_eval.toSummaryString();
                  
                  if (ce.getTestSet().getDataSet().
                      classAttribute().isNominal()) {
                    resultT += "\n" + m_eval.toClassDetailsString()
                      + "\n" + m_eval.toMatrixString();
                  }
                  
		  TextEvent te = 
		    new TextEvent(ClassifierPerformanceEvaluator.this, 
				  resultT,
				  textTitle);
		  notifyTextListeners(te);

                  // set up visualizable errors
                  if (m_visualizableErrorListeners.size() > 0) {
                    PlotData2D errorD = new PlotData2D(m_predInstances);
                    errorD.setShapeSize(m_plotSize);
                    errorD.setShapeType(m_plotShape);
                    errorD.setPlotName(textTitle + " " +textOptions + " ("
                                       +ce.getTestSet().getDataSet().relationName()
                                       +")");
                    errorD.addInstanceNumberAttribute();
                    VisualizableErrorEvent vel = 
                      new VisualizableErrorEvent(ClassifierPerformanceEvaluator.this,
                                                 errorD);
                    notifyVisualizableErrorListeners(vel);
                  }
                  

		  if (ce.getTestSet().getDataSet().classAttribute().isNominal() &&
		      m_thresholdListeners.size() > 0) {
		    ThresholdCurve tc = new ThresholdCurve();
		    Instances result = tc.getCurve(m_eval.predictions(), 0);
		    result.
		      setRelationName(ce.getTestSet().getDataSet().relationName());
		    PlotData2D pd = new PlotData2D(result);
		    String htmlTitle = Messages.getInstance().getString("ClassifierPerformanceEvaluator_AcceptClassifier_HtmlTitle_Text_First")
		      + textTitle;
		    String newOptions = "";
		    if (classifier instanceof OptionHandler) {
		      String[] options = 
		        ((OptionHandler) classifier).getOptions();
		      if (options.length > 0) {
		        for (int ii = 0; ii < options.length; ii++) {
		          if (options[ii].length() == 0) {
		            continue;
		          }
		          if (options[ii].charAt(0) == '-' && 
		              !(options[ii].charAt(1) >= '0' &&
		                  options[ii].charAt(1)<= '9')) {
		            newOptions += "
"; } newOptions += options[ii]; } } } htmlTitle += " " + newOptions + Messages.getInstance().getString("ClassifierPerformanceEvaluator_AcceptClassifier_HtmlTitle_Text_Second") +ce.getTestSet().getDataSet(). classAttribute().value(0) + ")" + Messages.getInstance().getString("ClassifierPerformanceEvaluator_AcceptClassifier_HtmlTitle_Text_Third"); pd.setPlotName(textTitle + Messages.getInstance().getString("ClassifierPerformanceEvaluator_AcceptClassifier_HtmlTitle_Text_Fourth") +ce.getTestSet().getDataSet(). classAttribute().value(0) + Messages.getInstance().getString("ClassifierPerformanceEvaluator_AcceptClassifier_HtmlTitle_Text_Fifth")); pd.setPlotNameHTML(htmlTitle); boolean [] connectPoints = new boolean [result.numInstances()]; for (int jj = 1; jj < connectPoints.length; jj++) { connectPoints[jj] = true; } pd.setConnectPoints(connectPoints); ThresholdDataEvent rde = new ThresholdDataEvent(ClassifierPerformanceEvaluator.this, pd, ce.getTestSet().getDataSet().classAttribute()); notifyThresholdListeners(rde); /*te = new TextEvent(ClassifierPerformanceEvaluator.this, result.toString(), "ThresholdCurveInst"); notifyTextListeners(te); */ } if (m_logger != null) { m_logger.statusMessage(statusMessagePrefix() + Messages.getInstance().getString("ClassifierPerformanceEvaluator_AcceptClassifier_StatusMessage_Text_Third")); } // save memory m_predInstances = null; m_plotShape = null; m_plotSize = null; } } catch (Exception ex) { errorOccurred = true; ClassifierPerformanceEvaluator.this.stop(); // stop all processing if (m_logger != null) { m_logger.logMessage(Messages.getInstance().getString("ClassifierPerformanceEvaluator_AcceptClassifier_LogMessage_Text_First") + statusMessagePrefix() + Messages.getInstance().getString("ClassifierPerformanceEvaluator_AcceptClassifier_LogMessage_Text_Second") + ex.getMessage()); } ex.printStackTrace(); } finally { // m_visual.setText(oldText); m_visual.setStatic(); m_evaluateThread = null; if (m_logger != null) { if (errorOccurred) { m_logger.statusMessage(statusMessagePrefix() + Messages.getInstance().getString("ClassifierPerformanceEvaluator_AcceptClassifier_StatusMessage_Text_Fourth")); } else if (isInterrupted()) { m_logger.logMessage(Messages.getInstance().getString("ClassifierPerformanceEvaluator_AcceptClassifier_LogMessage_Text_Third") + getCustomName() + Messages.getInstance().getString("ClassifierPerformanceEvaluator_AcceptClassifier_LogMessage_Text_Fourth")); m_logger.statusMessage(statusMessagePrefix() + Messages.getInstance().getString("ClassifierPerformanceEvaluator_AcceptClassifier_StatusMessage_Text_Fifth")); } } block(false); } } }; m_evaluateThread.setPriority(Thread.MIN_PRIORITY); m_evaluateThread.start(); // make sure the thread is still running before we block // if (m_evaluateThread.isAlive()) { block(true); // } m_evaluateThread = null; } } catch (Exception ex) { ex.printStackTrace(); } } /** * Returns true if. at this time, the bean is busy with some * (i.e. perhaps a worker thread is performing some calculation). * * @return true if the bean is busy. */ public boolean isBusy() { return (m_evaluateThread != null); } /** * Try and stop any action */ public void stop() { // tell the listenee (upstream bean) to stop if (m_listenee instanceof BeanCommon) { // System.err.println("Listener is BeanCommon"); ((BeanCommon)m_listenee).stop(); } // stop the evaluate thread if (m_evaluateThread != null) { m_evaluateThread.interrupt(); m_evaluateThread.stop(); m_evaluateThread = null; m_visual.setStatic(); } } /** * Function used to stop code that calls acceptClassifier. This is * needed as classifier evaluation is performed inside a separate * thread of execution. * * @param tf a boolean value */ private synchronized void block(boolean tf) { if (tf) { try { // only block if thread is still doing something useful! if (m_evaluateThread != null && m_evaluateThread.isAlive()) { wait(); } } catch (InterruptedException ex) { } } else { notifyAll(); } } /** * Return an enumeration of user activated requests for this bean * * @return an Enumeration value */ public Enumeration enumerateRequests() { Vector newVector = new Vector(0); if (m_evaluateThread != null) { newVector.addElement("Stop"); } return newVector.elements(); } /** * Perform the named request * * @param request the request to perform * @exception IllegalArgumentException if an error occurs */ public void performRequest(String request) { if (request.compareTo("Stop") == 0) { stop(); } else { throw new IllegalArgumentException(request + Messages.getInstance().getString("ClassifierPerformanceEvaluator_PerformRequest_IllegalArgumentException_Text")); } } /** * Add a text listener * * @param cl a TextListener value */ public synchronized void addTextListener(TextListener cl) { m_textListeners.addElement(cl); } /** * Remove a text listener * * @param cl a TextListener value */ public synchronized void removeTextListener(TextListener cl) { m_textListeners.remove(cl); } /** * Add a threshold data listener * * @param cl a ThresholdDataListener value */ public synchronized void addThresholdDataListener(ThresholdDataListener cl) { m_thresholdListeners.addElement(cl); } /** * Remove a Threshold data listener * * @param cl a ThresholdDataListener value */ public synchronized void removeThresholdDataListener(ThresholdDataListener cl) { m_thresholdListeners.remove(cl); } /** * Add a visualizable error listener * * @param vel a VisualizableErrorListener value */ public synchronized void addVisualizableErrorListener(VisualizableErrorListener vel) { m_visualizableErrorListeners.add(vel); } /** * Remove a visualizable error listener * * @param vel a VisualizableErrorListener value */ public synchronized void removeVisualizableErrorListener(VisualizableErrorListener vel) { m_visualizableErrorListeners.remove(vel); } /** * Notify all text listeners of a TextEvent * * @param te a TextEvent value */ private void notifyTextListeners(TextEvent te) { Vector l; synchronized (this) { l = (Vector)m_textListeners.clone(); } if (l.size() > 0) { for(int i = 0; i < l.size(); i++) { // System.err.println("Notifying text listeners " // +"(ClassifierPerformanceEvaluator)"); ((TextListener)l.elementAt(i)).acceptText(te); } } } /** * Notify all ThresholdDataListeners of a ThresholdDataEvent * * @param te a ThresholdDataEvent value */ private void notifyThresholdListeners(ThresholdDataEvent re) { Vector l; synchronized (this) { l = (Vector)m_thresholdListeners.clone(); } if (l.size() > 0) { for(int i = 0; i < l.size(); i++) { // System.err.println("Notifying text listeners " // +"(ClassifierPerformanceEvaluator)"); ((ThresholdDataListener)l.elementAt(i)).acceptDataSet(re); } } } /** * Notify all VisualizableErrorListeners of a VisualizableErrorEvent * * @param te a VisualizableErrorEvent value */ private void notifyVisualizableErrorListeners(VisualizableErrorEvent re) { Vector l; synchronized (this) { l = (Vector)m_visualizableErrorListeners.clone(); } if (l.size() > 0) { for(int i = 0; i < l.size(); i++) { // System.err.println("Notifying text listeners " // +"(ClassifierPerformanceEvaluator)"); ((VisualizableErrorListener)l.elementAt(i)).acceptDataSet(re); } } } /** * Returns true, if at the current time, the named event could * be generated. Assumes that supplied event names are names of * events that could be generated by this bean. * * @param eventName the name of the event in question * @return true if the named event could be generated at this point in * time */ public boolean eventGeneratable(String eventName) { if (m_listenee == null) { return false; } if (m_listenee instanceof EventConstraints) { if (!((EventConstraints)m_listenee). eventGeneratable("batchClassifier")) { return false; } } return true; } private String statusMessagePrefix() { return getCustomName() + "$" + hashCode() + "|"; } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy