All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.parser.dvparser.CacheParseHypotheses Maven / Gradle / Ivy

Go to download

Stanford Parser processes raw text in English, Chinese, German, Arabic, and French, and extracts constituency parse trees.

There is a newer version: 3.9.2
Show newest version
package edu.stanford.nlp.parser.dvparser;

import java.io.*;
import java.util.ArrayList;
import java.util.Collection;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.parser.common.ArgUtils;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.trees.BasicCategoryTreeTransformer;
import edu.stanford.nlp.trees.LabeledScoredTreeReaderFactory;
import edu.stanford.nlp.trees.SynchronizedTreeTransformer;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.trees.TreeNormalizer;
import edu.stanford.nlp.trees.TreeReaderFactory;
import edu.stanford.nlp.trees.TreeTransformer;
import edu.stanford.nlp.util.CollectionUtils;
import edu.stanford.nlp.util.ErasureUtils;
import java.util.function.Predicate;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;

public class CacheParseHypotheses {

  static final TreeReaderFactory trf = new LabeledScoredTreeReaderFactory(CoreLabel.factory(), new TreeNormalizer());

  final BasicCategoryTreeTransformer treeBasicCategories;
  final public Predicate treeFilter;

  public CacheParseHypotheses(LexicalizedParser parser) {
    treeBasicCategories = new BasicCategoryTreeTransformer(parser.treebankLanguagePack());
    treeFilter = new FilterConfusingRules(parser);
  }

  public byte[] convertToBytes(List input) {
    try {
      ByteArrayOutputStream bos = new ByteArrayOutputStream();
      GZIPOutputStream gos = new GZIPOutputStream(bos);
      ObjectOutputStream oos = new ObjectOutputStream(gos);
      List transformed = CollectionUtils.transformAsList(input, treeBasicCategories);
      List filtered = CollectionUtils.filterAsList(transformed, treeFilter);
      oos.writeObject(filtered.size());
      for (Tree tree : filtered) {
        oos.writeObject(tree.toString());
      }
      oos.close();
      gos.close();
      bos.close();
      return bos.toByteArray();
    } catch (IOException e) {
      throw new RuntimeIOException(e);
    }
  }

  public IdentityHashMap convertToBytes(IdentityHashMap> uncompressed) {
    IdentityHashMap compressed = Generics.newIdentityHashMap();
    for (Map.Entry> entry : uncompressed.entrySet()) {
      compressed.put(entry.getKey(), convertToBytes(entry.getValue()));
    }
    return compressed;
  }

  public static List convertToTrees(byte[] input) {
    try {
      List output = new ArrayList<>();
      ByteArrayInputStream bis = new ByteArrayInputStream(input);
      GZIPInputStream gis = new GZIPInputStream(bis);
      ObjectInputStream ois = new ObjectInputStream(gis);
      int size = ErasureUtils.uncheckedCast(ois.readObject());
      for (int i = 0; i < size; ++i) {
        String rawTree = ErasureUtils.uncheckedCast(ois.readObject());
        Tree tree = Tree.valueOf(rawTree, trf);
        tree.setSpans();
        output.add(tree);
      }
      ois.close();
      gis.close();
      bis.close();
      return output;
    } catch (IOException e) {
      throw new RuntimeIOException(e);
    } catch (ClassNotFoundException e) {
      throw new RuntimeException(e);
    }
  }

  public static IdentityHashMap> convertToTrees(IdentityHashMap compressed, int numThreads) {
    return convertToTrees(compressed.keySet(), compressed, numThreads);
  }

  static class DecompressionProcessor implements ThreadsafeProcessor> {
    @Override
    public List process(byte[] input) {
      return convertToTrees(input);
    }

    @Override
    public ThreadsafeProcessor> newInstance() {
      // should be threadsafe
      return this;
    }
  }

  public static IdentityHashMap> convertToTrees(Collection keys, IdentityHashMap compressed,
                                                                 int numThreads) {
    IdentityHashMap> uncompressed = Generics.newIdentityHashMap();
    MulticoreWrapper> wrapper = new MulticoreWrapper<>(numThreads, new DecompressionProcessor());
    for (Tree tree : keys) {
      wrapper.put(compressed.get(tree));
    }
    for (Tree tree : keys) {
      if (!wrapper.peek()) {
        wrapper.join();
      }
      uncompressed.put(tree, wrapper.poll());
    }
    return uncompressed;
  }


  static class CacheProcessor implements ThreadsafeProcessor> {
    CacheParseHypotheses cacher;
    LexicalizedParser parser;
    int dvKBest;
    TreeTransformer transformer;

    public CacheProcessor(CacheParseHypotheses cacher, LexicalizedParser parser, int dvKBest, TreeTransformer transformer) {
      this.cacher = cacher;
      this.parser = parser;
      this.dvKBest = dvKBest;
      this.transformer = transformer;
    }

    @Override
    public Pair process(Tree tree) {
      List topParses = DVParser.getTopParsesForOneTree(parser, dvKBest, tree, transformer);
      // this block is a test to make sure the conversion code is working...
      List converted = cacher.convertToTrees(cacher.convertToBytes(topParses));
      List simplified = CollectionUtils.transformAsList(topParses, cacher.treeBasicCategories);
      simplified = CollectionUtils.filterAsList(simplified, cacher.treeFilter);
      if (simplified.size() != topParses.size()) {
        System.err.println("Filtered " + (topParses.size() - simplified.size()) + " trees");
        if (simplified.size() == 0) {
          System.err.println(" WARNING: filtered all trees for " + tree);
        }
      }
      if (!simplified.equals(converted)) {
        if (converted.size() != simplified.size()) {
          throw new AssertionError("horrible error: tree sizes not equal, " + converted.size() + " vs " + simplified.size());
        }
        for (int i = 0; i < converted.size(); ++i) {
          if (!simplified.get(i).equals(converted.get(i))) {
            System.out.println("=============================");
            System.out.println(simplified.get(i));
            System.out.println("=============================");
            System.out.println(converted.get(i));
            System.out.println("=============================");
            throw new AssertionError("horrible error: tree " + i + " not equal for base tree " + tree);
          }
        }
      }
      return Pair.makePair(tree, cacher.convertToBytes(topParses));
    }

    @Override
    public ThreadsafeProcessor> newInstance() {
      // should be threadsafe
      return this;
    }
  }


  /**
   * An example of a command line is
   * 
* java -mx1g edu.stanford.nlp.parser.dvparser.CacheParseHypotheses -model /scr/horatio/dvparser/wsjPCFG.nocompact.simple.ser.gz -output cached9.simple.ser.gz -treebank /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj 200-202 *
* java -mx4g edu.stanford.nlp.parser.dvparser.CacheParseHypotheses -model ~/scr/dvparser/wsjPCFG.nocompact.simple.ser.gz -output cached.train.simple.ser.gz -treebank /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj 200-2199 -numThreads 6 *
* java -mx4g edu.stanford.nlp.parser.dvparser.CacheParseHypotheses -model ~/scr/dvparser/chinese/xinhuaPCFG.ser.gz -output cached.xinhua.train.ser.gz -treebank /afs/ir/data/linguistic-data/Chinese-Treebank/6/data/utf8/bracketed 026-270,301-499,600-999 */ public static void main(String[] args) throws IOException { String parserModel = null; String output = null; List> treebanks = Generics.newArrayList(); int dvKBest = 200; int numThreads = 1; for (int argIndex = 0; argIndex < args.length; ) { if (args[argIndex].equalsIgnoreCase("-dvKBest")) { dvKBest = Integer.valueOf(args[argIndex + 1]); argIndex += 2; continue; } if (args[argIndex].equalsIgnoreCase("-parser") || args[argIndex].equals("-model")) { parserModel = args[argIndex + 1]; argIndex += 2; continue; } if (args[argIndex].equalsIgnoreCase("-output")) { output = args[argIndex + 1]; argIndex += 2; continue; } if (args[argIndex].equalsIgnoreCase("-treebank")) { Pair treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-treebank"); argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1; treebanks.add(treebankDescription); continue; } if (args[argIndex].equalsIgnoreCase("-numThreads")) { numThreads = Integer.valueOf(args[argIndex + 1]); argIndex += 2; continue; } throw new IllegalArgumentException("Unknown argument " + args[argIndex]); } if (parserModel == null) { throw new IllegalArgumentException("Need to supply a parser model with -model"); } if (output == null) { throw new IllegalArgumentException("Need to supply an output filename with -output"); } if (treebanks.size() == 0) { throw new IllegalArgumentException("Need to supply a treebank with -treebank"); } System.err.println("Writing output to " + output); System.err.println("Loading parser model " + parserModel); System.err.println("Writing " + dvKBest + " hypothesis trees for each tree"); LexicalizedParser parser = LexicalizedParser.loadModel(parserModel, "-dvKBest", Integer.toString(dvKBest)); CacheParseHypotheses cacher = new CacheParseHypotheses(parser); TreeTransformer transformer = DVParser.buildTrainTransformer(parser.getOp()); List sentences = new ArrayList<>(); for (Pair description : treebanks) { System.err.println("Reading trees from " + description.first); Treebank treebank = parser.getOp().tlpParams.memoryTreebank(); treebank.loadPath(description.first, description.second); treebank = treebank.transform(transformer); sentences.addAll(treebank); } System.err.println("Processing " + sentences.size() + " trees"); List> cache = Generics.newArrayList(); transformer = new SynchronizedTreeTransformer(transformer); MulticoreWrapper> wrapper = new MulticoreWrapper<>(numThreads, new CacheProcessor(cacher, parser, dvKBest, transformer)); for (Tree tree : sentences) { wrapper.put(tree); while (wrapper.peek()) { cache.add(wrapper.poll()); if (cache.size() % 10 == 0) { System.out.println("Processed " + cache.size() + " trees"); } } } wrapper.join(); while (wrapper.peek()) { cache.add(wrapper.poll()); if (cache.size() % 10 == 0) { System.out.println("Processed " + cache.size() + " trees"); } } System.out.println("Finished processing " + cache.size() + " trees"); IOUtils.writeObjectToFile(cache, output); } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy