All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.parser.shiftreduce.ShiftReduceParser Maven / Gradle / Ivy

Go to download

Stanford Parser processes raw text in English, Chinese, German, Arabic, and French, and extracts constituency parse trees.

There is a newer version: 3.9.2
Show newest version
// Stanford Parser -- a probabilistic lexicalized NL CFG parser
// Copyright (c) 2002 - 2014 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program.  If not, see http://www.gnu.org/licenses/ .
//
// For more information, bug reports, fixes, contact:
//    Christopher Manning
//    Dept of Computer Science, Gates 2A
//    Stanford CA 94305-9020
//    USA
//    [email protected]
//    https://nlp.stanford.edu/software/srparser.html

package edu.stanford.nlp.parser.shiftreduce;

import java.io.FileFilter;
import java.io.IOException;
import java.io.Serializable;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.Set;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.HasTag;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.ling.Word;
import edu.stanford.nlp.parser.common.ArgUtils;
import edu.stanford.nlp.parser.common.ParserGrammar;
import edu.stanford.nlp.parser.common.ParserQuery;
import edu.stanford.nlp.parser.common.ParserUtils;
import edu.stanford.nlp.parser.lexparser.BinaryHeadFinder;
import edu.stanford.nlp.parser.lexparser.EvaluateTreebank;
import edu.stanford.nlp.parser.lexparser.Options;
import edu.stanford.nlp.parser.lexparser.TreebankLangParserParams;
import edu.stanford.nlp.parser.lexparser.TreeBinarizer;
import edu.stanford.nlp.parser.metrics.ParserQueryEval;
import edu.stanford.nlp.parser.metrics.Eval;
import edu.stanford.nlp.tagger.common.Tagger;
import edu.stanford.nlp.trees.BasicCategoryTreeTransformer;
import edu.stanford.nlp.trees.CompositeTreeTransformer;
import edu.stanford.nlp.trees.HeadFinder;
import edu.stanford.nlp.trees.LabeledScoredTreeNode;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.trees.TreebankLanguagePack;
import edu.stanford.nlp.trees.TreeCoreAnnotations;
import edu.stanford.nlp.trees.Trees;
import edu.stanford.nlp.util.ArrayUtils;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ReflectionLoading;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import edu.stanford.nlp.util.logging.Redwood;


/**
 * A shift-reduce constituency parser.
 * Overview and description available at
 * import edu.stanford.nlp.util.logging.Redwood;
 *
 * @author John Bauer
 */
public class ShiftReduceParser extends ParserGrammar implements Serializable  {

  /** A logger for this class */
  private static final Redwood.RedwoodChannels log = Redwood.channels(ShiftReduceParser.class);

  final ShiftReduceOptions op;

  BaseModel model;

  public ShiftReduceParser(ShiftReduceOptions op) {
    this(op, null);
  }

  public ShiftReduceParser(ShiftReduceOptions op, BaseModel model) {
    this.op = op;
    this.model = model;
  }

  /*
  private void readObject(ObjectInputStream in)
    throws IOException, ClassNotFoundException
  {
    ObjectInputStream.GetField fields = in.readFields();
    op = ErasureUtils.uncheckedCast(fields.get("op", null));

    Index transitionIndex = ErasureUtils.uncheckedCast(fields.get("transitionIndex", null));
    Set knownStates = ErasureUtils.uncheckedCast(fields.get("knownStates", null));
    Set rootStates = ErasureUtils.uncheckedCast(fields.get("rootStates", null));
    Set rootOnlyStates = ErasureUtils.uncheckedCast(fields.get("rootOnlyStates", null));

    FeatureFactory featureFactory = ErasureUtils.uncheckedCast(fields.get("featureFactory", null));
    Map featureWeights = ErasureUtils.uncheckedCast(fields.get("featureWeights", null));
    this.model = new PerceptronModel(op, transitionIndex, knownStates, rootStates, rootOnlyStates, featureFactory, featureWeights);
  }
  */

  @Override
  public Options getOp() {
    return op;
  }

  @Override
  public TreebankLangParserParams getTLPParams() {
    return op.tlpParams;
  }

  @Override
  public TreebankLanguagePack treebankLanguagePack() {
    return getTLPParams().treebankLanguagePack();
  }

  private static final String[] BEAM_FLAGS = { "-beamSize", "4" };

  @Override
  public String[] defaultCoreNLPFlags() {
    if (op.trainOptions().beamSize > 1) {
      return ArrayUtils.concatenate(getTLPParams().defaultCoreNLPFlags(), BEAM_FLAGS);
    } else {
      // TODO: this may result in some options which are useless for
      // this model, such as -retainTmpSubcategories
      return getTLPParams().defaultCoreNLPFlags();
    }
  }

  /**
   * Return an unmodifiableSet containing the known states (including binarization)
   */
  public Set knownStates() {
    return Collections.unmodifiableSet(model.knownStates);
  }

  /** Return the Set of POS tags used in the model. */
  public Set tagSet() {
    return model.tagSet();
  }

  @Override
  public boolean requiresTags() {
    return true;
  }

  @Override
  public ParserQuery parserQuery() {
    return new ShiftReduceParserQuery(this);
  }

  @Override
  public Tree parse(String sentence) {
    if (!getOp().testOptions.preTag) {
      throw new UnsupportedOperationException("Can only parse raw text if a tagger is specified, as the ShiftReduceParser cannot produce its own tags");
    }
    return super.parse(sentence);
  }

  @Override
  public Tree parse(List sentence) {
    ShiftReduceParserQuery pq = new ShiftReduceParserQuery(this);
    if (pq.parse(sentence)) {
      return pq.getBestParse();
    }
    return ParserUtils.xTree(sentence);
  }


  /** TODO: add an eval which measures transition accuracy? */
  @Override
  public List getExtraEvals() {
    return Collections.emptyList();
  }

  @Override
  public List getParserQueryEvals() {
    if (op.testOptions().recordBinarized == null && op.testOptions().recordDebinarized == null) {
      return Collections.emptyList();
    }
    List evals = Generics.newArrayList();
    if (op.testOptions().recordBinarized != null) {
      evals.add(new TreeRecorder(TreeRecorder.Mode.BINARIZED, op.testOptions().recordBinarized));
    }
    if (op.testOptions().recordDebinarized != null) {
      evals.add(new TreeRecorder(TreeRecorder.Mode.DEBINARIZED, op.testOptions().recordDebinarized));
    }
    return evals;
  }

  public static State initialStateFromGoldTagTree(Tree tree) {
    return initialStateFromTaggedSentence(tree.taggedYield());
  }

  public static State initialStateFromTaggedSentence(List words) {
    List preterminals = Generics.newArrayList();
    for (int index = 0; index < words.size(); ++index) {
      HasWord hw = words.get(index);

      CoreLabel wordLabel;
      String tag;
      if (hw instanceof CoreLabel) {
        wordLabel = (CoreLabel) hw;
        tag = wordLabel.tag();
      } else {
        wordLabel = new CoreLabel();
        wordLabel.setValue(hw.word());
        wordLabel.setWord(hw.word());
        if (!(hw instanceof HasTag)) {
          throw new IllegalArgumentException("Expected tagged words");
        }
        tag = ((HasTag) hw).tag();
        wordLabel.setTag(tag);
      }
      if (tag == null) {
        throw new IllegalArgumentException("Input word not tagged");
      }
      CoreLabel tagLabel = new CoreLabel();
      tagLabel.setValue(tag);

      // Index from 1.  Tools downstream from the parser expect that
      // Internally this parser uses the index, so we have to
      // overwrite incorrect indices if the label is already indexed
      wordLabel.setIndex(index + 1);
      tagLabel.setIndex(index + 1);

      LabeledScoredTreeNode wordNode = new LabeledScoredTreeNode(wordLabel);
      LabeledScoredTreeNode tagNode = new LabeledScoredTreeNode(tagLabel);
      tagNode.addChild(wordNode);

      // TODO: can we get away with not setting these on the wordLabel?
      wordLabel.set(TreeCoreAnnotations.HeadWordLabelAnnotation.class, wordLabel);
      wordLabel.set(TreeCoreAnnotations.HeadTagLabelAnnotation.class, tagLabel);
      tagLabel.set(TreeCoreAnnotations.HeadWordLabelAnnotation.class, wordLabel);
      tagLabel.set(TreeCoreAnnotations.HeadTagLabelAnnotation.class, tagLabel);

      preterminals.add(tagNode);
    }
    return new State(preterminals);
  }

  public static ShiftReduceOptions buildTrainingOptions(String tlppClass, String[] args) {
    ShiftReduceOptions op = new ShiftReduceOptions();
    op.setOptions("-forceTags", "-debugOutputFrequency", "1", "-quietEvaluation");
    if (tlppClass != null) {
      op.tlpParams = ReflectionLoading.loadByReflection(tlppClass);
    }
    op.setOptions(args);

    if (op.trainOptions.randomSeed == 0) {
      op.trainOptions.randomSeed = System.nanoTime();
      log.info("Random seed not set by options, using " + op.trainOptions.randomSeed);
    }
    return op;
  }

  public Treebank readTreebank(String treebankPath, FileFilter treebankFilter) {
    log.info("Loading trees from " + treebankPath);
    Treebank treebank = op.tlpParams.memoryTreebank();
    treebank.loadPath(treebankPath, treebankFilter);
    log.info("Read in " + treebank.size() + " trees from " + treebankPath);
    return treebank;
  }

  public List readBinarizedTreebank(String treebankPath, FileFilter treebankFilter) {
    Treebank treebank = readTreebank(treebankPath, treebankFilter);
    List binarized = binarizeTreebank(treebank, op);
    log.info("Converted trees to binarized format");
    return binarized;
  }

  public static List binarizeTreebank(Treebank treebank, Options op) {
    TreeBinarizer binarizer = TreeBinarizer.simpleTreeBinarizer(op.tlpParams.headFinder(), op.tlpParams.treebankLanguagePack());
    BasicCategoryTreeTransformer basicTransformer = new BasicCategoryTreeTransformer(op.langpack());
    CompositeTreeTransformer transformer = new CompositeTreeTransformer();
    transformer.addTransformer(binarizer);
    transformer.addTransformer(basicTransformer);

    treebank = treebank.transform(transformer);

    HeadFinder binaryHeadFinder = new BinaryHeadFinder(op.tlpParams.headFinder());
    List binarizedTrees = Generics.newArrayList();
    for (Tree tree : treebank) {
      Trees.convertToCoreLabels(tree);
      tree.percolateHeadAnnotations(binaryHeadFinder);
      // Index from 1.  Tools downstream expect index from 1, so for
      // uses internal to the srparser we have to renormalize the
      // indices, with the result that here we have to index from 1
      tree.indexLeaves(1, true);
      binarizedTrees.add(tree);
    }
    return binarizedTrees;
  }

  public static Set findKnownStates(List binarizedTrees) {
    Set knownStates = Generics.newHashSet();
    for (Tree tree : binarizedTrees) {
      findKnownStates(tree, knownStates);
    }
    return Collections.unmodifiableSet(knownStates);
  }

  public static void findKnownStates(Tree tree, Set knownStates) {
    if (tree.isLeaf() || tree.isPreTerminal()) {
      return;
    }
    if (!ShiftReduceUtils.isTemporary(tree)) {
      knownStates.add(tree.value());
    }
    for (Tree child : tree.children()) {
      findKnownStates(child, knownStates);
    }
  }


  // TODO: factor out the retagging?
  public static void redoTags(Tree tree, Tagger tagger) {
    List words = tree.yieldWords();
    List tagged = tagger.apply(words);
    List




© 2015 - 2024 Weber Informatics LLC | Privacy Policy