All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.ie.machinereading.RelationExtractorResultsPrinter Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.ie.machinereading;

import java.text.DecimalFormat;
import java.util.*;
import java.io.*;

import edu.stanford.nlp.ie.machinereading.structure.AnnotationUtils;
import edu.stanford.nlp.ie.machinereading.structure.RelationMention;
import edu.stanford.nlp.ie.machinereading.structure.RelationMentionFactory;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;

public class RelationExtractorResultsPrinter extends ResultsPrinter {
  
  protected boolean createUnrelatedRelations;
  
  protected final RelationMentionFactory relationMentionFactory;
  
  public RelationExtractorResultsPrinter(RelationMentionFactory factory) {
    this(factory, true);
  }
  
  public RelationExtractorResultsPrinter() {
    this(new RelationMentionFactory(), true);
  }
  
  public RelationExtractorResultsPrinter(boolean createUnrelatedRelations) {
    this(new RelationMentionFactory(), createUnrelatedRelations);
  }
  
  public RelationExtractorResultsPrinter(RelationMentionFactory factory, boolean createUnrelatedRelations) {
    this.createUnrelatedRelations = createUnrelatedRelations;
    this.relationMentionFactory = factory;
  }
  
  private static final int MAX_LABEL_LENGTH = 31;

  @Override
  public void printResults(PrintWriter pw, 
      List goldStandard,
      List extractorOutput) {
    ResultsPrinter.align(goldStandard, extractorOutput);
    
    // the mention factory cannot be null here
    assert relationMentionFactory != null: "ERROR: RelationExtractorResultsPrinter.relationMentionFactory cannot be null in printResults!";
    
    // Count predicted-actual relation type pairs
    Counter> results = new ClassicCounter<>();
    ClassicCounter labelCount = new ClassicCounter<>();
    
    // TODO: assumes binary relations
    for (int goldSentenceIndex = 0; goldSentenceIndex < goldStandard.size(); goldSentenceIndex++) {
      for (RelationMention goldRelation : AnnotationUtils.getAllRelations(relationMentionFactory, goldStandard.get(goldSentenceIndex), createUnrelatedRelations)) {
        CoreMap extractorSentence = extractorOutput.get(goldSentenceIndex);
        List extractorRelations = AnnotationUtils.getRelations(relationMentionFactory, extractorSentence, goldRelation.getArg(0), goldRelation.getArg(1));
        labelCount.incrementCount(goldRelation.getType());
        for (RelationMention extractorRelation : extractorRelations) {
          results.incrementCount(new Pair<>(extractorRelation.getType(), goldRelation.getType()));
        }
      }
    }

    printResultsInternal(pw, results, labelCount);
  }
  
  private void printResultsInternal(PrintWriter pw, Counter> results, ClassicCounter labelCount) {
    ClassicCounter correct = new ClassicCounter<>();
    ClassicCounter predictionCount = new ClassicCounter<>();
    boolean countGoldLabels = false;
    if (labelCount == null) {
      labelCount = new ClassicCounter<>();
      countGoldLabels = true;
    }

    for (Pair predictedActual : results.keySet()) {
      String predicted = predictedActual.first;
      String actual = predictedActual.second;
      if (predicted.equals(actual)) {
        correct.incrementCount(actual, results.getCount(predictedActual));
      }
      predictionCount.incrementCount(predicted, results.getCount(predictedActual));
      if (countGoldLabels) {
        labelCount.incrementCount(actual, results.getCount(predictedActual));	
      }
    }

    DecimalFormat formatter = new DecimalFormat();
    formatter.setMaximumFractionDigits(1);
    formatter.setMinimumFractionDigits(1);

    double totalCount = 0;
    double totalCorrect = 0;
    double totalPredicted = 0;
    pw.println("Label\tCorrect\tPredict\tActual\tPrecn\tRecall\tF");
    List labels = new ArrayList<>(labelCount.keySet());
    Collections.sort(labels);
    for (String label : labels) {
      double numcorrect = correct.getCount(label);
      double predicted = predictionCount.getCount(label);
      double trueCount = labelCount.getCount(label);
      double precision = (predicted > 0) ? (numcorrect / predicted) : 0;
      double recall = numcorrect / trueCount;
      double f = (precision + recall > 0) ? 2 * precision * recall / (precision + recall) : 0.0;
      pw.println(StringUtils.padOrTrim(label, MAX_LABEL_LENGTH) + "\t" + numcorrect + "\t" + predicted + "\t" + trueCount + "\t"
          + formatter.format(precision * 100) + "\t" + formatter.format(100 * recall) + "\t"
          + formatter.format(100 * f));
      if (!RelationMention.isUnrelatedLabel(label)) {
        totalCount += trueCount;
        totalCorrect += numcorrect;
        totalPredicted += predicted;
      }
    }
    double precision = (totalPredicted > 0) ? (totalCorrect / totalPredicted) : 0;
    double recall = totalCorrect / totalCount;
    double f = (totalPredicted > 0 && totalCorrect > 0) ? 2 * precision * recall / (precision + recall) : 0.0;
    pw.println("Total\t" + totalCorrect + "\t" + totalPredicted + "\t" + totalCount + "\t"
        + formatter.format(100 * precision) + "\t" + formatter.format(100 * recall) + "\t" + formatter.format(100 * f));
  }

  public void printResultsUsingLabels(PrintWriter pw, 
      List goldStandard,
      List extractorOutput) {
    
    // Count predicted-actual relation type pairs
    Counter> results = new ClassicCounter<>();
    assert(goldStandard.size() == extractorOutput.size());
    for(int i = 0; i < goldStandard.size(); i ++) 
      results.incrementCount(new Pair<>(extractorOutput.get(i), goldStandard.get(i)));
    
    printResultsInternal(pw, results, null);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy