All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.parser.nndep.DependencyParser Maven / Gradle / Ivy

Go to download

Stanford Parser processes raw text in English, Chinese, German, Arabic, and French, and extracts constituency parse trees.

There is a newer version: 3.9.2
Show newest version
package edu.stanford.nlp.parser.nndep;

import edu.stanford.nlp.international.Language;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.HasTag;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.ling.IndexedWord;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.ling.Word;
import edu.stanford.nlp.process.DocumentPreprocessor;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.tagger.maxent.MaxentTagger;
import edu.stanford.nlp.trees.EnglishGrammaticalRelations;
import edu.stanford.nlp.trees.EnglishGrammaticalStructure;
import edu.stanford.nlp.trees.GrammaticalRelation;
import edu.stanford.nlp.trees.GrammaticalStructure;
import edu.stanford.nlp.trees.TreeGraphNode;
import edu.stanford.nlp.trees.TypedDependency;
import edu.stanford.nlp.trees.UniversalEnglishGrammaticalRelations;
import edu.stanford.nlp.trees.UniversalEnglishGrammaticalStructure;
import edu.stanford.nlp.trees.international.pennchinese.ChineseGrammaticalRelations;
import edu.stanford.nlp.trees.international.pennchinese.ChineseGrammaticalStructure;
import edu.stanford.nlp.util.*;
import edu.stanford.nlp.util.logging.Redwood;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.Writer;
import java.util.*;

import static java.util.stream.Collectors.toList;

/**
 * This class defines a transition-based dependency parser which makes
 * use of a classifier powered by a neural network. The neural network
 * accepts distributed representation inputs: dense, continuous
 * representations of words, their part of speech tags, and the labels
 * which connect words in a partial dependency parse.
 *
 * 

* This is an implementation of the method described in * *

* Danqi Chen and Christopher Manning. A Fast and Accurate Dependency * Parser Using Neural Networks. In EMNLP 2014. *
* *

* The parser can also be used from the command line to train models and to parse text. * New models can be trained from the command line; see the {@link #main} method * for details on training options. The parser can parse either plain text files or * CoNLL-X format files and output * CoNLL-X format predictions; again see {@link #main} for available options. * (The options available for things like tokenization and sentence splitting * in this class are not as extensive as and not necessarily consistent with * the options of other classes like {@code LexicalizedParser} and {@code StanfordCoreNLP}. * *

* This parser can also be used programmatically. The easiest way to * prepare the parser with a pre-trained model is to call * {@link #loadFromModelFile(String)}. Then call * {@link #predict(edu.stanford.nlp.util.CoreMap)} on the returned * parser instance in order to get new parses. * * @author Danqi Chen ([email protected]) * @author Jon Gauthier */ public class DependencyParser { /** A logger for this class */ private static final Redwood.RedwoodChannels log = Redwood.channels(DependencyParser.class); public static final String DEFAULT_MODEL = "edu/stanford/nlp/models/parser/nndep/english_UD.gz"; /** * Words, parts of speech, and dependency relation labels which were * observed in our corpus / stored in the model * * @see #genDictionaries(java.util.List, java.util.List) */ private List knownWords, knownPos, knownLabels; /** Return the set of part-of-speech tags of this parser. We normalize it a bit to help it match what * other parsers use. * * @return Set of POS tags */ public Set getPosSet() { Set foo = Generics.newHashSet(knownPos); // Don't really understand why these ones are there, but remove them. [CDM 2016] foo.remove("-NULL-"); foo.remove("-UNKNOWN-"); foo.remove("-ROOT-"); // but our other models do include an EOS tag foo.add(".$$."); return Collections.unmodifiableSet(foo); } /** * Mapping from word / POS / dependency relation label to integer ID */ private Map wordIDs, posIDs, labelIDs; private List preComputed; /** * Given a particular parser configuration, this classifier will * predict the best transition to make next. * * The {@link edu.stanford.nlp.parser.nndep.Classifier} class * handles both training and inference. */ private Classifier classifier; private ParsingSystem system; private final Config config; /** * Language used to generate * {@link edu.stanford.nlp.trees.GrammaticalRelation} instances. */ private final Language language; DependencyParser() { this(new Properties()); } public DependencyParser(Properties properties) { config = new Config(properties); // Convert Languages.Language instance to // GrammaticalLanguage.Language this.language = config.language; } /** * Get an integer ID for the given word. This ID can be used to index * into the embeddings {@link Classifier#E}. * * @return An ID for the given word, or an ID referring to a generic * "unknown" word if the word is unknown */ public int getWordID(String s) { return wordIDs.containsKey(s) ? wordIDs.get(s) : wordIDs.get(Config.UNKNOWN); } public int getPosID(String s) { return posIDs.containsKey(s) ? posIDs.get(s) : posIDs.get(Config.UNKNOWN); } public int getLabelID(String s) { return labelIDs.get(s); } public List getFeatures(Configuration c) { // Presize the arrays for very slight speed gain. Hardcoded, but so is the current feature list. List fWord = new ArrayList<>(18); List fPos = new ArrayList<>(18); List fLabel = new ArrayList<>(12); for (int j = 2; j >= 0; --j) { int index = c.getStack(j); fWord.add(getWordID(c.getWord(index))); fPos.add(getPosID(c.getPOS(index))); } for (int j = 0; j <= 2; ++j) { int index = c.getBuffer(j); fWord.add(getWordID(c.getWord(index))); fPos.add(getPosID(c.getPOS(index))); } for (int j = 0; j <= 1; ++j) { int k = c.getStack(j); int index = c.getLeftChild(k); fWord.add(getWordID(c.getWord(index))); fPos.add(getPosID(c.getPOS(index))); fLabel.add(getLabelID(c.getLabel(index))); index = c.getRightChild(k); fWord.add(getWordID(c.getWord(index))); fPos.add(getPosID(c.getPOS(index))); fLabel.add(getLabelID(c.getLabel(index))); index = c.getLeftChild(k, 2); fWord.add(getWordID(c.getWord(index))); fPos.add(getPosID(c.getPOS(index))); fLabel.add(getLabelID(c.getLabel(index))); index = c.getRightChild(k, 2); fWord.add(getWordID(c.getWord(index))); fPos.add(getPosID(c.getPOS(index))); fLabel.add(getLabelID(c.getLabel(index))); index = c.getLeftChild(c.getLeftChild(k)); fWord.add(getWordID(c.getWord(index))); fPos.add(getPosID(c.getPOS(index))); fLabel.add(getLabelID(c.getLabel(index))); index = c.getRightChild(c.getRightChild(k)); fWord.add(getWordID(c.getWord(index))); fPos.add(getPosID(c.getPOS(index))); fLabel.add(getLabelID(c.getLabel(index))); } List feature = new ArrayList<>(48); feature.addAll(fWord); feature.addAll(fPos); feature.addAll(fLabel); return feature; } private static final int POS_OFFSET = 18; private static final int DEP_OFFSET = 36; private static final int STACK_OFFSET = 6; private static final int STACK_NUMBER = 6; private int[] getFeatureArray(Configuration c) { int[] feature = new int[config.numTokens]; // positions 0-17 hold fWord, 18-35 hold fPos, 36-47 hold fLabel for (int j = 2; j >= 0; --j) { int index = c.getStack(j); feature[2-j] = getWordID(c.getWord(index)); feature[POS_OFFSET + (2-j)] = getPosID(c.getPOS(index)); } for (int j = 0; j <= 2; ++j) { int index = c.getBuffer(j); feature[3 + j] = getWordID(c.getWord(index)); feature[POS_OFFSET + 3 + j] = getPosID(c.getPOS(index)); } for (int j = 0; j <= 1; ++j) { int k = c.getStack(j); int index = c.getLeftChild(k); feature[STACK_OFFSET + j * STACK_NUMBER] = getWordID(c.getWord(index)); feature[POS_OFFSET + STACK_OFFSET + j * STACK_NUMBER] = getPosID(c.getPOS(index)); feature[DEP_OFFSET + j * STACK_NUMBER] = getLabelID(c.getLabel(index)); index = c.getRightChild(k); feature[STACK_OFFSET + j * STACK_NUMBER + 1] = getWordID(c.getWord(index)); feature[POS_OFFSET + STACK_OFFSET + j * STACK_NUMBER + 1] = getPosID(c.getPOS(index)); feature[DEP_OFFSET + j * STACK_NUMBER + 1] = getLabelID(c.getLabel(index)); index = c.getLeftChild(k, 2); feature[STACK_OFFSET + j * STACK_NUMBER + 2] = getWordID(c.getWord(index)); feature[POS_OFFSET + STACK_OFFSET + j * STACK_NUMBER + 2] = getPosID(c.getPOS(index)); feature[DEP_OFFSET + j * STACK_NUMBER + 2] = getLabelID(c.getLabel(index)); index = c.getRightChild(k, 2); feature[STACK_OFFSET + j * STACK_NUMBER + 3] = getWordID(c.getWord(index)); feature[POS_OFFSET + STACK_OFFSET + j * STACK_NUMBER + 3] = getPosID(c.getPOS(index)); feature[DEP_OFFSET + j * STACK_NUMBER + 3] = getLabelID(c.getLabel(index)); index = c.getLeftChild(c.getLeftChild(k)); feature[STACK_OFFSET + j * STACK_NUMBER + 4] = getWordID(c.getWord(index)); feature[POS_OFFSET + STACK_OFFSET + j * STACK_NUMBER + 4] = getPosID(c.getPOS(index)); feature[DEP_OFFSET + j * STACK_NUMBER + 4] = getLabelID(c.getLabel(index)); index = c.getRightChild(c.getRightChild(k)); feature[STACK_OFFSET + j * STACK_NUMBER + 5] = getWordID(c.getWord(index)); feature[POS_OFFSET + STACK_OFFSET + j * STACK_NUMBER + 5] = getPosID(c.getPOS(index)); feature[DEP_OFFSET + j * STACK_NUMBER + 5] = getLabelID(c.getLabel(index)); } return feature; } public Dataset genTrainExamples(List sents, List trees) { int numTrans = system.numTransitions(); Dataset ret = new Dataset(config.numTokens, numTrans); Counter tokPosCount = new IntCounter<>(); log.info(Config.SEPARATOR); log.info("Generate training examples..."); for (int i = 0; i < sents.size(); ++i) { if (i > 0) { if (i % 1000 == 0) log.info(i + " "); if (i % 10000 == 0 || i == sents.size() - 1) log.info(); } if (trees.get(i).isProjective()) { Configuration c = system.initialConfiguration(sents.get(i)); while (!system.isTerminal(c)) { String oracle = system.getOracle(c, trees.get(i)); List feature = getFeatures(c); List label = new ArrayList<>(); for (int j = 0; j < numTrans; ++j) { String str = system.transitions.get(j); if (str.equals(oracle)) label.add(1); else if (system.canApply(c, str)) label.add(0); else label.add(-1); } ret.addExample(feature, label); for (int j = 0; j < feature.size(); ++j) tokPosCount.incrementCount(feature.get(j) * feature.size() + j); system.apply(c, oracle); } } } log.info("#Train Examples: " + ret.n); List sortedTokens = Counters.toSortedList(tokPosCount, false); preComputed = new ArrayList<>(sortedTokens.subList(0, Math.min(config.numPreComputed, sortedTokens.size()))); return ret; } /** * Generate unique integer IDs for all known words / part-of-speech * tags / dependency relation labels. * * All three of the aforementioned types are assigned IDs from a * continuous range of integers; all IDs 0 <= ID < n_w are word IDs, * all IDs n_w <= ID < n_w + n_pos are POS tag IDs, and so on. */ private void generateIDs() { wordIDs = new HashMap<>(); posIDs = new HashMap<>(); labelIDs = new HashMap<>(); int index = 0; for (String word : knownWords) wordIDs.put(word, (index++)); for (String pos : knownPos) posIDs.put(pos, (index++)); for (String label : knownLabels) labelIDs.put(label, (index++)); } /** * Scan a corpus and store all words, part-of-speech tags, and * dependency relation labels observed. Prepare other structures * which support word / POS / label lookup at train- / run-time. */ private void genDictionaries(List sents, List trees) { // Collect all words (!), etc. in lists, tacking on one sentence // after the other List word = new ArrayList<>(); List pos = new ArrayList<>(); List label = new ArrayList<>(); for (CoreMap sentence : sents) { List tokens = sentence.get(CoreAnnotations.TokensAnnotation.class); for (CoreLabel token : tokens) { word.add(token.word()); pos.add(token.tag()); } } String rootLabel = null; for (DependencyTree tree : trees) for (int k = 1; k <= tree.n; ++k) if (tree.getHead(k) == 0) rootLabel = tree.getLabel(k); else label.add(tree.getLabel(k)); // Generate "dictionaries," possibly with frequency cutoff knownWords = Util.generateDict(word, config.wordCutOff); knownPos = Util.generateDict(pos); knownLabels = Util.generateDict(label); knownLabels.add(0, rootLabel); // Avoid the case that rootLabel equals to one of the other labels for (int k = 1; k < knownLabels.size(); ++ k) if (knownLabels.get(k).equals(rootLabel)) { knownLabels.remove(k); break; } knownWords.add(0, Config.UNKNOWN); knownWords.add(1, Config.NULL); knownWords.add(2, Config.ROOT); knownPos.add(0, Config.UNKNOWN); knownPos.add(1, Config.NULL); knownPos.add(2, Config.ROOT); knownLabels.add(0, Config.NULL); generateIDs(); log.info(Config.SEPARATOR); log.info("#Word: " + knownWords.size()); log.info("#POS:" + knownPos.size()); log.info("#Label: " + knownLabels.size()); } public void writeModelFile(String modelFile) { try { double[][] W1 = classifier.getW1(); double[] b1 = classifier.getb1(); double[][] W2 = classifier.getW2(); double[][] E = classifier.getE(); Writer output = IOUtils.getPrintWriter(modelFile); output.write("language=" + language.toString() + "\n"); output.write("tlp=" + config.tlp.getClass().getCanonicalName() + "\n"); output.write("dict=" + knownWords.size() + "\n"); output.write("pos=" + knownPos.size() + "\n"); output.write("label=" + knownLabels.size() + "\n"); output.write("embeddingSize=" + E[0].length + "\n"); output.write("hiddenSize=" + b1.length + "\n"); output.write("numTokens=" + (W1[0].length / E[0].length) + "\n"); output.write("preComputed=" + preComputed.size() + "\n"); int index = 0; // First write word / POS / label embeddings for (String word : knownWords) { index = writeEmbedding(E[index], output, index, word); } for (String pos : knownPos) { index = writeEmbedding(E[index], output, index, pos); } for (String label : knownLabels) { index = writeEmbedding(E[index], output, index, label); } // Now write classifier weights for (int j = 0; j < W1[0].length; ++j) for (int i = 0; i < W1.length; ++i) { output.write(String.valueOf(W1[i][j])); if (i == W1.length - 1) output.write("\n"); else output.write(" "); } for (int i = 0; i < b1.length; ++i) { output.write(String.valueOf(b1[i])); if (i == b1.length - 1) output.write("\n"); else output.write(" "); } for (int j = 0; j < W2[0].length; ++j) for (int i = 0; i < W2.length; ++i) { output.write(String.valueOf(W2[i][j])); if (i == W2.length - 1) output.write("\n"); else output.write(" "); } // Finish with pre-computation info for (int i = 0; i < preComputed.size(); ++i) { output.write(String.valueOf(preComputed.get(i))); if ((i + 1) % 100 == 0 || i == preComputed.size() - 1) output.write("\n"); else output.write(" "); } output.close(); } catch (IOException e) { throw new RuntimeIOException(e); } } private static int writeEmbedding(double[] doubles, Writer output, int index, String word) throws IOException { output.write(word); for (double aDouble : doubles) { output.write(" " + aDouble); } output.write("\n"); index = index + 1; return index; } /** * Convenience method; see {@link #loadFromModelFile(String, java.util.Properties)}. * * @see #loadFromModelFile(String, java.util.Properties) */ public static DependencyParser loadFromModelFile(String modelFile) { return loadFromModelFile(modelFile, null); } /** * Load a saved parser model. * * @param modelFile Path to serialized model (may be GZipped) * @param extraProperties Extra test-time properties not already associated with model (may be null) * * @return Loaded and initialized (see {@link #initialize(boolean)} model */ public static DependencyParser loadFromModelFile(String modelFile, Properties extraProperties) { DependencyParser parser = extraProperties == null ? new DependencyParser() : new DependencyParser(extraProperties); parser.loadModelFile(modelFile, false); return parser; } /** Load a parser model file, printing out some messages about the grammar in the file. * * @param modelFile The file (classpath resource, etc.) to load the model from. */ public void loadModelFile(String modelFile) { loadModelFile(modelFile, true); } /** helper to check if the model file is new format or not * * @param firstLine the first line of the model file * @return true if this is a new format model file */ private static boolean isModelNewFormat(String firstLine) { return firstLine.startsWith("language="); } private void loadModelFile(String modelFile, boolean verbose) { Timing t = new Timing(); try (BufferedReader input = IOUtils.readerFromString(modelFile)) { log.info("Loading depparse model: " + modelFile + " ... "); String s; // first line in newer saved models is language, legacy models don't store this s = input.readLine(); // check if language was stored if (isModelNewFormat(s)) { // set up language config.language = Config.getLanguage(s.substring(9, s.length() - 1)); // set up tlp s = input.readLine(); String tlpCanonicalName = s.substring(4, s.length()); try { config.tlp = ReflectionLoading.loadByReflection(tlpCanonicalName); log.info("Loaded TreebankLanguagePack: " + tlpCanonicalName); } catch (Exception e) { log.warn("Error: Failed to load TreebankLanguagePack: " + tlpCanonicalName); } s = input.readLine(); } int nDict = Integer.parseInt(s.substring(s.indexOf('=') + 1)); s = input.readLine(); int nPOS = Integer.parseInt(s.substring(s.indexOf('=') + 1)); s = input.readLine(); int nLabel = Integer.parseInt(s.substring(s.indexOf('=') + 1)); s = input.readLine(); int eSize = Integer.parseInt(s.substring(s.indexOf('=') + 1)); s = input.readLine(); int hSize = Integer.parseInt(s.substring(s.indexOf('=') + 1)); s = input.readLine(); int nTokens = Integer.parseInt(s.substring(s.indexOf('=') + 1)); s = input.readLine(); int nPreComputed = Integer.parseInt(s.substring(s.indexOf('=') + 1)); knownWords = new ArrayList<>(); knownPos = new ArrayList<>(); knownLabels = new ArrayList<>(); double[][] E = new double[nDict + nPOS + nLabel][eSize]; String[] splits; int index = 0; for (int k = 0; k < nDict; ++k) { s = input.readLine(); splits = s.split(" "); knownWords.add(splits[0]); for (int i = 0; i < eSize; ++i) E[index][i] = Double.parseDouble(splits[i + 1]); index = index + 1; } for (int k = 0; k < nPOS; ++k) { s = input.readLine(); splits = s.split(" "); knownPos.add(splits[0]); for (int i = 0; i < eSize; ++i) E[index][i] = Double.parseDouble(splits[i + 1]); index = index + 1; } for (int k = 0; k < nLabel; ++k) { s = input.readLine(); splits = s.split(" "); knownLabels.add(splits[0]); for (int i = 0; i < eSize; ++i) E[index][i] = Double.parseDouble(splits[i + 1]); index = index + 1; } generateIDs(); double[][] W1 = new double[hSize][eSize * nTokens]; for (int j = 0; j < W1[0].length; ++j) { s = input.readLine(); splits = s.split(" "); for (int i = 0; i < W1.length; ++i) W1[i][j] = Double.parseDouble(splits[i]); } double[] b1 = new double[hSize]; s = input.readLine(); splits = s.split(" "); for (int i = 0; i < b1.length; ++i) b1[i] = Double.parseDouble(splits[i]); double[][] W2 = new double[nLabel * 2 - 1][hSize]; for (int j = 0; j < W2[0].length; ++j) { s = input.readLine(); splits = s.split(" "); for (int i = 0; i < W2.length; ++i) W2[i][j] = Double.parseDouble(splits[i]); } preComputed = new ArrayList<>(); while (preComputed.size() < nPreComputed) { s = input.readLine(); splits = s.split(" "); for (String split : splits) { preComputed.add(Integer.parseInt(split)); } } config.hiddenSize = hSize; config.embeddingSize = eSize; classifier = new Classifier(config, E, W1, b1, W2, preComputed); } catch (IOException e) { throw new RuntimeIOException(e); } // initialize the loaded parser initialize(verbose); t.done(log, "Initializing dependency parser"); } // TODO this should be a function which returns the embeddings array + embedID // otherwise the class needlessly carries around the extra baggage of `embeddings` // (never again used) for the entire training process private double[][] readEmbedFile(String embedFile, Map embedID) { double[][] embeddings = null; if (embedFile != null) { try (BufferedReader input = IOUtils.readerFromString(embedFile)) { List lines = new ArrayList<>(); for (String s; (s = input.readLine()) != null; ) { lines.add(s); } int nWords = lines.size(); String[] splits = lines.get(0).split("\\s+"); int dim = splits.length - 1; embeddings = new double[nWords][dim]; log.info("Embedding File " + embedFile + ": #Words = " + nWords + ", dim = " + dim); if (dim != config.embeddingSize) throw new IllegalArgumentException("The dimension of embedding file does not match config.embeddingSize"); for (int i = 0; i < lines.size(); ++i) { splits = lines.get(i).split("\\s+"); embedID.put(splits[0], i); for (int j = 0; j < dim; ++j) embeddings[i][j] = Double.parseDouble(splits[j + 1]); } } catch (IOException e) { throw new RuntimeIOException(e); } embeddings = Util.scaling(embeddings, 0, 1.0); } return embeddings; } /** * Train a new dependency parser model. * * @param trainFile Training data * @param devFile Development data (used for regular UAS evaluation * of model) * @param modelFile String to which model should be saved * @param embedFile File containing word embeddings for words used in * training corpus */ public void train(String trainFile, String devFile, String modelFile, String embedFile, String preModel) { log.info("Train File: " + trainFile); log.info("Dev File: " + devFile); log.info("Model File: " + modelFile); log.info("Embedding File: " + embedFile); log.info("Pre-trained Model File: " + preModel); List trainSents = new ArrayList<>(); List trainTrees = new ArrayList<>(); Util.loadConllFile(trainFile, trainSents, trainTrees, config.unlabeled, config.cPOS); Util.printTreeStats("Train", trainTrees); List devSents = new ArrayList<>(); List devTrees = new ArrayList<>(); if (devFile != null) { Util.loadConllFile(devFile, devSents, devTrees, config.unlabeled, config.cPOS); Util.printTreeStats("Dev", devTrees); } genDictionaries(trainSents, trainTrees); //NOTE: remove -NULL-, and the pass it to ParsingSystem List lDict = new ArrayList<>(knownLabels); lDict.remove(0); system = new ArcStandard(config.tlp, lDict, true); // Initialize a classifier; prepare for training setupClassifierForTraining(trainSents, trainTrees, embedFile, preModel); log.info(Config.SEPARATOR); config.printParameters(); long startTime = System.currentTimeMillis(); // Track the best UAS performance we've seen. double bestUAS = 0; for (int iter = 0; iter < config.maxIter; ++iter) { log.info("##### Iteration " + iter); Classifier.Cost cost = classifier.computeCostFunction(config.batchSize, config.regParameter, config.dropProb); log.info("Cost = " + cost.getCost() + ", Correct(%) = " + cost.getPercentCorrect()); classifier.takeAdaGradientStep(cost, config.adaAlpha, config.adaEps); log.info("Elapsed Time: " + (System.currentTimeMillis() - startTime) / 1000.0 + " (s)"); // UAS evaluation if (devFile != null && iter % config.evalPerIter == 0) { // Redo precomputation with updated weights. This is only // necessary because we're updating weights -- for normal // prediction, we just do this once in #initialize classifier.preCompute(); List predicted = devSents.stream().map(this::predictInner).collect(toList()); double uas = config.noPunc ? system.getUASnoPunc(devSents, predicted, devTrees) : system.getUAS(devSents, predicted, devTrees); log.info("UAS: " + uas); if (config.saveIntermediate && uas > bestUAS) { log.info("Exceeds best previous UAS of %f. Saving model file.%n", bestUAS); bestUAS = uas; writeModelFile(modelFile); } } // Clear gradients if (config.clearGradientsPerIter > 0 && iter % config.clearGradientsPerIter == 0) { log.info("Clearing gradient histories.."); classifier.clearGradientHistories(); } } classifier.finalizeTraining(); if (devFile != null) { // Do final UAS evaluation and save if final model beats the // best intermediate one List predicted = devSents.stream().map(this::predictInner).collect(toList()); double uas = config.noPunc ? system.getUASnoPunc(devSents, predicted, devTrees) : system.getUAS(devSents, predicted, devTrees); if (uas > bestUAS) { log.info(String.format("Final model UAS: %f%n", uas)); log.info(String.format("Exceeds best previous UAS of %f. Saving model file..%n", bestUAS)); writeModelFile(modelFile); } } else { writeModelFile(modelFile); } } /** * @see #train(String, String, String, String, String) */ public void train(String trainFile, String devFile, String modelFile, String embedFile) { train(trainFile, devFile, modelFile, embedFile, null); } /** * @see #train(String, String, String, String) */ public void train(String trainFile, String devFile, String modelFile) { train(trainFile, devFile, modelFile, null); } /** * @see #train(String, String, String) */ public void train(String trainFile, String modelFile) { train(trainFile, null, modelFile); } /** * Prepare a classifier for training with the given dataset. */ private void setupClassifierForTraining(List trainSents, List trainTrees, String embedFile, String preModel) { double[][] E = new double[knownWords.size() + knownPos.size() + knownLabels.size()][config.embeddingSize]; double[][] W1 = new double[config.hiddenSize][config.embeddingSize * config.numTokens]; double[] b1 = new double[config.hiddenSize]; double[][] W2 = new double[system.numTransitions()][config.hiddenSize]; // Randomly initialize weight matrices / vectors Random random = Util.getRandom(); for (int i = 0; i < W1.length; ++i) for (int j = 0; j < W1[i].length; ++j) W1[i][j] = random.nextDouble() * 2 * config.initRange - config.initRange; for (int i = 0; i < b1.length; ++i) b1[i] = random.nextDouble() * 2 * config.initRange - config.initRange; for (int i = 0; i < W2.length; ++i) for (int j = 0; j < W2[i].length; ++j) W2[i][j] = random.nextDouble() * 2 * config.initRange - config.initRange; // Read embeddings into `embedID`, `embeddings` Map embedID = new HashMap<>(); double[][] embeddings = readEmbedFile(embedFile, embedID); // Try to match loaded embeddings with words in dictionary int foundEmbed = 0; for (int i = 0; i < E.length; ++i) { int index = -1; if (i < knownWords.size()) { String str = knownWords.get(i); //NOTE: exact match first, and then try lower case.. if (embedID.containsKey(str)) index = embedID.get(str); else if (embedID.containsKey(str.toLowerCase())) index = embedID.get(str.toLowerCase()); } if (index >= 0) { ++foundEmbed; System.arraycopy(embeddings[index], 0, E[i], 0, E[i].length); } else { for (int j = 0; j < E[i].length; ++j) //E[i][j] = random.nextDouble() * config.initRange * 2 - config.initRange; //E[i][j] = random.nextDouble() * 0.2 - 0.1; //E[i][j] = random.nextGaussian() * Math.sqrt(0.1); E[i][j] = random.nextDouble() * 0.02 - 0.01; } } log.info("Found embeddings: " + foundEmbed + " / " + knownWords.size()); if (preModel != null) { try (BufferedReader input = IOUtils.readerFromString(preModel)) { log.info("Loading pre-trained model file: " + preModel + " ... "); String s; s = input.readLine(); int nDict = Integer.parseInt(s.substring(s.indexOf('=') + 1)); s = input.readLine(); int nPOS = Integer.parseInt(s.substring(s.indexOf('=') + 1)); s = input.readLine(); int nLabel = Integer.parseInt(s.substring(s.indexOf('=') + 1)); s = input.readLine(); int eSize = Integer.parseInt(s.substring(s.indexOf('=') + 1)); s = input.readLine(); int hSize = Integer.parseInt(s.substring(s.indexOf('=') + 1)); s = input.readLine(); int nTokens = Integer.parseInt(s.substring(s.indexOf('=') + 1)); s = input.readLine(); String[] splits; for (int k = 0; k < nDict; ++k) { s = input.readLine(); splits = s.split(" "); if (wordIDs.containsKey(splits[0]) && eSize == config.embeddingSize) { int index = getWordID(splits[0]); for (int i = 0; i < eSize; ++i) E[index][i] = Double.parseDouble(splits[i + 1]); } } for (int k = 0; k < nPOS; ++k) { s = input.readLine(); splits = s.split(" "); if (posIDs.containsKey(splits[0]) && eSize == config.embeddingSize) { int index = getPosID(splits[0]); for (int i = 0; i < eSize; ++i) E[index][i] = Double.parseDouble(splits[i + 1]); } } for (int k = 0; k < nLabel; ++k) { s = input.readLine(); splits = s.split(" "); if (labelIDs.containsKey(splits[0]) && eSize == config.embeddingSize) { int index = getLabelID(splits[0]); for (int i = 0; i < eSize; ++i) E[index][i] = Double.parseDouble(splits[i + 1]); } } boolean copyLayer1 = hSize == config.hiddenSize && config.embeddingSize == eSize && config.numTokens == nTokens; if (copyLayer1) { log.info("Copying parameters W1 && b1..."); } for (int j = 0; j < eSize * nTokens; ++j) { s = input.readLine(); if (copyLayer1) { splits = s.split(" "); for (int i = 0; i < hSize; ++i) W1[i][j] = Double.parseDouble(splits[i]); } } s = input.readLine(); if (copyLayer1) { splits = s.split(" "); for (int i = 0; i < hSize; ++i) b1[i] = Double.parseDouble(splits[i]); } boolean copyLayer2 = (nLabel * 2 - 1 == system.numTransitions()) && hSize == config.hiddenSize; if (copyLayer2) log.info("Copying parameters W2..."); for (int j = 0; j < hSize; ++j) { s = input.readLine(); if (copyLayer2) { splits = s.split(" "); for (int i = 0; i < nLabel * 2 - 1; ++i) W2[i][j] = Double.parseDouble(splits[i]); } } } catch (IOException e) { throw new RuntimeIOException(e); } } Dataset trainSet = genTrainExamples(trainSents, trainTrees); classifier = new Classifier(config, trainSet, E, W1, b1, W2, preComputed); } /** * Determine the dependency parse of the given sentence. *

* This "inner" method returns a structure unique to this package; use {@link #predict(edu.stanford.nlp.util.CoreMap)} * for general parsing purposes. */ private DependencyTree predictInner(CoreMap sentence) { int numTrans = system.numTransitions(); Configuration c = system.initialConfiguration(sentence); while (!system.isTerminal(c)) { if (Thread.interrupted()) { // Allow interrupting throw new RuntimeInterruptedException(); } double[] scores = classifier.computeScores(getFeatureArray(c)); double optScore = Double.NEGATIVE_INFINITY; String optTrans = null; for (int j = 0; j < numTrans; ++j) { if (scores[j] > optScore && system.canApply(c, system.transitions.get(j))) { optScore = scores[j]; optTrans = system.transitions.get(j); } } system.apply(c, optTrans); } return c.tree; } /** * Determine the dependency parse of the given sentence using the loaded model. * You must first load a parser before calling this method. * * @throws java.lang.IllegalStateException If parser has not yet been loaded and initialized * (see {@link #initialize(boolean)} */ public GrammaticalStructure predict(CoreMap sentence) { if (system == null) throw new IllegalStateException("Parser has not been " + "loaded and initialized; first load a model."); DependencyTree result = predictInner(sentence); // The rest of this method is just busy-work to convert the // package-local representation into a CoreNLP-standard // GrammaticalStructure. List tokens = sentence.get(CoreAnnotations.TokensAnnotation.class); List dependencies = new ArrayList<>(); IndexedWord root = new IndexedWord(new Word("ROOT")); root.set(CoreAnnotations.IndexAnnotation.class, 0); for (int i = 1; i <= result.n; i++) { int head = result.getHead(i); String label = result.getLabel(i); IndexedWord thisWord = new IndexedWord(tokens.get(i - 1)); IndexedWord headWord = head == 0 ? root : new IndexedWord(tokens.get(head - 1)); GrammaticalRelation relation = head == 0 ? GrammaticalRelation.ROOT : makeGrammaticalRelation(label); dependencies.add(new TypedDependency(relation, headWord, thisWord)); } // Build GrammaticalStructure // TODO ideally submodule should just return GrammaticalStructure TreeGraphNode rootNode = new TreeGraphNode(root); return makeGrammaticalStructure(dependencies, rootNode); } private GrammaticalRelation makeGrammaticalRelation(String label) { GrammaticalRelation stored; switch (language) { case English: stored = EnglishGrammaticalRelations.shortNameToGRel.get(label); if (stored != null) return stored; break; case UniversalEnglish: stored = UniversalEnglishGrammaticalRelations.shortNameToGRel.get(label); if (stored != null) return stored; break; case Chinese: stored = ChineseGrammaticalRelations.shortNameToGRel.get(label); if (stored != null) return stored; break; } return new GrammaticalRelation(language, label, null, GrammaticalRelation.DEPENDENT); } private GrammaticalStructure makeGrammaticalStructure(List dependencies, TreeGraphNode rootNode) { switch (language) { case English: return new EnglishGrammaticalStructure(dependencies, rootNode); case UniversalEnglish: return new UniversalEnglishGrammaticalStructure(dependencies, rootNode); case Chinese: return new ChineseGrammaticalStructure(dependencies, rootNode); // TODO suboptimal: default to UniversalEnglishGrammaticalStructure return default: return new UniversalEnglishGrammaticalStructure(dependencies, rootNode); } } /** * Convenience method for {@link #predict(edu.stanford.nlp.util.CoreMap)}. The tokens of the provided sentence must * also have tag annotations (the parser requires part-of-speech tags). * * @see #predict(edu.stanford.nlp.util.CoreMap) */ public GrammaticalStructure predict(List sentence) { CoreLabel sentenceLabel = new CoreLabel(); List tokens = new ArrayList<>(); int i = 1; for (HasWord wd : sentence) { CoreLabel label; if (wd instanceof CoreLabel) { label = (CoreLabel) wd; if (label.tag() == null) throw new IllegalArgumentException("Parser requires words " + "with part-of-speech tag annotations"); } else { label = new CoreLabel(); label.setValue(wd.word()); label.setWord(wd.word()); if (!(wd instanceof HasTag)) throw new IllegalArgumentException("Parser requires words " + "with part-of-speech tag annotations"); label.setTag(((HasTag) wd).tag()); } label.setIndex(i); i++; tokens.add(label); } sentenceLabel.set(CoreAnnotations.TokensAnnotation.class, tokens); return predict(sentenceLabel); } //TODO: support sentence-only files as input /** Run the parser in the modelFile on a testFile and perhaps save output. * * @param testFile File to parse. In CoNLL-X format. Assumed to have gold answers included. * @param outFile File to write results to in CoNLL-X format. If null, no output is written * @return The LAS score on the dataset */ public double testCoNLL(String testFile, String outFile) { log.info("Test File: " + testFile); Timing timer = new Timing(); List testSents = new ArrayList<>(); List testTrees = new ArrayList<>(); Util.loadConllFile(testFile, testSents, testTrees, config.unlabeled, config.cPOS); // count how much to parse int numWords = 0; int numOOVWords = 0; int numSentences = 0; for (CoreMap testSent : testSents) { numSentences += 1; List tokens = testSent.get(CoreAnnotations.TokensAnnotation.class); for (CoreLabel token : tokens) { String word = token.word(); numWords += 1; if (!wordIDs.containsKey(word)) numOOVWords += 1; } } log.info(String.format("OOV Words: %d / %d = %.2f%%\n", numOOVWords, numWords, numOOVWords * 100.0 / numWords)); List predicted = testSents.stream().map(this::predictInner).collect(toList()); Map result = system.evaluate(testSents, predicted, testTrees); double uas = config.noPunc ? result.get("UASnoPunc") : result.get("UAS"); double las = config.noPunc ? result.get("LASnoPunc") : result.get("LAS"); log.info(String.format("UAS = %.4f%n", uas)); log.info(String.format("LAS = %.4f%n", las)); long millis = timer.stop(); double wordspersec = numWords / (((double) millis) / 1000); double sentspersec = numSentences / (((double) millis) / 1000); log.info(String.format("%s parsed %d words in %d sentences in %.1fs at %.1f w/s, %.1f sent/s.%n", StringUtils.getShortClassName(this), numWords, numSentences, millis / 1000.0, wordspersec, sentspersec)); if (outFile != null) { Util.writeConllFile(outFile, testSents, predicted); } return las; } private void parseTextFile(BufferedReader input, PrintWriter output) { DocumentPreprocessor preprocessor = new DocumentPreprocessor(input); preprocessor.setSentenceFinalPuncWords(config.tlp.sentenceFinalPunctuationWords()); preprocessor.setEscaper(config.escaper); preprocessor.setSentenceDelimiter(config.sentenceDelimiter); preprocessor.setTokenizerFactory(config.tlp.getTokenizerFactory()); Timing timer = new Timing(); MaxentTagger tagger = new MaxentTagger(config.tagger); List> tagged = new ArrayList<>(); for (List sentence : preprocessor) { tagged.add(tagger.tagSentence(sentence)); } log.info(String.format("Tagging completed in %.2f sec.%n", timer.stop() / 1000.0)); timer.start(); int numSentences = 0; for (List taggedSentence : tagged) { GrammaticalStructure parse = predict(taggedSentence); Collection deps = parse.typedDependencies(); for (TypedDependency dep : deps) output.println(dep); output.println(); numSentences++; } long millis = timer.stop(); double seconds = millis / 1000.0; log.info(String.format("Parsed %d sentences in %.2f seconds (%.2f sents/sec).%n", numSentences, seconds, numSentences / seconds)); } /** * Prepare for parsing after a model has been loaded. */ private void initialize(boolean verbose) { if (knownLabels == null) throw new IllegalStateException("Model has not been loaded or trained"); // NOTE: remove -NULL-, and then pass the label set to the ParsingSystem List lDict = new ArrayList<>(knownLabels); lDict.remove(0); system = new ArcStandard(config.tlp, lDict, verbose); // Pre-compute matrix multiplications if (config.numPreComputed > 0) { classifier.preCompute(); } } /** * Explicitly specifies the number of arguments expected with * particular command line options. */ private static final Map numArgs = new HashMap<>(); static { numArgs.put("textFile", 1); numArgs.put("outFile", 1); } /** * A main program for training, testing and using the parser. * *

* You can use this program to train new parsers from treebank data, * evaluate on test treebank data, or parse raw text input. * *

* Sample usages: *

    *
  • * Train a parser with CoNLL treebank data: * {@code java edu.stanford.nlp.parser.nndep.DependencyParser -trainFile trainPath -devFile devPath -embedFile wordEmbeddingFile -embeddingSize wordEmbeddingDimensionality -model modelOutputFile.txt.gz} *
  • *
  • * Parse raw text from a file: * {@code java edu.stanford.nlp.parser.nndep.DependencyParser -model modelOutputFile.txt.gz -textFile rawTextToParse -outFile dependenciesOutputFile.txt} *
  • *
  • * Parse raw text from standard input, writing to standard output: * {@code java edu.stanford.nlp.parser.nndep.DependencyParser -model modelOutputFile.txt.gz -textFile - -outFile -} *
  • *
* *

* See below for more information on all of these training / test options and more. * *

* Input / output options: *

* * * * * * * *
OptionRequired for trainingRequired for testing / parsingDescription
‑devFileOptionalNoPath to a development-set treebank in CoNLL-X format. If provided, the
‑embedFileOptional (highly recommended!)NoA word embedding file, containing distributed representations of English words. Each line of the provided file should contain a single word followed by the elements of the corresponding word embedding (space-delimited). It is not absolutely necessary that all words in the treebank be covered by this embedding file, though the parser's performance will generally improve if you are able to provide better embeddings for more words.
‑modelYesYesPath to a model file. If the path ends in .gz, the model will be read as a Gzipped model file. During training, we write to this path; at test time we read a pre-trained model from this path.
‑textFileNoYes (or testFile)Path to a plaintext file containing sentences to be parsed.
‑testFileNoYes (or textFile)Path to a test-set treebank in CoNLL-X format for final evaluation of the parser.
‑trainFileYesNoPath to a training treebank in CoNLL-X format
* * Training options: * * * * * * * * * * * * * * * * * *
OptionDefaultDescription
‑adaAlpha0.01Global learning rate for AdaGrad training
‑adaEps1e-6Epsilon value added to the denominator of AdaGrad update expression for numerical stability
‑batchSize10000Size of mini-batch used for training
‑clearGradientsPerIter0Clear AdaGrad gradient histories every n iterations. If zero, no gradient clearing is performed.
‑dropProb0.5Dropout probability. For each training example we randomly choose some amount of units to disable in the neural network classifier. This parameter controls the proportion of units "dropped out."
‑embeddingSize50Dimensionality of word embeddings provided
‑evalPerIter100Run full UAS (unlabeled attachment score) evaluation every time we finish this number of iterations. (Only valid if a development treebank is provided with ‑devFile.)
‑hiddenSize200Dimensionality of hidden layer in neural network classifier
‑initRange0.01Bounds of range within which weight matrix elements should be initialized. Each element is drawn from a uniform distribution over the range [-initRange, initRange].
‑maxIter20000Number of training iterations to complete before stopping and saving the final model.
‑numPreComputed100000The parser pre-computes hidden-layer unit activations for particular inputs words at both training and testing time in order to speed up feedforward computation in the neural network. This parameter determines how many words for which we should compute hidden-layer activations.
‑regParameter1e-8Regularization parameter for training
‑saveIntermediatetrueIf true, continually save the model version which gets the highest UAS value on the dev set. (Only valid if a development treebank is provided with ‑devFile.)
‑trainingThreads1Number of threads to use during training. Note that depending on training batch size, it may be unwise to simply choose the maximum amount of threads for your machine. On our 16-core test machines: a batch size of 10,000 runs fastest with around 6 threads; a batch size of 100,000 runs best with around 10 threads.
‑wordCutOff1The parser can optionally ignore rare words by simply choosing an arbitrary "unknown" feature representation for words that appear with frequency less than n in the corpus. This n is controlled by the wordCutOff parameter.
* * Runtime parsing options: * * * * * * *
OptionDefaultDescription
‑escaperN/AOnly applicable for testing with -textFile. If provided, use this word-escaper when parsing raw sentences. (Should be a fully-qualified class name like edu.stanford.nlp.trees.international.arabic.ATBEscaper.)
‑numPreComputed100000The parser pre-computes hidden-layer unit activations for particular inputs words at both training and testing time in order to speed up feedforward computation in the neural network. This parameter determines how many words for which we should compute hidden-layer activations.
‑sentenceDelimiterN/AOnly applicable for testing with -textFile. If provided, assume that the given textFile has already been sentence-split, and that sentences are separated by this delimiter.
‑tagger.modeledu/stanford/nlp/models/pos-tagger/english-left3words/english-left3words-distsim.taggerOnly applicable for testing with -textFile. Path to a part-of-speech tagger to use to pre-tag the raw sentences before parsing.
*/ public static void main(String[] args) { Properties props = StringUtils.argsToProperties(args, numArgs); DependencyParser parser = new DependencyParser(props); // Train with CoNLL-X data if (props.containsKey("trainFile")) { parser.train(props.getProperty("trainFile"), props.getProperty("devFile"), props.getProperty("model"), props.getProperty("embedFile"), props.getProperty("preModel")); } boolean loaded = false; // Test with CoNLL-X data if (props.containsKey("testFile")) { parser.loadModelFile(props.getProperty("model")); loaded = true; parser.testCoNLL(props.getProperty("testFile"), props.getProperty("outFile")); } // Parse raw text data if (props.containsKey("textFile")) { if (!loaded) { parser.loadModelFile(props.getProperty("model")); loaded = true; } String encoding = parser.config.tlp.getEncoding(); String inputFilename = props.getProperty("textFile"); BufferedReader input; try { input = inputFilename.equals("-") ? IOUtils.readerFromStdin(encoding) : IOUtils.readerFromString(inputFilename, encoding); } catch (IOException e) { throw new RuntimeIOException("No input file provided (use -textFile)", e); } String outputFilename = props.getProperty("outFile"); PrintWriter output; try { output = outputFilename == null || outputFilename.equals("-") ? IOUtils.encodedOutputStreamPrintWriter(System.out, encoding, true) : IOUtils.getPrintWriter(outputFilename, encoding); } catch (IOException e) { throw new RuntimeIOException("Error opening output file", e); } parser.parseTextFile(input, output); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy