All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.stats.MultiClassAccuracyStats Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.stats;

import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.ProbabilisticClassifier;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.util.BinaryHeapPriorityQueue;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.PriorityQueue;
import edu.stanford.nlp.util.StringUtils;

import java.text.NumberFormat;
import java.util.List;


/**
 * @author Jenny Finkel
 */
public class MultiClassAccuracyStats implements Scorer {
  double[] scores; //sorted scores
  boolean[] isCorrect; // is the i-th example correct
  double logLikelihood;
  double accuracy;
  static String saveFile = null;
  static int saveIndex = 1;

  public static final int USE_ACCURACY = 1;
  public static final int USE_LOGLIKELIHOOD = 2;

  private int scoreType = USE_ACCURACY;


  public MultiClassAccuracyStats(){
  }

  public MultiClassAccuracyStats(int scoreType){
    this.scoreType = scoreType;
  }

  public MultiClassAccuracyStats(String file){
    this(file, USE_ACCURACY);
  }

  public MultiClassAccuracyStats(String file, int scoreType){
    saveFile=file;
    this.scoreType = scoreType;
  }

  public  MultiClassAccuracyStats(ProbabilisticClassifier classifier, GeneralDataset data,String file) {
    this(classifier, data, file, USE_ACCURACY);
  }

  public   MultiClassAccuracyStats(ProbabilisticClassifier classifier, GeneralDataset data,String file, int scoreType) {
    saveFile=file;
    this.scoreType = scoreType;
    initMC(classifier, data);
  }

  int correct = 0;
  int total = 0;

  public  double score(ProbabilisticClassifier classifier, GeneralDataset data) {
      initMC(classifier,data);
      return score();
  }

  public double score() {
    if (scoreType == USE_ACCURACY) {
      return accuracy;
    } else if (scoreType == USE_LOGLIKELIHOOD) {
      return logLikelihood;
    } else {
      throw new RuntimeException("Unknown score type: "+scoreType);
    }
  }

  public int numSamples() {
    return scores.length;
  }

  public double confidenceWeightedAccuracy() {
    double acc = 0;
    for (int recall = 1; recall <= numSamples(); recall++) {
      acc += numCorrect(recall) / (double) recall;
    }
    return acc / numSamples();
  }

  public  void initMC(ProbabilisticClassifier classifier, GeneralDataset data) {
    //if (!(gData instanceof Dataset)) {
    //  throw new UnsupportedOperationException("Can only handle Datasets, not "+gData.getClass().getName());
    //}
    //
    //Dataset data = (Dataset)gData;

    PriorityQueue>> q = new BinaryHeapPriorityQueue<>();
    total = 0;
    correct = 0;
    logLikelihood = 0.0;
    for (int i = 0; i < data.size(); i++) {
      Datum d = data.getRVFDatum(i);
      Counter scores = classifier.logProbabilityOf(d);
      L guess = Counters.argmax(scores);
      L correctLab = d.label();
      double guessScore = scores.getCount(guess);
      double correctScore = scores.getCount(correctLab);
      int guessInd = data.labelIndex().indexOf(guess);
      int correctInd = data.labelIndex().indexOf(correctLab);

      total++;
      if (guessInd == correctInd) {
        correct++;
      }
      logLikelihood += correctScore;
      q.add(new Pair<>(Integer.valueOf(i), new Pair<>(new Double(guessScore), Boolean.valueOf(guessInd == correctInd))), -guessScore);
    }
    accuracy = (double) correct / (double) total;
    List>> sorted = q.toSortedList();
    scores = new double[sorted.size()];
    isCorrect = new boolean[sorted.size()];

    for (int i = 0; i < sorted.size(); i++) {
      Pair next = sorted.get(i).second();
      scores[i] = next.first().doubleValue();
      isCorrect[i] = next.second().booleanValue();
    }

  }

  /**
   * how many correct do we have if we return the most confident num recall ones
   *
   */
  public int numCorrect(int recall) {
    int correct = 0;
    for (int j = scores.length - 1; j >= scores.length - recall; j--) {
      if (isCorrect[j]) {
        correct++;
      }
    }
    return correct;
  }


  public int[] getAccCoverage() {
    int[] arr = new int[numSamples()];
    for (int recall = 1; recall <= numSamples(); recall++) {
      arr[recall - 1] = numCorrect(recall);
    }
    return arr;
  }

  public String getDescription(int numDigits) {
    NumberFormat nf = NumberFormat.getNumberInstance();
    nf.setMaximumFractionDigits(numDigits);

    StringBuilder sb = new StringBuilder();
    double confWeightedAccuracy = confidenceWeightedAccuracy();
    sb.append("--- Accuracy Stats ---").append("\n");
    sb.append("accuracy: ").append(nf.format(accuracy)).append(" (").append(correct).append("/").append(total).append(")\n");
    sb.append("confidence weighted accuracy :").append(nf.format(confWeightedAccuracy)).append("\n");
    sb.append("log-likelihood: ").append(logLikelihood).append("\n");
    if (saveFile != null) {
      String f = saveFile + "-" + saveIndex;
      sb.append("saving accuracy info to ").append(f).append(".accuracy\n");
      StringUtils.printToFile(f + ".accuracy", AccuracyStats.toStringArr(getAccCoverage()));
      saveIndex++;
      //sb.append("accuracy coverage: ").append(toStringArr(accrecall)).append("\n");
      //sb.append("optimal accuracy coverage: ").append(toStringArr(optaccrecall));
    }
    return sb.toString();
  }

  @Override
  public String toString() {
    String accuracyType = null;
    if(scoreType == USE_ACCURACY)
      accuracyType = "classification_accuracy";
    else if(scoreType == USE_LOGLIKELIHOOD)
      accuracyType = "log_likelihood";
    else
      accuracyType = "unknown";
    return "MultiClassAccuracyStats(" + accuracyType  + ")" + scoreType + USE_ACCURACY + USE_LOGLIKELIHOOD;
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy