All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.parser.lexparser.GrammarCompactor Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.parser.lexparser; 
import edu.stanford.nlp.util.logging.Redwood;

import edu.stanford.nlp.fsm.TransducerGraph;
import edu.stanford.nlp.fsm.TransducerGraph.Arc;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Distribution;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Triple;

import java.io.*;
import java.util.*;
import java.util.Map.Entry;

/**
 * @author Teg Grenager ([email protected])
 */
public abstract class GrammarCompactor  {

  /** A logger for this class */
  private static Redwood.RedwoodChannels log = Redwood.channels(GrammarCompactor.class);

  // so that the grammar remembers its graphs after compacting them
  Set compactedGraphs;

  public static final Object RAW_COUNTS = new Object();
  public static final Object NORMALIZED_LOG_PROBABILITIES = new Object();

  public Object outputType = RAW_COUNTS; // default value

  protected Index stateIndex;
  protected Index newStateIndex;

  // String rawBaseDir = "raw";
  // String compactedBaseDir = "compacted";
  // boolean writeToFile = false;
  protected Distribution inputPrior;
  private static final String END = "END";
  private static final String EPSILON = "EPSILON";
  protected boolean verbose = false;

  protected final Options op;

  public GrammarCompactor(Options op) {
    this.op = op;
  }


  protected abstract TransducerGraph doCompaction(TransducerGraph graph, List> trainPaths, List> testPaths);

  public Triple, UnaryGrammar, BinaryGrammar> compactGrammar(Pair grammar, Index originalStateIndex) {
    return compactGrammar(grammar, Generics.>>newHashMap(), Generics.>>newHashMap(), originalStateIndex);
  }

  /**
   * Compacts the grammar specified by the Pair.
   *
   * @param grammar       a Pair of grammars, ordered UnaryGrammar BinaryGrammar.
   * @param allTrainPaths a Map from String passive constituents to Lists of paths
   * @param allTestPaths  a Map from String passive constituents to Lists of paths
   * @return a Pair of grammars, ordered UnaryGrammar BinaryGrammar.
   */
  public Triple, UnaryGrammar, BinaryGrammar> compactGrammar(Pair grammar, Map>> allTrainPaths, Map>> allTestPaths, Index originalStateIndex) {
    inputPrior = computeInputPrior(allTrainPaths); // computed once for the whole grammar
    // BinaryGrammar bg = grammar.second;
    this.stateIndex = originalStateIndex;
    List> trainPaths, testPaths;
    Set unaryRules = Generics.newHashSet();
    Set binaryRules = Generics.newHashSet();
    Map graphs = convertGrammarToGraphs(grammar, unaryRules, binaryRules);
    compactedGraphs = Generics.newHashSet();
    if (verbose) {
      System.out.println("There are " + graphs.size() + " categories to compact.");
    }
    int i = 0;
    for (Iterator> graphIter = graphs.entrySet().iterator(); graphIter.hasNext();) {
      Map.Entry entry = graphIter.next();
      String cat = entry.getKey();
      TransducerGraph graph = entry.getValue();
      if (verbose) {
        System.out.println("About to compact grammar for " + cat + " with numNodes=" + graph.getNodes().size());
      }
      trainPaths = allTrainPaths.remove(cat);// to save memory
      if (trainPaths == null) {
        trainPaths = new ArrayList<>();
      }
      testPaths = allTestPaths.remove(cat);// to save memory
      if (testPaths == null) {
        testPaths = new ArrayList<>();
      }
      TransducerGraph compactedGraph = doCompaction(graph, trainPaths, testPaths);
      i++;
      if (verbose) {
        System.out.println(i + ". Compacted grammar for " + cat + " from " + graph.getArcs().size() + " arcs to " + compactedGraph.getArcs().size() + " arcs.");
      }
      graphIter.remove(); // to save memory, remove the last thing
      compactedGraphs.add(compactedGraph);
    }
    Pair ugbg = convertGraphsToGrammar(compactedGraphs, unaryRules, binaryRules);
    return new Triple<>(newStateIndex, ugbg.first(), ugbg.second());
  }

  protected static Distribution computeInputPrior(Map>> allTrainPaths) {
    ClassicCounter result = new ClassicCounter<>();
    for (List> pathList : allTrainPaths.values()) {
      for (List path : pathList) {
        for (String input : path) {
          result.incrementCount(input);
        }
      }
    }
    return Distribution.laplaceSmoothedDistribution(result, result.size() * 2, 0.5);
  }

  private double smartNegate(double output) {
    if (outputType == NORMALIZED_LOG_PROBABILITIES) {
      return -output;
    }
    return output;
  }

  public static boolean writeFile(TransducerGraph graph, String dir, String name) {
    try {
      File baseDir = new File(dir);
      if (baseDir.exists()) {
        if (!baseDir.isDirectory()) {
          return false;
        }
      } else {
        if (!baseDir.mkdirs()) {
          return false;
        }
      }
      File file = new File(baseDir, name + ".dot");
      PrintWriter w;
      try {
        w = new PrintWriter(new FileWriter(file));
        String dotString = graph.asDOTString();
        w.print(dotString);
        w.flush();
        w.close();
      } catch (FileNotFoundException e) {
        log.info("Failed to open file in writeToDOTfile: " + file);
        return false;
      } catch (IOException e) {
        log.info("Failed to open file in writeToDOTfile: " + file);
        return false;
      }
      return true;
    } catch (Exception e) {
      e.printStackTrace();
      return false;
    }
  }

  /**
   *
   */
  protected Map convertGrammarToGraphs(Pair grammar, Set unaryRules, Set binaryRules) {
    int numRules = 0;
    UnaryGrammar ug = grammar.first;
    BinaryGrammar bg = grammar.second;
    Map graphs = Generics.newHashMap();
    // go through the BinaryGrammar and add everything
    for (BinaryRule rule : bg) {
      numRules++;
      boolean wasAdded = addOneBinaryRule(rule, graphs);
      if (!wasAdded)
      // add it for later, since we don't make graphs for these
      {
        binaryRules.add(rule);
      }
    }
    // now we need to use the UnaryGrammar to
    // add start and end Arcs to the graphs
    for (UnaryRule rule : ug) {
      numRules++;
      boolean wasAdded = addOneUnaryRule(rule, graphs);
      if (!wasAdded)
      // add it for later, since we don't make graphs for these
      {
        unaryRules.add(rule);
      }
    }
    if (verbose) {
      System.out.println("Number of raw rules: " + numRules);
      System.out.println("Number of raw states: " + stateIndex.size());
    }
    return graphs;
  }

  protected static TransducerGraph getGraphFromMap(Map m, String o) {
    TransducerGraph graph = m.get(o);
    if (graph == null) {
      graph = new TransducerGraph();
      graph.setEndNode(o);
      m.put(o, graph);
    }
    return graph;
  }

  protected static String getTopCategoryOfSyntheticState(String s) {
    if (s.charAt(0) != '@') {
      return null;
    }
    int bar = s.indexOf('|');
    if (bar < 0) {
      throw new RuntimeException("Grammar format error. Expected bar in state name: " + s);
    }
    String topcat = s.substring(1, bar);
    return topcat;
  }

  protected boolean addOneUnaryRule(UnaryRule rule, Map graphs) {
    String parentString = stateIndex.get(rule.parent);
    String childString = stateIndex.get(rule.child);
    if (isSyntheticState(parentString)) {
      String topcat = getTopCategoryOfSyntheticState(parentString);
      TransducerGraph graph = getGraphFromMap(graphs, topcat);
      Double output = new Double(smartNegate(rule.score()));
      graph.addArc(graph.getStartNode(), parentString, childString, output);
      return true;
    } else if (isSyntheticState(childString)) {
      // need to add Arc from synthetic state to endState
      TransducerGraph graph = getGraphFromMap(graphs, parentString);
      Double output = new Double(smartNegate(rule.score()));
      graph.addArc(childString, parentString, END, output); // parentString should the the same as endState
      graph.setEndNode(parentString);
      return true;
    } else {
      return false;
    }
  }

  protected boolean addOneBinaryRule(BinaryRule rule, Map graphs) {
    // parent has to be synthetic in BinaryRule
    String parentString = stateIndex.get(rule.parent);
    String leftString = stateIndex.get(rule.leftChild);
    String rightString = stateIndex.get(rule.rightChild);
    String source, target, input;
    String bracket = null;
    if (op.trainOptions.markFinalStates) {
      bracket = parentString.substring(parentString.length() - 1, parentString.length());
    }
    // the below test is not necessary with left to right grammars
    if (isSyntheticState(leftString)) {
      source = leftString;
      input = rightString + (bracket == null ? ">" : bracket);
    } else if (isSyntheticState(rightString)) {
      source = rightString;
      input = leftString + (bracket == null ? "<" : bracket);
    } else {
      // we don't know what to do with this rule
      return false;
    }
    target = parentString;
    Double output = new Double(smartNegate(rule.score())); // makes it a real  0 <= k <= infty
    String topcat = getTopCategoryOfSyntheticState(source);
    if (topcat == null) {
      throw new RuntimeException("can't have null topcat");
    }
    TransducerGraph graph = getGraphFromMap(graphs, topcat);
    graph.addArc(source, target, input, output);
    return true;
  }

  protected static boolean isSyntheticState(String state) {
    return state.charAt(0) == '@';
  }


  /**
   * @param graphs      a Map from String categories to TransducerGraph objects
   * @param unaryRules  is a Set of UnaryRule objects that we need to add
   * @param binaryRules is a Set of BinaryRule objects that we need to add
   * @return a new Pair of UnaryGrammar, BinaryGrammar
   */
  protected Pair convertGraphsToGrammar(Set graphs, Set unaryRules, Set binaryRules) {
    // first go through all the existing rules and number them with new numberer
    newStateIndex = new HashIndex<>();
    for (UnaryRule rule : unaryRules) {
      String parent = stateIndex.get(rule.parent);
      rule.parent = newStateIndex.addToIndex(parent);
      String child = stateIndex.get(rule.child);
      rule.child = newStateIndex.addToIndex(child);
    }
    for (BinaryRule rule : binaryRules) {
      String parent = stateIndex.get(rule.parent);
      rule.parent = newStateIndex.addToIndex(parent);
      String leftChild = stateIndex.get(rule.leftChild);
      rule.leftChild = newStateIndex.addToIndex(leftChild);
      String rightChild = stateIndex.get(rule.rightChild);
      rule.rightChild = newStateIndex.addToIndex(rightChild);
    }

    // now go through the graphs and add the rules
    for (TransducerGraph graph : graphs) {
      Object startNode = graph.getStartNode();
      for (Arc arc : graph.getArcs()) {
        // TODO: make sure these are the strings we're looking for
        String source = arc.getSourceNode().toString();
        String target = arc.getTargetNode().toString();
        Object input = arc.getInput();
        String inputString = input.toString();
        double output = ((Double) arc.getOutput()).doubleValue();
        if (source.equals(startNode)) {
          // make a UnaryRule
          UnaryRule ur = new UnaryRule(newStateIndex.addToIndex(target), newStateIndex.addToIndex(inputString), smartNegate(output));
          unaryRules.add(ur);
        } else if (inputString.equals(END) || inputString.equals(EPSILON)) {
          // make a UnaryRule
          UnaryRule ur = new UnaryRule(newStateIndex.addToIndex(target), newStateIndex.addToIndex(source), smartNegate(output));
          unaryRules.add(ur);
        } else {
          // make a BinaryRule
          // figure out whether the input was generated on the left or right
          int length = inputString.length();
          char leftOrRight = inputString.charAt(length - 1);
          inputString = inputString.substring(0, length - 1);
          BinaryRule br;
          if (leftOrRight == '<' || leftOrRight == '[') {
            br = new BinaryRule(newStateIndex.addToIndex(target), newStateIndex.addToIndex(inputString), newStateIndex.addToIndex(source), smartNegate(output));
          } else if (leftOrRight == '>' || leftOrRight == ']') {
            br = new BinaryRule(newStateIndex.addToIndex(target), newStateIndex.addToIndex(source), newStateIndex.addToIndex(inputString), smartNegate(output));
          } else {
            throw new RuntimeException("Arc input is in unexpected format: " + arc);
          }
          binaryRules.add(br);
        }
      }
    }
    // by now, the unaryRules and binaryRules Sets have old untouched and new rules with scores
    ClassicCounter symbolCounter = new ClassicCounter<>();
    if (outputType == RAW_COUNTS) {
      // now we take the sets of rules and turn them into grammars
      // the scores of the rules we are given are actually counts
      // so we count parent symbol occurrences
      for (UnaryRule rule : unaryRules) {
        symbolCounter.incrementCount(newStateIndex.get(rule.parent), rule.score);
      }
      for (BinaryRule rule : binaryRules) {
        symbolCounter.incrementCount(newStateIndex.get(rule.parent), rule.score);
      }
    }
    // now we put the rules in the grammars
    int numStates = newStateIndex.size();     // this should be smaller than last one
    int numRules = 0;
    UnaryGrammar ug = new UnaryGrammar(newStateIndex);
    BinaryGrammar bg = new BinaryGrammar(newStateIndex);
    for (UnaryRule rule : unaryRules) {
      if (outputType == RAW_COUNTS) {
        double count = symbolCounter.getCount(newStateIndex.get(rule.parent));
        rule.score = (float) Math.log(rule.score / count);
      }
      ug.addRule(rule);
      numRules++;
    }
    for (BinaryRule rule : binaryRules) {
      if (outputType == RAW_COUNTS) {
        double count = symbolCounter.getCount(newStateIndex.get(rule.parent));
        rule.score = (float) Math.log((rule.score - op.trainOptions.ruleDiscount) / count);
      }
      bg.addRule(rule);
      numRules++;
    }
    if (verbose) {
      System.out.println("Number of minimized rules: " + numRules);
      System.out.println("Number of minimized states: " + newStateIndex.size());
    }

    ug.purgeRules();
    bg.splitRules();
    return new Pair<>(ug, bg);
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy