All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.parser.lexparser.ParentAnnotationStats Maven / Gradle / Ivy

package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.io.NumberRangeFileFilter;
import edu.stanford.nlp.ling.StringLabelFactory;
import edu.stanford.nlp.trees.*;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import java.io.Reader;
import java.text.NumberFormat;
import java.util.*;

/**
 * See what parent annotation helps in treebank, based on support and
 * KL divergence.
 *
 * @author Christopher Manning
 * @version 2003/01/04
 */
public class ParentAnnotationStats implements TreeVisitor {

  private final TreebankLanguagePack tlp;

  private ParentAnnotationStats(TreebankLanguagePack tlp, boolean doTags) {
    this.tlp = tlp;
    this.doTags = doTags;
  }

  private final boolean doTags;

  private Map>> nodeRules = Generics.newHashMap();
  private Map,ClassicCounter>> pRules = Generics.newHashMap();
  private Map,ClassicCounter>> gPRules = Generics.newHashMap();

  // corresponding ones for tags
  private Map>> tagNodeRules = Generics.newHashMap();
  private Map,ClassicCounter>> tagPRules = Generics.newHashMap();
  private Map,ClassicCounter>> tagGPRules = Generics.newHashMap();

  /**
   * Minimum support * KL to be included in output and as feature
   */
  public static final double[] CUTOFFS = {100.0, 200.0, 500.0, 1000.0};

  /**
   * Minimum support of parent annotated node for grandparent to be
   * studied.  Just there to reduce runtime and printout size.
   */
  public static final double SUPPCUTOFF = 100.0;

  /**
   * Does whatever one needs to do to a particular parse tree
   */
  public void visitTree(Tree t) {
    processTreeHelper("TOP", "TOP", t);
  }

  public static List kidLabels(Tree t) {
    Tree[] kids = t.children();
    List l = new ArrayList(kids.length);
    for (Tree kid : kids) {
      l.add(kid.label().value());
    }
    return l;
  }

  public void processTreeHelper(String gP, String p, Tree t) {
    if (!t.isLeaf() && (doTags || !t.isPreTerminal())) { // stop at words/tags
      Map>> nr;
      Map,ClassicCounter>> pr;
      Map,ClassicCounter>> gpr;
      if (t.isPreTerminal()) {
        nr = tagNodeRules;
        pr = tagPRules;
        gpr = tagGPRules;
      } else {
        nr = nodeRules;
        pr = pRules;
        gpr = gPRules;
      }
      String n = t.label().value();
      if (tlp != null) {
        p = tlp.basicCategory(p);
        gP = tlp.basicCategory(gP);
      }
      List kidn = kidLabels(t);
      ClassicCounter> cntr = nr.get(n);
      if (cntr == null) {
        cntr = new ClassicCounter>();
        nr.put(n, cntr);
      }
      cntr.incrementCount(kidn);
      List pairStr = new ArrayList(2);
      pairStr.add(n);
      pairStr.add(p);
      cntr = pr.get(pairStr);
      if (cntr == null) {
        cntr = new ClassicCounter>();
        pr.put(pairStr, cntr);
      }
      cntr.incrementCount(kidn);
      List tripleStr = new ArrayList(3);
      tripleStr.add(n);
      tripleStr.add(p);
      tripleStr.add(gP);
      cntr = gpr.get(tripleStr);
      if (cntr == null) {
        cntr = new ClassicCounter>();
        gpr.put(tripleStr, cntr);
      }
      cntr.incrementCount(kidn);
      Tree[] kids = t.children();
      for (Tree kid : kids) {
        processTreeHelper(p, n, kid);
      }
    }
  }


  public void printStats() {
    NumberFormat nf = NumberFormat.getNumberInstance();
    nf.setMaximumFractionDigits(2);
    // System.out.println("Node rules");
    // System.out.println(nodeRules);
    // System.out.println("Parent rules");
    // System.out.println(pRules);
    // System.out.println("Grandparent rules");
    // System.out.println(gPRules);

    // Store java code for selSplit
    StringBuffer[] javaSB = new StringBuffer[CUTOFFS.length];
    for (int i = 0; i < CUTOFFS.length; i++) {
      javaSB[i] = new StringBuffer("  private static String[] splitters" + (i + 1) + " = new String[] {");
    }

    ClassicCounter> allScores = new ClassicCounter>();
    // do value of parent
    for (String node : nodeRules.keySet()) {
      ArrayList,Double>> answers = Generics.newArrayList();
      ClassicCounter> cntr = nodeRules.get(node);
      double support = (cntr.totalCount());
      System.out.println("Node " + node + " support is " + support);
      for (Iterator> it2 = pRules.keySet().iterator(); it2.hasNext();) {
        List key = it2.next();
        if (key.get(0).equals(node)) {   // only do it if they match
          ClassicCounter> cntr2 = pRules.get(key);
          double support2 = (cntr2.totalCount());
          double kl = Counters.klDivergence(cntr2, cntr);
          System.out.println("KL(" + key + "||" + node + ") = " + nf.format(kl) + "\t" + "support(" + key + ") = " + support2);
          double score = kl * support2;
          answers.add(new Pair,Double>(key, new Double(score)));
          allScores.setCount(key, score);
        }
      }
      System.out.println("----");
      System.out.println("Sorted descending support * KL");
      Collections.sort(answers, (o1, o2) -> o2.second().compareTo(o1.second()));
      for (int i = 0, size = answers.size(); i < size; i++) {
        Pair p = (Pair) answers.get(i);
        double psd = ((Double) p.second()).doubleValue();
        System.out.println(p.first() + ": " + nf.format(psd));
        if (psd >= CUTOFFS[0]) {
          List lst = (List) p.first();
          String nd = (String) lst.get(0);
          String par = (String) lst.get(1);
          for (int j = 0; j < CUTOFFS.length; j++) {
            if (psd >= CUTOFFS[j]) {
              javaSB[j].append("\"").append(nd).append("^");
              javaSB[j].append(par).append("\", ");
            }
          }
        }
      }
      System.out.println();
    }

    /*
          // do value of parent with info gain -- yet to finish this
          for (Iterator it = nodeRules.entrySet().iterator(); it.hasNext(); ) {
              Map.Entry pair = (Map.Entry) it.next();
              String node = (String) pair.getKey();
              Counter cntr = (Counter) pair.getValue();
              double support = (cntr.totalCount());
              System.out.println("Node " + node + " support is " + support);
              ArrayList dtrs = new ArrayList();
              for (Iterator it2 = pRules.entrySet().iterator(); it2.hasNext();) {
                  HashMap annotated = new HashMap();
                  Map.Entry pair2 = (Map.Entry) it2.next();
                  List node2 = (List) pair2.getKey();
                  Counter cntr2 = (Counter) pair2.getValue();
                  if (node2.get(0).equals(node)) {   // only do it if they match
                      annotated.put(node2, cntr2);
                  }
              }

              // upto

              List answers = new ArrayList();
              System.out.println("----");
              System.out.println("Sorted descending support * KL");
              Collections.sort(answers,
                               new Comparator() {
                                   public int compare(Object o1, Object o2) {
                                       Pair p1 = (Pair) o1;
                                       Pair p2 = (Pair) o2;
                                       Double p12 = (Double) p1.second();
                                       Double p22 = (Double) p2.second();
                                       return p22.compareTo(p12);
                                   }
                               });
              for (int i = 0, size = answers.size(); i < size; i++) {
                  Pair p = (Pair) answers.get(i);
                  double psd = ((Double) p.second()).doubleValue();
                  System.out.println(p.first() + ": " + nf.format(psd));
                  if (psd >= CUTOFFS[0]) {
                      List lst = (List) p.first();
                      String nd = (String) lst.get(0);
                      String par = (String) lst.get(1);
                      for (int j=0; j < CUTOFFS.length; j++) {
                          if (psd >= CUTOFFS[j]) {
                              javaSB[j].append("\"").append(nd).append("^");
                              javaSB[j].append(par).append("\", ");
                          }
                      }
                  }
              }
              System.out.println();
          }
    */

    // do value of grandparent
    for (List node : pRules.keySet()) {
      ArrayList, Double>> answers = Generics.newArrayList();
      ClassicCounter> cntr = pRules.get(node);
      double support = (cntr.totalCount());
      if (support < SUPPCUTOFF) {
        continue;
      }
      System.out.println("Node " + node + " support is " + support);
      for (List key : gPRules.keySet()) {
        if (key.get(0).equals(node.get(0)) && key.get(1).equals(node.get(1))) {  // only do it if they match
          ClassicCounter> cntr2 = gPRules.get(key);
          double support2 = (cntr2.totalCount());
          double kl = Counters.klDivergence(cntr2, cntr);
          System.out.println("KL(" + key + "||" + node + ") = " + nf.format(kl) + "\t" + "support(" + key + ") = " + support2);
          double score = kl * support2;
          answers.add(Pair.makePair(key, new Double(score)));
          allScores.setCount(key,score);
        }
      }
      System.out.println("----");
      System.out.println("Sorted descending support * KL");
      Collections.sort(answers, (o1, o2) -> o2.second().compareTo(o1.second()));
      for (int i = 0, size = answers.size(); i < size; i++) {
        Pair p = (Pair) answers.get(i);
        double psd = ((Double) p.second()).doubleValue();
        System.out.println(p.first() + ": " + nf.format(psd));
        if (psd >= CUTOFFS[0]) {
          List lst = (List) p.first();
          String nd = (String) lst.get(0);
          String par = (String) lst.get(1);
          String gpar = (String) lst.get(2);
          for (int j = 0; j < CUTOFFS.length; j++) {
            if (psd >= CUTOFFS[j]) {
              javaSB[j].append("\"").append(nd).append("^");
              javaSB[j].append(par).append("~");
              javaSB[j].append(gpar).append("\", ");
            }
          }
        }
      }
      System.out.println();
    }
    System.out.println();

    System.out.println("All scores:");
    edu.stanford.nlp.util.PriorityQueue> pq = Counters.toPriorityQueue(allScores);
    while (! pq.isEmpty()) {
      List key = pq.getFirst();
      double score = pq.getPriority(key);
      pq.removeFirst();
      System.out.println(key + "\t" + score);
    }

    System.out.println("  // Automatically generated by ParentAnnotationStats -- preferably don't edit");
    for (int i = 0; i < CUTOFFS.length; i++) {
      int len = javaSB[i].length();
      javaSB[i].replace(len - 2, len, "};");
      System.out.println(javaSB[i]);
    }
    System.out.print("  public static HashSet splitters = new HashSet(Arrays.asList(");
    for (int i = CUTOFFS.length; i > 0; i--) {
      if (i == 1) {
        System.out.print("splitters1");
      } else {
        System.out.print("selectiveSplit" + i + " ? splitters" + i + " : (");
      }
    }
    // need to print extra one to close other things open
    for (int i = CUTOFFS.length; i >= 0; i--) {
      System.out.print(")");
    }
    System.out.println(";");
  }


  private static void getSplitters(double cutOff, Map>> nr,
                                   Map,ClassicCounter>> pr,
                                   Map,ClassicCounter>> gpr,
                                   Set splitters) {

    // do value of parent
    for (String node : nr.keySet()) {
      List,Double>> answers = new ArrayList,Double>>();
      ClassicCounter> cntr = nr.get(node);
      double support = (cntr.totalCount());
      for (List key : pr.keySet()) {
        if (key.get(0).equals(node)) {   // only do it if they match
          ClassicCounter> cntr2 = pr.get(key);
          double support2 = cntr2.totalCount();
          double kl = Counters.klDivergence(cntr2, cntr);
          answers.add(new Pair, Double>(key, new Double(kl * support2)));
        }
      }
      Collections.sort(answers, (o1, o2) -> o2.second().compareTo(o1.second()));
      for (int i = 0, size = answers.size(); i < size; i++) {
        Pair,Double> p = answers.get(i);
        double psd = p.second().doubleValue();
        if (psd >= cutOff) {
          List lst = p.first();
          String nd = lst.get(0);
          String par = lst.get(1);
          String name = nd + "^" + par;
          splitters.add(name);
        }
      }
    }

    /*
          // do value of parent with info gain -- yet to finish this
          for (Iterator it = nr.entrySet().iterator(); it.hasNext(); ) {
              Map.Entry pair = (Map.Entry) it.next();
              String node = (String) pair.getKey();
              Counter cntr = (Counter) pair.getValue();
              double support = (cntr.totalCount());
              ArrayList dtrs = new ArrayList();
              for (Iterator it2 = pr.entrySet().iterator(); it2.hasNext();) {
                  HashMap annotated = new HashMap();
                  Map.Entry pair2 = (Map.Entry) it2.next();
                  List node2 = (List) pair2.getKey();
                  Counter cntr2 = (Counter) pair2.getValue();
                  if (node2.get(0).equals(node)) {   // only do it if they match
                      annotated.put(node2, cntr2);
                  }
              }

              // upto

              List answers = new ArrayList();
              Collections.sort(answers,
                               new Comparator() {
                                   public int compare(Object o1, Object o2) {
                                       Pair p1 = (Pair) o1;
                                       Pair p2 = (Pair) o2;
                                       Double p12 = (Double) p1.second();
                                       Double p22 = (Double) p2.second();
                                       return p22.compareTo(p12);
                                   }
                               });
              for (int i = 0, size = answers.size(); i < size; i++) {
                  Pair p = (Pair) answers.get(i);
                  double psd = ((Double) p.second()).doubleValue();
                  if (psd >= cutOff) {
                      List lst = (List) p.first();
                      String nd = (String) lst.get(0);
                      String par = (String) lst.get(1);
                      String name = nd + "^" + par;
                      splitters.add(name);
                  }
              }
          }
    */

    // do value of grandparent
    for (List node : pr.keySet()) {
      ArrayList,Double>> answers = Generics.newArrayList();
      ClassicCounter> cntr = pr.get(node);
      double support = (cntr.totalCount());
      if (support < SUPPCUTOFF) {
        continue;
      }
      for (List key : gpr.keySet()) {
        if (key.get(0).equals(node.get(0)) && key.get(1).equals(node.get(1))) {
          // only do it if they match
          ClassicCounter> cntr2 = gpr.get(key);
          double support2 = (cntr2.totalCount());
          double kl = Counters.klDivergence(cntr2, cntr);
          answers.add(new Pair,Double>(key, new Double(kl * support2)));
        }
      }
      Collections.sort(answers, (o1, o2) -> o2.second().compareTo(o1.second()));
      for (int i = 0, size = answers.size(); i < size; i++) {
        Pair p = (Pair) answers.get(i);
        double psd = ((Double) p.second()).doubleValue();
        if (psd >= cutOff) {
          List lst = (List) p.first();
          String nd = (String) lst.get(0);
          String par = (String) lst.get(1);
          String gpar = (String) lst.get(2);
          String name = nd + "^" + par + "~" + gpar;
          splitters.add(name);
        }
      }
    }
  }


  /**
   * Calculate parent annotation statistics suitable for doing
   * selective parent splitting in the PCFGParser inside
   * FactoredParser.  

* Usage: java edu.stanford.nlp.parser.lexparser.ParentAnnotationStats * [-tags] treebankPath * * @param args One argument: path to the Treebank */ public static void main(String[] args) { boolean doTags = false; if (args.length < 1) { System.out.println("Usage: java edu.stanford.nlp.parser.lexparser.ParentAnnotationStats [-tags] treebankPath"); } else { int i = 0; boolean useCutOff = false; double cutOff = 0.0; while (args[i].startsWith("-")) { if (args[i].equals("-tags")) { doTags = true; i++; } else if (args[i].equals("-cutOff") && i + 1 < args.length) { useCutOff = true; cutOff = Double.parseDouble(args[i + 1]); i += 2; } else { System.err.println("Unknown option: " + args[i]); i++; } } Treebank treebank = new DiskTreebank(in -> new PennTreeReader(in, new LabeledScoredTreeFactory(new StringLabelFactory()), new BobChrisTreeNormalizer())); treebank.loadPath(args[i]); if (useCutOff) { Set splitters = getSplitCategories(treebank, doTags, 0, cutOff, cutOff, null); System.out.println(splitters); } else { ParentAnnotationStats pas = new ParentAnnotationStats(null, doTags); treebank.apply(pas); pas.printStats(); } } } /** * Call this method to get a String array of categories to split on. * It calculates parent annotation statistics suitable for doing * selective parent splitting in the PCFGParser inside * FactoredParser.

* If tlp is non-null tlp.basicCategory() will be called on parent and * grandparent nodes.

* This version just defaults some parameters. * Implementation note: This method is not designed for concurrent * invocation: it uses static state variables. */ public static Set getSplitCategories(Treebank t, double cutOff, TreebankLanguagePack tlp) { return getSplitCategories(t, true, 0, cutOff, cutOff, tlp); } /** * Call this method to get a String array of categories to split on. * It calculates parent annotation statistics suitable for doing * selective parent splitting in the PCFGParser inside * FactoredParser.

* If tlp is non-null tlp.basicCategory() will be called on parent and * grandparent nodes.

* Implementation note: This method is not designed for concurrent * invocation: it uses static state variables. */ public static Set getSplitCategories(Treebank t, boolean doTags, int algorithm, double phrasalCutOff, double tagCutOff, TreebankLanguagePack tlp) { ParentAnnotationStats pas = new ParentAnnotationStats(tlp, doTags); t.apply(pas); Set splitters = Generics.newHashSet(); pas.getSplitters(phrasalCutOff, pas.nodeRules, pas.pRules, pas.gPRules, splitters); pas.getSplitters(tagCutOff, pas.tagNodeRules, pas.tagPRules, pas.tagGPRules, splitters); return splitters; } /** * This is hardwired to calculate the split categories from English * Penn Treebank sections 2-21 with a default cutoff of 300 (as used * in ACL03PCFG). It was added to upgrading of code in cases where no * Treebank was available, and the pre-stored list was being used). */ public static Set getEnglishSplitCategories(String treebankRoot) { TreebankLangParserParams tlpParams = new EnglishTreebankParserParams(); Treebank trees = tlpParams.memoryTreebank(); trees.loadPath(treebankRoot, new NumberRangeFileFilter(200, 2199, true)); return getSplitCategories(trees, 300.0, tlpParams.treebankLanguagePack()); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy