All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.gui.beans.IncrementalClassifierEvaluator Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    IncrementalClassifierEvaluator.java
 *    Copyright (C) 2002-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.gui.beans;

import java.util.LinkedList;
import java.util.Vector;

import weka.classifiers.Evaluation;
import weka.core.Instance;
import weka.core.Utils;

/**
 * Bean that evaluates incremental classifiers
 * 
 * @author Mark Hall
 * @version $Revision: 10220 $
 */
public class IncrementalClassifierEvaluator extends AbstractEvaluator implements
  IncrementalClassifierListener, EventConstraints {

  /** for serialization */
  private static final long serialVersionUID = -3105419818939541291L;

  private transient Evaluation m_eval;

  private final Vector m_listeners = new Vector();
  private final Vector m_textListeners = new Vector();

  private Vector m_dataLegend = new Vector();

  private final ChartEvent m_ce = new ChartEvent(this);
  private double[] m_dataPoint = new double[1];
  private boolean m_reset = false;

  private double m_min = Double.MAX_VALUE;
  private double m_max = Double.MIN_VALUE;

  // how often (in milliseconds) to report throughput to the log
  private int m_statusFrequency = 2000;
  private int m_instanceCount = 0;

  // output info retrieval and auc stats for each class (if class is nominal)
  private boolean m_outputInfoRetrievalStats = false;

  // window size for computing performance metrics - 0 means no window, i.e
  // don't "forget" performance on any instances
  private int m_windowSize = 0;
  private Evaluation m_windowEval;
  private LinkedList m_window;
  private LinkedList m_windowedPreds;

  public IncrementalClassifierEvaluator() {
    m_visual.loadIcons(BeanVisual.ICON_PATH
      + "IncrementalClassifierEvaluator.gif", BeanVisual.ICON_PATH
      + "IncrementalClassifierEvaluator_animated.gif");
    m_visual.setText("IncrementalClassifierEvaluator");
  }

  /**
   * Set a custom (descriptive) name for this bean
   * 
   * @param name the name to use
   */
  @Override
  public void setCustomName(String name) {
    m_visual.setText(name);
  }

  /**
   * Get the custom (descriptive) name for this bean (if one has been set)
   * 
   * @return the custom name (or the default name)
   */
  @Override
  public String getCustomName() {
    return m_visual.getText();
  }

  /**
   * Global info for this bean
   * 
   * @return a String value
   */
  public String globalInfo() {
    return "Evaluate the performance of incrementally trained classifiers.";
  }

  protected transient StreamThroughput m_throughput;

  /**
   * Accepts and processes a classifier encapsulated in an incremental
   * classifier event
   * 
   * @param ce an IncrementalClassifierEvent value
   */
  @Override
  public void acceptClassifier(final IncrementalClassifierEvent ce) {
    try {
      if (ce.getStatus() == IncrementalClassifierEvent.NEW_BATCH) {
        m_throughput = new StreamThroughput(statusMessagePrefix());
        m_throughput.setSamplePeriod(m_statusFrequency);

        // m_eval = new Evaluation(ce.getCurrentInstance().dataset());
        m_eval = new Evaluation(ce.getStructure());
        m_eval.useNoPriors();

        m_dataLegend = new Vector();
        m_reset = true;
        m_dataPoint = new double[0];
        ce.getStructure();
        System.err.println("NEW BATCH");
        m_instanceCount = 0;

        if (m_windowSize > 0) {
          m_window = new LinkedList();
          m_windowEval = new Evaluation(ce.getStructure());
          m_windowEval.useNoPriors();
          m_windowedPreds = new LinkedList();

          if (m_logger != null) {
            m_logger.logMessage(statusMessagePrefix()
              + "[IncrementalClassifierEvaluator] Chart output using windowed "
              + "evaluation over " + m_windowSize + " instances");
          }
        }

        /*
         * if (m_logger != null) { m_logger.statusMessage(statusMessagePrefix()
         * + "IncrementalClassifierEvaluator: started processing...");
         * m_logger.logMessage(statusMessagePrefix() +
         * " [IncrementalClassifierEvaluator]" + statusMessagePrefix() +
         * " started processing..."); }
         */
      } else {
        Instance inst = ce.getCurrentInstance();
        if (inst != null) {
          m_throughput.updateStart();
          m_instanceCount++;
          // if (inst.attribute(inst.classIndex()).isNominal()) {
          double[] dist = ce.getClassifier().distributionForInstance(inst);
          double pred = 0;
          if (!inst.isMissing(inst.classIndex())) {
            if (m_outputInfoRetrievalStats) {
              // store predictions so AUC etc can be output.
              m_eval.evaluateModelOnceAndRecordPrediction(dist, inst);
            } else {
              m_eval.evaluateModelOnce(dist, inst);
            }

            if (m_windowSize > 0) {

              m_windowEval.evaluateModelOnce(dist, inst);
              m_window.addFirst(inst);
              m_windowedPreds.addFirst(dist);

              if (m_instanceCount > m_windowSize) {
                // "forget" the oldest prediction
                Instance oldest = m_window.removeLast();

                double[] oldDist = m_windowedPreds.removeLast();
                oldest.setWeight(-oldest.weight());
                m_windowEval.evaluateModelOnce(oldDist, oldest);
                oldest.setWeight(-oldest.weight());
              }
            }
          } else {
            pred = ce.getClassifier().classifyInstance(inst);
          }
          if (inst.classIndex() >= 0) {
            // need to check that the class is not missing
            if (inst.attribute(inst.classIndex()).isNominal()) {
              if (!inst.isMissing(inst.classIndex())) {
                if (m_dataPoint.length < 2) {
                  m_dataPoint = new double[3];
                  m_dataLegend.addElement("Accuracy");
                  m_dataLegend.addElement("RMSE (prob)");
                  m_dataLegend.addElement("Kappa");
                }
                // int classV = (int) inst.value(inst.classIndex());

                if (m_windowSize > 0) {
                  m_dataPoint[1] = m_windowEval.rootMeanSquaredError();
                  m_dataPoint[2] = m_windowEval.kappa();
                } else {
                  m_dataPoint[1] = m_eval.rootMeanSquaredError();
                  m_dataPoint[2] = m_eval.kappa();
                }
                // int maxO = Utils.maxIndex(dist);
                // if (maxO == classV) {
                // dist[classV] = -1;
                // maxO = Utils.maxIndex(dist);
                // }
                // m_dataPoint[1] -= dist[maxO];
              } else {
                if (m_dataPoint.length < 1) {
                  m_dataPoint = new double[1];
                  m_dataLegend.addElement("Confidence");
                }
              }
              double primaryMeasure = 0;
              if (!inst.isMissing(inst.classIndex())) {
                if (m_windowSize > 0) {
                  primaryMeasure = 1.0 - m_windowEval.errorRate();
                } else {
                  primaryMeasure = 1.0 - m_eval.errorRate();
                }
              } else {
                // record confidence as the primary measure
                // (another possibility would be entropy of
                // the distribution, or perhaps average
                // confidence)
                primaryMeasure = dist[Utils.maxIndex(dist)];
              }
              // double [] dataPoint = new double[1];
              m_dataPoint[0] = primaryMeasure;
              // double min = 0; double max = 100;
              /*
               * ChartEvent e = new
               * ChartEvent(IncrementalClassifierEvaluator.this, m_dataLegend,
               * min, max, dataPoint);
               */

              m_ce.setLegendText(m_dataLegend);
              m_ce.setMin(0);
              m_ce.setMax(1);
              m_ce.setDataPoint(m_dataPoint);
              m_ce.setReset(m_reset);
              m_reset = false;
            } else {
              // numeric class
              if (m_dataPoint.length < 1) {
                m_dataPoint = new double[1];
                if (inst.isMissing(inst.classIndex())) {
                  m_dataLegend.addElement("Prediction");
                } else {
                  m_dataLegend.addElement("RMSE");
                }
              }
              if (!inst.isMissing(inst.classIndex())) {
                double update;
                if (!inst.isMissing(inst.classIndex())) {
                  if (m_windowSize > 0) {
                    update = m_windowEval.rootMeanSquaredError();
                  } else {
                    update = m_eval.rootMeanSquaredError();
                  }
                } else {
                  update = pred;
                }
                m_dataPoint[0] = update;
                if (update > m_max) {
                  m_max = update;
                }
                if (update < m_min) {
                  m_min = update;
                }
              }

              m_ce.setLegendText(m_dataLegend);
              m_ce.setMin((inst.isMissing(inst.classIndex()) ? m_min : 0));
              m_ce.setMax(m_max);
              m_ce.setDataPoint(m_dataPoint);
              m_ce.setReset(m_reset);
              m_reset = false;
            }
            notifyChartListeners(m_ce);
          }
          m_throughput.updateEnd(m_logger);
        }

        if (ce.getStatus() == IncrementalClassifierEvent.BATCH_FINISHED
          || inst == null) {
          if (m_logger != null) {
            m_logger.logMessage("[IncrementalClassifierEvaluator]"
              + statusMessagePrefix() + " Finished processing.");
          }
          m_throughput.finished(m_logger);

          // save memory if using windowed evaluation for charting
          m_windowEval = null;
          m_window = null;
          m_windowedPreds = null;

          if (m_textListeners.size() > 0) {
            String textTitle = ce.getClassifier().getClass().getName();
            textTitle = textTitle.substring(textTitle.lastIndexOf('.') + 1,
              textTitle.length());
            String results = "=== Performance information ===\n\n"
              + "Scheme:   " + textTitle + "\n" + "Relation: "
              + m_eval.getHeader().relationName() + "\n\n"
              + m_eval.toSummaryString();
            if (m_eval.getHeader().classIndex() >= 0
              && m_eval.getHeader().classAttribute().isNominal()
              && (m_outputInfoRetrievalStats)) {
              results += "\n" + m_eval.toClassDetailsString();
            }

            if (m_eval.getHeader().classIndex() >= 0
              && m_eval.getHeader().classAttribute().isNominal()) {
              results += "\n" + m_eval.toMatrixString();
            }
            textTitle = "Results: " + textTitle;
            TextEvent te = new TextEvent(this, results, textTitle);
            notifyTextListeners(te);
          }
        }
      }
    } catch (Exception ex) {
      if (m_logger != null) {
        m_logger.logMessage("[IncrementalClassifierEvaluator]"
          + statusMessagePrefix() + " Error processing prediction "
          + ex.getMessage());
        m_logger.statusMessage(statusMessagePrefix()
          + "ERROR: problem processing prediction (see log for details)");
      }
      ex.printStackTrace();
      stop();
    }
  }

  /**
   * Returns true, if at the current time, the named event could be generated.
   * Assumes that supplied event names are names of events that could be
   * generated by this bean.
   * 
   * @param eventName the name of the event in question
   * @return true if the named event could be generated at this point in time
   */
  @Override
  public boolean eventGeneratable(String eventName) {
    if (m_listenee == null) {
      return false;
    }

    if (m_listenee instanceof EventConstraints) {
      if (!((EventConstraints) m_listenee)
        .eventGeneratable("incrementalClassifier")) {
        return false;
      }
    }
    return true;
  }

  /**
   * Stop all action
   */
  @Override
  public void stop() {
    // tell the listenee (upstream bean) to stop
    if (m_listenee instanceof BeanCommon) {
      // System.err.println("Listener is BeanCommon");
      ((BeanCommon) m_listenee).stop();
    }
  }

  /**
   * Returns true if. at this time, the bean is busy with some (i.e. perhaps a
   * worker thread is performing some calculation).
   * 
   * @return true if the bean is busy.
   */
  @Override
  public boolean isBusy() {
    return false;
  }

  @SuppressWarnings("unchecked")
  private void notifyChartListeners(ChartEvent ce) {
    Vector l;
    synchronized (this) {
      l = (Vector) m_listeners.clone();
    }
    if (l.size() > 0) {
      for (int i = 0; i < l.size(); i++) {
        l.elementAt(i).acceptDataPoint(ce);
      }
    }
  }

  /**
   * Notify all text listeners of a TextEvent
   * 
   * @param te a TextEvent value
   */
  @SuppressWarnings("unchecked")
  private void notifyTextListeners(TextEvent te) {
    Vector l;
    synchronized (this) {
      l = (Vector) m_textListeners.clone();
    }
    if (l.size() > 0) {
      for (int i = 0; i < l.size(); i++) {
        // System.err.println("Notifying text listeners "
        // +"(ClassifierPerformanceEvaluator)");
        l.elementAt(i).acceptText(te);
      }
    }
  }

  /**
   * Set how often progress is reported to the status bar.
   * 
   * @param s report progress every s instances
   */
  public void setStatusFrequency(int s) {
    m_statusFrequency = s;
  }

  /**
   * Get how often progress is reported to the status bar.
   * 
   * @return after how many instances, progress is reported to the status bar
   */
  public int getStatusFrequency() {
    return m_statusFrequency;
  }

  /**
   * Return a tip text string for this property
   * 
   * @return a string for the tip text
   */
  public String statusFrequencyTipText() {
    return "How often to report progress to the status bar.";
  }

  /**
   * Set whether to output per-class information retrieval statistics (nominal
   * class only).
   * 
   * @param i true if info retrieval stats are to be output
   */
  public void setOutputPerClassInfoRetrievalStats(boolean i) {
    m_outputInfoRetrievalStats = i;
  }

  /**
   * Get whether per-class information retrieval stats are to be output.
   * 
   * @return true if info retrieval stats are to be output
   */
  public boolean getOutputPerClassInfoRetrievalStats() {
    return m_outputInfoRetrievalStats;
  }

  /**
   * Return a tip text string for this property
   * 
   * @return a string for the tip text
   */
  public String outputPerClassInfoRetrievalStatsTipText() {
    return "Output per-class info retrieval stats. If set to true, predictions get "
      + "stored so that stats such as AUC can be computed. Note: this consumes some memory.";
  }

  /**
   * Set whether to compute evaluation for charting over a fixed sized window of
   * the most recent instances (rather than the whole stream).
   * 
   * @param windowSize the size of the window to use for computing the
   *          evaluation metrics used for charting. Setting a value of zero or
   *          less specifies that no windowing is to be used.
   */
  public void setChartingEvalWindowSize(int windowSize) {
    m_windowSize = windowSize;
  }

  /**
   * Get whether to compute evaluation for charting over a fixed sized window of
   * the most recent instances (rather than the whole stream).
   * 
   * @return the size of the window to use for computing the evaluation metrics
   *         used for charting. Setting a value of zero or less specifies that
   *         no windowing is to be used.
   */
  public int getChartingEvalWindowSize() {
    return m_windowSize;
  }

  /**
   * Return a tip text string for this property
   * 
   * @return a string for the tip text
   */
  public String chartingEvalWindowSizeTipText() {
    return "For charting only, specify a sliding window size over which to compute "
      + "performance stats. <= 0 means eval on whole stream";
  }

  /**
   * Add a chart listener
   * 
   * @param cl a ChartListener value
   */
  public synchronized void addChartListener(ChartListener cl) {
    m_listeners.addElement(cl);
  }

  /**
   * Remove a chart listener
   * 
   * @param cl a ChartListener value
   */
  public synchronized void removeChartListener(ChartListener cl) {
    m_listeners.remove(cl);
  }

  /**
   * Add a text listener
   * 
   * @param cl a TextListener value
   */
  public synchronized void addTextListener(TextListener cl) {
    m_textListeners.addElement(cl);
  }

  /**
   * Remove a text listener
   * 
   * @param cl a TextListener value
   */
  public synchronized void removeTextListener(TextListener cl) {
    m_textListeners.remove(cl);
  }

  private String statusMessagePrefix() {
    return getCustomName() + "$" + hashCode() + "|";
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy