All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.evaluation.MarginCurve Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    MarginCurve.java
 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.evaluation;

import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;

/**
 * Generates points illustrating the prediction margin. The margin is defined
 * as the difference between the probability predicted for the actual class and
 * the highest probability predicted for the other classes. One hypothesis
 * as to the good performance of boosting algorithms is that they increaes the
 * margins on the training data and this gives better performance on test data.
 *
 * @author Len Trigg ([email protected])
 * @version $Revision: 1.11 $
 */
public class MarginCurve
  implements RevisionHandler {

  /**
   * Calculates the cumulative margin distribution for the set of
   * predictions, returning the result as a set of Instances. The
   * structure of these Instances is as follows:

    *
  • Margin contains the margin value (which should be plotted * as an x-coordinate) *
  • Current contains the count of instances with the current * margin (plot as y axis) *
  • Cumulative contains the count of instances with margin * less than or equal to the current margin (plot as y axis) *

* * @return datapoints as a set of instances, null if no predictions * have been made. */ public Instances getCurve(FastVector predictions) { if (predictions.size() == 0) { return null; } Instances insts = makeHeader(); double [] margins = getMargins(predictions); int [] sorted = Utils.sort(margins); int binMargin = 0; int totalMargin = 0; insts.add(makeInstance(-1, binMargin, totalMargin)); for (int i = 0; i < sorted.length; i++) { double current = margins[sorted[i]]; double weight = ((NominalPrediction)predictions.elementAt(sorted[i])) .weight(); totalMargin += weight; binMargin += weight; if (true) { insts.add(makeInstance(current, binMargin, totalMargin)); binMargin = 0; } } return insts; } /** * Pulls all the margin values out of a vector of NominalPredictions. * * @param predictions a FastVector containing NominalPredictions * @return an array of margin values. */ private double [] getMargins(FastVector predictions) { // sort by predicted probability of the desired class. double [] margins = new double [predictions.size()]; for (int i = 0; i < margins.length; i++) { NominalPrediction pred = (NominalPrediction)predictions.elementAt(i); margins[i] = pred.margin(); } return margins; } /** * Creates an Instances object with the attributes we will be calculating. * * @return the Instances structure. */ private Instances makeHeader() { FastVector fv = new FastVector(); fv.addElement(new Attribute("Margin")); fv.addElement(new Attribute("Current")); fv.addElement(new Attribute("Cumulative")); return new Instances("MarginCurve", fv, 100); } /** * Creates an Instance object with the attributes calculated. * * @param margin the margin for this data point. * @param current the number of instances with this margin. * @param cumulative the number of instances with margin less than or equal * to this margin. * @return the Instance object. */ private Instance makeInstance(double margin, int current, int cumulative) { int count = 0; double [] vals = new double[3]; vals[count++] = margin; vals[count++] = current; vals[count++] = cumulative; return new Instance(1.0, vals); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 1.11 $"); } /** * Tests the MarginCurve generation from the command line. * The classifier is currently hardcoded. Pipe in an arff file. * * @param args currently ignored */ public static void main(String [] args) { try { Utils.SMALL = 0; Instances inst = new Instances(new java.io.InputStreamReader(System.in)); inst.setClassIndex(inst.numAttributes() - 1); MarginCurve tc = new MarginCurve(); EvaluationUtils eu = new EvaluationUtils(); weka.classifiers.meta.LogitBoost classifier = new weka.classifiers.meta.LogitBoost(); classifier.setNumIterations(20); FastVector predictions = eu.getTrainTestPredictions(classifier, inst, inst); Instances result = tc.getCurve(predictions); System.out.println(result); } catch (Exception ex) { ex.printStackTrace(); } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy