All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.parser.dvparser.DVParser Maven / Gradle / Ivy

Go to download

Stanford Parser processes raw text in English, Chinese, German, Arabic, and French, and extracts constituency parse trees.

There is a newer version: 3.9.2
Show newest version
package edu.stanford.nlp.parser.dvparser;

import java.io.FileFilter;
import java.io.FileWriter;
import java.io.IOException;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Random;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.Word;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.parser.common.ArgUtils;
import edu.stanford.nlp.parser.common.ParserQuery;
import edu.stanford.nlp.parser.lexparser.EvaluateTreebank;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.parser.lexparser.Options;
import edu.stanford.nlp.parser.lexparser.TrainOptions;
import edu.stanford.nlp.trees.CompositeTreeTransformer;
import edu.stanford.nlp.trees.TreeTransformer;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.trees.Trees;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ScoredObject;
import edu.stanford.nlp.util.Timing;

/**
 * @author John Bauer & Richard Socher
 */
public class DVParser {
  DVModel dvModel;
  LexicalizedParser parser;
  Options op;

  public Options getOp() {
    return op;
  }

  DVModel getDVModel() {
    return dvModel;
  }

  private static final NumberFormat NF = new DecimalFormat("0.00");
  private static final NumberFormat FILENAME = new DecimalFormat("0000");

  static public List getTopParsesForOneTree(LexicalizedParser parser, int dvKBest, Tree tree,
                                                  TreeTransformer transformer) {
    ParserQuery pq = parser.parserQuery();
    List sentence = tree.yieldWords();
    // Since the trees are binarized and otherwise manipulated, we
    // need to chop off the last word in order to remove the end of
    // sentence symbol
    if (sentence.size() <= 1) {
      return null;
    }
    sentence = sentence.subList(0, sentence.size() - 1);
    if (!pq.parse(sentence)) {
      System.err.println("Failed to use the given parser to reparse sentence \"" + sentence + "\"");
      return null;
    }
    List parses = new ArrayList();
    List> bestKParses = pq.getKBestPCFGParses(dvKBest);
    for (ScoredObject so : bestKParses) {
      Tree result = so.object();
      if (transformer != null) {
        result = transformer.transformTree(result);
      }
      parses.add(result);
    }
    return parses;
  }

  static IdentityHashMap> getTopParses(LexicalizedParser parser, Options op,
                                                        Collection trees, TreeTransformer transformer,
                                                        boolean outputUpdates) {
    IdentityHashMap> topParses = new IdentityHashMap>();
    for (Tree tree : trees) {
      List parses = getTopParsesForOneTree(parser, op.trainOptions.dvKBest, tree, transformer);
      topParses.put(tree, parses);
      if (outputUpdates && topParses.size() % 10 == 0) {
        System.err.println("Processed " + topParses.size() + " trees");
      }
    }
    if (outputUpdates) {
      System.err.println("Finished processing " + topParses.size() + " trees");
    }
    return topParses;
  }

  IdentityHashMap> getTopParses(List trees, TreeTransformer transformer) {
    return getTopParses(parser, op, trees, transformer, false);
  }

  public void train(List sentences, IdentityHashMap compressedParses, Treebank testTreebank, String modelPath, String resultsRecordPath) throws IOException {
    // process:
    //   we come up with a cost and a derivative for the model
    //   we always use the gold tree as the example to train towards
    //   every time through, we will look at the top N trees from
    //     the LexicalizedParser and pick the best one according to
    //     our model (at the start, this is essentially random)
    // we use QN to minimize the cost function for the model
    // to do this minimization, we turn all of the matrices in the
    //   DVModel into one big Theta, which is the set of variables to
    //   be optimized by the QN.

    Timing timing = new Timing();
    long maxTrainTimeMillis = op.trainOptions.maxTrainTimeSeconds * 1000;
    int batchCount = 0;
    int debugCycle = 0;
    double bestLabelF1 = 0.0;

    if (op.trainOptions.useContextWords) {
      for (Tree tree : sentences) {
        Trees.convertToCoreLabels(tree);
        tree.setSpans();
      }
    }

    // for AdaGrad
    double[] sumGradSquare = new double[dvModel.totalParamSize()];
    Arrays.fill(sumGradSquare, 1.0);

    int numBatches = sentences.size() / op.trainOptions.batchSize + 1;
    System.err.println("Training on " + sentences.size() + " trees in " + numBatches + " batches");
    System.err.println("Times through each training batch: " + op.trainOptions.trainingIterations);
    System.err.println("QN iterations per batch: " + op.trainOptions.qnIterationsPerBatch);
    for (int iter = 0; iter < op.trainOptions.trainingIterations; ++iter) {
      List shuffledSentences = new ArrayList(sentences);
      Collections.shuffle(shuffledSentences, dvModel.rand);
      for (int batch = 0; batch < numBatches; ++batch) {
        ++batchCount;
        // This did not help performance
        //System.err.println("Setting AdaGrad's sum of squares to 1...");
        //Arrays.fill(sumGradSquare, 1.0);

        System.err.println("======================================");
        System.err.println("Iteration " + iter + " batch " + batch);

        // Each batch will be of the specified batch size, except the
        // last batch will include any leftover trees at the end of
        // the list
        int startTree = batch * op.trainOptions.batchSize;
        int endTree = (batch + 1) * op.trainOptions.batchSize;
        if (endTree + op.trainOptions.batchSize > shuffledSentences.size()) {
          endTree = shuffledSentences.size();
        }

        executeOneTrainingBatch(shuffledSentences.subList(startTree, endTree), compressedParses, sumGradSquare);

        long totalElapsed = timing.report();
        System.err.println("Finished iteration " + iter + " batch " + batch + "; total training time " + totalElapsed + " ms");

        if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) {
          // no need to debug output, we're done now
          break;
        }

        if (op.trainOptions.debugOutputFrequency > 0 && batchCount % op.trainOptions.debugOutputFrequency == 0) {
          System.err.println("Finished " + batchCount + " total batches, running evaluation cycle");
          // Time for debugging output!
          double tagF1 = 0.0;
          double labelF1 = 0.0;
          if (testTreebank != null) {
            EvaluateTreebank evaluator = new EvaluateTreebank(attachModelToLexicalizedParser());
            evaluator.testOnTreebank(testTreebank);
            labelF1 = evaluator.getLBScore();
            tagF1 = evaluator.getTagScore();
            if (labelF1 > bestLabelF1) {
              bestLabelF1 = labelF1;
            }
            System.err.println("Best label f1 on dev set so far: " + NF.format(bestLabelF1));
          }

          String tempName = null;
          if (modelPath != null) {
            tempName = modelPath;
            if (modelPath.endsWith(".ser.gz")) {
              tempName = modelPath.substring(0, modelPath.length() - 7) + "-" + FILENAME.format(debugCycle) + "-" + NF.format(labelF1) + ".ser.gz";
            }
            saveModel(tempName);
          }

          String statusLine = ("CHECKPOINT:" +
                               " iteration " + iter +
                               " batch " + batch +
                               " labelF1 " + NF.format(labelF1) +
                               " tagF1 " + NF.format(tagF1) +
                               " bestLabelF1 " + NF.format(bestLabelF1) +
                               " model " + tempName +
                               op.trainOptions +
                               " word vectors: " + op.lexOptions.wordVectorFile +
                               " numHid: " + op.lexOptions.numHid);
          System.err.println(statusLine);
          if (resultsRecordPath != null) {
            FileWriter fout = new FileWriter(resultsRecordPath, true); // append
            fout.write(statusLine);
            fout.write("\n");
            fout.close();
          }

          ++debugCycle;
        }
      }
      long totalElapsed = timing.report();

      if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) {
        // no need to debug output, we're done now
        System.err.println("Max training time exceeded, exiting");
        break;
      }
    }
  }

  static final int MINIMIZER = 3;

  public void executeOneTrainingBatch(List trainingBatch, IdentityHashMap compressedParses, double[] sumGradSquare) {
    Timing convertTiming = new Timing();
    convertTiming.doing("Converting trees");
    IdentityHashMap> topParses = CacheParseHypotheses.convertToTrees(trainingBatch, compressedParses, op.trainOptions.trainingThreads);
    convertTiming.done();

    DVParserCostAndGradient gcFunc = new DVParserCostAndGradient(trainingBatch, topParses, dvModel, op);
    double[] theta = dvModel.paramsToVector();

    //maxFuncIter = 10;
    // 1: QNMinimizer, 2: SGD
    switch (MINIMIZER) {
    case (1): {
      QNMinimizer qn = new QNMinimizer(op.trainOptions.qnEstimates, true);
      qn.useMinPackSearch();
      qn.useDiagonalScaling();
      qn.terminateOnAverageImprovement(true);
      qn.terminateOnNumericalZero(true);
      qn.terminateOnRelativeNorm(true);

      theta = qn.minimize(gcFunc, op.trainOptions.qnTolerance, theta, op.trainOptions.qnIterationsPerBatch);
      break;
    }
    case 2:{
      //Minimizer smd = new SGDMinimizer();    	double tol = 1e-4;    	theta = smd.minimize(gcFunc,tol,theta,op.trainOptions.qnIterationsPerBatch);
      double lastCost = 0, currCost = 0;
      boolean firstTime = true;
      for(int i = 0; i < op.trainOptions.qnIterationsPerBatch; i++){
        //gcFunc.calculate(theta);
        double[] grad = gcFunc.derivativeAt(theta);
        currCost = gcFunc.valueAt(theta);
        System.err.println("batch cost: " + currCost);
        //    		if(!firstTime){
        //    			if(currCost > lastCost){
        //    				System.out.println("HOW IS FUNCTION VALUE INCREASING????!!! ... still updating theta");
        //    			}
        //    			if(Math.abs(currCost - lastCost) < 0.0001){
        //    				System.out.println("function value is not decreasing. stop");
        //    			}
        //    		}else{
        //    			firstTime = false;
        //    		}
        lastCost = currCost;
        ArrayMath.addMultInPlace(theta, grad, -1*op.trainOptions.learningRate);
      }
      break;
    }
    case 3:{
      // AdaGrad
      double eps = 1e-3;
      double currCost = 0;
      for(int i = 0; i < op.trainOptions.qnIterationsPerBatch; i++){
        double[] gradf = gcFunc.derivativeAt(theta);
        currCost = gcFunc.valueAt(theta);
        System.err.println("batch cost: " + currCost);
        for (int feature =0; feature sentences, IdentityHashMap compressedParses) {
    System.err.println("Gradient check: converting " + sentences.size() + " compressed trees");
    IdentityHashMap> topParses = CacheParseHypotheses.convertToTrees(sentences, compressedParses, op.trainOptions.trainingThreads);
    System.err.println("Done converting trees");
    DVParserCostAndGradient gcFunc = new DVParserCostAndGradient(sentences, topParses, dvModel, op);
    return gcFunc.gradientCheck(1000, 50, dvModel.paramsToVector());
  }

  public static TreeTransformer buildTrainTransformer(Options op) {
    CompositeTreeTransformer transformer = LexicalizedParser.buildTrainTransformer(op);
    return transformer;
  }

  public LexicalizedParser attachModelToLexicalizedParser() {
    LexicalizedParser newParser = LexicalizedParser.copyLexicalizedParser(parser);
    DVModelReranker reranker = new DVModelReranker(dvModel);
    newParser.reranker = reranker;
    return newParser;
  }

  public void saveModel(String filename) {
    System.err.println("Saving serialized model to " + filename);
    LexicalizedParser newParser = attachModelToLexicalizedParser();
    newParser.saveParserToSerialized(filename);
    System.err.println("... done");
  }

  public static DVParser loadModel(String filename, String[] args) {
    System.err.println("Loading serialized model from " + filename);
    DVParser dvparser;
    try {
      dvparser = IOUtils.readObjectFromFile(filename);
      dvparser.op.setOptions(args);
    } catch (IOException e) {
      throw new RuntimeIOException(e);
    } catch (ClassNotFoundException e) {
      throw new RuntimeIOException(e);
    }
    System.err.println("... done");
    return dvparser;
  }

  public static DVModel getModelFromLexicalizedParser(LexicalizedParser parser) {
    if (!(parser.reranker instanceof DVModelReranker)) {
      throw new IllegalArgumentException("This parser does not contain a DVModel reranker");
    }
    DVModelReranker reranker = (DVModelReranker) parser.reranker;
    return reranker.getModel();
  }

  public static void help() {
    System.err.println("Options supplied by this file:");
    System.err.println("  -model : When training, the name of the model to save.  Otherwise, the name of the model to load.");
    System.err.println("  -parser : When training, the LexicalizedParser to use as the base model.");
    System.err.println("  -cachedTrees : The name of the file containing a treebank with cached parses.  See CacheParseHypotheses.java");
    System.err.println("  -treebank  [filter]: A treebank to use instead of cachedTrees.  Trees will be reparsed.  Slow.");
    System.err.println("  -testTreebank  [filter]: A treebank for testing the model.");
    System.err.println("  -train: Run training over the treebank, testing on the testTreebank.");
    System.err.println("  -continueTraining : The name of a file to continue training.");
    System.err.println("  -nofilter: Rules for the parser will not be filtered based on the training treebank.");
    System.err.println("  -runGradientCheck: Run a gradient check.");
    System.err.println("  -resultsRecord: A file for recording info on intermediate results");
    System.err.println();
    System.err.println("Options overlapping the parser:");
    System.err.println("  -trainingThreads : How many threads to use when training.");
    System.err.println("  -dvKBest : How many hypotheses to use from the underlying parser.");
    System.err.println("  -trainingIterations : When training, how many times to go through the train set.");
    System.err.println("  -regCost : How large of a cost to put on regularization.");
    System.err.println("  -batchSize : How many trees to use in each batch of the training.");
    System.err.println("  -qnIterationsPerBatch : How many steps to take per batch.");
    System.err.println("  -qnEstimates : Parameter for qn optimization.");
    System.err.println("  -qnTolerance : Tolerance for early exit when optimizing a batch.");
    System.err.println("  -debugOutputFrequency : How frequently to score a model when training and write out intermediate models.");
    System.err.println("  -maxTrainTimeSeconds : How long to train before terminating.");
    System.err.println("  -randomSeed : A starting point for the random number generator.  Setting this should lead to repeatable results, even taking into account randomness.  Otherwise, a new random seed will be picked.");
    System.err.println("  -wordVectorFile : A filename to load word vectors from.");
    System.err.println("  -numHid: The size of the matrices.  In most circumstances, should be set to the size of the word vectors.");
    System.err.println("  -learningRate: The rate of optimization when training");
    System.err.println("  -deltaMargin: How much we punish trees for being incorrect when training");
    System.err.println("  -(no)unknownNumberVector: Whether or not to use a word vector for unknown numbers");
    System.err.println("  -(no)unknownDashedWordVectors: Whether or not to split unknown dashed words");
    System.err.println("  -(no)unknownCapsVector: Whether or not to use a word vector for unknown words with capitals");
    System.err.println("  -dvSimplifiedModel: Use a greatly dumbed down DVModel");
    System.err.println("  -scalingForInit: How much to scale matrices when creating a new DVModel");
    System.err.println("  -baseParserWeight: A weight to give the original LexicalizedParser when testing (0.2 seems to work well for English)");
    System.err.println("  -unkWord: The vector representing unknown word in the word vectors file");
    System.err.println("  -transformMatrixType: A couple different methods for initializing transform matrices");
    System.err.println("  -(no)trainWordVectors: whether or not to train the word vectors along with the matrices.  True by default");
  }

  /**
   * An example command line for training a new parser:
   * 
* nohup java -mx6g edu.stanford.nlp.parser.dvparser.DVParser -cachedTrees /scr/nlp/data/dvparser/wsj/cached.wsj.train.simple.ser.gz -train -testTreebank /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj/22 2200-2219 -debugOutputFrequency 400 -nofilter -trainingThreads 5 -parser /u/nlp/data/lexparser/wsjPCFG.nocompact.simple.ser.gz -trainingIterations 40 -batchSize 25 -model /scr/nlp/data/dvparser/wsj/wsj.combine.v2.ser.gz -unkWord "*UNK*" -dvCombineCategories > /scr/nlp/data/dvparser/wsj/wsj.combine.v2.out 2>&1 & */ public static void main(String[] args) throws IOException, ClassNotFoundException { if (args.length == 0) { help(); System.exit(2); } System.err.println("Running DVParser with arguments:"); for (int i = 0; i < args.length; ++i) { System.err.print(" " + args[i]); } System.err.println(); String parserPath = null; String trainTreebankPath = null; FileFilter trainTreebankFilter = null; String cachedTrainTreesPath = null; boolean runGradientCheck = false; boolean runTraining = false; String testTreebankPath = null; FileFilter testTreebankFilter = null; String initialModelPath = null; String modelPath = null; boolean filter = true; String resultsRecordPath = null; List unusedArgs = new ArrayList(); // These parameters can be null or 0 if the model was not // serialized with the new parameters. Setting the options at the // command line will override these defaults. // TODO: if/when we integrate back into the main branch and // rebuild models, we can get rid of this List argsWithDefaults = new ArrayList(Arrays.asList(new String[] { "-wordVectorFile", Options.LexOptions.DEFAULT_WORD_VECTOR_FILE, "-dvKBest", Integer.toString(TrainOptions.DEFAULT_K_BEST), "-batchSize", Integer.toString(TrainOptions.DEFAULT_BATCH_SIZE), "-trainingIterations", Integer.toString(TrainOptions.DEFAULT_TRAINING_ITERATIONS), "-qnIterationsPerBatch", Integer.toString(TrainOptions.DEFAULT_QN_ITERATIONS_PER_BATCH), "-regCost", Double.toString(TrainOptions.DEFAULT_REGCOST), "-learningRate", Double.toString(TrainOptions.DEFAULT_LEARNING_RATE), "-deltaMargin", Double.toString(TrainOptions.DEFAULT_DELTA_MARGIN), "-unknownNumberVector", "-unknownDashedWordVectors", "-unknownCapsVector", "-unknownchinesepercentvector", "-unknownchinesenumbervector", "-unknownchineseyearvector", "-unkWord", "*UNK*", "-transformMatrixType", "DIAGONAL", "-scalingForInit", Double.toString(TrainOptions.DEFAULT_SCALING_FOR_INIT), "-trainWordVectors", } )); argsWithDefaults.addAll(Arrays.asList(args)); args = argsWithDefaults.toArray(new String[argsWithDefaults.size()]); for (int argIndex = 0; argIndex < args.length; ) { if (args[argIndex].equalsIgnoreCase("-parser")) { parserPath = args[argIndex + 1]; argIndex += 2; } else if (args[argIndex].equalsIgnoreCase("-testTreebank")) { Pair treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank"); argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1; testTreebankPath = treebankDescription.first(); testTreebankFilter = treebankDescription.second(); } else if (args[argIndex].equalsIgnoreCase("-treebank")) { Pair treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-treebank"); argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1; trainTreebankPath = treebankDescription.first(); trainTreebankFilter = treebankDescription.second(); } else if (args[argIndex].equalsIgnoreCase("-cachedTrees")) { cachedTrainTreesPath = args[argIndex + 1]; argIndex += 2; } else if (args[argIndex].equalsIgnoreCase("-runGradientCheck")) { runGradientCheck = true; argIndex++; } else if (args[argIndex].equalsIgnoreCase("-train")) { runTraining = true; argIndex++; } else if (args[argIndex].equalsIgnoreCase("-model")) { modelPath = args[argIndex + 1]; argIndex += 2; } else if (args[argIndex].equalsIgnoreCase("-nofilter")) { filter = false; argIndex++; } else if (args[argIndex].equalsIgnoreCase("-continueTraining")) { runTraining = true; filter = false; initialModelPath = args[argIndex + 1]; argIndex += 2; } else if (args[argIndex].equalsIgnoreCase("-resultsRecord")) { resultsRecordPath = args[argIndex + 1]; argIndex += 2; } else { unusedArgs.add(args[argIndex++]); } } if (parserPath == null && modelPath == null) { throw new IllegalArgumentException("Must supply either a base parser model with -parser or a serialized DVParser with -model"); } if (!runTraining && modelPath == null && !runGradientCheck) { throw new IllegalArgumentException("Need to either train a new model, run the gradient check or specify a model to load with -model"); } String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]); DVParser dvparser = null; LexicalizedParser lexparser = null; if (initialModelPath != null) { lexparser = LexicalizedParser.loadModel(initialModelPath, newArgs); DVModel model = getModelFromLexicalizedParser(lexparser); dvparser = new DVParser(model, lexparser); } else if (runTraining || runGradientCheck) { lexparser = LexicalizedParser.loadModel(parserPath, newArgs); dvparser = new DVParser(lexparser); } else if (modelPath != null) { lexparser = LexicalizedParser.loadModel(modelPath, newArgs); DVModel model = getModelFromLexicalizedParser(lexparser); dvparser = new DVParser(model, lexparser); } List trainSentences = new ArrayList(); IdentityHashMap trainCompressedParses = Generics.newIdentityHashMap(); if (cachedTrainTreesPath != null) { for (String path : cachedTrainTreesPath.split(",")) { List> cache = IOUtils.readObjectFromFile(path); for (Pair pair : cache) { trainSentences.add(pair.first()); trainCompressedParses.put(pair.first(), pair.second()); } System.err.println("Read in " + cache.size() + " trees from " + path); } } if (trainTreebankPath != null) { // TODO: make the transformer a member of the model? TreeTransformer transformer = buildTrainTransformer(dvparser.getOp()); Treebank treebank = dvparser.getOp().tlpParams.memoryTreebank();; treebank.loadPath(trainTreebankPath, trainTreebankFilter); treebank = treebank.transform(transformer); System.err.println("Read in " + treebank.size() + " trees from " + trainTreebankPath); CacheParseHypotheses cacher = new CacheParseHypotheses(dvparser.parser); CacheParseHypotheses.CacheProcessor processor = new CacheParseHypotheses.CacheProcessor(cacher, lexparser, dvparser.op.trainOptions.dvKBest, transformer); for (Tree tree : treebank) { trainSentences.add(tree); trainCompressedParses.put(tree, processor.process(tree).second); //System.out.println(tree); } System.err.println("Finished parsing " + treebank.size() + " trees, getting " + dvparser.op.trainOptions.dvKBest + " hypotheses each"); } if ((runTraining || runGradientCheck) && filter) { System.err.println("Filtering rules for the given training set"); dvparser.dvModel.setRulesForTrainingSet(trainSentences, trainCompressedParses); System.err.println("Done filtering rules; " + dvparser.dvModel.numBinaryMatrices + " binary matrices, " + dvparser.dvModel.numUnaryMatrices + " unary matrices, " + dvparser.dvModel.wordVectors.size() + " word vectors"); } //dvparser.dvModel.printAllMatrices(); Treebank testTreebank = null; if (testTreebankPath != null) { System.err.println("Reading in trees from " + testTreebankPath); if (testTreebankFilter != null) { System.err.println("Filtering on " + testTreebankFilter); } testTreebank = dvparser.getOp().tlpParams.memoryTreebank();; testTreebank.loadPath(testTreebankPath, testTreebankFilter); System.err.println("Read in " + testTreebank.size() + " trees for testing"); } // runGradientCheck= true; if (runGradientCheck) { System.err.println("Running gradient check on " + trainSentences.size() + " trees"); dvparser.runGradientCheck(trainSentences, trainCompressedParses); } if (runTraining) { System.err.println("Training the RNN parser"); System.err.println("Current train options: " + dvparser.getOp().trainOptions); dvparser.train(trainSentences, trainCompressedParses, testTreebank, modelPath, resultsRecordPath); if (modelPath != null) { dvparser.saveModel(modelPath); } } if (testTreebankPath != null) { EvaluateTreebank evaluator = new EvaluateTreebank(dvparser.attachModelToLexicalizedParser()); evaluator.testOnTreebank(testTreebank); } System.err.println("Successfully ran DVParser"); } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy