All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.parser.shiftreduce.ShiftReduceParserQuery Maven / Gradle / Ivy

package edu.stanford.nlp.parser.shiftreduce;


import java.io.PrintWriter;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;

import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.ling.Sentence;
import edu.stanford.nlp.parser.KBestViterbiParser;
import edu.stanford.nlp.parser.common.ParserConstraint;
import edu.stanford.nlp.parser.common.ParserQuery;
import edu.stanford.nlp.parser.lexparser.Debinarizer;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.tregex.TregexPattern;
import edu.stanford.nlp.trees.tregex.tsurgeon.Tsurgeon;
import edu.stanford.nlp.trees.tregex.tsurgeon.TsurgeonPattern;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.RuntimeInterruptedException;
import edu.stanford.nlp.util.ScoredComparator;
import edu.stanford.nlp.util.ScoredObject;

public class ShiftReduceParserQuery implements ParserQuery {
  Debinarizer debinarizer = new Debinarizer(false);

  List originalSentence;
  private State initialState, finalState;
  Tree debinarized;

  boolean success;
  boolean unparsable;

  private List bestParses;

  final ShiftReduceParser parser;

  List constraints = null;

  public ShiftReduceParserQuery(ShiftReduceParser parser) {
    this.parser = parser;
  }

  @Override
  public boolean parse(List sentence) {
    this.originalSentence = sentence;
    initialState = ShiftReduceParser.initialStateFromTaggedSentence(sentence);
    return parseInternal();
  }

  public boolean parse(Tree tree) {
    this.originalSentence = tree.yieldHasWord();
    initialState = ShiftReduceParser.initialStateFromGoldTagTree(tree);
    return parseInternal();
  }

  // TODO: we are assuming that sentence final punctuation always has
  // either . or PU as the tag.
  private static TregexPattern rearrangeFinalPunctuationTregex =
    TregexPattern.compile("__ !> __ <- (__=top <- (__ <<- (/[.]|PU/=punc < /[.!?。!?]/ ?> (__=single <: =punc))))");

  private static TsurgeonPattern rearrangeFinalPunctuationTsurgeon =
    Tsurgeon.parseOperation("[move punc >-1 top] [if exists single prune single]");

  private boolean parseInternal() {
    final int maxBeamSize = Math.max(parser.op.testOptions().beamSize, 1);

    success = true;
    unparsable = false;
    PriorityQueue beam = new PriorityQueue<>(maxBeamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
    beam.add(initialState);
    // TODO: don't construct as many PriorityQueues
    while (beam.size() > 0) {
      if (Thread.interrupted()) { // Allow interrupting the parser
        throw new RuntimeInterruptedException();
      }
      // System.err.println("================================================");
      // System.err.println("Current beam:");
      // System.err.println(beam);
      PriorityQueue oldBeam = beam;
      beam = new PriorityQueue<>(maxBeamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
      State bestState = null;
      for (State state : oldBeam) {
        if (Thread.interrupted()) {  // Allow interrupting the parser
          throw new RuntimeInterruptedException();
        }
        Collection> predictedTransitions = parser.model.findHighestScoringTransitions(state, true, maxBeamSize, constraints);
        // System.err.println("Examining state: " + state);
        for (ScoredObject predictedTransition : predictedTransitions) {
          Transition transition = parser.model.transitionIndex.get(predictedTransition.object());
          State newState = transition.apply(state, predictedTransition.score());
          // System.err.println("  Transition: " + transition + " (" + predictedTransition.score() + ")");
          if (bestState == null || bestState.score() < newState.score()) {
            bestState = newState;
          }
          beam.add(newState);
          if (beam.size() > maxBeamSize) {
            beam.poll();
          }
        }
      }
      if (beam.size() == 0) {
        // Oops, time for some fallback plan
        // This can happen with the set of constraints given by the original paper
        // For example, one particular French model had a situation where it would reach
        //   @Ssub @Ssub .
        // without a left(Ssub) transition, so finishing the parse was impossible.
        // This will probably result in a bad parse, but at least it
        // will result in some sort of parse.
        for (State state : oldBeam) {
          Transition transition = parser.model.findEmergencyTransition(state, constraints);
          if (transition != null) {
            State newState = transition.apply(state);
            if (bestState == null || bestState.score() < newState.score()) {
              bestState = newState;
            }
            beam.add(newState);
          }
        }
      }

      // bestState == null only happens when we have failed to make progress, so quit
      // If the bestState is finished, we are done
      if (bestState == null || bestState.isFinished()) {
        break;
      }
    }
    if (beam.size() == 0) {
      success = false;
      unparsable = true;
      debinarized = null;
      finalState = null;
      bestParses = Collections.emptyList();
    } else {
      // TODO: filter out beam elements that aren't finished
      bestParses = Generics.newArrayList(beam);
      Collections.sort(bestParses, beam.comparator());
      Collections.reverse(bestParses);
      finalState = bestParses.get(0);
      debinarized = debinarizer.transformTree(finalState.stack.peek());
      debinarized = Tsurgeon.processPattern(rearrangeFinalPunctuationTregex, rearrangeFinalPunctuationTsurgeon, debinarized);
    }
    return success;
  }

  /**
   * TODO: if we add anything interesting to report, we should report it here
   */
  @Override
  public boolean parseAndReport(List sentence, PrintWriter pwErr) {
    boolean success = parse(sentence);
    //System.err.println(getBestTransitionSequence());
    //System.err.println(getBestBinarizedParse());
    return success;
  }

  public Tree getBestBinarizedParse() {
    return finalState.stack.peek();
  }

  public List getBestTransitionSequence() {
    return finalState.transitions.asList();
  }

  @Override
  public double getPCFGScore() {
    return finalState.score;
  }

  @Override
  public Tree getBestParse() {
    return debinarized;
  }

  /** TODO: can we get away with not calling this PCFG? */
  @Override
  public Tree getBestPCFGParse() {
    return debinarized;
  }

  @Override
  public Tree getBestDependencyParse(boolean debinarize) {
    return null;
  }

  @Override
  public Tree getBestFactoredParse() {
    return null;
  }

  /** TODO: if this is a beam, return all equal parses */
  @Override
  public List> getBestPCFGParses() {
    ScoredObject parse = new ScoredObject<>(debinarized, finalState.score);
    return Collections.singletonList(parse);
  }

  @Override
  public boolean hasFactoredParse() {
    return false;
  }

  /** TODO: return more if this used a beam */
  @Override
  public List> getKBestPCFGParses(int kbestPCFG) {
    ScoredObject parse = new ScoredObject<>(debinarized, finalState.score);
    return Collections.singletonList(parse);
  }

  @Override
  public List> getKGoodFactoredParses(int kbest) {
    throw new UnsupportedOperationException();
  }

  @Override
  public KBestViterbiParser getPCFGParser() {
    // TODO: find some way to treat this as a KBestViterbiParser?
    return null;
  }

  @Override
  public KBestViterbiParser getDependencyParser() {
    return null;
  }

  @Override
  public KBestViterbiParser getFactoredParser() {
    return null;
  }

  @Override
  public void setConstraints(List constraints) {
    this.constraints = constraints;
  }

  @Override
  public boolean saidMemMessage() {
    return false;
  }

  @Override
  public boolean parseSucceeded() {
    return success;
  }

  /** TODO: skip sentences which are too long */
  @Override
  public boolean parseSkipped() {
    return false;
  }

  @Override
  public boolean parseFallback() {
    return false;
  }

  /** TODO: add memory handling? */
  @Override
  public boolean parseNoMemory() {
    return false;
  }

  @Override
  public boolean parseUnparsable() {
    return unparsable;
  }

  @Override
  public List originalSentence() {
    return originalSentence;
  }

  /**
   * TODO: clearly this should be a default method in ParserQuery once Java 8 comes out
   */
  @Override
  public void restoreOriginalWords(Tree tree) {
    if (originalSentence == null || tree == null) {
      return;
    }
    List leaves = tree.getLeaves();
    if (leaves.size() != originalSentence.size()) {
      throw new IllegalStateException("originalWords and sentence of different sizes: " + originalSentence.size() + " vs. " + leaves.size() +
                                      "\n Orig: " + Sentence.listToString(originalSentence) +
                                      "\n Pars: " + Sentence.listToString(leaves));
    }
    // TODO: get rid of this cast
    Iterator wordsIterator = (Iterator) originalSentence.iterator();
    for (Tree leaf : leaves) {
      leaf.setLabel(wordsIterator.next());
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy