All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.parser.metrics.AbstractEval Maven / Gradle / Ivy

package edu.stanford.nlp.parser.metrics;

import java.util.*;
import java.io.PrintWriter;
import java.text.NumberFormat;
import java.text.DecimalFormat;

import edu.stanford.nlp.parser.KBestViterbiParser;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;

/**
 * A framework for Set-based precision/recall/F1 evaluation.
 *
 * @author Dan Klein
 */
public abstract class AbstractEval implements Eval {

  private static final boolean DEBUG = false;

  protected final String str;
  protected final boolean runningAverages;

  private double precision = 0.0;
  private double recall = 0.0;
  private double f1 = 0.0;
  protected double num = 0.0;
  private double exact = 0.0;

  private double precision2 = 0.0;
  private double recall2 = 0.0;
  private double pnum2 = 0.0;
  private double rnum2 = 0.0;

  protected double curF1 = 0.0;

  public AbstractEval() {
    this(true);
  }

  public AbstractEval(boolean runningAverages) {
    this("", runningAverages);
  }

  public AbstractEval(String str) {
    this(str, true);
  }

  public AbstractEval(String str, boolean runningAverages) {
    this.str = str;
    this.runningAverages = runningAverages;
  }

  public double getSentAveF1() {
    return f1 / num;
  }

  public double getEvalbF1() {
    return 2.0 / (rnum2 / recall2 + pnum2 / precision2);
  }

  /**
   * Return the evalb F1% from the last call to {@link #evaluate}.
   *
   * @return The F1 percentage
   */
  public double getLastF1() {
    return curF1 * 100.0;
  }

  /** @return The evalb (micro-averaged) F1 times 100 to make it
   *  a number between 0 and 100.
   */
  public double getEvalbF1Percent() {
    return getEvalbF1() * 100.0;
  }

  public double getExact() {
    return exact / num;
  }

  public double getExactPercent() {
    return getExact() * 100.0;
  }

  public int getNum() {
    return (int) num;
  }

  // should be able to pass in a comparator!
  protected static double precision(Set s1, Set s2) {
    double n = 0.0;
    double p = 0.0;
    for (Object o1 : s1) {
      if (s2.contains(o1)) {
        p += 1.0;
      }
      if (DEBUG) {
        if (s2.contains(o1)) {
          System.err.println("Eval Found: "+o1);
        } else {
          System.err.println("Eval Failed to find: "+o1);
        }
      }
      n += 1.0;
    }
    if (DEBUG) System.err.println("Matched " + p + " of " + n);
    return (n > 0.0 ? p / n : 0.0);
  }

  protected abstract Set makeObjects(Tree tree);

  public void evaluate(Tree guess, Tree gold) {
    evaluate(guess, gold, new PrintWriter(System.out, true));
  }

  /* Evaluates precision and recall by calling makeObjects() to make a
   * set of structures for guess Tree and gold Tree, and compares them
   * with each other.
   */
  public void evaluate(Tree guess, Tree gold, PrintWriter pw) {
    evaluate(guess, gold, pw, 1.0);
  }

  public void evaluate(Tree guess, Tree gold, PrintWriter pw, double weight) {
    if (DEBUG) {
      System.err.println("Evaluating gold tree:");
      gold.pennPrint(System.err);
      System.err.println("and guess tree");
      guess.pennPrint(System.err);
    }
    Set dep1 = makeObjects(guess);
    Set dep2 = makeObjects(gold);
    final double curPrecision = precision(dep1, dep2);
    final double curRecall = precision(dep2, dep1);
    curF1 = (curPrecision > 0.0 && curRecall > 0.0 ? 2.0 / (1.0 / curPrecision + 1.0 / curRecall) : 0.0);
    precision += curPrecision * weight;
    recall += curRecall * weight;
    f1 += curF1 * weight;
    num += weight;

    precision2 += dep1.size() * curPrecision * weight;
    pnum2 += dep1.size() * weight;

    recall2 += dep2.size() * curRecall * weight;
    rnum2 += dep2.size() * weight;

    if (curF1 > 0.9999) {
      exact += 1.0;
    }
    if (pw != null) {
      pw.print(" P: " + ((int) (curPrecision * 10000)) / 100.0);
      if (runningAverages) {
        pw.println(" (sent ave " + ((int) (precision * 10000 / num)) / 100.0 + ") (evalb " + ((int) (precision2 * 10000 / pnum2)) / 100.0 + ")");
      }
      pw.print(" R: " + ((int) (curRecall * 10000)) / 100.0);
      if (runningAverages) {
        pw.print(" (sent ave " + ((int) (recall * 10000 / num)) / 100.0 + ") (evalb " + ((int) (recall2 * 10000 / rnum2)) / 100.0 + ")");
      }
      pw.println();
      double cF1 = 2.0 / (rnum2 / recall2 + pnum2 / precision2);
      pw.print(str + " F1: " + ((int) (curF1 * 10000)) / 100.0);
      if (runningAverages) {
        pw.print(" (sent ave " + ((int) (10000 * f1 / num)) / 100.0 + ", evalb " + ((int) (10000 * cF1)) / 100.0 + ")   Exact: " + ((int) (10000 * exact / num)) / 100.0);
      }
//      pw.println(" N: " + getNum());
      pw.println(" N: " + num);
    }
    /*
      Sentence s = guess.yield();
      for (Object obj : s) {
        if (curF1 < 0.7) {
          badwords.incrementCount(obj);
        } else {
          goodwords.incrementCount(obj);
        }
      }
    */
  }

  /*
  private Counter goodwords = new Counter();
  private Counter badwords = new Counter();

  public void printGoodBad() {
    System.out.println("Printing bad categories");
    for (Object key : Counters.keysAbove(badwords, 5.0)) {
      System.out.println("In badwords 5 times: " + key);
      double numb = badwords.getCount(key);
      double numg = goodwords.getCount(key);
      if (numb / (numb + numg) > 0.1) {
        System.out.println("Bad word!  " + key + " (" +
                           (numb / (numb + numg)) + " bad)");
        // EncodingPrintWriter.out.println("Bad word!  " + key + " (" +
        //                 (numb / (numb + numg)) + " bad)",
        //                              "GB18030");
      }
    }
  }
  */

  public void display(boolean verbose) {
    display(verbose, new PrintWriter(System.out, true));
  }

  public void display(boolean verbose, PrintWriter pw) {
    double prec = precision2 / pnum2;//(num > 0.0 ? precision/num : 0.0);
    double rec = recall2 / rnum2;//(num > 0.0 ? recall/num : 0.0);
    double f = 2.0 / (1.0 / prec + 1.0 / rec);//(num > 0.0 ? f1/num : 0.0);
    //System.out.println(" Precision: "+((int)(10000.0*prec))/100.0);
    //System.out.println(" Recall:    "+((int)(10000.0*rec))/100.0);
    //System.out.println(" F1:        "+((int)(10000.0*f))/100.0);
    pw.println(str + " summary evalb: LP: " + ((int) (10000.0 * prec)) / 100.0 + " LR: " + ((int) (10000.0 * rec)) / 100.0 + " F1: " + ((int) (10000.0 * f)) / 100.0 + " Exact: " + ((int) (10000.0 * exact / num)) / 100.0 + " N: " + getNum());
    /*
    double prec = (num > 0.0 ? precision/num : 0.0);
    double rec = (num > 0.0 ? recall/num : 0.0);
    double f = (num > 0.0 ? f1/num : 0.0);
    System.out.println(" Precision: "+prec);
    System.out.println(" Recall:    "+rec);
    System.out.println(" F1:        "+f);
    */
  }


  public static class RuleErrorEval extends AbstractEval {

    //private boolean verbose = false;

    private ClassicCounter over = new ClassicCounter<>();
    private ClassicCounter under = new ClassicCounter<>();

    protected static String localize(Tree tree) {
      if (tree.isLeaf()) {
        return "";
      }
      StringBuilder sb = new StringBuilder();
      sb.append(tree.label());
      sb.append(" ->");
      for (int i = 0; i < tree.children().length; i++) {
        sb.append(' ');
        sb.append(tree.children()[i].label());
      }
      return sb.toString();
    }

    @Override
    protected Set makeObjects(Tree tree) {
      Set localTrees = Generics.newHashSet();
      for (Tree st : tree.subTreeList()) {
        localTrees.add(localize(st));
      }
      return localTrees;
    }

    @Override
    public void evaluate(Tree t1, Tree t2, PrintWriter pw) {
      Set s1 = makeObjects(t1);
      Set s2 = makeObjects(t2);
      for (String o1 : s1) {
        if (!s2.contains(o1)) {
          over.incrementCount(o1);
        }
      }
      for (String o2 : s2) {
        if (!s1.contains(o2)) {
          under.incrementCount(o2);
        }
      }
    }

    private static  void display(ClassicCounter c, int num, PrintWriter pw) {
      List rules = new ArrayList<>(c.keySet());
      Collections.sort(rules, Counters.toComparatorDescending(c));
      int rSize = rules.size();
      if (num > rSize) {
        num = rSize;
      }
      for (int i = 0; i < num; i++) {
        pw.println(rules.get(i) + " " + c.getCount(rules.get(i)));
      }
    }

    @Override
    public void display(boolean verbose, PrintWriter pw) {
      //this.verbose = verbose;
      pw.println("Most frequently underproposed rules:");
      display(under, (verbose ? 100 : 10), pw);
      pw.println("Most frequently overproposed rules:");
      display(over, (verbose ? 100 : 10), pw);
    }

    public RuleErrorEval(String str) {
      super(str);
    }

  } // end class RuleErrorEval


  /** This class counts which categories are over and underproposed in trees.
   */
  public static class CatErrorEval extends AbstractEval {

    private ClassicCounter over = new ClassicCounter<>();
    private ClassicCounter under = new ClassicCounter<>();

    /** Unused. Fake satisfying the abstract class. */
    @Override
    protected Set makeObjects(Tree tree) {
      return null;
    }

    private static List myMakeObjects(Tree tree) {
      List cats = new LinkedList<>();
      for (Tree st : tree.subTreeList()) {
        cats.add(st.value());
      }
      return cats;
    }

    @Override
    public void evaluate(Tree t1, Tree t2, PrintWriter pw) {
      List s1 = myMakeObjects(t1);
      List s2 = myMakeObjects(t2);
      List del2 = new LinkedList<>(s2);
      // we delete out as we find them so we can score correctly a cat with
      // a certain cardinality in a tree.
      for (String o1 : s1) {
        if ( ! del2.remove(o1)) {
          over.incrementCount(o1);
        }
      }
      for (String o2 : s2) {
        if (! s1.remove(o2)) {
          under.incrementCount(o2);
        }
      }
    }

    private static  void display(ClassicCounter c, PrintWriter pw) {
      List cats = new ArrayList<>(c.keySet());
      Collections.sort(cats, Counters.toComparatorDescending(c));
      for (T ob : cats) {
        pw.println(ob + " " + c.getCount(ob));
      }
    }

    @Override
    public void display(boolean verbose, PrintWriter pw) {
      pw.println("Most frequently underproposed categories:");
      display(under, pw);
      pw.println("Most frequently overproposed categories:");
      display(over, pw);
    }

    public CatErrorEval(String str) {
      super(str);
    }

  } // end class CatErrorEval


  /** This isn't really a kind of AbstractEval: we're sort of cheating here. */
  public static class ScoreEval extends AbstractEval {

    double totScore = 0.0;
    double n = 0.0;
    NumberFormat nf = new DecimalFormat("0.000");

    @Override
    protected Set makeObjects(Tree tree) {
      return null;
    }

    public void recordScore(KBestViterbiParser parser, PrintWriter pw) {
      double score = parser.getBestScore();
      totScore += score;
      n++;
      if (pw != null) {
        pw.print(str + " score: " + nf.format(score));
        if (runningAverages) {
          pw.print(" average score: " + nf.format(totScore / n));
        }
        pw.println();
      }
    }

    @Override
    public void display(boolean verbose, PrintWriter pw) {
      if (pw != null) {
        pw.println(str + " total score: " + nf.format(totScore) +
                " average score: " + ((n == 0.0) ? "N/A": nf.format(totScore / n)));
      }
    }

    public ScoreEval(String str, boolean runningAverages) {
      super(str, runningAverages);
    }

  } // end class DependencyEval

} // end class AbstractEval




© 2015 - 2024 Weber Informatics LLC | Privacy Policy