All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.parser.dvparser.DVParser Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.parser.dvparser; 
import edu.stanford.nlp.util.logging.Redwood;

import java.io.FileFilter;
import java.io.FileWriter;
import java.io.IOException;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Random;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.Word;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.parser.common.ArgUtils;
import edu.stanford.nlp.parser.common.ParserQuery;
import edu.stanford.nlp.parser.lexparser.EvaluateTreebank;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.parser.lexparser.Options;
import edu.stanford.nlp.parser.lexparser.TrainOptions;
import edu.stanford.nlp.trees.CompositeTreeTransformer;
import edu.stanford.nlp.trees.TreeTransformer;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.trees.Trees;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ScoredObject;
import edu.stanford.nlp.util.Timing;

/**
 * @author John Bauer & Richard Socher
 */
public class DVParser  {

  /** A logger for this class */
  private static Redwood.RedwoodChannels log = Redwood.channels(DVParser.class);
  DVModel dvModel;
  LexicalizedParser parser;
  Options op;

  public Options getOp() {
    return op;
  }

  DVModel getDVModel() {
    return dvModel;
  }

  private static final NumberFormat NF = new DecimalFormat("0.00");
  private static final NumberFormat FILENAME = new DecimalFormat("0000");

  static public List getTopParsesForOneTree(LexicalizedParser parser, int dvKBest, Tree tree,
                                                  TreeTransformer transformer) {
    ParserQuery pq = parser.parserQuery();
    List sentence = tree.yieldWords();
    // Since the trees are binarized and otherwise manipulated, we
    // need to chop off the last word in order to remove the end of
    // sentence symbol
    if (sentence.size() <= 1) {
      return null;
    }
    sentence = sentence.subList(0, sentence.size() - 1);
    if (!pq.parse(sentence)) {
      log.info("Failed to use the given parser to reparse sentence \"" + sentence + "\"");
      return null;
    }
    List parses = new ArrayList<>();
    List> bestKParses = pq.getKBestPCFGParses(dvKBest);
    for (ScoredObject so : bestKParses) {
      Tree result = so.object();
      if (transformer != null) {
        result = transformer.transformTree(result);
      }
      parses.add(result);
    }
    return parses;
  }

  static IdentityHashMap> getTopParses(LexicalizedParser parser, Options op,
                                                        Collection trees, TreeTransformer transformer,
                                                        boolean outputUpdates) {
    IdentityHashMap> topParses = new IdentityHashMap<>();
    for (Tree tree : trees) {
      List parses = getTopParsesForOneTree(parser, op.trainOptions.dvKBest, tree, transformer);
      topParses.put(tree, parses);
      if (outputUpdates && topParses.size() % 10 == 0) {
        log.info("Processed " + topParses.size() + " trees");
      }
    }
    if (outputUpdates) {
      log.info("Finished processing " + topParses.size() + " trees");
    }
    return topParses;
  }

  IdentityHashMap> getTopParses(List trees, TreeTransformer transformer) {
    return getTopParses(parser, op, trees, transformer, false);
  }

  public void train(List sentences, IdentityHashMap compressedParses, Treebank testTreebank, String modelPath, String resultsRecordPath) throws IOException {
    // process:
    //   we come up with a cost and a derivative for the model
    //   we always use the gold tree as the example to train towards
    //   every time through, we will look at the top N trees from
    //     the LexicalizedParser and pick the best one according to
    //     our model (at the start, this is essentially random)
    // we use QN to minimize the cost function for the model
    // to do this minimization, we turn all of the matrices in the
    //   DVModel into one big Theta, which is the set of variables to
    //   be optimized by the QN.

    Timing timing = new Timing();
    long maxTrainTimeMillis = op.trainOptions.maxTrainTimeSeconds * 1000;
    int batchCount = 0;
    int debugCycle = 0;
    double bestLabelF1 = 0.0;

    if (op.trainOptions.useContextWords) {
      for (Tree tree : sentences) {
        Trees.convertToCoreLabels(tree);
        tree.setSpans();
      }
    }

    // for AdaGrad
    double[] sumGradSquare = new double[dvModel.totalParamSize()];
    Arrays.fill(sumGradSquare, 1.0);

    int numBatches = sentences.size() / op.trainOptions.batchSize + 1;
    log.info("Training on " + sentences.size() + " trees in " + numBatches + " batches");
    log.info("Times through each training batch: " + op.trainOptions.trainingIterations);
    log.info("QN iterations per batch: " + op.trainOptions.qnIterationsPerBatch);
    for (int iter = 0; iter < op.trainOptions.trainingIterations; ++iter) {
      List shuffledSentences = new ArrayList<>(sentences);
      Collections.shuffle(shuffledSentences, dvModel.rand);
      for (int batch = 0; batch < numBatches; ++batch) {
        ++batchCount;
        // This did not help performance
        //log.info("Setting AdaGrad's sum of squares to 1...");
        //Arrays.fill(sumGradSquare, 1.0);

        log.info("======================================");
        log.info("Iteration " + iter + " batch " + batch);

        // Each batch will be of the specified batch size, except the
        // last batch will include any leftover trees at the end of
        // the list
        int startTree = batch * op.trainOptions.batchSize;
        int endTree = (batch + 1) * op.trainOptions.batchSize;
        if (endTree > shuffledSentences.size()) {
          endTree = shuffledSentences.size();
        }

        executeOneTrainingBatch(shuffledSentences.subList(startTree, endTree), compressedParses, sumGradSquare);

        long totalElapsed = timing.report();
        log.info("Finished iteration " + iter + " batch " + batch + "; total training time " + totalElapsed + " ms");

        if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) {
          // no need to debug output, we're done now
          break;
        }

        if (op.trainOptions.debugOutputFrequency > 0 && batchCount % op.trainOptions.debugOutputFrequency == 0) {
          log.info("Finished " + batchCount + " total batches, running evaluation cycle");
          // Time for debugging output!
          double tagF1 = 0.0;
          double labelF1 = 0.0;
          if (testTreebank != null) {
            EvaluateTreebank evaluator = new EvaluateTreebank(attachModelToLexicalizedParser());
            evaluator.testOnTreebank(testTreebank);
            labelF1 = evaluator.getLBScore();
            tagF1 = evaluator.getTagScore();
            if (labelF1 > bestLabelF1) {
              bestLabelF1 = labelF1;
            }
            log.info("Best label f1 on dev set so far: " + NF.format(bestLabelF1));
          }

          String tempName = null;
          if (modelPath != null) {
            tempName = modelPath;
            if (modelPath.endsWith(".ser.gz")) {
              tempName = modelPath.substring(0, modelPath.length() - 7) + "-" + FILENAME.format(debugCycle) + "-" + NF.format(labelF1) + ".ser.gz";
            }
            saveModel(tempName);
          }

          String statusLine = ("CHECKPOINT:" +
                               " iteration " + iter +
                               " batch " + batch +
                               " labelF1 " + NF.format(labelF1) +
                               " tagF1 " + NF.format(tagF1) +
                               " bestLabelF1 " + NF.format(bestLabelF1) +
                               " model " + tempName +
                               op.trainOptions +
                               " word vectors: " + op.lexOptions.wordVectorFile +
                               " numHid: " + op.lexOptions.numHid);
          log.info(statusLine);
          if (resultsRecordPath != null) {
            FileWriter fout = new FileWriter(resultsRecordPath, true); // append
            fout.write(statusLine);
            fout.write("\n");
            fout.close();
          }

          ++debugCycle;
        }
      }
      long totalElapsed = timing.report();

      if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) {
        // no need to debug output, we're done now
        log.info("Max training time exceeded, exiting");
        break;
      }
    }
  }

  static final int MINIMIZER = 3;

  public void executeOneTrainingBatch(List trainingBatch, IdentityHashMap compressedParses, double[] sumGradSquare) {
    Timing convertTiming = new Timing();
    convertTiming.doing("Converting trees");
    IdentityHashMap> topParses = CacheParseHypotheses.convertToTrees(trainingBatch, compressedParses, op.trainOptions.trainingThreads);
    convertTiming.done();

    DVParserCostAndGradient gcFunc = new DVParserCostAndGradient(trainingBatch, topParses, dvModel, op);
    double[] theta = dvModel.paramsToVector();

    //maxFuncIter = 10;
    // 1: QNMinimizer, 2: SGD
    switch (MINIMIZER) {
    case (1): {
      QNMinimizer qn = new QNMinimizer(op.trainOptions.qnEstimates, true);
      qn.useMinPackSearch();
      qn.useDiagonalScaling();
      qn.terminateOnAverageImprovement(true);
      qn.terminateOnNumericalZero(true);
      qn.terminateOnRelativeNorm(true);

      theta = qn.minimize(gcFunc, op.trainOptions.qnTolerance, theta, op.trainOptions.qnIterationsPerBatch);
      break;
    }
    case 2:{
      //Minimizer smd = new SGDMinimizer();    	double tol = 1e-4;    	theta = smd.minimize(gcFunc,tol,theta,op.trainOptions.qnIterationsPerBatch);
      double lastCost = 0, currCost = 0;
      boolean firstTime = true;
      for(int i = 0; i < op.trainOptions.qnIterationsPerBatch; i++){
        //gcFunc.calculate(theta);
        double[] grad = gcFunc.derivativeAt(theta);
        currCost = gcFunc.valueAt(theta);
        log.info("batch cost: " + currCost);
        //    		if(!firstTime){
        //    			if(currCost > lastCost){
        //    				System.out.println("HOW IS FUNCTION VALUE INCREASING????!!! ... still updating theta");
        //    			}
        //    			if(Math.abs(currCost - lastCost) < 0.0001){
        //    				System.out.println("function value is not decreasing. stop");
        //    			}
        //    		}else{
        //    			firstTime = false;
        //    		}
        lastCost = currCost;
        ArrayMath.addMultInPlace(theta, grad, -1*op.trainOptions.learningRate);
      }
      break;
    }
    case 3:{
      // AdaGrad
      double eps = 1e-3;
      double currCost = 0;
      for(int i = 0; i < op.trainOptions.qnIterationsPerBatch; i++){
        double[] gradf = gcFunc.derivativeAt(theta);
        currCost = gcFunc.valueAt(theta);
        log.info("batch cost: " + currCost);
        for (int feature =0; feature sentences, IdentityHashMap compressedParses) {
    log.info("Gradient check: converting " + sentences.size() + " compressed trees");
    IdentityHashMap> topParses = CacheParseHypotheses.convertToTrees(sentences, compressedParses, op.trainOptions.trainingThreads);
    log.info("Done converting trees");
    DVParserCostAndGradient gcFunc = new DVParserCostAndGradient(sentences, topParses, dvModel, op);
    return gcFunc.gradientCheck(1000, 50, dvModel.paramsToVector());
  }

  public static TreeTransformer buildTrainTransformer(Options op) {
    CompositeTreeTransformer transformer = LexicalizedParser.buildTrainTransformer(op);
    return transformer;
  }

  public LexicalizedParser attachModelToLexicalizedParser() {
    LexicalizedParser newParser = LexicalizedParser.copyLexicalizedParser(parser);
    DVModelReranker reranker = new DVModelReranker(dvModel);
    newParser.reranker = reranker;
    return newParser;
  }

  public void saveModel(String filename) {
    log.info("Saving serialized model to " + filename);
    LexicalizedParser newParser = attachModelToLexicalizedParser();
    newParser.saveParserToSerialized(filename);
    log.info("... done");
  }

  public static DVParser loadModel(String filename, String[] args) {
    log.info("Loading serialized model from " + filename);
    DVParser dvparser;
    try {
      dvparser = IOUtils.readObjectFromURLOrClasspathOrFileSystem(filename);
      dvparser.op.setOptions(args);
    } catch (IOException e) {
      throw new RuntimeIOException(e);
    } catch (ClassNotFoundException e) {
      throw new RuntimeIOException(e);
    }
    log.info("... done");
    return dvparser;
  }

  public static DVModel getModelFromLexicalizedParser(LexicalizedParser parser) {
    if (!(parser.reranker instanceof DVModelReranker)) {
      throw new IllegalArgumentException("This parser does not contain a DVModel reranker");
    }
    DVModelReranker reranker = (DVModelReranker) parser.reranker;
    return reranker.getModel();
  }

  public static void help() {
    log.info("Options supplied by this file:");
    log.info("  -model : When training, the name of the model to save.  Otherwise, the name of the model to load.");
    log.info("  -parser : When training, the LexicalizedParser to use as the base model.");
    log.info("  -cachedTrees : The name of the file containing a treebank with cached parses.  See CacheParseHypotheses.java");
    log.info("  -treebank  [filter]: A treebank to use instead of cachedTrees.  Trees will be reparsed.  Slow.");
    log.info("  -testTreebank  [filter]: A treebank for testing the model.");
    log.info("  -train: Run training over the treebank, testing on the testTreebank.");
    log.info("  -continueTraining : The name of a file to continue training.");
    log.info("  -nofilter: Rules for the parser will not be filtered based on the training treebank.");
    log.info("  -runGradientCheck: Run a gradient check.");
    log.info("  -resultsRecord: A file for recording info on intermediate results");
    log.info();
    log.info("Options overlapping the parser:");
    log.info("  -trainingThreads : How many threads to use when training.");
    log.info("  -dvKBest : How many hypotheses to use from the underlying parser.");
    log.info("  -trainingIterations : When training, how many times to go through the train set.");
    log.info("  -regCost : How large of a cost to put on regularization.");
    log.info("  -batchSize : How many trees to use in each batch of the training.");
    log.info("  -qnIterationsPerBatch : How many steps to take per batch.");
    log.info("  -qnEstimates : Parameter for qn optimization.");
    log.info("  -qnTolerance : Tolerance for early exit when optimizing a batch.");
    log.info("  -debugOutputFrequency : How frequently to score a model when training and write out intermediate models.");
    log.info("  -maxTrainTimeSeconds : How long to train before terminating.");
    log.info("  -randomSeed : A starting point for the random number generator.  Setting this should lead to repeatable results, even taking into account randomness.  Otherwise, a new random seed will be picked.");
    log.info("  -wordVectorFile : A filename to load word vectors from.");
    log.info("  -numHid: The size of the matrices.  In most circumstances, should be set to the size of the word vectors.");
    log.info("  -learningRate: The rate of optimization when training");
    log.info("  -deltaMargin: How much we punish trees for being incorrect when training");
    log.info("  -(no)unknownNumberVector: Whether or not to use a word vector for unknown numbers");
    log.info("  -(no)unknownDashedWordVectors: Whether or not to split unknown dashed words");
    log.info("  -(no)unknownCapsVector: Whether or not to use a word vector for unknown words with capitals");
    log.info("  -dvSimplifiedModel: Use a greatly dumbed down DVModel");
    log.info("  -scalingForInit: How much to scale matrices when creating a new DVModel");
    log.info("  -baseParserWeight: A weight to give the original LexicalizedParser when testing (0.2 seems to work well for English)");
    log.info("  -unkWord: The vector representing unknown word in the word vectors file");
    log.info("  -transformMatrixType: A couple different methods for initializing transform matrices");
    log.info("  -(no)trainWordVectors: whether or not to train the word vectors along with the matrices.  True by default");
  }

  /**
   * An example command line for training a new parser:
   * 
* nohup java -mx6g edu.stanford.nlp.parser.dvparser.DVParser -cachedTrees /scr/nlp/data/dvparser/wsj/cached.wsj.train.simple.ser.gz -train -testTreebank /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj/22 2200-2219 -debugOutputFrequency 400 -nofilter -trainingThreads 5 -parser /u/nlp/data/lexparser/wsjPCFG.nocompact.simple.ser.gz -trainingIterations 40 -batchSize 25 -model /scr/nlp/data/dvparser/wsj/wsj.combine.v2.ser.gz -unkWord "*UNK*" -dvCombineCategories > /scr/nlp/data/dvparser/wsj/wsj.combine.v2.out 2>&1 & */ public static void main(String[] args) throws IOException, ClassNotFoundException { if (args.length == 0) { help(); System.exit(2); } log.info("Running DVParser with arguments:"); for (String arg : args) { log.info(" " + arg); } log.info(); String parserPath = null; String trainTreebankPath = null; FileFilter trainTreebankFilter = null; String cachedTrainTreesPath = null; boolean runGradientCheck = false; boolean runTraining = false; String testTreebankPath = null; FileFilter testTreebankFilter = null; String initialModelPath = null; String modelPath = null; boolean filter = true; String resultsRecordPath = null; List unusedArgs = new ArrayList<>(); // These parameters can be null or 0 if the model was not // serialized with the new parameters. Setting the options at the // command line will override these defaults. // TODO: if/when we integrate back into the main branch and // rebuild models, we can get rid of this List argsWithDefaults = new ArrayList<>(Arrays.asList(new String[]{ "-wordVectorFile", Options.LexOptions.DEFAULT_WORD_VECTOR_FILE, "-dvKBest", Integer.toString(TrainOptions.DEFAULT_K_BEST), "-batchSize", Integer.toString(TrainOptions.DEFAULT_BATCH_SIZE), "-trainingIterations", Integer.toString(TrainOptions.DEFAULT_TRAINING_ITERATIONS), "-qnIterationsPerBatch", Integer.toString(TrainOptions.DEFAULT_QN_ITERATIONS_PER_BATCH), "-regCost", Double.toString(TrainOptions.DEFAULT_REGCOST), "-learningRate", Double.toString(TrainOptions.DEFAULT_LEARNING_RATE), "-deltaMargin", Double.toString(TrainOptions.DEFAULT_DELTA_MARGIN), "-unknownNumberVector", "-unknownDashedWordVectors", "-unknownCapsVector", "-unknownchinesepercentvector", "-unknownchinesenumbervector", "-unknownchineseyearvector", "-unkWord", "*UNK*", "-transformMatrixType", "DIAGONAL", "-scalingForInit", Double.toString(TrainOptions.DEFAULT_SCALING_FOR_INIT), "-trainWordVectors", })); argsWithDefaults.addAll(Arrays.asList(args)); args = argsWithDefaults.toArray(new String[argsWithDefaults.size()]); for (int argIndex = 0; argIndex < args.length; ) { if (args[argIndex].equalsIgnoreCase("-parser")) { parserPath = args[argIndex + 1]; argIndex += 2; } else if (args[argIndex].equalsIgnoreCase("-testTreebank")) { Pair treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank"); argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1; testTreebankPath = treebankDescription.first(); testTreebankFilter = treebankDescription.second(); } else if (args[argIndex].equalsIgnoreCase("-treebank")) { Pair treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-treebank"); argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1; trainTreebankPath = treebankDescription.first(); trainTreebankFilter = treebankDescription.second(); } else if (args[argIndex].equalsIgnoreCase("-cachedTrees")) { cachedTrainTreesPath = args[argIndex + 1]; argIndex += 2; } else if (args[argIndex].equalsIgnoreCase("-runGradientCheck")) { runGradientCheck = true; argIndex++; } else if (args[argIndex].equalsIgnoreCase("-train")) { runTraining = true; argIndex++; } else if (args[argIndex].equalsIgnoreCase("-model")) { modelPath = args[argIndex + 1]; argIndex += 2; } else if (args[argIndex].equalsIgnoreCase("-nofilter")) { filter = false; argIndex++; } else if (args[argIndex].equalsIgnoreCase("-continueTraining")) { runTraining = true; filter = false; initialModelPath = args[argIndex + 1]; argIndex += 2; } else if (args[argIndex].equalsIgnoreCase("-resultsRecord")) { resultsRecordPath = args[argIndex + 1]; argIndex += 2; } else { unusedArgs.add(args[argIndex++]); } } if (parserPath == null && modelPath == null) { throw new IllegalArgumentException("Must supply either a base parser model with -parser or a serialized DVParser with -model"); } if (!runTraining && modelPath == null && !runGradientCheck) { throw new IllegalArgumentException("Need to either train a new model, run the gradient check or specify a model to load with -model"); } String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]); DVParser dvparser = null; LexicalizedParser lexparser = null; if (initialModelPath != null) { lexparser = LexicalizedParser.loadModel(initialModelPath, newArgs); DVModel model = getModelFromLexicalizedParser(lexparser); dvparser = new DVParser(model, lexparser); } else if (runTraining || runGradientCheck) { lexparser = LexicalizedParser.loadModel(parserPath, newArgs); dvparser = new DVParser(lexparser); } else if (modelPath != null) { lexparser = LexicalizedParser.loadModel(modelPath, newArgs); DVModel model = getModelFromLexicalizedParser(lexparser); dvparser = new DVParser(model, lexparser); } List trainSentences = new ArrayList<>(); IdentityHashMap trainCompressedParses = Generics.newIdentityHashMap(); if (cachedTrainTreesPath != null) { for (String path : cachedTrainTreesPath.split(",")) { List> cache = IOUtils.readObjectFromFile(path); for (Pair pair : cache) { trainSentences.add(pair.first()); trainCompressedParses.put(pair.first(), pair.second()); } log.info("Read in " + cache.size() + " trees from " + path); } } if (trainTreebankPath != null) { // TODO: make the transformer a member of the model? TreeTransformer transformer = buildTrainTransformer(dvparser.getOp()); Treebank treebank = dvparser.getOp().tlpParams.memoryTreebank();; treebank.loadPath(trainTreebankPath, trainTreebankFilter); treebank = treebank.transform(transformer); log.info("Read in " + treebank.size() + " trees from " + trainTreebankPath); CacheParseHypotheses cacher = new CacheParseHypotheses(dvparser.parser); CacheParseHypotheses.CacheProcessor processor = new CacheParseHypotheses.CacheProcessor(cacher, lexparser, dvparser.op.trainOptions.dvKBest, transformer); for (Tree tree : treebank) { trainSentences.add(tree); trainCompressedParses.put(tree, processor.process(tree).second); //System.out.println(tree); } log.info("Finished parsing " + treebank.size() + " trees, getting " + dvparser.op.trainOptions.dvKBest + " hypotheses each"); } if ((runTraining || runGradientCheck) && filter) { log.info("Filtering rules for the given training set"); dvparser.dvModel.setRulesForTrainingSet(trainSentences, trainCompressedParses); log.info("Done filtering rules; " + dvparser.dvModel.numBinaryMatrices + " binary matrices, " + dvparser.dvModel.numUnaryMatrices + " unary matrices, " + dvparser.dvModel.wordVectors.size() + " word vectors"); } //dvparser.dvModel.printAllMatrices(); Treebank testTreebank = null; if (testTreebankPath != null) { log.info("Reading in trees from " + testTreebankPath); if (testTreebankFilter != null) { log.info("Filtering on " + testTreebankFilter); } testTreebank = dvparser.getOp().tlpParams.memoryTreebank();; testTreebank.loadPath(testTreebankPath, testTreebankFilter); log.info("Read in " + testTreebank.size() + " trees for testing"); } // runGradientCheck= true; if (runGradientCheck) { log.info("Running gradient check on " + trainSentences.size() + " trees"); dvparser.runGradientCheck(trainSentences, trainCompressedParses); } if (runTraining) { log.info("Training the RNN parser"); log.info("Current train options: " + dvparser.getOp().trainOptions); dvparser.train(trainSentences, trainCompressedParses, testTreebank, modelPath, resultsRecordPath); if (modelPath != null) { dvparser.saveModel(modelPath); } } if (testTreebankPath != null) { EvaluateTreebank evaluator = new EvaluateTreebank(dvparser.attachModelToLexicalizedParser()); evaluator.testOnTreebank(testTreebank); } log.info("Successfully ran DVParser"); } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy