All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.knowledgeflow.steps.BoundaryPlotter Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    BoundaryPlotter.java
 *    Copyright (C) 2015 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.knowledgeflow.steps;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.clusterers.AbstractClusterer;
import weka.clusterers.DensityBasedClusterer;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionHandler;
import weka.core.OptionMetadata;
import weka.core.SerializedObject;
import weka.core.Utils;
import weka.core.WekaException;
import weka.gui.ProgrammaticProperty;
import weka.gui.boundaryvisualizer.DataGenerator;
import weka.gui.boundaryvisualizer.KDDataGenerator;
import weka.gui.knowledgeflow.KFGUIConsts;
import weka.knowledgeflow.Data;
import weka.knowledgeflow.ExecutionResult;
import weka.knowledgeflow.StepManager;
import weka.knowledgeflow.StepTask;

import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.Future;

/**
 * A step that computes visualization data for class/cluster decision
 * boundaries.
 *
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision: $
 */
@KFStep(name = "BoundaryPlotter", category = "Visualization",
  toolTipText = "Visualize class/cluster decision boundaries in a 2D plot",
  iconPath = KFGUIConsts.BASE_ICON_PATH + "DefaultDataVisualizer.gif")
public class BoundaryPlotter extends BaseStep implements DataCollector {

  /** default colours for classes */
  public static final Color[] DEFAULT_COLORS = { Color.red, Color.green,
    Color.blue, new Color(0, 255, 255), // cyan
    new Color(255, 0, 255), // pink
    new Color(255, 255, 0), // yellow
    new Color(255, 255, 255), // white
    new Color(0, 0, 0) };

  private static final long serialVersionUID = 7864251468395026619L;

  /** Holds colors to use */
  protected List m_Colors = new ArrayList();

  /**
   * Number of rows of the visualization to compute in parallel. We don't want
   * to dominate the thread pool that is used for executing all steps and step
   * sub-tasks in the KF (this is currently fixed at 50 threads by FlowRunner).
   */
  protected int m_maxRowsInParallel = 10;

  /** Width of images to generate */
  protected int m_imageWidth = 400;

  /** Height of images to generate */
  protected int m_imageHeight = 400;

  /** X axis attribute name/index */
  protected String m_xAttName = "/first";

  /** Y axis attribute name/index */
  protected String m_yAttName = "2";

  /** Superimpose the training data on the plot? */
  protected boolean m_plotTrainingData = true;

  // attribute indices for visualizing on
  protected int m_xAttribute;
  protected int m_yAttribute;

  // min, max and ranges of these attributes
  protected double m_minX;
  protected double m_minY;
  protected double m_maxX;
  protected double m_maxY;
  protected double m_rangeX;
  protected double m_rangeY;

  // pixel width and height in terms of attribute values
  protected double m_pixHeight;
  protected double m_pixWidth;

  /** The currently rendering image */
  protected transient BufferedImage m_osi;

  /** The spec of the scheme being used to render the current image */
  protected String m_currentDescription;

  /** Completed images */
  protected transient Map m_completedImages;

  /** Classifiers to use */
  protected List m_classifierTemplates;

  /** Clusterers to use */
  protected List m_clustererTemplates;

  /** Copies of trained classifier to use in parallel for prediction */
  protected weka.classifiers.Classifier[] m_threadClassifiers;

  /** Copies of trained clusterer to use in parallel for prediction */
  protected weka.clusterers.Clusterer[] m_threadClusterers;

  /** Data generator copies to use in parallel */
  protected DataGenerator[] m_threadGenerators;

  /** The data generator to use */
  protected KDDataGenerator m_dataGenerator;

  /** User-specified bandwidth */
  protected String m_kBand = "3";

  /** User-specified num samples */
  protected String m_nSamples = "2";

  /** User-specified base for sampling */
  protected String m_sBase = "2";

  /** Parsed bandwidth */
  protected int m_kernelBandwidth = 3;

  /** Parsed samples */
  protected int m_numSamplesPerRegion = 2;

  /** Parsed base */
  protected int m_samplesBase = 2;

  /** Open interactive view? */
  protected transient RenderingUpdateListener m_plotListener;

  /** True if we've been reset */
  protected boolean m_isReset;

  /**
   * Constructor
   */
  public BoundaryPlotter() {
    for (Color element : DEFAULT_COLORS) {
      m_Colors.add(new Color(element.getRed(), element.getGreen(), element
        .getBlue()));
    }
  }

  /**
   * Set the name/index of the X axis attribute
   *
   * @param xAttName name/index of the X axis attribute
   */
  // make programmatic as our dialog will handle these directly, rather than
  // deferring to the GOE
  @ProgrammaticProperty
  @OptionMetadata(displayName = "X attribute",
    description = "Attribute to visualize on the x-axis", displayOrder = 1)
  public void setXAttName(String xAttName) {
    m_xAttName = xAttName;
  }

  /**
   * Get the name/index of the X axis attribute
   *
   * @return the name/index of the X axis attribute
   */
  public String getXAttName() {
    return m_xAttName;
  }

  /**
   * Set the name/index of the Y axis attribute
   *
   * @param attName name/index of the Y axis attribute
   */
  // make programmatic as our dialog will handle these directly, rather than
  // deferring to the GOE
  @ProgrammaticProperty
  @OptionMetadata(displayName = "Y attribute",
    description = "Attribute to visualize on the y-axis", displayOrder = 2)
  public void setYAttName(String attName) {
    m_yAttName = attName;
  }

  /**
   * Get the name/index of the Y axis attribute
   *
   * @return the name/index of the Y axis attribute
   */
  public String getYAttName() {
    return m_yAttName;
  }

  /**
   * Set the base for sampling
   *
   * @param base the base to use
   */
  @OptionMetadata(displayName = "Base for sampling (r)",
    description = "The base for sampling", displayOrder = 3)
  public void setBaseForSampling(String base) {
    m_sBase = base;
  }

  /**
   * Get the base for sampling
   *
   * @return the base to use
   */
  public String getBaseForSampling() {
    return m_sBase;
  }

  /**
   * Set the number of locations/samples per pixel
   *
   * @param num the number of samples to use
   */
  @OptionMetadata(displayName = "Num. locations per pixel",
    description = "Number of locations per pixel", displayOrder = 4)
  public void setNumLocationsPerPixel(String num) {
    m_nSamples = num;
  }

  /**
   * Get the number of locations/samples per pixel
   *
   * @return the number of samples to use
   */
  public String getNumLocationsPerPixel() {
    return m_nSamples;
  }

  /**
   * Set the kernel bandwidth
   *
   * @param band the bandwidth
   */
  @OptionMetadata(displayName = "Kernel bandwidth (k)",
    description = "Kernel bandwidth", displayOrder = 4)
  public void setKernelBandwidth(String band) {
    m_kBand = band;
  }

  /**
   * Get the kernel bandwidth
   *
   * @return the bandwidth
   */
  public String getKernelBandwidth() {
    return m_kBand;
  }

  /**
   * Set the image width (in pixels)
   *
   * @param width the width to use
   */
  @OptionMetadata(displayName = "Image width (pixels)",
    description = "Image width in pixels", displayOrder = 5)
  public void setImageWidth(int width) {
    m_imageWidth = width;
  }

  /**
   * Get the image width (in pixels)
   *
   * @return the width to use
   */
  public int getImageWidth() {
    return m_imageWidth;
  }

  /**
   * Set the image height (in pixels)
   *
   * @param height the height to use
   */
  @OptionMetadata(displayName = "Image height (pixels)",
    description = "Image height in pixels", displayOrder = 6)
  public void setImageHeight(int height) {
    m_imageHeight = height;
  }

  /**
   * Get the image height (in pixels)
   *
   * @return the height to use
   */
  public int getImageHeight() {
    return m_imageHeight;
  }

  /**
   * Set the maximum number of threads to use when computing image rows
   *
   * @param max maximum number of rows to compute in parallel
   */
  @OptionMetadata(displayName = "Max image rows to compute in parallel",
    description = "Use this many tasks for computing rows of the image",
    displayOrder = 7)
  public void setComputeMaxRowsInParallel(int max) {
    if (max > 0) {
      m_maxRowsInParallel = max;
    }
  }

  /**
   * Get the maximum number of threads to use when computing image rows
   *
   * @return the maximum number of rows to compute in parallel
   */
  public int getComputeMaxRowsInParallel() {
    return m_maxRowsInParallel;
  }

  /**
   * Set whether to superimpose the training data points on the plot or not
   *
   * @param plot true to plot the training data
   */
  @OptionMetadata(displayName = "Plot training points",
    description = "Superimpose the training data over the top of the plot",
    displayOrder = 8)
  public void setPlotTrainingData(boolean plot) {
    m_plotTrainingData = plot;
  }

  /**
   * Get whether to superimpose the training data points on the plot or not
   *
   * @return true if plotting the training data
   */
  public boolean getPlotTrainingData() {
    return m_plotTrainingData;
  }

  /**
   * Initialize the step.
   *
   * @throws WekaException if a problem occurs during initialization
   */
  @Override
  public void stepInit() throws WekaException {

    List infos =
      getStepManager().getIncomingConnectedStepsOfConnectionType(
        StepManager.CON_INFO);
    if (infos.size() == 0) {
      throw new WekaException(
        "One or more classifiers/clusterers need to be supplied via an 'info' "
          + "connection type");
    }

    m_classifierTemplates = new ArrayList();
    m_clustererTemplates = new ArrayList();
    for (StepManager m : infos) {
      Step info = m.getInfoStep();

      if (info instanceof weka.knowledgeflow.steps.Classifier) {
        m_classifierTemplates.add(((weka.knowledgeflow.steps.Classifier) info)
          .getClassifier());
      } else if (info instanceof weka.knowledgeflow.steps.Clusterer) {
        weka.clusterers.Clusterer c =
          ((weka.knowledgeflow.steps.Clusterer) info).getClusterer();
        if (!(c instanceof DensityBasedClusterer)) {
          throw new WekaException("Clusterer "
            + c.getClass().getCanonicalName()
            + " is not a DensityBasedClusterer");
        }
        m_clustererTemplates.add((DensityBasedClusterer) c);
      }
    }

    m_completedImages = new LinkedHashMap();

    if (m_nSamples != null && m_nSamples.length() > 0) {
      String nSampes = environmentSubstitute(m_nSamples);
      try {
        m_numSamplesPerRegion = Integer.parseInt(nSampes);
      } catch (NumberFormatException ex) {
        getStepManager().logWarning(
          "Unable to parse '" + nSampes + "' for num "
            + "samples per region parameter, using default: "
            + m_numSamplesPerRegion);
      }
    }

    if (m_sBase != null && m_sBase.length() > 0) {
      String sBase = environmentSubstitute(m_sBase);
      try {
        m_samplesBase = Integer.parseInt(sBase);
      } catch (NumberFormatException ex) {
        getStepManager().logWarning(
          "Unable to parse '" + sBase + "' for "
            + "the base for sampling parameter, using default: "
            + m_samplesBase);
      }
    }

    if (m_kBand != null && m_kBand.length() > 0) {
      String kBand = environmentSubstitute(m_kBand);
      try {
        m_kernelBandwidth = Integer.parseInt(kBand);
      } catch (NumberFormatException ex) {
        getStepManager().logWarning(
          "Unable to parse '" + kBand + "' for kernel "
            + "bandwidth parameter, using default: " + m_kernelBandwidth);
      }
    }

    /*
     * m_osi = new BufferedImage(m_imageWidth, m_imageHeight,
     * BufferedImage.TYPE_INT_RGB);
     */
    m_isReset = true;
  }

  protected void computeMinMaxAtts(Instances trainingData) {
    m_minX = Double.MAX_VALUE;
    m_minY = Double.MAX_VALUE;
    m_maxX = Double.MIN_VALUE;
    m_maxY = Double.MIN_VALUE;

    boolean allPointsLessThanOne = true;

    if (trainingData.numInstances() == 0) {
      m_minX = m_minY = 0.0;
      m_maxX = m_maxY = 1.0;
    } else {
      for (int i = 0; i < trainingData.numInstances(); i++) {
        Instance inst = trainingData.instance(i);
        double x = inst.value(m_xAttribute);
        double y = inst.value(m_yAttribute);
        if (!Utils.isMissingValue(x) && !Utils.isMissingValue(y)) {
          if (x < m_minX) {
            m_minX = x;
          }
          if (x > m_maxX) {
            m_maxX = x;
          }

          if (y < m_minY) {
            m_minY = y;
          }
          if (y > m_maxY) {
            m_maxY = y;
          }
          if (x > 1.0 || y > 1.0) {
            allPointsLessThanOne = false;
          }
        }
      }
    }

    if (m_minX == m_maxX) {
      m_minX = 0;
    }
    if (m_minY == m_maxY) {
      m_minY = 0;
    }
    if (m_minX == Double.MAX_VALUE) {
      m_minX = 0;
    }
    if (m_minY == Double.MAX_VALUE) {
      m_minY = 0;
    }
    if (m_maxX == Double.MIN_VALUE) {
      m_maxX = 1;
    }
    if (m_maxY == Double.MIN_VALUE) {
      m_maxY = 1;
    }
    if (allPointsLessThanOne) {
      // m_minX = m_minY = 0.0;
      m_maxX = m_maxY = 1.0;
    }

    m_rangeX = (m_maxX - m_minX);
    m_rangeY = (m_maxY - m_minY);

    m_pixWidth = m_rangeX / m_imageWidth;
    m_pixHeight = m_rangeY / m_imageHeight;
  }

  protected int getAttIndex(String attName, Instances data)
    throws WekaException {
    attName = environmentSubstitute(attName);
    int index = -1;

    if (attName.equalsIgnoreCase("first") || attName.equalsIgnoreCase("/first")) {
      index = 0;
    } else if (attName.equalsIgnoreCase("last")
      || attName.equalsIgnoreCase("/last")) {
      index = data.numAttributes() - 1;
    } else {
      Attribute a = data.attribute(attName);
      if (a != null) {
        index = a.index();
      } else {
        // try parsing as a number
        try {
          index = Integer.parseInt(attName);
          index--;
        } catch (NumberFormatException ex) {
        }
      }
    }

    if (index == -1) {
      throw new WekaException("Unable to find attribute '" + attName
        + "' in the data " + "or to parse it as an index");
    }

    return index;
  }

  protected void initDataGenerator(Instances trainingData) throws WekaException {
    boolean[] attsToWeightOn;
    // build DataGenerator
    attsToWeightOn = new boolean[trainingData.numAttributes()];
    attsToWeightOn[m_xAttribute] = true;
    attsToWeightOn[m_yAttribute] = true;

    m_dataGenerator = new KDDataGenerator();
    m_dataGenerator.setWeightingDimensions(attsToWeightOn);
    m_dataGenerator.setKernelBandwidth(m_kernelBandwidth);
    try {
      m_dataGenerator.buildGenerator(trainingData);
    } catch (Exception ex) {
      throw new WekaException(ex);
    }
  }

  @Override
  public synchronized void processIncoming(Data data) throws WekaException {

    getStepManager().processing();
    Instances training = data.getPrimaryPayload();
    Integer setNum =
      data.getPayloadElement(StepManager.CON_AUX_DATA_SET_NUM, 1);
    Integer maxSetNum =
      data.getPayloadElement(StepManager.CON_AUX_DATA_MAX_SET_NUM, 1);

    m_xAttribute = getAttIndex(m_xAttName, training);
    m_yAttribute = getAttIndex(m_yAttName, training);

    computeMinMaxAtts(training);
    initDataGenerator(training);

    for (Classifier c : m_classifierTemplates) {
      if (isStopRequested()) {
        getStepManager().interrupted();
        return;
      }
      // do classifiers
      doScheme(c, null, training, setNum, maxSetNum);
    }

    for (DensityBasedClusterer c : m_clustererTemplates) {
      if (isStopRequested()) {
        getStepManager().interrupted();
        return;
      }
      doScheme(null, c, training, setNum, maxSetNum);
    }

    if (isStopRequested()) {
      getStepManager().interrupted();
    } else {
      getStepManager().finished();
    }
  }

  protected void doScheme(Classifier classifier, DensityBasedClusterer clust,
    Instances trainingData, int setNum, int maxSetNum) throws WekaException {
    try {
      m_osi =
        new BufferedImage(m_imageWidth, m_imageHeight,
          BufferedImage.TYPE_INT_RGB);
      m_currentDescription =
        makeSchemeSpec(classifier != null ? classifier : clust, setNum,
          maxSetNum);
      // notify listeners
      getStepManager()
        .logBasic("Starting new plot for " + m_currentDescription);
      if (m_plotListener != null) {
        m_plotListener.newPlotStarted(m_currentDescription);
      }

      Graphics m = m_osi.getGraphics();
      m.fillRect(0, 0, m_imageWidth, m_imageHeight);

      Classifier toTrainClassifier = null;
      weka.clusterers.DensityBasedClusterer toTrainClusterer = null;
      if (classifier != null) {
        toTrainClassifier =
          (Classifier) AbstractClassifier.makeCopy(classifier);
        toTrainClassifier.buildClassifier(trainingData);
      } else {
        int tempClassIndex = trainingData.classIndex();
        trainingData.setClassIndex(-1);
        toTrainClusterer =
          (DensityBasedClusterer) weka.clusterers.AbstractClusterer
            .makeCopy((weka.clusterers.Clusterer) clust);
        toTrainClusterer.buildClusterer(trainingData);
        trainingData.setClassIndex(tempClassIndex);
      }

      // populate the thread classifiers ready for parallel processing
      if (toTrainClassifier != null) {
        m_threadClassifiers =
          AbstractClassifier.makeCopies(toTrainClassifier, m_maxRowsInParallel);
      } else {
        m_threadClusterers =
          AbstractClusterer.makeCopies(toTrainClusterer, m_maxRowsInParallel);
      }
      m_threadGenerators = new DataGenerator[m_maxRowsInParallel];
      SerializedObject so = new SerializedObject(m_dataGenerator);
      for (int i = 0; i < m_maxRowsInParallel; i++) {
        m_threadGenerators[i] = (DataGenerator) so.getObject();
      }

      int taskCount = 0;
      List>> results =
        new ArrayList>>();
      for (int i = 0; i < m_imageHeight; i++) {
        if (taskCount < m_maxRowsInParallel) {
          getStepManager().logDetailed(
            "Launching task to compute image row " + i);
          SchemeRowTask t = new SchemeRowTask(this);
          t.setResourceIntensive(isResourceIntensive());
          t.m_classifier = null;
          t.m_clusterer = null;
          if (toTrainClassifier != null) {
            t.m_classifier = m_threadClassifiers[taskCount];
          } else {
            t.m_clusterer =
              (DensityBasedClusterer) m_threadClusterers[taskCount];
          }
          t.m_rowNum = i;
          t.m_xAtt = m_xAttribute;
          t.m_yAtt = m_yAttribute;
          t.m_imageWidth = m_imageWidth;
          t.m_imageHeight = m_imageHeight;
          t.m_pixWidth = m_pixWidth;
          t.m_pixHeight = m_pixHeight;
          t.m_dataGenerator = m_threadGenerators[taskCount];
          t.m_trainingData = trainingData;
          t.m_minX = m_minX;
          t.m_maxX = m_maxX;
          t.m_minY = m_minY;
          t.m_maxY = m_maxY;
          t.m_numOfSamplesPerRegion = m_numSamplesPerRegion;
          t.m_samplesBase = m_samplesBase;

          results.add(getStepManager().getExecutionEnvironment().submitTask(t));
          taskCount++;
        } else {
          // wait for running tasks
          for (Future> r : results) {
            double[][] rowProbs = r.get().getResult().m_rowProbs;
            for (int j = 0; j < m_imageWidth; j++) {
              plotPoint(m_osi, j, r.get().getResult().m_rowNumber, rowProbs[j],
                j == m_imageWidth - 1);
            }
            getStepManager().statusMessage(
              "Completed row " + r.get().getResult().m_rowNumber);
            getStepManager().logDetailed(
              "Completed image row " + r.get().getResult().m_rowNumber);
          }
          results.clear();
          taskCount = 0;
          if (i != m_imageHeight - 1) {
            i--;
          }
          if (isStopRequested()) {
            return;
          }
        }
      }
      if (results.size() > 0) {
        // wait for running tasks
        for (Future> r : results) {
          double[][] rowProbs = r.get().getResult().m_rowProbs;
          for (int i = 0; i < m_imageWidth; i++) {
            plotPoint(m_osi, i, r.get().getResult().m_rowNumber, rowProbs[i],
              i == m_imageWidth - 1);
          }
          getStepManager().statusMessage(
            "Completed row " + r.get().getResult().m_rowNumber);
          getStepManager().logDetailed(
            "Completed image row " + r.get().getResult().m_rowNumber);
        }
        if (isStopRequested()) {
          return;
        }
      }

      if (m_plotTrainingData) {
        plotTrainingData(trainingData);
      }

      m_completedImages.put(m_currentDescription, m_osi);
      Data imageOut = new Data(StepManager.CON_IMAGE, m_osi);
      imageOut.setPayloadElement(StepManager.CON_AUX_DATA_TEXT_TITLE,
        m_currentDescription);
      getStepManager().outputData(imageOut);
    } catch (Exception ex) {
      throw new WekaException(ex);
    }
  }

  protected String makeSchemeSpec(Object scheme, int setNum, int maxSetNum) {
    String name = scheme.getClass().getCanonicalName();
    name = name.substring(name.lastIndexOf('.') + 1, name.length());
    if (scheme instanceof OptionHandler) {
      name += " " + Utils.joinOptions(((OptionHandler) scheme).getOptions());
    }
    if (maxSetNum != 1) {
      name += " (set " + setNum + " of " + maxSetNum + ")";
    }

    return name;
  }

  protected void plotPoint(BufferedImage osi, int x, int y, double[] probs,
    boolean update) {
    Graphics osg = osi.getGraphics();
    osg.setPaintMode();

    float[] colVal = new float[3];

    float[] tempCols = new float[3];
    for (int k = 0; k < probs.length; k++) {
      Color curr = m_Colors.get(k % m_Colors.size());

      curr.getRGBColorComponents(tempCols);
      for (int z = 0; z < 3; z++) {
        colVal[z] += probs[k] * tempCols[z];
      }
    }

    for (int z = 0; z < 3; z++) {
      if (colVal[z] < 0) {
        colVal[z] = 0;
      } else if (colVal[z] > 1) {
        colVal[z] = 1;
      }
    }

    osg.setColor(new Color(colVal[0], colVal[1], colVal[2]));
    osg.fillRect(x, y, 1, 1);

    if (update) {
      // end of row
      // generate an update event for interactive viewer to consume
      if (m_plotListener != null) {
        m_plotListener.currentPlotRowCompleted(y);
      }
    }
  }

  public void plotTrainingData(Instances trainingData) {
    Graphics2D osg = (Graphics2D) m_osi.getGraphics();
    // Graphics g = m_plotPanel.getGraphics();
    osg.setRenderingHint(RenderingHints.KEY_ANTIALIASING,
      RenderingHints.VALUE_ANTIALIAS_ON);
    double xval = 0;
    double yval = 0;

    for (int i = 0; i < trainingData.numInstances(); i++) {
      if (!trainingData.instance(i).isMissing(m_xAttribute)
        && !trainingData.instance(i).isMissing(m_yAttribute)) {

        xval = trainingData.instance(i).value(m_xAttribute);
        yval = trainingData.instance(i).value(m_yAttribute);

        int panelX = convertToImageX(xval);
        int panelY = convertToImageY(yval);
        Color colorToPlotWith = Color.white;
        if (trainingData.classIndex() > 0) {
          colorToPlotWith =
            m_Colors.get((int) trainingData.instance(i).value(
              trainingData.classIndex())
              % m_Colors.size());
        }
        if (colorToPlotWith.equals(Color.white)) {
          osg.setColor(Color.black);
        } else {
          osg.setColor(Color.white);
        }
        osg.fillOval(panelX - 3, panelY - 3, 7, 7);
        osg.setColor(colorToPlotWith);
        osg.fillOval(panelX - 2, panelY - 2, 5, 5);
      }
    }

    if (m_plotListener != null) {
      m_plotListener.renderingImageUpdate();
    }
  }

  private int convertToImageX(double xval) {
    double temp = (xval - m_minX) / m_rangeX;
    temp = temp * m_imageWidth;

    return (int) temp;
  }

  private int convertToImageY(double yval) {
    double temp = (yval - m_minY) / m_rangeY;
    temp = temp * m_imageHeight;
    temp = m_imageHeight - temp;

    return (int) temp;
  }

  /**
   * Get a list of incoming connection types that this step can accept. Ideally
   * (and if appropriate), this should take into account the state of the step
   * and any existing incoming connections. E.g. a step might be able to accept
   * one (and only one) incoming batch data connection.
   *
   * @return a list of incoming connections that this step can accept given its
   *         current state
   */
  @Override
  public List getIncomingConnectionTypes() {
    return Arrays.asList(StepManager.CON_DATASET, StepManager.CON_TRAININGSET,
      StepManager.CON_INFO);
  }

  /**
   * Get a list of outgoing connection types that this step can produce. Ideally
   * (and if appropriate), this should take into account the state of the step
   * and the incoming connections. E.g. depending on what incoming connection is
   * present, a step might be able to produce a trainingSet output, a testSet
   * output or neither, but not both.
   *
   * @return a list of outgoing connections that this step can produce
   */
  @Override
  public List getOutgoingConnectionTypes() {
    return Arrays.asList(StepManager.CON_IMAGE);
  }

  /**
   * Get the completed images
   *
   * @return a map of completed images
   */
  public Map getImages() {
    return m_completedImages;
  }

  /**
   * Get the currently rendering image
   *
   * @return the current image
   */
  public BufferedImage getCurrentImage() {
    return m_osi;
  }

  /**
   * Set a listener to receive rendering updates
   *
   * @param l the {@code RenderingUpdateListener} to add
   */
  public void setRenderingListener(RenderingUpdateListener l) {
    m_plotListener = l;
  }

  /**
   * Remove the rendering update listener
   *
   * @param l the {@code RenderingUpdateListener} to remove
   */
  public void removeRenderingListener(RenderingUpdateListener l) {
    if (l == m_plotListener) {
      m_plotListener = null;
    }
  }

  /**
   * When running in a graphical execution environment a step can make one or
   * more popup Viewer components available. These might be used to display
   * results, graphics etc. Returning null indicates that the step has no such
   * additional graphical views. The map returned by this method should be keyed
   * by action name (e.g. "View results"), and values should be fully qualified
   * names of the corresponding StepInteractiveView implementation. Furthermore,
   * the contents of this map can (and should) be dependent on whether a
   * particular viewer should be made available - i.e. if execution hasn't
   * occurred yet, or if a particular incoming connection type is not present,
   * then it might not be possible to view certain results.
   *
   * Viewers can implement StepInteractiveView directly (in which case they need
   * to extends JPanel), or extends the AbstractInteractiveViewer class. The
   * later extends JPanel, uses a BorderLayout, provides a "Close" button and a
   * method to add additional buttons.
   *
   * @return a map of viewer component names, or null if this step has no
   *         graphical views
   */
  @Override
  public Map getInteractiveViewers() {
    Map views = new LinkedHashMap();
    if (m_plotListener == null) {
      views.put("Show plots",
        "weka.gui.knowledgeflow.steps.BoundaryPlotterInteractiveView");
    }
    return views;
  }

  /**
   * Return the fully qualified name of a custom editor component (JComponent)
   * to use for editing the properties of the step. This method can return null,
   * in which case the system will dynamically generate an editor using the
   * GenericObjectEditor
   *
   * @return the fully qualified name of a step editor component
   */
  @Override
  public String getCustomEditorForStep() {
    return "weka.gui.knowledgeflow.steps.BoundaryPlotterStepEditorDialog";
  }

  /**
   * Get the map of completed images
   *
   * @return the map of completed images
   */
  @Override
  public Object retrieveData() {
    return ImageViewer.bufferedImageMapToSerializableByteMap(m_completedImages);
  }

  /**
   * Set a map of images.
   *
   * @param data the images to set
   * @throws WekaException if a problem occurs
   */
  @Override
  @SuppressWarnings("unchecked")
  public void restoreData(Object data) throws WekaException {
    if (!(data instanceof Map)) {
      throw new IllegalArgumentException("Argument must be a Map");
    }

    try {
      m_completedImages =
        ImageViewer
          .byteArrayImageMapToBufferedImageMap((Map) data);
    } catch (IOException ex) {
      throw new WekaException(ex);
    }
  }

  /**
   * Interface for something that wants to be informed of rendering progress
   * updates
   */
  public interface RenderingUpdateListener {

    /**
     * Called when a new plot is started
     * 
     * @param description the description/title of the plot
     */
    void newPlotStarted(String description);

    /**
     * Called when rendering of a row in the current plot has completed
     * 
     * @param row the index of the row that was completed
     */
    void currentPlotRowCompleted(int row);

    /**
     * Called when a change (other than rendering a row) to the current plot has
     * occurred.
     */
    void renderingImageUpdate();
  }

  /**
   * Holds computed image data for a row of an image
   */
  protected static class RowResult {
    /** Probabilities for the pixels in a row of the image */
    protected double[][] m_rowProbs;

    /** The row number of this result */
    protected int m_rowNumber;
  }

  /**
   * A task for computing a row of an image using a trained model
   */
  protected static class SchemeRowTask extends StepTask implements
    Serializable {

    private static final long serialVersionUID = -4144732293602550066L;

    protected int m_xAtt;
    protected int m_yAtt;
    protected int m_rowNum;
    protected int m_imageWidth;
    protected int m_imageHeight;
    protected double m_pixWidth;
    protected double m_pixHeight;
    protected weka.classifiers.Classifier m_classifier;
    protected weka.clusterers.DensityBasedClusterer m_clusterer;
    protected DataGenerator m_dataGenerator;
    protected Instances m_trainingData;
    protected double m_minX;
    protected double m_maxX;
    protected double m_minY;
    protected double m_maxY;
    protected int m_numOfSamplesPerRegion;
    protected double m_samplesBase;

    private Random m_random;
    private int m_numOfSamplesPerGenerator;
    private boolean[] m_attsToWeightOn;
    private double[] m_weightingAttsValues;
    private double[] m_vals;
    private double[] m_dist;
    Instance m_predInst;

    public SchemeRowTask(Step source) {
      super(source);
    }

    @Override
    public void process() throws Exception {
      RowResult result = new RowResult();
      result.m_rowNumber = m_rowNum;
      result.m_rowProbs = new double[m_imageWidth][0];

      m_random = new Random(m_rowNum * 11);
      m_dataGenerator.setSeed(m_rowNum * 11);

      m_numOfSamplesPerGenerator =
        (int) Math.pow(m_samplesBase, m_trainingData.numAttributes() - 3);
      if (m_trainingData == null) {
        throw new Exception("No training data set");
      }
      if (m_classifier == null && m_clusterer == null) {
        throw new Exception("No scheme set");
      }
      if (m_dataGenerator == null) {
        throw new Exception("No data generator set");
      }
      if (m_trainingData.attribute(m_xAtt).isNominal()
        || m_trainingData.attribute(m_yAtt).isNominal()) {
        throw new Exception("Visualization dimensions must be numeric");
      }

      m_attsToWeightOn = new boolean[m_trainingData.numAttributes()];
      m_attsToWeightOn[m_xAtt] = true;
      m_attsToWeightOn[m_yAtt] = true;

      // generate samples
      m_weightingAttsValues = new double[m_attsToWeightOn.length];
      m_vals = new double[m_trainingData.numAttributes()];
      m_predInst = new DenseInstance(1.0, m_vals);
      m_predInst.setDataset(m_trainingData);
      getLogHandler().logDetailed("Computing row number: " + m_rowNum);
      for (int j = 0; j < m_imageWidth; j++) {
        double[] preds = calculateRegionProbs(j, m_rowNum);
        result.m_rowProbs[j] = preds;
      }

      getExecutionResult().setResult(result);
    }

    private double[] calculateRegionProbs(int j, int i) throws Exception {
      double[] sumOfProbsForRegion =
        new double[m_classifier != null ? m_trainingData.classAttribute()
          .numValues() : ((weka.clusterers.Clusterer) m_clusterer)
          .numberOfClusters()];

      double sumOfSums = 0;
      for (int u = 0; u < m_numOfSamplesPerRegion; u++) {

        double[] sumOfProbsForLocation =
          new double[m_classifier != null ? m_trainingData.classAttribute()
            .numValues() : ((weka.clusterers.Clusterer) m_clusterer)
            .numberOfClusters()];

        m_weightingAttsValues[m_xAtt] = getRandomX(j);
        m_weightingAttsValues[m_yAtt] = getRandomY(m_imageHeight - i - 1);

        m_dataGenerator.setWeightingValues(m_weightingAttsValues);

        double[] weights = m_dataGenerator.getWeights();
        double sumOfWeights = Utils.sum(weights);
        sumOfSums += sumOfWeights;
        int[] indices = Utils.sort(weights);

        // Prune 1% of weight mass
        int[] newIndices = new int[indices.length];
        double sumSoFar = 0;
        double criticalMass = 0.99 * sumOfWeights;
        int index = weights.length - 1;
        int counter = 0;
        for (int z = weights.length - 1; z >= 0; z--) {
          newIndices[index--] = indices[z];
          sumSoFar += weights[indices[z]];
          counter++;
          if (sumSoFar > criticalMass) {
            break;
          }
        }
        indices = new int[counter];
        System.arraycopy(newIndices, index + 1, indices, 0, counter);

        for (int z = 0; z < m_numOfSamplesPerGenerator; z++) {

          m_dataGenerator.setWeightingValues(m_weightingAttsValues);
          double[][] values = m_dataGenerator.generateInstances(indices);

          for (int q = 0; q < values.length; q++) {
            if (values[q] != null) {
              System.arraycopy(values[q], 0, m_vals, 0, m_vals.length);
              m_vals[m_xAtt] = m_weightingAttsValues[m_xAtt];
              m_vals[m_yAtt] = m_weightingAttsValues[m_yAtt];

              // classify/cluster the instance
              m_dist =
                m_classifier != null ? m_classifier
                  .distributionForInstance(m_predInst) : m_clusterer
                  .distributionForInstance(m_predInst);

              for (int k = 0; k < sumOfProbsForLocation.length; k++) {
                sumOfProbsForLocation[k] += (m_dist[k] * weights[q]);
              }
            }
          }
        }

        for (int k = 0; k < sumOfProbsForRegion.length; k++) {
          sumOfProbsForRegion[k] +=
            (sumOfProbsForLocation[k] / m_numOfSamplesPerGenerator);
        }
      }

      if (sumOfSums > 0) {
        // average
        Utils.normalize(sumOfProbsForRegion, sumOfSums);
      } else {
        throw new Exception(
          "Arithmetic underflow. Please increase value of kernel bandwidth " +
                  "parameter (k).");
      }

      // cache
      double[] tempDist = new double[sumOfProbsForRegion.length];
      System.arraycopy(sumOfProbsForRegion, 0, tempDist, 0,
        sumOfProbsForRegion.length);

      return tempDist;
    }

    /**
     * Return a random x attribute value contained within the pix'th horizontal
     * pixel
     *
     * @param pix the horizontal pixel number
     * @return a value in attribute space
     */
    private double getRandomX(int pix) {

      double minPix = m_minX + (pix * m_pixWidth);

      return minPix + m_random.nextDouble() * m_pixWidth;
    }

    /**
     * Return a random y attribute value contained within the pix'th vertical
     * pixel
     *
     * @param pix the vertical pixel number
     * @return a value in attribute space
     */
    private double getRandomY(int pix) {

      double minPix = m_minY + (pix * m_pixHeight);

      return minPix + m_random.nextDouble() * m_pixHeight;
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy