All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.fst.confidence.ConfidenceEvaluator Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.fst.confidence;

import java.util.Vector;
import java.util.Collections;
import java.util.Comparator;

import cc.mallet.fst.*;
import cc.mallet.types.*;

public class ConfidenceEvaluator
{
	static int DEFAULT_NUM_BINS = 20;
	Vector confidences;
	int nBins;
	int numCorrect;
	
	public ConfidenceEvaluator (Vector confidences, int nBins)
	{
		this.confidences = confidences;
		this.nBins = nBins;
		this.numCorrect = getNumCorrectEntities();
		// sort confidences by score
		Collections.sort (confidences, new ConfidenceComparator());
	}

	public ConfidenceEvaluator (Vector confidences)
	{
		this (confidences, DEFAULT_NUM_BINS);
	}

	public ConfidenceEvaluator (Segment[] segments, boolean sorted)
	{
		this.confidences = new Vector ();
		for (int i=0; i < segments.length; i++) {
			confidences.add (new EntityConfidence (segments[i].getConfidence(),
																						 segments[i].correct(), segments[i].getInput(),
																						 segments[i].getStart(), segments[i].getEnd()));
		}
		if (!sorted)
			Collections.sort (confidences, new ConfidenceComparator());
		this.nBins = DEFAULT_NUM_BINS;
		this.numCorrect = getNumCorrectEntities ();
	}

	public ConfidenceEvaluator (InstanceWithConfidence[] instances, boolean sorted) {
		this.confidences = new Vector ();
		for (int i=0; i < instances.length; i++) {
			Sequence input = (Sequence) instances[i].getInstance().getData();
			confidences.add (new EntityConfidence (instances[i].getConfidence(),
																						 instances[i].correct(), input,
																						 0, input.size()-1));
		}
		if (!sorted)
			Collections.sort (confidences, new ConfidenceComparator());
		this.nBins = DEFAULT_NUM_BINS;
		this.numCorrect = getNumCorrectEntities ();		
	}

	public ConfidenceEvaluator (PipedInstanceWithConfidence[] instances, boolean sorted) {
		this.confidences = new Vector ();
		for (int i=0; i < instances.length; i++) {
			confidences.add (new EntityConfidence (instances[i].getConfidence(),
																						 instances[i].correct(), null,
																						 0, 1));
		}
		if (!sorted)
			Collections.sort (confidences, new ConfidenceComparator());
		this.nBins = DEFAULT_NUM_BINS;
		this.numCorrect = getNumCorrectEntities ();		
	}

	/** Correlation when one variable (X) is binary: r = (bar(x1) -
			bar(x0)) * sqrt(p(1-p)) / sx , where bar(x1) = mean of X when Y
			is 1 bar(x0) = mean of X when Y is 0 sx = standard deviation of
			X p = proportion of values where Y=1
	*/
	
 	public double pointBiserialCorrelation ()
	{
		// here, Y = {incorrect = 0,correct = 1}, X = confidence
		double x0bar = getAverageIncorrectConfidence ();
		double x1bar = getAverageCorrectConfidence ();
		double p = (double)this.numCorrect / size();
		double sx = getConfidenceStandardDeviation ();
		return (x1bar - x0bar) * Math.sqrt(p*(1-p)) / sx;
	}

	/**
		 IR Average precision measure. Analogous to ranking _correct_
		 documents by confidence score. 
	 */
	public double getAveragePrecision () {
		int nc = 0;
		int ni = 0;
		double totalPrecision = 0.0;
		for (int i=confidences.size()-1; i >= 0; i--) {
			EntityConfidence c = (EntityConfidence) confidences.get (i);
			if (c.correct()) {
				nc++;
				totalPrecision += (double)nc / (nc + ni);
			}
			else ni++;
		}
		return totalPrecision / nc;
	}

	/**
		 For comparison, rank segments as badly as possible (all
		 "incorrect" before "correct").
	 */
	public double getWorstAveragePrecision () {
		int ni = confidences.size() - this.numCorrect;
		double totalPrecision = 0.0;
		for (int nc=1; nc <= this.numCorrect; nc++) {
			totalPrecision += (double) nc / (nc + ni);
		}
		return totalPrecision / this.numCorrect;
	}
	
	public double getConfidenceSum()
	{
		double sum = 0.0;
		for (int i = 0; i < size(); i++)
			sum += ((EntityConfidence)confidences.get(i)).confidence();
		return sum;
	}
	
	public double getConfidenceMean ()
	{
		return getConfidenceSum() / size();
	}
	
	/** Standard deviation of confidence scores
	 */
	public double getConfidenceStandardDeviation ()
	{
		double mean = getConfidenceMean();
		double sumSquaredDifference = 0.0;
		for (int i = 0; i < size(); i++) {
			double conf = ((EntityConfidence)confidences.get(i)).confidence();
			sumSquaredDifference += ((conf - mean) * (conf - mean));
		}
		return Math.sqrt (sumSquaredDifference / (double)size());
	}
	
	/** Calculate pearson's R for the corellation between confidence and
	 * correct, where 1 = correct and -1 = incorrect
	 */
	public double correlation ()
	{
		double xSum = 0;
		double xSumOfSquares = 0;
		double ySum = 0;
		double ySumOfSquares = 0;
		double xySum = 0; // product of x and y
		for (int i = 0; i < size(); i++) {
			double value = ((EntityConfidence)confidences.get(i)).correct() ? 1.0 : -1.0;
			xSum += value;
			xSumOfSquares += (value * value);
			double conf = ((EntityConfidence)confidences.get(i)).confidence();
			ySum += conf;
			ySumOfSquares += (conf * conf);
			xySum += value * conf;
		}
		double xVariance = xSumOfSquares - (xSum * xSum / size());
		double yVariance = ySumOfSquares - (ySum * ySum / size());
		double crossVariance = xySum  - (xSum * ySum / size());
		return crossVariance / Math.sqrt (xVariance * yVariance);
	}
	
	/** get accuracy at coverage for each bin of values
	 */
	public double[] getAccuracyCoverageValues ()
	{
		double [] values = new double [this.nBins];
		int step = 100 / nBins;
		for (int i = 0; i < values.length; i++) {
			values[i] = accuracyAtCoverage (step * (double)(i+1) / 100.0);
		}
		return values;
	}

	public String accuracyCoverageValuesToString () {
		String buf = "";
		double [] vals = getAccuracyCoverageValues ();
		int step = 100 / nBins;
		for (int i=0; i < vals.length; i++) {
			buf += ((step * (double)(i+1))/100.0) + "\t" + vals[i] + "\n";
		}
		return buf;
	}
	
	/** get accuracy at recall for each bin of values
         * @param totalTrue total number of true Segments
         * @return 2-d array where values[i][0] is coverage and
         * values[i][1] is accuracy at position i.
	 */
	public double[][] getAccuracyRecallValues (int totalTrue)
	{
		double [][] values = new double [this.nBins][2];
		int step = 100 / nBins;
		for (int i = 0; i < this.nBins; i++) {
                  values[i] = new double[2];
                  double coverage = step * (double)(i+1) / 100.0;
                  values[i][1] = accuracyAtCoverage(coverage);
                  int numCorrect = numCorrectAtCoverage(coverage);
                  values[i][0] = (double)numCorrect / totalTrue;
		}
		return values;
	}

	public String accuracyRecallValuesToString (int totalTrue) {
		String buf = "";
		double [][] vals = getAccuracyRecallValues (totalTrue);
		for (int i=0; i < this.nBins; i++) 
                  buf += vals[i][0] + "\t" + vals[i][1] + "\n";
		return buf;
	}

	public double accuracyAtCoverage (double cov)
	{
		assert (cov <= 1 && cov > 0);
		int numPoints = (int) (Math.round ((double)size()*cov));
		return ((double)numCorrectAtCoverage(cov) / numPoints);
	}

        public int numCorrectAtCoverage (double cov) {
		assert (cov <= 1 && cov > 0);
		// num accuracies to sum for this value of cov
		int numPoints = (int) (Math.round ((double)size()*cov));
		int numCorrect = 0;
		for (int i = 0; i < numPoints; i++) {
			if (((EntityConfidence)confidences.get(size() - i - 1)).correct())
				numCorrect++;
		}
		return numCorrect;          
        }

	public double getAverageAccuracy ()
	{
		int numCorrect = 0;
		double totalArea= 0.0;
		for(int i=confidences.size()-1; i>=0; i--){
			if ( ((EntityConfidence)confidences.get(i)).correct()) 
				numCorrect++;
			totalArea += (double)numCorrect / (confidences.size() - i);
		}
		return totalArea / confidences.size();				
	}

	public int numCorrect()
	{
		return this.numCorrect;
	}
	/**
		 number of entities correctly extracted 
	 */
	private int getNumCorrectEntities ()
	{
		int sum = 0;
		for (int i = 0; i < confidences.size(); i++) {
			EntityConfidence ec = (EntityConfidence) confidences.get(i);
			if (ec.correct()) {
				sum++;
			}				
		}
		return sum;
	}

  /** Average confidence score for the incorrect entities
	 */
	public double getAverageIncorrectConfidence ()
	{
		double sum = 0.0;
		for (int i = 0; i < confidences.size(); i++) {
			EntityConfidence ec = (EntityConfidence) confidences.get(i);
			if (!ec.correct()) {
				sum += ec.confidence();				
			}				
		}
		return sum / ((double)size() - (double) this.numCorrect); 		
	}
	/** Average confidence score for the incorrect entities		 
	 */
	public double getAverageCorrectConfidence ()
	{
		double sum = 0.0;
		for (int i = 0; i < confidences.size(); i++) {
			EntityConfidence ec = (EntityConfidence) confidences.get(i);
			if (ec.correct()) {
				sum += ec.confidence();				
			}				
		}
		return sum / (double) this.numCorrect; 		
	}

	public int size()
	{
		return confidences.size();
	}

	public String toString()
	{
		StringBuffer toReturn = new StringBuffer();
		for (int i = 0; i < size(); i++) {
			toReturn.append (((EntityConfidence)confidences.get(i)).toString() + " ");
		}
		return toReturn.toString();
	}

  /** a simple class to store a confidence score and whether or not this
   * labeling is correct
   */
  public static class EntityConfidence
  {
    double confidence;
    boolean correct;
    String entity;
    
    public EntityConfidence (double conf, boolean corr, String text){
      this.confidence = conf;
      this.correct = corr;
      this.entity = text;
    }


    public EntityConfidence (double conf, boolean corr, Sequence input, int start, int end){
      this.confidence = conf;
      this.correct = corr;
      StringBuffer buff = new StringBuffer();
      if (input != null) {
        for (int j = start; j <= end; j++){
          FeatureVector fv = (FeatureVector) input.get(j);
          for (int k = 0; k < fv.numLocations(); k++) {
            String featureName = fv.getAlphabet().lookupObject (fv.indexAtLocation (k)).toString();
            if (featureName.startsWith ("W=") && featureName.indexOf("@") == -1){
              buff.append(featureName.substring (featureName.indexOf ('=')+1) + " ");
            }
          }
        }
      }
      this.entity = buff.toString();
    }
    public double confidence () {return confidence;}
    public boolean correct () {return correct;}
    public String toString ()
    {
      StringBuffer toReturn = new StringBuffer();
      toReturn.append(this.entity + " / " + this.confidence + " / "+ (this.correct ? "correct" : "incorrect") + "\n");
      return toReturn.toString();
    }	
  }

  private class ConfidenceComparator implements Comparator
  {
    public final int compare (Object a, Object b)
    {
      double x = ((EntityConfidence) a).confidence();
      double y = ((EntityConfidence) b).confidence();
      double difference = x - y;
      int toReturn = 0;
      if(difference > 0)
        toReturn = 1;
      else if (difference < 0)
        toReturn = -1;
      return(toReturn);		
    }    
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy