All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.gui.beans.CostBenefitAnalysis Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This version represents the developer version, the "bleeding edge" of development, you could say. New functionality gets added to this version.

There is a newer version: 3.9.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    CostBenefitAnalysis.java
 *    Copyright (C) 2009-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.gui.beans;

import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.FlowLayout;
import java.awt.Graphics;
import java.awt.GraphicsEnvironment;
import java.awt.GridLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.FocusEvent;
import java.awt.event.FocusListener;
import java.beans.EventSetDescriptor;
import java.beans.PropertyVetoException;
import java.beans.VetoableChangeListener;
import java.beans.beancontext.BeanContext;
import java.beans.beancontext.BeanContextChild;
import java.beans.beancontext.BeanContextChildSupport;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.EventObject;
import java.util.List;
import java.util.Vector;

import javax.swing.BorderFactory;
import javax.swing.ButtonGroup;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JRadioButton;
import javax.swing.JSlider;
import javax.swing.JTextField;
import javax.swing.SwingConstants;
import javax.swing.event.ChangeEvent;
import javax.swing.event.ChangeListener;

import weka.classifiers.evaluation.Prediction;
import weka.classifiers.evaluation.ThresholdCurve;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.gui.Logger;
import weka.gui.visualize.PlotData2D;
import weka.gui.visualize.VisualizePanel;

/**
 * Bean that aids in analyzing cost/benefit tradeoffs.
 * 
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision: 10220 $
 */
@KFStep(category = "Visualize", toolTipText = "Interactive cost/benefit analysis")
public class CostBenefitAnalysis extends JPanel implements BeanCommon,
  ThresholdDataListener, Visible, UserRequestAcceptor, Serializable,
  BeanContextChild, HeadlessEventCollector {

  /** For serialization */
  private static final long serialVersionUID = 8647471654613320469L;

  protected BeanVisual m_visual = new BeanVisual("CostBenefitAnalysis",
    BeanVisual.ICON_PATH + "ModelPerformanceChart.gif", BeanVisual.ICON_PATH
      + "ModelPerformanceChart_animated.gif");

  protected transient JFrame m_popupFrame;

  protected boolean m_framePoppedUp = false;

  private transient AnalysisPanel m_analysisPanel;

  /**
   * True if this bean's appearance is the design mode appearance
   */
  protected boolean m_design;

  /**
   * BeanContex that this bean might be contained within
   */
  protected transient BeanContext m_beanContext = null;

  /**
   * BeanContextChild support
   */
  protected BeanContextChildSupport m_bcSupport = new BeanContextChildSupport(
    this);

  /**
   * The object sending us data (we allow only one connection at any one time)
   */
  protected Object m_listenee;

  protected List m_headlessEvents;

  /**
   * Inner class for displaying the plots and all control widgets.
   * 
   * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
   */
  protected static class AnalysisPanel extends JPanel {

    /** For serialization */
    private static final long serialVersionUID = 5364871945448769003L;

    /** Displays the performance graphs(s) */
    protected VisualizePanel m_performancePanel = new VisualizePanel();

    /** Displays the cost/benefit (profit/loss) graph */
    protected VisualizePanel m_costBenefitPanel = new VisualizePanel();

    /**
     * The class attribute from the data that was used to generate the threshold
     * curve
     */
    protected Attribute m_classAttribute;

    /** Data for the threshold curve */
    protected PlotData2D m_masterPlot;

    /** Data for the cost/benefit curve */
    protected PlotData2D m_costBenefit;

    /** The size of the points being plotted */
    protected int[] m_shapeSizes;

    /** The index of the previous plotted point that was highlighted */
    protected int m_previousShapeIndex = -1;

    /** The slider for adjusting the threshold */
    protected JSlider m_thresholdSlider = new JSlider(0, 100, 0);

    protected JRadioButton m_percPop = new JRadioButton("% of Population");
    protected JRadioButton m_percOfTarget = new JRadioButton(
      "% of Target (recall)");
    protected JRadioButton m_threshold = new JRadioButton("Score Threshold");

    protected JLabel m_percPopLab = new JLabel();
    protected JLabel m_percOfTargetLab = new JLabel();
    protected JLabel m_thresholdLab = new JLabel();

    // Confusion matrix stuff
    protected JLabel m_conf_predictedA = new JLabel("Predicted (a)",
      SwingConstants.RIGHT);
    protected JLabel m_conf_predictedB = new JLabel("Predicted (b)",
      SwingConstants.RIGHT);
    protected JLabel m_conf_actualA = new JLabel(" Actual (a):");
    protected JLabel m_conf_actualB = new JLabel(" Actual (b):");
    protected ConfusionCell m_conf_aa = new ConfusionCell();
    protected ConfusionCell m_conf_ab = new ConfusionCell();
    protected ConfusionCell m_conf_ba = new ConfusionCell();
    protected ConfusionCell m_conf_bb = new ConfusionCell();

    // Cost matrix stuff
    protected JLabel m_cost_predictedA = new JLabel("Predicted (a)",
      SwingConstants.RIGHT);
    protected JLabel m_cost_predictedB = new JLabel("Predicted (b)",
      SwingConstants.RIGHT);
    protected JLabel m_cost_actualA = new JLabel(" Actual (a)");
    protected JLabel m_cost_actualB = new JLabel(" Actual (b)");
    protected JTextField m_cost_aa = new JTextField("0.0", 5);
    protected JTextField m_cost_ab = new JTextField("1.0", 5);
    protected JTextField m_cost_ba = new JTextField("1.0", 5);
    protected JTextField m_cost_bb = new JTextField("0.0", 5);
    protected JButton m_maximizeCB = new JButton("Maximize Cost/Benefit");
    protected JButton m_minimizeCB = new JButton("Minimize Cost/Benefit");
    protected JRadioButton m_costR = new JRadioButton("Cost");
    protected JRadioButton m_benefitR = new JRadioButton("Benefit");
    protected JLabel m_costBenefitL = new JLabel("Cost: ", SwingConstants.RIGHT);
    protected JLabel m_costBenefitV = new JLabel("0");
    protected JLabel m_randomV = new JLabel("0");
    protected JLabel m_gainV = new JLabel("0");

    protected int m_originalPopSize;

    /** Population text field */
    protected JTextField m_totalPopField = new JTextField(6);
    protected int m_totalPopPrevious;

    /** Classification accuracy */
    protected JLabel m_classificationAccV = new JLabel("-");

    // Only update curve & stats if values in cost matrix have changed
    protected double m_tpPrevious;
    protected double m_fpPrevious;
    protected double m_tnPrevious;
    protected double m_fnPrevious;

    /**
     * Inner class for handling a single cell in the confusion matrix. Displays
     * the value, value as a percentage of total population and graphical
     * depiction of percentage.
     * 
     * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
     */
    protected static class ConfusionCell extends JPanel {

      /** For serialization */
      private static final long serialVersionUID = 6148640235434494767L;

      private final JLabel m_conf_cell = new JLabel("-", SwingConstants.RIGHT);
      JLabel m_conf_perc = new JLabel("-", SwingConstants.RIGHT);

      private final JPanel m_percentageP;

      protected double m_percentage = 0;

      @SuppressWarnings("serial")
      public ConfusionCell() {
        setLayout(new BorderLayout());
        setBorder(BorderFactory.createEtchedBorder());

        add(m_conf_cell, BorderLayout.NORTH);

        m_percentageP = new JPanel() {
          @Override
          public void paintComponent(Graphics gx) {
            super.paintComponent(gx);

            if (m_percentage > 0) {
              gx.setColor(Color.BLUE);
              int height = this.getHeight();
              double width = this.getWidth();
              int barWidth = (int) (m_percentage * width);
              gx.fillRect(0, 0, barWidth, height);
            }
          }
        };

        Dimension d = new Dimension(30, 5);
        m_percentageP.setMinimumSize(d);
        m_percentageP.setPreferredSize(d);
        JPanel percHolder = new JPanel();
        percHolder.setLayout(new BorderLayout());
        percHolder.add(m_percentageP, BorderLayout.CENTER);
        percHolder.add(m_conf_perc, BorderLayout.EAST);

        add(percHolder, BorderLayout.SOUTH);
      }

      /**
       * Set the value of a cell.
       * 
       * @param cellValue the value of the cell
       * @param max the max (for setting value as a percentage)
       * @param scaleFactor scale the value by this amount
       * @param precision precision for the percentage value
       */
      public void setCellValue(double cellValue, double max,
        double scaleFactor, int precision) {
        if (!Utils.isMissingValue(cellValue)) {
          m_percentage = cellValue / max;
        } else {
          m_percentage = 0;
        }

        m_conf_cell.setText(Utils.doubleToString((cellValue * scaleFactor), 0));
        m_conf_perc.setText(Utils.doubleToString(m_percentage * 100.0,
          precision) + "%");

        // refresh the percentage bar
        m_percentageP.repaint();
      }
    }

    public AnalysisPanel() {
      setLayout(new BorderLayout());
      m_performancePanel.setShowAttBars(false);
      m_performancePanel.setShowClassPanel(false);
      m_costBenefitPanel.setShowAttBars(false);
      m_costBenefitPanel.setShowClassPanel(false);

      Dimension size = new Dimension(500, 400);
      m_performancePanel.setPreferredSize(size);
      m_performancePanel.setMinimumSize(size);

      size = new Dimension(500, 400);
      m_costBenefitPanel.setMinimumSize(size);
      m_costBenefitPanel.setPreferredSize(size);

      m_thresholdSlider.addChangeListener(new ChangeListener() {
        @Override
        public void stateChanged(ChangeEvent e) {
          updateInfoForSliderValue(m_thresholdSlider.getValue() / 100.0);
        }
      });

      JPanel plotHolder = new JPanel();
      plotHolder.setLayout(new GridLayout(1, 2));
      plotHolder.add(m_performancePanel);
      plotHolder.add(m_costBenefitPanel);
      add(plotHolder, BorderLayout.CENTER);

      JPanel lowerPanel = new JPanel();
      lowerPanel.setLayout(new BorderLayout());

      ButtonGroup bGroup = new ButtonGroup();
      bGroup.add(m_percPop);
      bGroup.add(m_percOfTarget);
      bGroup.add(m_threshold);

      ButtonGroup bGroup2 = new ButtonGroup();
      bGroup2.add(m_costR);
      bGroup2.add(m_benefitR);
      ActionListener rl = new ActionListener() {
        @Override
        public void actionPerformed(ActionEvent e) {
          if (m_costR.isSelected()) {
            m_costBenefitL.setText("Cost: ");
          } else {
            m_costBenefitL.setText("Benefit: ");
          }

          double gain = Double.parseDouble(m_gainV.getText());
          gain = -gain;
          m_gainV.setText(Utils.doubleToString(gain, 2));
        }
      };
      m_costR.addActionListener(rl);
      m_benefitR.addActionListener(rl);
      m_costR.setSelected(true);

      m_percPop.setSelected(true);
      JPanel threshPanel = new JPanel();
      threshPanel.setLayout(new BorderLayout());
      JPanel radioHolder = new JPanel();
      radioHolder.setLayout(new FlowLayout());
      radioHolder.add(m_percPop);
      radioHolder.add(m_percOfTarget);
      radioHolder.add(m_threshold);
      threshPanel.add(radioHolder, BorderLayout.NORTH);
      threshPanel.add(m_thresholdSlider, BorderLayout.SOUTH);

      JPanel threshInfoPanel = new JPanel();
      threshInfoPanel.setLayout(new GridLayout(3, 2));
      threshInfoPanel
        .add(new JLabel("% of Population: ", SwingConstants.RIGHT));
      threshInfoPanel.add(m_percPopLab);
      threshInfoPanel.add(new JLabel("% of Target: ", SwingConstants.RIGHT));
      threshInfoPanel.add(m_percOfTargetLab);
      threshInfoPanel
        .add(new JLabel("Score Threshold: ", SwingConstants.RIGHT));
      threshInfoPanel.add(m_thresholdLab);

      JPanel threshHolder = new JPanel();
      threshHolder.setBorder(BorderFactory.createTitledBorder("Threshold"));
      threshHolder.setLayout(new BorderLayout());
      threshHolder.add(threshPanel, BorderLayout.CENTER);
      threshHolder.add(threshInfoPanel, BorderLayout.EAST);

      lowerPanel.add(threshHolder, BorderLayout.NORTH);

      // holder for the two matrixes
      JPanel matrixHolder = new JPanel();
      matrixHolder.setLayout(new GridLayout(1, 2));

      // confusion matrix
      JPanel confusionPanel = new JPanel();
      confusionPanel.setLayout(new GridLayout(3, 3));
      confusionPanel.add(m_conf_predictedA);
      confusionPanel.add(m_conf_predictedB);
      confusionPanel.add(new JLabel()); // dummy
      confusionPanel.add(m_conf_aa);
      confusionPanel.add(m_conf_ab);
      confusionPanel.add(m_conf_actualA);
      confusionPanel.add(m_conf_ba);
      confusionPanel.add(m_conf_bb);
      confusionPanel.add(m_conf_actualB);
      JPanel tempHolderCA = new JPanel();
      tempHolderCA.setLayout(new BorderLayout());
      tempHolderCA.setBorder(BorderFactory
        .createTitledBorder("Confusion Matrix"));
      tempHolderCA.add(confusionPanel, BorderLayout.CENTER);

      JPanel accHolder = new JPanel();
      accHolder.setLayout(new FlowLayout(FlowLayout.LEFT));
      accHolder.add(new JLabel("Classification Accuracy: "));
      accHolder.add(m_classificationAccV);
      tempHolderCA.add(accHolder, BorderLayout.SOUTH);

      matrixHolder.add(tempHolderCA);

      // cost matrix
      JPanel costPanel = new JPanel();
      costPanel.setBorder(BorderFactory.createTitledBorder("Cost Matrix"));
      costPanel.setLayout(new BorderLayout());

      JPanel cmHolder = new JPanel();
      cmHolder.setLayout(new GridLayout(3, 3));
      cmHolder.add(m_cost_predictedA);
      cmHolder.add(m_cost_predictedB);
      cmHolder.add(new JLabel()); // dummy
      cmHolder.add(m_cost_aa);
      cmHolder.add(m_cost_ab);
      cmHolder.add(m_cost_actualA);
      cmHolder.add(m_cost_ba);
      cmHolder.add(m_cost_bb);
      cmHolder.add(m_cost_actualB);
      costPanel.add(cmHolder, BorderLayout.CENTER);

      FocusListener fl = new FocusListener() {
        @Override
        public void focusGained(FocusEvent e) {

        }

        @Override
        public void focusLost(FocusEvent e) {
          if (constructCostBenefitData()) {
            try {
              m_costBenefitPanel.setMasterPlot(m_costBenefit);
              m_costBenefitPanel.validate();
              m_costBenefitPanel.repaint();
            } catch (Exception ex) {
              ex.printStackTrace();
            }
            updateCostBenefit();
          }
        }
      };

      ActionListener al = new ActionListener() {
        @Override
        public void actionPerformed(ActionEvent e) {
          if (constructCostBenefitData()) {
            try {
              m_costBenefitPanel.setMasterPlot(m_costBenefit);
              m_costBenefitPanel.validate();
              m_costBenefitPanel.repaint();
            } catch (Exception ex) {
              ex.printStackTrace();
            }
            updateCostBenefit();
          }
        }
      };

      m_cost_aa.addFocusListener(fl);
      m_cost_aa.addActionListener(al);
      m_cost_ab.addFocusListener(fl);
      m_cost_ab.addActionListener(al);
      m_cost_ba.addFocusListener(fl);
      m_cost_ba.addActionListener(al);
      m_cost_bb.addFocusListener(fl);
      m_cost_bb.addActionListener(al);

      m_totalPopField.addFocusListener(fl);
      m_totalPopField.addActionListener(al);

      JPanel cbHolder = new JPanel();
      cbHolder.setLayout(new BorderLayout());
      JPanel tempP = new JPanel();
      tempP.setLayout(new GridLayout(3, 2));
      tempP.add(m_costBenefitL);
      tempP.add(m_costBenefitV);
      tempP.add(new JLabel("Random: ", SwingConstants.RIGHT));
      tempP.add(m_randomV);
      tempP.add(new JLabel("Gain: ", SwingConstants.RIGHT));
      tempP.add(m_gainV);
      cbHolder.add(tempP, BorderLayout.NORTH);
      JPanel butHolder = new JPanel();
      butHolder.setLayout(new GridLayout(2, 1));
      butHolder.add(m_maximizeCB);
      butHolder.add(m_minimizeCB);
      m_maximizeCB.addActionListener(new ActionListener() {
        @Override
        public void actionPerformed(ActionEvent e) {
          findMaxMinCB(true);
        }
      });

      m_minimizeCB.addActionListener(new ActionListener() {
        @Override
        public void actionPerformed(ActionEvent e) {
          findMaxMinCB(false);
        }
      });

      cbHolder.add(butHolder, BorderLayout.SOUTH);
      costPanel.add(cbHolder, BorderLayout.EAST);

      JPanel popCBR = new JPanel();
      popCBR.setLayout(new GridLayout(1, 2));
      JPanel popHolder = new JPanel();
      popHolder.setLayout(new FlowLayout(FlowLayout.LEFT));
      popHolder.add(new JLabel("Total Population: "));
      popHolder.add(m_totalPopField);

      JPanel radioHolder2 = new JPanel();
      radioHolder2.setLayout(new FlowLayout(FlowLayout.RIGHT));
      radioHolder2.add(m_costR);
      radioHolder2.add(m_benefitR);
      popCBR.add(popHolder);
      popCBR.add(radioHolder2);

      costPanel.add(popCBR, BorderLayout.SOUTH);

      matrixHolder.add(costPanel);

      lowerPanel.add(matrixHolder, BorderLayout.SOUTH);

      // popAccHolder.add(popHolder);

      // popAccHolder.add(accHolder);

      /*
       * JPanel lowerPanel2 = new JPanel(); lowerPanel2.setLayout(new
       * BorderLayout()); lowerPanel2.add(lowerPanel, BorderLayout.NORTH);
       * lowerPanel2.add(popAccHolder, BorderLayout.SOUTH);
       */

      add(lowerPanel, BorderLayout.SOUTH);

    }

    private void findMaxMinCB(boolean max) {
      double maxMin = (max) ? Double.NEGATIVE_INFINITY
        : Double.POSITIVE_INFINITY;

      Instances cBCurve = m_costBenefit.getPlotInstances();
      int maxMinIndex = 0;

      for (int i = 0; i < cBCurve.numInstances(); i++) {
        Instance current = cBCurve.instance(i);
        if (max) {
          if (current.value(1) > maxMin) {
            maxMin = current.value(1);
            maxMinIndex = i;
          }
        } else {
          if (current.value(1) < maxMin) {
            maxMin = current.value(1);
            maxMinIndex = i;
          }
        }
      }

      // set the slider to the correct position
      int indexOfSampleSize = m_masterPlot.getPlotInstances()
        .attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index();
      int indexOfPercOfTarget = m_masterPlot.getPlotInstances()
        .attribute(ThresholdCurve.RECALL_NAME).index();
      int indexOfThreshold = m_masterPlot.getPlotInstances()
        .attribute(ThresholdCurve.THRESHOLD_NAME).index();
      int indexOfMetric;

      if (m_percPop.isSelected()) {
        indexOfMetric = indexOfSampleSize;
      } else if (m_percOfTarget.isSelected()) {
        indexOfMetric = indexOfPercOfTarget;
      } else {
        indexOfMetric = indexOfThreshold;
      }

      double valueOfMetric = m_masterPlot.getPlotInstances()
        .instance(maxMinIndex).value(indexOfMetric);
      valueOfMetric *= 100.0;

      // set the approximate location of the slider
      m_thresholdSlider.setValue((int) valueOfMetric);

      // make sure the actual values relate to the true min/max rather
      // than being off due to slider location error.
      updateInfoGivenIndex(maxMinIndex);
    }

    private void updateCostBenefit() {
      double value = m_thresholdSlider.getValue() / 100.0;
      Instances plotInstances = m_masterPlot.getPlotInstances();
      int indexOfSampleSize = m_masterPlot.getPlotInstances()
        .attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index();
      int indexOfPercOfTarget = m_masterPlot.getPlotInstances()
        .attribute(ThresholdCurve.RECALL_NAME).index();
      int indexOfThreshold = m_masterPlot.getPlotInstances()
        .attribute(ThresholdCurve.THRESHOLD_NAME).index();
      int indexOfMetric;

      if (m_percPop.isSelected()) {
        indexOfMetric = indexOfSampleSize;
      } else if (m_percOfTarget.isSelected()) {
        indexOfMetric = indexOfPercOfTarget;
      } else {
        indexOfMetric = indexOfThreshold;
      }

      int index = findIndexForValue(value, plotInstances, indexOfMetric);
      updateCBRandomGainInfo(index);
    }

    private void updateCBRandomGainInfo(int index) {
      double requestedPopSize = m_originalPopSize;
      try {
        requestedPopSize = Double.parseDouble(m_totalPopField.getText());
      } catch (NumberFormatException e) {
      }
      double scaleFactor = requestedPopSize / m_originalPopSize;

      double CB = m_costBenefit.getPlotInstances().instance(index).value(1);
      m_costBenefitV.setText(Utils.doubleToString(CB, 2));

      double totalRandomCB = 0.0;
      Instance first = m_masterPlot.getPlotInstances().instance(0);
      double totalPos = first.value(m_masterPlot.getPlotInstances()
        .attribute(ThresholdCurve.TRUE_POS_NAME).index())
        * scaleFactor;
      double totalNeg = first.value(m_masterPlot.getPlotInstances().attribute(
        ThresholdCurve.FALSE_POS_NAME))
        * scaleFactor;

      double posInSample = (totalPos * (Double.parseDouble(m_percPopLab
        .getText()) / 100.0));
      double negInSample = (totalNeg * (Double.parseDouble(m_percPopLab
        .getText()) / 100.0));
      double posOutSample = totalPos - posInSample;
      double negOutSample = totalNeg - negInSample;

      double tpCost = 0.0;
      try {
        tpCost = Double.parseDouble(m_cost_aa.getText());
      } catch (NumberFormatException n) {
      }
      double fpCost = 0.0;
      try {
        fpCost = Double.parseDouble(m_cost_ba.getText());
      } catch (NumberFormatException n) {
      }
      double tnCost = 0.0;
      try {
        tnCost = Double.parseDouble(m_cost_bb.getText());
      } catch (NumberFormatException n) {
      }
      double fnCost = 0.0;
      try {
        fnCost = Double.parseDouble(m_cost_ab.getText());
      } catch (NumberFormatException n) {
      }

      totalRandomCB += posInSample * tpCost;
      totalRandomCB += negInSample * fpCost;
      totalRandomCB += posOutSample * fnCost;
      totalRandomCB += negOutSample * tnCost;

      m_randomV.setText(Utils.doubleToString(totalRandomCB, 2));
      double gain = (m_costR.isSelected()) ? totalRandomCB - CB : CB
        - totalRandomCB;
      m_gainV.setText(Utils.doubleToString(gain, 2));

      // update classification rate
      Instance currentInst = m_masterPlot.getPlotInstances().instance(index);
      double tp = currentInst.value(m_masterPlot.getPlotInstances()
        .attribute(ThresholdCurve.TRUE_POS_NAME).index());
      double tn = currentInst.value(m_masterPlot.getPlotInstances()
        .attribute(ThresholdCurve.TRUE_NEG_NAME).index());
      m_classificationAccV.setText(Utils.doubleToString((tp + tn)
        / (totalPos + totalNeg) * 100.0, 4)
        + "%");
    }

    private void updateInfoGivenIndex(int index) {
      Instances plotInstances = m_masterPlot.getPlotInstances();
      int indexOfSampleSize = m_masterPlot.getPlotInstances()
        .attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index();
      int indexOfPercOfTarget = m_masterPlot.getPlotInstances()
        .attribute(ThresholdCurve.RECALL_NAME).index();
      int indexOfThreshold = m_masterPlot.getPlotInstances()
        .attribute(ThresholdCurve.THRESHOLD_NAME).index();

      // update labels
      m_percPopLab.setText(Utils.doubleToString(
        100.0 * plotInstances.instance(index).value(indexOfSampleSize), 4));
      m_percOfTargetLab.setText(Utils.doubleToString(100.0 * plotInstances
        .instance(index).value(indexOfPercOfTarget), 4));
      m_thresholdLab.setText(Utils.doubleToString(plotInstances.instance(index)
        .value(indexOfThreshold), 4));
      /*
       * if (m_percPop.isSelected()) {
       * m_percPopLab.setText(Utils.doubleToString(100.0 * value, 4)); } else if
       * (m_percOfTarget.isSelected()) {
       * m_percOfTargetLab.setText(Utils.doubleToString(100.0 * value, 4)); }
       * else { m_thresholdLab.setText(Utils.doubleToString(value, 4)); }
       */

      // Update the highlighted point on the graphs */
      if (m_previousShapeIndex >= 0) {
        m_shapeSizes[m_previousShapeIndex] = 1;
      }

      m_shapeSizes[index] = 10;
      m_previousShapeIndex = index;

      // Update the confusion matrix
      // double totalInstances =
      int tp = plotInstances.attribute(ThresholdCurve.TRUE_POS_NAME).index();
      int fp = plotInstances.attribute(ThresholdCurve.FALSE_POS_NAME).index();
      int tn = plotInstances.attribute(ThresholdCurve.TRUE_NEG_NAME).index();
      int fn = plotInstances.attribute(ThresholdCurve.FALSE_NEG_NAME).index();
      Instance temp = plotInstances.instance(index);
      double totalInstances = temp.value(tp) + temp.value(fp) + temp.value(tn)
        + temp.value(fn);
      // get the value out of the total pop field (if possible)
      double requestedPopSize = totalInstances;
      try {
        requestedPopSize = Double.parseDouble(m_totalPopField.getText());
      } catch (NumberFormatException e) {
      }

      m_conf_aa.setCellValue(temp.value(tp), totalInstances, requestedPopSize
        / totalInstances, 2);
      m_conf_ab.setCellValue(temp.value(fn), totalInstances, requestedPopSize
        / totalInstances, 2);
      m_conf_ba.setCellValue(temp.value(fp), totalInstances, requestedPopSize
        / totalInstances, 2);
      m_conf_bb.setCellValue(temp.value(tn), totalInstances, requestedPopSize
        / totalInstances, 2);

      updateCBRandomGainInfo(index);

      repaint();
    }

    private void updateInfoForSliderValue(double value) {
      int indexOfSampleSize = m_masterPlot.getPlotInstances()
        .attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index();
      int indexOfPercOfTarget = m_masterPlot.getPlotInstances()
        .attribute(ThresholdCurve.RECALL_NAME).index();
      int indexOfThreshold = m_masterPlot.getPlotInstances()
        .attribute(ThresholdCurve.THRESHOLD_NAME).index();
      int indexOfMetric;

      if (m_percPop.isSelected()) {
        indexOfMetric = indexOfSampleSize;
      } else if (m_percOfTarget.isSelected()) {
        indexOfMetric = indexOfPercOfTarget;
      } else {
        indexOfMetric = indexOfThreshold;
      }

      Instances plotInstances = m_masterPlot.getPlotInstances();
      int index = findIndexForValue(value, plotInstances, indexOfMetric);
      updateInfoGivenIndex(index);
    }

    private int findIndexForValue(double value, Instances plotInstances,
      int indexOfMetric) {
      // binary search
      // threshold curve is sorted ascending in the threshold (thus
      // descending for recall and pop size)
      int index = -1;
      int lower = 0;
      int upper = plotInstances.numInstances() - 1;
      int mid = (upper - lower) / 2;
      boolean done = false;
      while (!done) {
        if (upper - lower <= 1) {

          // choose the one closest to the value
          double comp1 = plotInstances.instance(upper).value(indexOfMetric);
          double comp2 = plotInstances.instance(lower).value(indexOfMetric);
          if (Math.abs(comp1 - value) < Math.abs(comp2 - value)) {
            index = upper;
          } else {
            index = lower;
          }

          break;
        }
        double comparisonVal = plotInstances.instance(mid).value(indexOfMetric);
        if (value > comparisonVal) {
          if (m_threshold.isSelected()) {
            lower = mid;
            mid += (upper - lower) / 2;
          } else {
            upper = mid;
            mid -= (upper - lower) / 2;
          }
        } else if (value < comparisonVal) {
          if (m_threshold.isSelected()) {
            upper = mid;
            mid -= (upper - lower) / 2;
          } else {
            lower = mid;
            mid += (upper - lower) / 2;
          }
        } else {
          index = mid;
          done = true;
        }
      }

      // now check for ties in the appropriate direction
      if (!m_threshold.isSelected()) {
        while (index + 1 < plotInstances.numInstances()) {
          if (plotInstances.instance(index + 1).value(indexOfMetric) == plotInstances
            .instance(index).value(indexOfMetric)) {
            index++;
          } else {
            break;
          }
        }
      } else {
        while (index - 1 >= 0) {
          if (plotInstances.instance(index - 1).value(indexOfMetric) == plotInstances
            .instance(index).value(indexOfMetric)) {
            index--;
          } else {
            break;
          }
        }
      }
      return index;
    }

    /**
     * Set the threshold data for the panel to use.
     * 
     * @param data PlotData2D object encapsulating the threshold data.
     * @param classAtt the class attribute from the original data used to
     *          generate the threshold data.
     * @throws Exception if something goes wrong.
     */
    public synchronized void setDataSet(PlotData2D data, Attribute classAtt)
      throws Exception {
      // make a copy of the PlotData2D object
      m_masterPlot = new PlotData2D(data.getPlotInstances());
      boolean[] connectPoints = new boolean[m_masterPlot.getPlotInstances()
        .numInstances()];
      for (int i = 1; i < connectPoints.length; i++) {
        connectPoints[i] = true;
      }
      m_masterPlot.setConnectPoints(connectPoints);

      m_masterPlot.m_alwaysDisplayPointsOfThisSize = 10;
      setClassForConfusionMatrix(classAtt);
      m_performancePanel.setMasterPlot(m_masterPlot);
      m_performancePanel.validate();
      m_performancePanel.repaint();

      m_shapeSizes = new int[m_masterPlot.getPlotInstances().numInstances()];
      for (int i = 0; i < m_shapeSizes.length; i++) {
        m_shapeSizes[i] = 1;
      }
      m_masterPlot.setShapeSize(m_shapeSizes);
      constructCostBenefitData();
      m_costBenefitPanel.setMasterPlot(m_costBenefit);
      m_costBenefitPanel.validate();
      m_costBenefitPanel.repaint();

      m_totalPopPrevious = 0;
      m_fpPrevious = 0;
      m_tpPrevious = 0;
      m_tnPrevious = 0;
      m_fnPrevious = 0;
      m_previousShapeIndex = -1;

      // set the total population size
      Instance first = m_masterPlot.getPlotInstances().instance(0);
      double totalPos = first.value(m_masterPlot.getPlotInstances()
        .attribute(ThresholdCurve.TRUE_POS_NAME).index());
      double totalNeg = first.value(m_masterPlot.getPlotInstances().attribute(
        ThresholdCurve.FALSE_POS_NAME));
      m_originalPopSize = (int) (totalPos + totalNeg);
      m_totalPopField.setText("" + m_originalPopSize);

      m_performancePanel.setYIndex(5);
      m_performancePanel.setXIndex(10);
      m_costBenefitPanel.setXIndex(0);
      m_costBenefitPanel.setYIndex(1);
      // System.err.println(m_masterPlot.getPlotInstances());
      updateInfoForSliderValue(m_thresholdSlider.getValue() / 100.0);
    }

    private void setClassForConfusionMatrix(Attribute classAtt) {
      m_classAttribute = classAtt;
      m_conf_actualA.setText(" Actual (a): " + classAtt.value(0));
      m_conf_actualA.setToolTipText(classAtt.value(0));
      String negClasses = "";
      for (int i = 1; i < classAtt.numValues(); i++) {
        negClasses += classAtt.value(i);
        if (i < classAtt.numValues() - 1) {
          negClasses += ",";
        }
      }
      m_conf_actualB.setText(" Actual (b): " + negClasses);
      m_conf_actualB.setToolTipText(negClasses);
    }

    private boolean constructCostBenefitData() {
      double tpCost = 0.0;
      try {
        tpCost = Double.parseDouble(m_cost_aa.getText());
      } catch (NumberFormatException n) {
      }
      double fpCost = 0.0;
      try {
        fpCost = Double.parseDouble(m_cost_ba.getText());
      } catch (NumberFormatException n) {
      }
      double tnCost = 0.0;
      try {
        tnCost = Double.parseDouble(m_cost_bb.getText());
      } catch (NumberFormatException n) {
      }
      double fnCost = 0.0;
      try {
        fnCost = Double.parseDouble(m_cost_ab.getText());
      } catch (NumberFormatException n) {
      }

      double requestedPopSize = m_originalPopSize;
      try {
        requestedPopSize = Double.parseDouble(m_totalPopField.getText());
      } catch (NumberFormatException e) {
      }

      double scaleFactor = 1.0;
      if (m_originalPopSize != 0) {
        scaleFactor = requestedPopSize / m_originalPopSize;
      }

      if (tpCost == m_tpPrevious && fpCost == m_fpPrevious
        && tnCost == m_tnPrevious && fnCost == m_fnPrevious
        && requestedPopSize == m_totalPopPrevious) {
        return false;
      }

      // First construct some Instances for the curve
      ArrayList fv = new ArrayList();
      fv.add(new Attribute("Sample Size"));
      fv.add(new Attribute("Cost/Benefit"));
      fv.add(new Attribute("Threshold"));
      Instances costBenefitI = new Instances("Cost/Benefit Curve", fv, 100);

      // process the performance data to make this curve
      Instances performanceI = m_masterPlot.getPlotInstances();

      for (int i = 0; i < performanceI.numInstances(); i++) {
        Instance current = performanceI.instance(i);

        double[] vals = new double[3];
        vals[0] = current.value(10); // sample size
        vals[1] = (current.value(0) * tpCost + current.value(1) * fnCost
          + current.value(2) * fpCost + current.value(3) * tnCost)
          * scaleFactor;
        vals[2] = current.value(current.numAttributes() - 1);
        Instance newInst = new DenseInstance(1.0, vals);
        costBenefitI.add(newInst);
      }

      costBenefitI.compactify();

      // now set up the plot data
      m_costBenefit = new PlotData2D(costBenefitI);
      m_costBenefit.m_alwaysDisplayPointsOfThisSize = 10;
      m_costBenefit.setPlotName("Cost/benefit curve");
      boolean[] connectPoints = new boolean[costBenefitI.numInstances()];

      for (int i = 0; i < connectPoints.length; i++) {
        connectPoints[i] = true;
      }
      try {
        m_costBenefit.setConnectPoints(connectPoints);
        m_costBenefit.setShapeSize(m_shapeSizes);
      } catch (Exception ex) {
        // ignore
      }

      m_tpPrevious = tpCost;
      m_fpPrevious = fpCost;
      m_tnPrevious = tnCost;
      m_fnPrevious = fnCost;

      return true;
    }
  }

  /**
   * Constructor.
   */
  public CostBenefitAnalysis() {

    if (!GraphicsEnvironment.isHeadless()) {
      appearanceFinal();
    } else {
      m_headlessEvents = new ArrayList();
    }
  }

  /**
   * Global info for this bean
   * 
   * @return a String value
   */
  public String globalInfo() {
    return "Visualize performance charts (such as ROC).";
  }

  /**
   * Accept a threshold data event and set up the visualization.
   * 
   * @param e a threshold data event
   */
  @Override
  public void acceptDataSet(ThresholdDataEvent e) {
    if (!GraphicsEnvironment.isHeadless()) {
      try {
        setCurveData(e.getDataSet(), e.getClassAttribute());
      } catch (Exception ex) {
        System.err
          .println("[CostBenefitAnalysis] Problem setting up visualization.");
        ex.printStackTrace();
      }
    } else {
      m_headlessEvents = new ArrayList();
      m_headlessEvents.add(e);
    }
  }

  /**
   * Set the threshold curve data to use.
   * 
   * @param curveData a PlotData2D object set up with the curve data.
   * @param origClassAtt the class attribute from the original data used to
   *          generate the curve.
   * @throws Exception if somthing goes wrong during the setup process.
   */
  public void setCurveData(PlotData2D curveData, Attribute origClassAtt)
    throws Exception {

    if (m_analysisPanel == null) {
      m_analysisPanel = new AnalysisPanel();
    }
    m_analysisPanel.setDataSet(curveData, origClassAtt);
  }

  @Override
  public BeanVisual getVisual() {
    return m_visual;
  }

  @Override
  public void setVisual(BeanVisual newVisual) {
    m_visual = newVisual;
  }

  @Override
  public void useDefaultVisual() {
    m_visual.loadIcons(BeanVisual.ICON_PATH + "DefaultDataVisualizer.gif",
      BeanVisual.ICON_PATH + "DefaultDataVisualizer_animated.gif");
  }

  @Override
  public Enumeration enumerateRequests() {
    Vector newVector = new Vector(0);
    if (m_analysisPanel != null) {
      if (m_analysisPanel.m_masterPlot != null) {
        newVector.addElement("Show analysis");
      }
    }
    return newVector.elements();
  }

  @Override
  public void performRequest(String request) {
    if (request.compareTo("Show analysis") == 0) {
      try {
        // popup visualize panel
        if (!m_framePoppedUp) {
          m_framePoppedUp = true;

          final javax.swing.JFrame jf = new javax.swing.JFrame(
            "Cost/Benefit Analysis");
          jf.setSize(1000, 600);
          jf.getContentPane().setLayout(new BorderLayout());
          jf.getContentPane().add(m_analysisPanel, BorderLayout.CENTER);
          jf.addWindowListener(new java.awt.event.WindowAdapter() {
            @Override
            public void windowClosing(java.awt.event.WindowEvent e) {
              jf.dispose();
              m_framePoppedUp = false;
            }
          });
          jf.setVisible(true);
          m_popupFrame = jf;
        } else {
          m_popupFrame.toFront();
        }
      } catch (Exception ex) {
        ex.printStackTrace();
        m_framePoppedUp = false;
      }
    } else {
      throw new IllegalArgumentException(request
        + " not supported (Cost/Benefit Analysis");
    }
  }

  @Override
  public void addVetoableChangeListener(String name, VetoableChangeListener vcl) {
    m_bcSupport.addVetoableChangeListener(name, vcl);
  }

  @Override
  public BeanContext getBeanContext() {
    return m_beanContext;
  }

  @Override
  public void removeVetoableChangeListener(String name,
    VetoableChangeListener vcl) {
    m_bcSupport.removeVetoableChangeListener(name, vcl);
  }

  protected void appearanceFinal() {
    removeAll();
    setLayout(new BorderLayout());
    setUpFinal();
  }

  protected void setUpFinal() {
    if (m_analysisPanel == null) {
      m_analysisPanel = new AnalysisPanel();
    }
    add(m_analysisPanel, BorderLayout.CENTER);
  }

  protected void appearanceDesign() {
    removeAll();
    useDefaultVisual();
    setLayout(new BorderLayout());
    add(m_visual, BorderLayout.CENTER);
  }

  @Override
  public void setBeanContext(BeanContext bc) throws PropertyVetoException {
    m_beanContext = bc;
    m_design = m_beanContext.isDesignTime();
    if (m_design) {
      appearanceDesign();
    } else {
      if (!GraphicsEnvironment.isHeadless()) {
        appearanceFinal();
      }
    }
  }

  /**
   * Returns true if, at this time, the object will accept a connection via the
   * named event
   * 
   * @param eventName the name of the event in question
   * @return true if the object will accept a connection
   */
  @Override
  public boolean connectionAllowed(String eventName) {
    return (m_listenee == null);
  }

  /**
   * Notify this object that it has been registered as a listener with a source
   * for recieving events described by the named event This object is
   * responsible for recording this fact.
   * 
   * @param eventName the event
   * @param source the source with which this object has been registered as a
   *          listener
   */
  @Override
  public void connectionNotification(String eventName, Object source) {
    if (connectionAllowed(eventName)) {
      m_listenee = source;
    }
  }

  /**
   * Returns true if, at this time, the object will accept a connection
   * according to the supplied EventSetDescriptor
   * 
   * @param esd the EventSetDescriptor
   * @return true if the object will accept a connection
   */
  @Override
  public boolean connectionAllowed(EventSetDescriptor esd) {
    return connectionAllowed(esd.getName());
  }

  /**
   * Notify this object that it has been deregistered as a listener with a
   * source for named event. This object is responsible for recording this fact.
   * 
   * @param eventName the event
   * @param source the source with which this object has been registered as a
   *          listener
   */
  @Override
  public void disconnectionNotification(String eventName, Object source) {
    if (m_listenee == source) {
      m_listenee = null;
    }

  }

  /**
   * Get the custom (descriptive) name for this bean (if one has been set)
   * 
   * @return the custom name (or the default name)
   */
  @Override
  public String getCustomName() {
    return m_visual.getText();
  }

  /**
   * Returns true if. at this time, the bean is busy with some (i.e. perhaps a
   * worker thread is performing some calculation).
   * 
   * @return true if the bean is busy.
   */
  @Override
  public boolean isBusy() {
    return false;
  }

  /**
   * Set a custom (descriptive) name for this bean
   * 
   * @param name the name to use
   */
  @Override
  public void setCustomName(String name) {
    m_visual.setText(name);
  }

  /**
   * Set a logger
   * 
   * @param logger a weka.gui.Logger value
   */
  @Override
  public void setLog(Logger logger) {
    // we don't need to do any logging
  }

  /**
   * Stop any processing that the bean might be doing.
   */
  @Override
  public void stop() {
    // nothing to do here
  }

  public static void main(String[] args) {
    try {
      Instances train = new Instances(new java.io.BufferedReader(
        new java.io.FileReader(args[0])));
      train.setClassIndex(train.numAttributes() - 1);
      weka.classifiers.evaluation.ThresholdCurve tc = new weka.classifiers.evaluation.ThresholdCurve();
      weka.classifiers.evaluation.EvaluationUtils eu = new weka.classifiers.evaluation.EvaluationUtils();
      // weka.classifiers.Classifier classifier = new
      // weka.classifiers.functions.Logistic();
      weka.classifiers.Classifier classifier = new weka.classifiers.bayes.NaiveBayes();
      ArrayList predictions = new ArrayList();
      eu.setSeed(1);
      predictions.addAll(eu.getCVPredictions(classifier, train, 10));
      Instances result = tc.getCurve(predictions, 0);
      PlotData2D pd = new PlotData2D(result);
      pd.m_alwaysDisplayPointsOfThisSize = 10;

      boolean[] connectPoints = new boolean[result.numInstances()];
      for (int i = 1; i < connectPoints.length; i++) {
        connectPoints[i] = true;
      }
      pd.setConnectPoints(connectPoints);
      final javax.swing.JFrame jf = new javax.swing.JFrame("CostBenefitTest");
      jf.setSize(1000, 600);
      // jf.pack();
      jf.getContentPane().setLayout(new BorderLayout());
      final CostBenefitAnalysis.AnalysisPanel analysisPanel = new CostBenefitAnalysis.AnalysisPanel();

      jf.getContentPane().add(analysisPanel, BorderLayout.CENTER);
      jf.addWindowListener(new java.awt.event.WindowAdapter() {
        @Override
        public void windowClosing(java.awt.event.WindowEvent e) {
          jf.dispose();
          System.exit(0);
        }
      });

      jf.setVisible(true);

      analysisPanel.setDataSet(pd, train.classAttribute());

    } catch (Exception ex) {
      ex.printStackTrace();
    }

  }

  /**
   * Get the list of events processed in headless mode. May return null or an
   * empty list if not running in headless mode or no events were processed
   * 
   * @return a list of EventObjects or null.
   */
  @Override
  public List retrieveHeadlessEvents() {
    return m_headlessEvents;
  }

  /**
   * Process a list of events that have been collected earlier. Has no affect if
   * the component is running in headless mode.
   * 
   * @param headless a list of EventObjects to process.
   */
  @Override
  public void processHeadlessEvents(List headless) {
    // only process if we're not headless
    if (!GraphicsEnvironment.isHeadless()) {
      for (EventObject e : headless) {
        if (e instanceof ThresholdDataEvent) {
          acceptDataSet((ThresholdDataEvent) e);
        }
      }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy