All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.parser.shiftreduce.PerceptronModel Maven / Gradle / Ivy

Go to download

Stanford Parser processes raw text in English, Chinese, German, Arabic, and French, and extracts constituency parse trees.

There is a newer version: 3.9.2
Show newest version
package edu.stanford.nlp.parser.shiftreduce;

import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import edu.stanford.nlp.parser.common.ParserConstraint;
import edu.stanford.nlp.parser.lexparser.EvaluateTreebank;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.tagger.common.Tagger;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.util.CollectionUtils;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ReflectionLoading;
import edu.stanford.nlp.util.ScoredComparator;
import edu.stanford.nlp.util.ScoredObject;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import edu.stanford.nlp.util.logging.Redwood;


public class PerceptronModel extends BaseModel  {

  /** A logger for this class */
  private static final Redwood.RedwoodChannels log = Redwood.channels(PerceptronModel.class); // Serializable

  private float learningRate = 1.0f;

  Map featureWeights;
  final FeatureFactory featureFactory;

  public PerceptronModel(ShiftReduceOptions op, Index transitionIndex,
                         Set knownStates, Set rootStates, Set rootOnlyStates) {
    super(op, transitionIndex, knownStates, rootStates, rootOnlyStates);
    this.featureWeights = Generics.newHashMap();

    String[] classes = op.featureFactoryClass.split(";");
    if (classes.length == 1) {
      this.featureFactory = ReflectionLoading.loadByReflection(classes[0]);
    } else {
      FeatureFactory[] factories = new FeatureFactory[classes.length];
      for (int i = 0; i < classes.length; ++i) {
        int paren = classes[i].indexOf('(');
        if (paren >= 0) {
          String arg = classes[i].substring(paren + 1, classes[i].length() - 1);
          factories[i] = ReflectionLoading.loadByReflection(classes[i].substring(0, paren), arg);
        } else {
          factories[i] = ReflectionLoading.loadByReflection(classes[i]);
        }
      }
      this.featureFactory = new CombinationFeatureFactory(factories);
    }
  }

  public PerceptronModel(PerceptronModel other) {
    super(other);
    this.featureFactory = other.featureFactory;

    this.featureWeights = Generics.newHashMap();
    for (String feature : other.featureWeights.keySet()) {
      featureWeights.put(feature, new Weight(other.featureWeights.get(feature)));
    }
  }

  private static final NumberFormat NF = new DecimalFormat("0.00");
  private static final NumberFormat FILENAME = new DecimalFormat("0000");

  public void averageScoredModels(Collection> scoredModels) {
    if (scoredModels.isEmpty()) {
      throw new IllegalArgumentException("Cannot average empty models");
    }

    log.info("Averaging " + scoredModels.size() + " models with scores");
    for (ScoredObject model : scoredModels) {
      log.info(" " + NF.format(model.score()));
    }
    log.info();

    List models = CollectionUtils.transformAsList(scoredModels, ScoredObject::object);
    averageModels(models);
  }

  public void averageModels(Collection models) {
    if (models.isEmpty()) {
      throw new IllegalArgumentException("Cannot average empty models");
    }

    Set features = Generics.newHashSet();
    for (PerceptronModel model : models) {
      for (String feature : model.featureWeights.keySet()) {
        features.add(feature);
      }
    }

    featureWeights = Generics.newHashMap();
    for (String feature : features) {
      featureWeights.put(feature, new Weight());
    }

    int numModels = models.size();
    for (String feature : features) {
      for (PerceptronModel model : models) {
        if (!model.featureWeights.containsKey(feature)) {
          continue;
        }
        featureWeights.get(feature).addScaled(model.featureWeights.get(feature), 1.0f / numModels);
      }
    }
  }

  /**
   * Iterate over the feature weight map.
   * For each feature, remove all transitions with score of 0.
   * Any feature with no transitions left is then removed
   */
  private void condenseFeatures() {
    Iterator featureIt = featureWeights.keySet().iterator();
    while (featureIt.hasNext()) {
      String feature = featureIt.next();
      Weight weights = featureWeights.get(feature);
      weights.condense();
      if (weights.size() == 0) {
        featureIt.remove();
      }
    }
  }

  private void filterFeatures(Set keep) {
    Iterator featureIt = featureWeights.keySet().iterator();
    while (featureIt.hasNext()) {
      if (!keep.contains(featureIt.next())) {
        featureIt.remove();
      }
    }
  }


  /**
   * Output some random facts about the model
   */
  public void outputStats() {
    log.info("Number of known features: " + featureWeights.size());
    int numWeights = 0;
    for (Map.Entry stringWeightEntry : featureWeights.entrySet()) {
      numWeights += stringWeightEntry.getValue().size();
    }
    log.info("Number of non-zero weights: " + numWeights);

    int wordLength = 0;
    for (String feature : featureWeights.keySet()) {
      wordLength += feature.length();
    }
    log.info("Total word length: " + wordLength);

    log.info("Number of transitions: " + transitionIndex.size());
  }

  /** Reconstruct the tag set that was used to train the model by decoding some of the features.
   *  This is slow and brittle but should work!  Only if "-" is not in the tag set....
   */
  @Override
  Set tagSet() {
    Set tags = Generics.newHashSet();
    Pattern p1 = Pattern.compile("Q0TQ1T-([^-]+)-.*");
    Pattern p2 = Pattern.compile("S0T-(.*)");
    for (String feat : featureWeights.keySet()) {
      Matcher m1 = p1.matcher(feat);
      if (m1.matches()) {
        tags.add(m1.group(1));
      }
      Matcher m2 = p2.matcher(feat);
      if (m2.matches()) {
        tags.add(m2.group(1));
      }
    }
    // Add the end of sentence tag!
    // The SR model doesn't use it, but other models do and report it.
    // todo [cdm 2014]: Maybe we should reverse the convention here?!?
    tags.add(Tagger.EOS_TAG);
    return tags;
  }

  /** Convenience method: returns one highest scoring transition, without any ParserConstraints */
  private ScoredObject findHighestScoringTransition(State state, List features, boolean requireLegal) {
    Collection> transitions = findHighestScoringTransitions(state, features, requireLegal, 1, null);
    if (transitions.isEmpty()) {
      return null;
    }
    return transitions.iterator().next();
  }

  @Override
  public Collection> findHighestScoringTransitions(State state, boolean requireLegal, int numTransitions, List constraints) {
    List features = featureFactory.featurize(state);
    return findHighestScoringTransitions(state, features, requireLegal, numTransitions, constraints);
  }

  private Collection> findHighestScoringTransitions(State state, List features, boolean requireLegal, int numTransitions, List constraints) {
    float[] scores = new float[transitionIndex.size()];
    for (String feature : features) {
      Weight weight = featureWeights.get(feature);
      if (weight == null) {
        // Features not in our index are ignored
        continue;
      }
      weight.score(scores);
    }

    PriorityQueue> queue = new PriorityQueue<>(numTransitions + 1, ScoredComparator.ASCENDING_COMPARATOR);
    for (int i = 0; i < scores.length; ++i) {
      if (!requireLegal || transitionIndex.get(i).isLegal(state, constraints)) {
        queue.add(new ScoredObject<>(i, scores[i]));
        if (queue.size() > numTransitions) {
          queue.poll();
        }
      }
    }

    return queue;
  }

  private static class Update {
    final List features;
    final int goldTransition;
    final int predictedTransition;
    final float delta;

    Update(List features, int goldTransition, int predictedTransition, float delta) {
      this.features = features;
      this.goldTransition = goldTransition;
      this.predictedTransition = predictedTransition;
      this.delta = delta;
    }
  }

  private Pair trainTree(int index, List binarizedTrees, List> transitionLists, List updates, Oracle oracle) {
    int numCorrect = 0;
    int numWrong = 0;

    Tree tree = binarizedTrees.get(index);

    ReorderingOracle reorderer = null;
    if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE ||
        op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
      reorderer = new ReorderingOracle(op);
    }

    // TODO.  This training method seems to be working in that it
    // trains models just like the gold and early termination methods do.
    // However, it causes the feature space to go crazy.  Presumably
    // leaving out features with low weights or low frequencies would
    // significantly help with that.  Otherwise, not sure how to keep
    // it under control.
    if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ORACLE) {
      State state = ShiftReduceParser.initialStateFromGoldTagTree(tree);
      while (!state.isFinished()) {
        List features = featureFactory.featurize(state);
        ScoredObject prediction = findHighestScoringTransition(state, features, true);
        if (prediction == null) {
          throw new AssertionError("Did not find a legal transition");
        }
        int predictedNum = prediction.object();
        Transition predicted = transitionIndex.get(predictedNum);
        OracleTransition gold = oracle.goldTransition(index, state);
        if (gold.isCorrect(predicted)) {
          numCorrect++;
          if (gold.transition != null && !gold.transition.equals(predicted)) {
            int transitionNum = transitionIndex.indexOf(gold.transition);
            if (transitionNum < 0) {
              // TODO: do we want to add unary transitions which are
              // only possible when the parser has gone off the rails?
              continue;
            }
            updates.add(new Update(features, transitionNum, -1, learningRate));
          }
        } else {
          numWrong++;
          int transitionNum = -1;
          if (gold.transition != null) {
            transitionNum = transitionIndex.indexOf(gold.transition);
            // TODO: this can theoretically result in a -1 gold
            // transition if the transition exists, but is a
            // CompoundUnaryTransition which only exists because the
            // parser is wrong.  Do we want to add those transitions?
          }
          updates.add(new Update(features, transitionNum, predictedNum, learningRate));
        }
        state = predicted.apply(state);
      }
    } else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM ||
               op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
      if (op.trainOptions().beamSize <= 0) {
        throw new IllegalArgumentException("Illegal beam size " + op.trainOptions().beamSize);
      }
      List transitions = Generics.newLinkedList(transitionLists.get(index));
      PriorityQueue agenda = new PriorityQueue<>(op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
      State goldState = ShiftReduceParser.initialStateFromGoldTagTree(tree);
      agenda.add(goldState);
      // int transitionCount = 0;
      while (transitions.size() > 0) {
        Transition goldTransition = transitions.get(0);
        Transition highestScoringTransitionFromGoldState = null;
        double highestScoreFromGoldState = 0.0;
        PriorityQueue newAgenda = new PriorityQueue<>(op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
        State highestScoringState = null;
        State highestCurrentState = null;
        for (State currentState : agenda) {
          boolean isGoldState = (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM &&
                                 goldState.areTransitionsEqual(currentState));

          List features = featureFactory.featurize(currentState);
          Collection> stateTransitions = findHighestScoringTransitions(currentState, features, true, op.trainOptions().beamSize, null);
          for (ScoredObject transition : stateTransitions) {
            State newState = transitionIndex.get(transition.object()).apply(currentState, transition.score());
            newAgenda.add(newState);
            if (newAgenda.size() > op.trainOptions().beamSize) {
              newAgenda.poll();
            }
            if (highestScoringState == null || highestScoringState.score() < newState.score()) {
              highestScoringState = newState;
              highestCurrentState = currentState;
            }
            if (isGoldState &&
                (highestScoringTransitionFromGoldState == null || transition.score() > highestScoreFromGoldState)) {
              highestScoringTransitionFromGoldState = transitionIndex.get(transition.object());
              highestScoreFromGoldState = transition.score();
            }
          }
        }

        // This can happen if the REORDER_BEAM method backs itself
        // into a corner, such as transitioning to something that
        // can't have a FinalizeTransition applied.  This doesn't
        // happen for the BEAM method because in that case the correct
        // state (eg one with ROOT) isn't on the agenda so it stops.
        if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM && highestScoringTransitionFromGoldState == null) {
          break;
        }

        State newGoldState = goldTransition.apply(goldState, 0.0);

        // if highest scoring state used the correct transition, no training
        // otherwise, down the last transition, up the correct
        if (!newGoldState.areTransitionsEqual(highestScoringState)) {
          ++numWrong;
          List goldFeatures = featureFactory.featurize(goldState);
          int lastTransition = transitionIndex.indexOf(highestScoringState.transitions.peek());
          updates.add(new Update(featureFactory.featurize(highestCurrentState), -1, lastTransition, learningRate));
          updates.add(new Update(goldFeatures, transitionIndex.indexOf(goldTransition), -1, learningRate));

          if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM) {
            // If the correct state has fallen off the agenda, break
            if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
              break;
            } else {
              transitions.remove(0);
            }
          } else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
            if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
              if (!reorderer.reorder(goldState, highestScoringTransitionFromGoldState, transitions)) {
                break;
              }
              newGoldState = highestScoringTransitionFromGoldState.apply(goldState);
              if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
                break;
              }
            } else {
              transitions.remove(0);
            }
          }
        } else {
          ++numCorrect;
          transitions.remove(0);
        }

        goldState = newGoldState;
        agenda = newAgenda;
      }
    } else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE ||
               op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.EARLY_TERMINATION ||
               op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.GOLD) {
      State state = ShiftReduceParser.initialStateFromGoldTagTree(tree);
      List transitions = transitionLists.get(index);
      transitions = Generics.newLinkedList(transitions);
      boolean keepGoing = true;
      while (transitions.size() > 0 && keepGoing) {
        Transition transition = transitions.get(0);
        int transitionNum = transitionIndex.indexOf(transition);
        List features = featureFactory.featurize(state);
        int predictedNum = findHighestScoringTransition(state, features, false).object();
        Transition predicted = transitionIndex.get(predictedNum);
        if (transitionNum == predictedNum) {
          transitions.remove(0);
          state = transition.apply(state);
          numCorrect++;
        } else {
          numWrong++;
          // TODO: allow weighted features, weighted training, etc
          updates.add(new Update(features, transitionNum, predictedNum, learningRate));
          switch (op.trainOptions().trainingMethod) {
          case EARLY_TERMINATION:
            keepGoing = false;
            break;
          case GOLD:
            transitions.remove(0);
            state = transition.apply(state);
            break;
          case REORDER_ORACLE:
            keepGoing = reorderer.reorder(state, predicted, transitions);
            if (keepGoing) {
              state = predicted.apply(state);
            }
            break;
          default:
            throw new IllegalArgumentException("Unexpected method " + op.trainOptions().trainingMethod);
          }
        }
      }
    }

    return Pair.makePair(numCorrect, numWrong);
  }

  private class TrainTreeProcessor implements ThreadsafeProcessor> {
    List binarizedTrees;
    List> transitionLists;
    List updates; // this needs to be a synchronized list
    Oracle oracle;

    public TrainTreeProcessor(List binarizedTrees, List> transitionLists, List updates, Oracle oracle) {
      this.binarizedTrees = binarizedTrees;
      this.transitionLists = transitionLists;
      this.updates = updates;
      this.oracle = oracle;
    }

    @Override
    public Pair process(Integer index) {
      return trainTree(index, binarizedTrees, transitionLists, updates, oracle);
    }

    @Override
    public TrainTreeProcessor newInstance() {
      // already threadsafe
      return this;
    }
  }

  /**
   * Trains a batch of trees and returns the following: a list of
   * Update objects, the number of transitions correct, and the number
   * of transitions wrong.
   *
   * If the model is trained with multiple threads, it is expected
   * that a valid MulticoreWrapper is passed in which does the
   * processing.  In that case, the processing is done on all of the
   * trees without updating any weights, which allows the results for
   * multithreaded training to be reproduced.
   */
  private Triple, Integer, Integer> trainBatch(List indices, List binarizedTrees, List> transitionLists, List updates, Oracle oracle, MulticoreWrapper> wrapper) {
    int numCorrect = 0;
    int numWrong = 0;
    if (op.trainOptions.trainingThreads == 1) {
      for (Integer index : indices) {
        Pair count = trainTree(index, binarizedTrees, transitionLists, updates, oracle);
        numCorrect += count.first;
        numWrong += count.second;
      }
    } else {
      for (Integer index : indices) {
        wrapper.put(index);
      }
      wrapper.join(false);
      while (wrapper.peek()) {
        Pair result = wrapper.poll();
        numCorrect += result.first;
        numWrong += result.second;
      }
    }
    return new Triple<>(updates, numCorrect, numWrong);
  }


  private void trainModel(String serializedPath, Tagger tagger, Random random, List binarizedTrees, List> transitionLists, Treebank devTreebank, int nThreads, Set allowedFeatures) {
    double bestScore = 0.0;
    int bestIteration = 0;
    PriorityQueue> bestModels = null;
    if (op.trainOptions().averagedModels > 0) {
      bestModels = new PriorityQueue<>(op.trainOptions().averagedModels + 1, ScoredComparator.ASCENDING_COMPARATOR);
    }

    List indices = Generics.newArrayList();
    for (int i = 0; i < binarizedTrees.size(); ++i) {
      indices.add(i);
    }

    Oracle oracle = null;
    if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ORACLE) {
      oracle = new Oracle(binarizedTrees, op.compoundUnaries, rootStates);
    }

    List updates = Generics.newArrayList();
    MulticoreWrapper> wrapper = null;
    if (nThreads != 1) {
      updates = Collections.synchronizedList(updates);
      wrapper = new MulticoreWrapper<>(op.trainOptions.trainingThreads, new TrainTreeProcessor(binarizedTrees, transitionLists, updates, oracle));
    }

    IntCounter featureFrequencies = null;
    if (op.trainOptions().featureFrequencyCutoff > 1) {
      featureFrequencies = new IntCounter<>();
    }

    for (int iteration = 1; iteration <= op.trainOptions.trainingIterations; ++iteration) {
      Timing trainingTimer = new Timing();
      int numCorrect = 0;
      int numWrong = 0;
      Collections.shuffle(indices, random);
      for (int start = 0; start < indices.size(); start += op.trainOptions.batchSize) {
        int end = Math.min(start + op.trainOptions.batchSize, indices.size());
        Triple, Integer, Integer> result = trainBatch(indices.subList(start, end), binarizedTrees, transitionLists, updates, oracle, wrapper);

        numCorrect += result.second;
        numWrong += result.third;

        for (Update update : result.first) {
          for (String feature : update.features) {
            if (allowedFeatures != null && !allowedFeatures.contains(feature)) {
              continue;
            }
            Weight weights = featureWeights.get(feature);
            if (weights == null) {
              weights = new Weight();
              featureWeights.put(feature, weights);
            }
            weights.updateWeight(update.goldTransition, update.delta);
            weights.updateWeight(update.predictedTransition, -update.delta);

            if (featureFrequencies != null) {
              featureFrequencies.incrementCount(feature, (update.goldTransition >= 0 && update.predictedTransition >= 0) ? 2 : 1);
            }
          }
        }
        updates.clear();
      }
      trainingTimer.done("Iteration " + iteration);
      log.info("While training, got " + numCorrect + " transitions correct and " + numWrong + " transitions wrong");
      outputStats();


      double labelF1 = 0.0;
      if (devTreebank != null) {
        EvaluateTreebank evaluator = new EvaluateTreebank(op, null, new ShiftReduceParser(op, this), tagger);
        evaluator.testOnTreebank(devTreebank);
        labelF1 = evaluator.getLBScore();
        log.info("Label F1 after " + iteration + " iterations: " + labelF1);

        if (labelF1 > bestScore) {
          log.info("New best dev score (previous best " + bestScore + ")");
          bestScore = labelF1;
          bestIteration = iteration;
        } else {
          log.info("Failed to improve for " + (iteration - bestIteration) + " iteration(s) on previous best score of " + bestScore);
          if (op.trainOptions.stalledIterationLimit > 0 && (iteration - bestIteration >= op.trainOptions.stalledIterationLimit)) {
            log.info("Failed to improve for too long, stopping training");
            break;
          }
        }
        log.info();

        if (bestModels != null) {
          bestModels.add(new ScoredObject<>(new PerceptronModel(this), labelF1));
          if (bestModels.size() > op.trainOptions().averagedModels) {
            bestModels.poll();
          }
        }
      }
      if (op.trainOptions().saveIntermediateModels && serializedPath != null && op.trainOptions.debugOutputFrequency > 0) {
        String tempName = serializedPath.substring(0, serializedPath.length() - 7) + "-" + FILENAME.format(iteration) + "-" + NF.format(labelF1) + ".ser.gz";
        ShiftReduceParser temp = new ShiftReduceParser(op, this);
        temp.saveModel(tempName);
        // TODO: we could save a cutoff version of the model,
        // especially if we also get a dev set number for it, but that
        // might be overkill
      }

      if (iteration % 10 == 0 && op.trainOptions().decayLearningRate > 0.0) {
        learningRate *= op.trainOptions().decayLearningRate;
      }
    } // end for iterations

    if (wrapper != null) {
      wrapper.join();
    }

    if (bestModels != null) {
      if (op.trainOptions().cvAveragedModels && devTreebank != null) {
        List> models = Generics.newArrayList();
        while (bestModels.size() > 0) {
          models.add(bestModels.poll());
        }
        Collections.reverse(models);
        double bestF1 = 0.0;
        int bestSize = 0;
        for (int i = 1; i <= models.size(); ++i) {
          log.info("Testing with " + i + " models averaged together");
          // TODO: this is kind of ugly, would prefer a separate object
          averageScoredModels(models.subList(0, i));
          ShiftReduceParser temp = new ShiftReduceParser(op, this);
          EvaluateTreebank evaluator = new EvaluateTreebank(temp.getOp(), null, temp, tagger);
          evaluator.testOnTreebank(devTreebank);
          double labelF1 = evaluator.getLBScore();
          log.info("Label F1 for " + i + " models: " + labelF1);
          if (labelF1 > bestF1) {
            bestF1 = labelF1;
            bestSize = i;
          }
        }
        averageScoredModels(models.subList(0, bestSize));
      } else {
        averageScoredModels(bestModels);
      }
    }

    // TODO: perhaps we should filter the features and then get dev
    // set scores.  That way we can merge the models which are best
    // after filtering.
    if (featureFrequencies != null) {
      filterFeatures(featureFrequencies.keysAbove(op.trainOptions().featureFrequencyCutoff));
    }

    condenseFeatures();
  }


  /**
   * Will train the model on the given treebank, using devTreebank as
   * a dev set.  If op.retrainAfterCutoff is set, will rerun training
   * after the first time through on a limited set of features.
   */
  @Override
  public void trainModel(String serializedPath, Tagger tagger, Random random, List binarizedTrees, List> transitionLists, Treebank devTreebank, int nThreads) {
    if (op.trainOptions().retrainAfterCutoff && op.trainOptions().featureFrequencyCutoff > 0) {
      String tempName = serializedPath.substring(0, serializedPath.length() - 7) + "-" + "temp.ser.gz";
      trainModel(tempName, tagger, random, binarizedTrees, transitionLists, devTreebank, nThreads, null);
      ShiftReduceParser temp = new ShiftReduceParser(op, this);
      temp.saveModel(tempName);
      Set features = featureWeights.keySet();
      featureWeights = Generics.newHashMap();
      trainModel(serializedPath, tagger, random, binarizedTrees, transitionLists, devTreebank, nThreads, features);
    } else {
      trainModel(serializedPath, tagger, random, binarizedTrees, transitionLists, devTreebank, nThreads, null);
    }
  }

  private static final long serialVersionUID = 1;

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy