All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.stats.MultiClassPrecisionRecallExtendedStats Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.stats;

import edu.stanford.nlp.classify.Classifier;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.Datum;
import java.util.function.Function;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Triple;

import java.io.BufferedReader;
import java.io.IOException;
import java.text.NumberFormat;
import java.util.*;
import java.util.regex.Pattern;

/**
 * Extension of MultiClassPrecisionRecallStats that also computes accuracy
 * @author Angel Chang
 */
public class MultiClassPrecisionRecallExtendedStats extends MultiClassPrecisionRecallStats {
  protected IntCounter correctGuesses;
  protected IntCounter foundCorrect;
  protected IntCounter foundGuessed;
  protected int tokensCount = 0;
  protected int tokensCorrect = 0;
  protected int noLabel = 0;

  protected Function stringConverter;

  public  MultiClassPrecisionRecallExtendedStats(Classifier classifier, GeneralDataset data, L negLabel)
  {
    super(classifier, data, negLabel);
  }

  public MultiClassPrecisionRecallExtendedStats(L negLabel)
  {
    super(negLabel);
  }

  public MultiClassPrecisionRecallExtendedStats(Index dataLabelIndex, L negLabel)
  {
    this(negLabel);
    setLabelIndex(dataLabelIndex);
  }

  public void setLabelIndex(Index dataLabelIndex) {
    labelIndex = dataLabelIndex;
    negIndex = labelIndex.indexOf(negLabel);
  }

  public  double score(Classifier classifier, GeneralDataset data) {
    labelIndex = new HashIndex<>();
    labelIndex.addAll(classifier.labels());
    labelIndex.addAll(data.labelIndex.objectsList());
    clearCounts();
    int[] labelsArr = data.getLabelsArray();
    for (int i = 0; i < data.size(); i++) {
      Datum d = data.getRVFDatum(i);
      L guess = classifier.classOf(d);
      addGuess(guess, labelIndex.get(labelsArr[i]));
    }
    finalizeCounts();

    return getFMeasure();
  }

  /**
   * Returns the score (F1) for the given list of guesses
   * @param guesses - Guesses by classifier
   * @param trueLabels - Gold labels to compare guesses against
   * @param dataLabelIndex - Index of labels
   * @return F1 score
   */
  public double score(List guesses, List trueLabels, Index dataLabelIndex) {
    setLabelIndex(dataLabelIndex);
    return score(guesses, trueLabels);
  }

  /**
   * Returns the score (F1) for the given list of guesses
   * @param guesses - Guesses by classifier
   * @param trueLabels - Gold labels to compare guesses against
   * @return F1 score
   */
  public double score(List guesses, List trueLabels) {
    clearCounts();
    addGuesses(guesses, trueLabels);
    finalizeCounts();
    return getFMeasure();
  }

  public double score()
  {
    finalizeCounts();
    return getFMeasure();    
  }

  public void clearCounts()
  {
    if (foundCorrect != null) {
      foundCorrect.clear();
    } else {
      foundCorrect = new IntCounter<>();
    }
    if (foundGuessed != null) {
      foundGuessed.clear();
    } else {
      foundGuessed = new IntCounter<>();
    }
    if (correctGuesses != null) {
      correctGuesses.clear();
    } else {
      correctGuesses = new IntCounter<>();
    }
    if (tpCount != null) {
      Arrays.fill(tpCount, 0);
    }
    if (fnCount != null) {
      Arrays.fill(fnCount, 0);
    }
    if (fpCount != null) {
      Arrays.fill(fpCount, 0);
    }
    tokensCount = 0;
    tokensCorrect = 0;
  }

  protected void finalizeCounts()
  {
    negIndex = labelIndex.indexOf(negLabel);
    int numClasses = labelIndex.size();
    if (tpCount == null || tpCount.length != numClasses) {
      tpCount = new int[numClasses];
    }
    if (fpCount == null || fpCount.length != numClasses) {
      fpCount = new int[numClasses];
    }
    if (fnCount == null || fnCount.length != numClasses) {
      fnCount = new int[numClasses];
    }
    for (int i = 0; i < numClasses; i++) {
      L label = labelIndex.get(i);
      tpCount[i] = correctGuesses.getIntCount(label);
      fnCount[i] = foundCorrect.getIntCount(label) - tpCount[i];
      fpCount[i] = foundGuessed.getIntCount(label) - tpCount[i];
    }
  }

  protected void markBoundary()
  {
  }

  protected void addGuess(L guess, L label)
  {
    addGuess(guess, label, true);
  }
  
  protected void addGuess(L guess, L label, boolean addUnknownLabels)
  {
    if (label == null) {
        noLabel++;
        return;
    }
    if (addUnknownLabels) {
      if (labelIndex == null) {
        labelIndex = new HashIndex<>();
      }
      labelIndex.add(guess);
      labelIndex.add(label);
    }
    if (guess.equals(label)) {
      correctGuesses.incrementCount(label);
      tokensCorrect++;
    }

    if (!guess.equals(negLabel)) {
      foundGuessed.incrementCount(guess);
    }

    if (!label.equals(negLabel)) {
      foundCorrect.incrementCount(label);
    }
    tokensCount++;
  }

  public void addGuesses(List guesses, List trueLabels)
  {
    for (int i=0; i < guesses.size(); ++i)
    {
      L guess = guesses.get(i);
      L label = trueLabels.get(i);
      addGuess(guess, label);
    }
  }

  /**
   * Return overall number of correct answers
   */
  public int getCorrect()
  {
    return correctGuesses.totalIntCount();
  }

  public int getCorrect(L label)
  {
    return correctGuesses.getIntCount(label);
  }

  public int getRetrieved(L label)
  {
    return foundGuessed.getIntCount(label);
  }

  public int getRetrieved()
  {
    return foundGuessed.totalIntCount();
  }

  public int getRelevant(L label)
  {
    return foundCorrect.getIntCount(label);
  }

  public int getRelevant()
  {
    return foundCorrect.totalIntCount();
  }

  /**
   * Return overall per token accuracy
   */
  public Triple getAccuracyInfo()
  {
    int totalCorrect = tokensCorrect;
    int totalWrong = tokensCount - tokensCorrect;
    return new Triple<>((((double) totalCorrect) / tokensCount),
            totalCorrect, totalWrong);
  }

  public double getAccuracy() {
    return getAccuracyInfo().first();
  }

  /**
   * Returns a String summarizing overall accuracy that will print nicely.
   */
  public String getAccuracyDescription(int numDigits) {
    NumberFormat nf = NumberFormat.getNumberInstance();
    nf.setMaximumFractionDigits(numDigits);
    Triple accu = getAccuracyInfo();
    return nf.format(accu.first()) + "  (" + accu.second() + "/" + (accu.second() + accu.third()) + ")";
  }

  public double score(String filename, String delimiter) throws IOException {
    return score(filename, delimiter, null);
  }

  public double score(String filename, String delimiter, String boundary) throws IOException {
    return score(IOUtils.getBufferedFileReader(filename), delimiter, boundary);
  }

  public double score(BufferedReader br, String delimiter) throws IOException
  {
    return score(br, delimiter, null);
  }
  
  public double score(BufferedReader br, String delimiter, String boundary) throws IOException
  {
    int TOKEN_INDEX = 0;
    int ANSWER_INDEX = 1;
    int GUESS_INDEX = 2;
    String line;
    Pattern delimPattern = Pattern.compile(delimiter);
    clearCounts();
    while ((line = br.readLine()) != null) {
      line = line.trim();
      if (line.length() > 0) {
        String[] fields = delimPattern.split(line);
        if (boundary != null && boundary.equals(fields[TOKEN_INDEX])) {
          markBoundary();
        } else {
          L answer = stringConverter.apply(fields[ANSWER_INDEX]);
          L guess = stringConverter.apply(fields[GUESS_INDEX]);
          addGuess(guess, answer);
        }
      } else {
        markBoundary();
      }
    }
    finalizeCounts();
    return getFMeasure();
  }

  public List getLabels() {
    return labelIndex.objectsList();
  }

  public String getConllEvalString()
  {
    return getConllEvalString(true);
  }

  public String getConllEvalString(boolean ignoreNegLabel)
  {
    List labels = getLabels();
    if (labels.size() > 1 && labels.get(0) instanceof Comparable) {
      List sortedLabels = (List) labels;
      Collections.sort(sortedLabels);
    }
    return getConllEvalString(labels, ignoreNegLabel);
  }

  private String getConllEvalString(List orderedLabels, boolean ignoreNegLabel)
  {
    StringBuilder sb = new StringBuilder();
    int correctPhrases = getCorrect() - getCorrect(negLabel);
    Triple accuracyInfo = getAccuracyInfo();
    int totalCount = accuracyInfo.second() + accuracyInfo.third();
    sb.append("processed " + totalCount + " tokens with " + getRelevant() + " phrases; ");
    sb.append("found: " + getRetrieved() + " phrases; correct: " + correctPhrases + "\n");

    Formatter formatter = new Formatter(sb, Locale.US);
    formatter.format("accuracy: %6.2f%%; ", accuracyInfo.first() * 100);
    formatter.format("precision: %6.2f%%; ", getPrecision() * 100);
    formatter.format("recall: %6.2f%%; ", getRecall() * 100);
    formatter.format("FB1: %6.2f\n", getFMeasure() * 100);
    for (L label: orderedLabels) {
      if (ignoreNegLabel && label.equals(negLabel)) { continue; }
      formatter.format("%17s: ", label);
      formatter.format("precision: %6.2f%%; ", getPrecision(label) * 100);
      formatter.format("recall: %6.2f%%; ", getRecall(label) * 100);
      formatter.format("FB1: %6.2f  %d\n", getFMeasure(label) * 100, getRetrieved(label));
    }
    return sb.toString();
  }

  public static class StringStringConverter implements Function
  {
    public String apply(String str) { return str; }
  }

  public static class MultiClassStringLabelStats extends MultiClassPrecisionRecallExtendedStats
  {
    public  MultiClassStringLabelStats(Classifier classifier, GeneralDataset data, String negLabel)
    {
      super(classifier, data, negLabel);
      stringConverter = new StringStringConverter();
    }

    public MultiClassStringLabelStats(String negLabel)
    {
      super(negLabel);
      stringConverter = new StringStringConverter();
    }

    public MultiClassStringLabelStats(Index dataLabelIndex, String negLabel)
    {
      this(negLabel);
      setLabelIndex(dataLabelIndex);
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy