All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.berkeley.nlp.parser.EnglishPennTreebankParseEvaluator Maven / Gradle / Ivy

Go to download

The Berkeley parser analyzes the grammatical structure of natural language using probabilistic context-free grammars (PCFGs).

The newest version!
package edu.berkeley.nlp.parser;

import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.syntax.Trees;

import java.util.*;
import java.io.PrintWriter;
import java.io.StringReader;

/**
 * Evaluates precision and recall for English Penn Treebank parse trees.  NOTE: Unlike the standard evaluation, multiplicity over each span is ignored.  Also, punction is NOT currently deleted properly (approximate hack), and other normalizations (like AVDP ~ PRT) are NOT done.
 *
 * @author Dan Klein
 */
public class EnglishPennTreebankParseEvaluator {
	static class UnlabeledConstituent {
	    
	    int start;
	    int end;

	   

	    public int getStart() {
	      return start;
	    }

	    public int getEnd() {
	      return end;
	    }

	    public boolean equals(Object o) {
	      if (this == o) return true;
	      if (!(o instanceof UnlabeledConstituent)) return false;

	      final UnlabeledConstituent unlabeledConstituent = (UnlabeledConstituent) o;

	      if (end != unlabeledConstituent.end) return false;
	      if (start != unlabeledConstituent.start) return false;
	    

	      return true;
	    }

	    public int hashCode() {
	      int result;
	    
	      result = start;
	      result = 29 * result + end;
	      return result;
	    }

	    public String toString() {
	      return "["+start+","+end+"]";
	    }

	    public UnlabeledConstituent(int start, int end) {
	      
	      this.start = start;
	      this.end = end;
	    }
	  }

	abstract static class AbstractEval {

    protected String str = "";

    private int exact = 0;
    private int total = 0;

    private int correctEvents = 0;
    private int guessedEvents = 0;
    private int goldEvents = 0;

    abstract Set makeObjects(Tree tree);

    public double evaluate(Tree guess, Tree gold) {
      return evaluate(guess, gold, new PrintWriter(System.out, true));
    }

    public double evaluate(Tree guess, Tree gold, boolean b) {
      return evaluate(guess, gold, null);
    }
    /* evaluates precision and recall by calling makeObjects() to make a
     * set of structures for guess Tree and gold Tree, and compares them
     * with each other.  */
    public double evaluate(Tree guess, Tree gold, PrintWriter pw) {
      Set guessedSet = makeObjects(guess);
      Set goldSet = makeObjects(gold);
      Set correctSet = new HashSet();
      correctSet.addAll(goldSet);
      correctSet.retainAll(guessedSet);

      correctEvents += correctSet.size();
      guessedEvents += guessedSet.size();
      goldEvents += goldSet.size();

      int currentExact = 0;
      if (correctSet.size() == guessedSet.size() &&
          correctSet.size() == goldSet.size()) {
        exact++;
        currentExact = 1;
      }
      total++;

//      guess.pennPrint(pw);
//      gold.pennPrint(pw);
      double f1 = displayPRF(str+" [Current] ", correctSet.size(), guessedSet.size(), goldSet.size(), currentExact, 1, pw);
      return f1;

    }

    public double evaluateMultiple(List> guesses, List> golds,
    		PrintWriter pw) {
    	assert(guesses.size() == golds.size());
    	int correctCount = 0;
    	int guessedCount = 0;
    	int goldCount = 0;
    	for (int i=0; i guess = guesses.get(i);
    		Tree gold = golds.get(i);
            Set guessedSet = makeObjects(guess);
            Set goldSet = makeObjects(gold);
            Set correctSet = new HashSet();
            correctSet.addAll(goldSet);
            correctSet.retainAll(guessedSet);
            correctCount += correctSet.size();
            guessedCount += guessedSet.size();
            goldCount += goldSet.size();
    	}

        correctEvents += correctCount;
        guessedEvents += guessedCount;
        goldEvents += goldCount;

        int currentExact = 0;
        if (correctCount == guessedCount &&
            correctCount == goldCount) {
          exact++;
          currentExact = 1;
        }
        total++;

//        guess.pennPrint(pw);
//        gold.pennPrint(pw);
        double f1 = displayPRF(str+" [Current] ", correctCount, guessedCount, goldCount, currentExact, 1, pw);
        return f1;
        
    }
    public double[] massEvaluate(Tree guess, Tree[] goldTrees) {
      Set guessedSet = makeObjects(guess);
      double cEvents = 0;
      double guEvents = 0;
      double goEvents = 0;
      double exactM = 0, precision=0, recall=0, f1=0;
      	
      for (int treeI=0; treeI gold = goldTrees[treeI];
	      Set goldSet = makeObjects(gold);
	      Set correctSet = new HashSet();
	      correctSet.addAll(goldSet);
	      correctSet.retainAll(guessedSet);
	      cEvents = correctSet.size();
	      guEvents = guessedSet.size();
	      goEvents = goldSet.size();

	      double p = cEvents / guEvents;
	      double r = cEvents / goEvents;
		    double f = (p > 0.0 && r > 0.0 ? 2.0 / (1.0 / p + 1.0 / r) : 0.0);
		    
		    precision += p;
	      recall += r;
	      f1 += f;
	      
	      if (cEvents == guEvents && cEvents == goEvents) {
	        exactM++;
	      }
      }
      double ex = exactM/goldTrees.length;
      double[] results = {precision, recall, f1, ex};
      
	    return results;

    }
    
    private double displayPRF(String prefixStr, int correct, int guessed, int gold, int exact, int total, PrintWriter pw) {
      double precision = (guessed > 0 ? correct / (double) guessed : 1.0);
      double recall = (gold > 0 ? correct / (double) gold : 1.0);
      double f1 = (precision > 0.0 && recall > 0.0 ? 2.0 / (1.0 / precision + 1.0 / recall) : 0.0);

      double exactMatch = exact / (double) total;

      String displayStr = " P: " + ((int) (precision * 10000)) / 100.0 + " R: " + ((int) (recall * 10000)) / 100.0 + " F1: " + ((int) (f1 * 10000)) / 100.0 + " EX: "+((int) (exactMatch * 10000)) / 100.0 ;

      if (pw!=null) pw.println(prefixStr+displayStr);
      return f1;
    }
    
    

    public double display(boolean verbose) {
      return display(verbose, new PrintWriter(System.out, true));
    }

    public double display(boolean verbose, PrintWriter pw) {
     return displayPRF(str+" [Average] ", correctEvents, guessedEvents, goldEvents, exact, total, pw);
    }
  }

  static class LabeledConstituent {
    L label;
    int start;
    int end;

    public L getLabel() {
      return label;
    }

    public int getStart() {
      return start;
    }

    public int getEnd() {
      return end;
    }

    public boolean equals(Object o) {
      if (this == o) return true;
      if (!(o instanceof LabeledConstituent)) return false;

      final LabeledConstituent labeledConstituent = (LabeledConstituent) o;

      if (end != labeledConstituent.end) return false;
      if (start != labeledConstituent.start) return false;
      if (label != null ? !label.equals(labeledConstituent.label) : labeledConstituent.label != null) return false;

      return true;
    }

    public int hashCode() {
      int result;
      result = (label != null ? label.hashCode() : 0);
      result = 29 * result + start;
      result = 29 * result + end;
      return result;
    }

    public String toString() {
      return label+"["+start+","+end+"]";
    }

    public LabeledConstituent(L label, int start, int end) {
      this.label = label;
      this.start = start;
      this.end = end;
    }
  }
  
  public static class UnlabeledConstituentEval extends AbstractEval
  {

	public UnlabeledConstituentEval()
	{
		
		
	}

	@Override
	 Set makeObjects(Tree tree) {
	      Tree noLeafTree = LabeledConstituentEval.stripLeaves(tree);
	      Set set = new HashSet();
	      addConstituents(noLeafTree, set, 0);
	      return set;
	    }

	    private int addConstituents(Tree tree, Set set, int start) {
	    	if (tree==null)
	    		return 0;
	      if (tree.getYield().size() == 1) {
	        
	          return 1;
	      }
	      int end = start;
	      for (Tree child : tree.getChildren()) {
	        int childSpan = addConstituents(child, set, end);
	        end += childSpan;
	      }
	     
	     
	        set.add(new UnlabeledConstituent(start, end));
	      
	      return end - start;
	    }

	

	
	  
  }

  public static class LabeledConstituentEval extends AbstractEval {

    Set labelsToIgnore;
    Set punctuationTags;

    static  Tree stripLeaves(Tree tree) {
      if (tree.isLeaf())
        return null;
      if (tree.isPreTerminal())
        return new Tree(tree.getLabel());
      List> children = new ArrayList>();
      for (Tree child : tree.getChildren()) {
        children.add(stripLeaves(child));
      }
      return new Tree(tree.getLabel(), children);
    }

    Set makeObjects(Tree tree) {
      Tree noLeafTree = stripLeaves(tree);
      Set set = new HashSet();
      addConstituents(noLeafTree, set, 0);
      return set;
    }

    private int addConstituents(Tree tree, Set set, int start) {
    	if (tree==null)
    		return 0;
      if (tree.isLeaf()) {
        if (punctuationTags.contains(tree.getLabel()))
          return 0;
        else
          return 1;
      }
      int end = start;
      for (Tree child : tree.getChildren()) {
        int childSpan = addConstituents(child, set, end);
        end += childSpan;
      }
      L label = tree.getLabel();
      if (! labelsToIgnore.contains(label)) {
        set.add(new LabeledConstituent(label, start, end));
      }
      return end - start;
    }

	


    public LabeledConstituentEval(Set labelsToIgnore, Set punctuationTags) {
      this.labelsToIgnore = labelsToIgnore;
      this.punctuationTags = punctuationTags;
    }

	public int getHammingDistance(Tree guess, Tree gold) {
	      Set guessedSet = makeObjects(guess);
	      Set goldSet = makeObjects(gold);
	      Set correctSet = new HashSet();
	      correctSet.addAll(goldSet);
	      correctSet.retainAll(guessedSet);
	      return (guessedSet.size() - correctSet.size()) + (goldSet.size() - correctSet.size());
	}

  }

  public static void main(String[] args) throws Throwable {
    Tree goldTree = (new Trees.PennTreeReader(new StringReader("(ROOT (S (NP (DT the) (NN can)) (VP (VBD fell))))"))).next();
    Tree guessedTree = (new Trees.PennTreeReader(new StringReader("(ROOT (S (NP (DT the)) (VP (MB can) (VP (VBD fell)))))"))).next();
    LabeledConstituentEval eval = new LabeledConstituentEval(Collections.singleton("ROOT"), new HashSet());
    RuleEval rule_eval = new RuleEval(Collections.singleton("ROOT"), new HashSet());
    System.out.println("Gold tree:\n"+Trees.PennTreeRenderer.render(goldTree));
    System.out.println("Guessed tree:\n"+Trees.PennTreeRenderer.render(guessedTree));
    eval.evaluate(guessedTree, goldTree);
    eval.display(true);
    rule_eval.evaluate(guessedTree, goldTree);
    rule_eval.display(true);
  }
  
  public static class RuleEval extends AbstractEval {
    Set labelsToIgnore;
    Set punctuationTags;

    static  Tree stripLeaves(Tree tree) {
      if (tree.isLeaf())
        return null;
      if (tree.isPreTerminal())
        return new Tree(tree.getLabel());
      List> children = new ArrayList>();
      for (Tree child : tree.getChildren()) {
        children.add(stripLeaves(child));
      }
      return new Tree(tree.getLabel(), children);
    }

    Set makeObjects(Tree tree) {
      Tree noLeafTree = stripLeaves(tree);
      Set set = new HashSet();
      addConstituents(noLeafTree, set, 0);
      return set;
    }

    private int addConstituents(Tree tree, Set set, int start) {
    	if (tree==null)
    		return 0;
      if (tree.isLeaf()) {
/*        if (punctuationTags.contains(tree.getLabel()))
          return 0;
        else*/
          return 1;
      }
      int end = start, i=0;
      L lC=null, rC=null;
      for (Tree child : tree.getChildren()) {
        int childSpan = addConstituents(child, set, end);
        if (i==0) lC = child.getLabel();
        else /*i==1*/ rC = child.getLabel();
        i++;
        end += childSpan;
      }
      L label = tree.getLabel();
      if (! labelsToIgnore.contains(label)) {
        set.add(new RuleConstituent(label, lC, rC, start, end));
      }
      return end - start;
    }


    public RuleEval(Set labelsToIgnore, Set punctuationTags) {
      this.labelsToIgnore = labelsToIgnore;
      this.punctuationTags = punctuationTags;
    }

  }

  
  static class RuleConstituent {
    L label, lChild, rChild;
    int start;
    int end;

    public L getLabel() {
      return label;
    }

    public int getStart() {
      return start;
    }

    public int getEnd() {
      return end;
    }

    public boolean equals(Object o) {
      if (this == o) return true;
      if (!(o instanceof RuleConstituent)) return false;

      final RuleConstituent labeledConstituent = (RuleConstituent) o;

      if (end != labeledConstituent.end) return false;
      if (start != labeledConstituent.start) return false;
      if (label != null ? !label.equals(labeledConstituent.label) : labeledConstituent.label != null) return false;
      if (lChild != null ? !lChild.equals(labeledConstituent.lChild) : labeledConstituent.lChild != null) return false;
      if (rChild != null ? !rChild.equals(labeledConstituent.rChild) : labeledConstituent.rChild != null) return false;

      return true;
    }

    public int hashCode() {
      int result;
      result = (label != null ? label.hashCode() : 0) + 17 * (lChild != null ? lChild.hashCode() : 0) - 7*(rChild != null ? rChild.hashCode() : 0);
      result = 29 * result + start;
      result = 29 * result + end;
      return result;
    }

    public String toString() {
    	String rChildStr = (rChild==null) ? "" : rChild.toString();
      return label+"->"+lChild+" "+rChildStr+"["+start+","+end+"]";
    }

    public RuleConstituent(L label, L lChild, L rChild, int start, int end) {
      this.label = label;
      this.lChild = lChild;
      this.rChild = rChild;
      this.start = start;
      this.end = end;
    }
  }

}