All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.stats.AccuracyStats Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.stats;

import java.text.NumberFormat;
import java.util.ArrayList;

import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.PRCurve;
import edu.stanford.nlp.classify.ProbabilisticClassifier;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;

/**
 * Utility class for aggregating counts of true positives, false positives, and
 * false negatives and computing precision/recall/F1 stats. Can be used for a single
 * collection of stats, or to aggregate stats from a bunch of runs.
 *
 * @author Kristina Toutanova
 * @author Jenny Finkel
 */
public class AccuracyStats implements Scorer {

  double confWeightedAccuracy;
  double accuracy;
  double optAccuracy;
  double optConfWeightedAccuracy;
  double logLikelihood;
  int[] accrecall;
  int[] optaccrecall;

  L posLabel;

  String saveFile; // = null;
  static int saveIndex = 1;

  public  AccuracyStats(ProbabilisticClassifier classifier, GeneralDataset data, L posLabel) {
    this.posLabel = posLabel;
    score(classifier, data);
  }

  public AccuracyStats(L posLabel, String saveFile) {
    this.posLabel = posLabel;
    this.saveFile = saveFile;
  }

  public  double score(ProbabilisticClassifier classifier, GeneralDataset data) {

    ArrayList> dataScores = new ArrayList<>();
    for (int i = 0; i < data.size(); i++) {
      Datum d = data.getRVFDatum(i);
      Counter scores = classifier.logProbabilityOf(d);
      int labelD = d.label().equals(posLabel) ? 1 : 0;
      dataScores.add(new Pair<>(Math.exp(scores.getCount(posLabel)), labelD));
    }

    PRCurve prc = new PRCurve(dataScores);

    confWeightedAccuracy = prc.cwa();
    accuracy = prc.accuracy();
    optAccuracy = prc.optimalAccuracy();
    optConfWeightedAccuracy = prc.optimalCwa();
    logLikelihood = prc.logLikelihood();
    accrecall = prc.cwaArray();
    optaccrecall = prc.optimalCwaArray();

    return accuracy;
  }

  public String getDescription(int numDigits) {
    NumberFormat nf = NumberFormat.getNumberInstance();
    nf.setMaximumFractionDigits(numDigits);

    StringBuilder sb = new StringBuilder();
    sb.append("--- Accuracy Stats ---").append('\n');
    sb.append("accuracy: ").append(nf.format(accuracy)).append('\n');
    sb.append("optimal fn accuracy: ").append(nf.format(optAccuracy)).append('\n');
    sb.append("confidence weighted accuracy :").append(nf.format(confWeightedAccuracy)).append('\n');
    sb.append("optimal confidence weighted accuracy: ").append(nf.format(optConfWeightedAccuracy)).append('\n');
    sb.append("log-likelihood: ").append(logLikelihood).append('\n');
    if (saveFile != null) {
      String f = saveFile + '-' + saveIndex;
      sb.append("saving accuracy info to ").append(f).append(".accuracy\n");
      StringUtils.printToFile(f + ".accuracy", toStringArr(accrecall));
      sb.append("saving optimal accuracy info to ").append(f).append(".optimal_accuracy\n");
      StringUtils.printToFile(f + ".optimal_accuracy", toStringArr(optaccrecall));
      saveIndex++;
      //sb.append("accuracy coverage: ").append(toStringArr(accrecall)).append("\n");
      //sb.append("optimal accuracy coverage: ").append(toStringArr(optaccrecall));
    }
    return sb.toString();
  }

  public static String toStringArr(int[] acc) {
    StringBuilder sb = new StringBuilder();
    int total = acc.length;
    for (int i = 0; i < acc.length; i++) {
      double coverage = (i + 1) / (double) total;
      double accuracy = acc[i] / (double) (i + 1);
      coverage *= 1000000;
      accuracy *= 1000000;
      sb.append(((int) coverage) / 10000);
      sb.append('\t');
      sb.append(((int) accuracy) / 10000);
      sb.append('\n');
    }
    return sb.toString();
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy