All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.coref.statistical.Clusterer Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.coref.statistical;

import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;

import edu.stanford.nlp.coref.statistical.ClustererDataLoader.ClustererDoc;
import edu.stanford.nlp.coref.statistical.EvalUtils.B3Evaluator;
import edu.stanford.nlp.coref.statistical.EvalUtils.Evaluator;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;

/**
 * System for building up coreference clusters incrementally, merging a pair of clusters each step.
 * Trained with a variant of the SEARN imitation learning algorithm.
 * @author Kevin Clark
 */
public class Clusterer {
  private static final boolean USE_CLASSIFICATION = true;
  private static final boolean USE_RANKING = true;
  private static final boolean LEFT_TO_RIGHT = false;
  private static final boolean EXACT_LOSS = false;
  private static final double MUC_WEIGHT = 0.25;
  private static final double EXPERT_DECAY = 0.0;
  private static final double LEARNING_RATE = 0.05;
  private static final int BUFFER_SIZE_MULTIPLIER = 20;
  private static final int MAX_DOCS = 1000;
  private static final int RETRAIN_ITERATIONS = 100;
  private static final int NUM_EPOCHS = 15;
  private static final int EVAL_FREQUENCY = 1;

  private static final int MIN_PAIRS = 10;
  private static final double MIN_PAIRWISE_SCORE = 0.15;
  private static final int EARLY_STOP_THRESHOLD = 1000;
  private static final double EARLY_STOP_VAL = 1500 / 0.2;

  public static int currentDocId = 0;
  public static int isTraining = 1;

  private final ClustererClassifier classifier;
  private final Random random;

  public Clusterer() {
    random = new Random(0);
    classifier = new ClustererClassifier(LEARNING_RATE);
  }

  public Clusterer(String modelPath) {
    random = new Random(0);
    classifier = new ClustererClassifier(modelPath, LEARNING_RATE);
  }

  public List> getClusterMerges(ClustererDoc doc) {
    List> merges = new ArrayList<>();
    State currentState = new State(doc);
    while (!currentState.isComplete()) {
      Pair currentPair =
          currentState.mentionPairs.get(currentState.currentIndex);
      if (currentState.doBestAction(classifier)) {
        merges.add(currentPair);
      }
    }
    return merges;
  }

  public void doTraining(String modelName) {
    classifier.setWeight("bias", -0.3);
    classifier.setWeight("anaphorSeen", -1);
    classifier.setWeight("max-ranking", 1);
    classifier.setWeight("bias-single", -0.3);
    classifier.setWeight("anaphorSeen-single", -1);
    classifier.setWeight("max-ranking-single", 1);

    String outputPath = StatisticalCorefTrainer.clusteringModelsPath +
        modelName + "/";
    File outDir = new File(outputPath);
    if (!outDir.exists()) {
      outDir.mkdir();
    }

    PrintWriter progressWriter;
    List trainDocs;
    try {
      PrintWriter configWriter = new PrintWriter(outputPath + "config", "UTF-8");
      configWriter.print(StatisticalCorefTrainer.fieldValues(this));
      configWriter.close();
      progressWriter = new PrintWriter(outputPath + "progress", "UTF-8");

      Redwood.log("scoref.train", "Loading training data");
      StatisticalCorefTrainer.setDataPath("dev");
      trainDocs = ClustererDataLoader.loadDocuments(MAX_DOCS);
    } catch (Exception e) {
      throw new RuntimeException("Error setting up training", e);
    }

    double bestTrainScore = 0;
    List>> examples = new ArrayList<>();
    for (int iteration = 0; iteration < RETRAIN_ITERATIONS; iteration++) {
      Redwood.log("scoref.train", "ITERATION " + iteration);
      classifier.printWeightVector(null);
      Redwood.log("scoref.train", "");
      try {
        classifier.writeWeights(outputPath + "model");
        classifier.printWeightVector(IOUtils.getPrintWriter(outputPath + "weights"));
      } catch (Exception e) {
        throw new RuntimeException();
      }

      long start = System.currentTimeMillis();
      Collections.shuffle(trainDocs, random);

      examples = examples.subList(Math.max(0, examples.size()
          - BUFFER_SIZE_MULTIPLIER * trainDocs.size()), examples.size());
      trainPolicy(examples);

      if (iteration % EVAL_FREQUENCY == 0) {
        double trainScore = evaluatePolicy(trainDocs, true);
        if (trainScore > bestTrainScore) {
          bestTrainScore = trainScore;
          writeModel("best", outputPath);
        }

        if (iteration % 10 == 0) {
          writeModel("iter_" + iteration, outputPath);
        }
        writeModel("last", outputPath);

        double timeElapsed = (System.currentTimeMillis() - start) / 1000.0;
        double ffhr = State.ffHits / (double) (State.ffHits + State.ffMisses);
        double shr = State.sHits / (double) (State.sHits + State.sMisses);
        double fhr = featuresCacheHits /
            (double) (featuresCacheHits + featuresCacheMisses);
        Redwood.log("scoref.train", modelName);
        Redwood.log("scoref.train", String.format("Best train: %.4f", bestTrainScore));
        Redwood.log("scoref.train", String.format("Time elapsed: %.2f", timeElapsed));
        Redwood.log("scoref.train", String.format("Cost hit rate: %.4f", ffhr));
        Redwood.log("scoref.train", String.format("Score hit rate: %.4f", shr));
        Redwood.log("scoref.train", String.format("Features hit rate: %.4f", fhr));
        Redwood.log("scoref.train", "");

        progressWriter.write(iteration + " " + trainScore + " "
            + " " + timeElapsed + " " + ffhr + " " + shr
            + " " + fhr + "\n");
        progressWriter.flush();
      }

      for (ClustererDoc trainDoc : trainDocs) {
        examples.add(runPolicy(trainDoc, Math.pow(EXPERT_DECAY,
                (iteration + 1))));
      }
    }

    progressWriter.close();
  }

  private void writeModel(String name, String modelPath) {
    try {
      classifier.writeWeights(modelPath + name + "_model.ser");
      classifier.printWeightVector(
          IOUtils.getPrintWriter(modelPath + name + "_weights"));
    } catch (Exception e) {
      throw new RuntimeException();
    }
  }

  private void trainPolicy(List>> examples) {
    List> flattenedExamples = new ArrayList<>();
    examples.stream().forEach(flattenedExamples::addAll);

    for (int epoch = 0; epoch < NUM_EPOCHS; epoch++) {
      Collections.shuffle(flattenedExamples, random);
      flattenedExamples.forEach(classifier::learn);
    }

    double totalCost = flattenedExamples.stream()
        .mapToDouble(e -> classifier.bestAction(e).cost).sum();
    Redwood.log("scoref.train",
        String.format("Training cost: %.4f", 100 * totalCost / flattenedExamples.size()));
  }

  private double evaluatePolicy(List docs, boolean training) {
    isTraining = 0;
    B3Evaluator evaluator = new B3Evaluator();
    for (ClustererDoc doc : docs) {
      State currentState = new State(doc);
      while (!currentState.isComplete()) {
        currentState.doBestAction(classifier);
      }
      currentState.updateEvaluator(evaluator);
    }
    isTraining = 1;

    double score = evaluator.getF1();
    Redwood.log("scoref.train", String.format("B3 F1 score on %s: %.4f",
        training ? "train" : "validate", score));
    return score;
  }

  private List> runPolicy(ClustererDoc doc, double beta) {
    List> examples = new ArrayList<>();
    State currentState = new State(doc);
    while (!currentState.isComplete()) {
      Pair actions = currentState.getActions(classifier);
      if (actions == null) {
        continue;
      }
      examples.add(actions);

      boolean useExpert = random.nextDouble() < beta;
      double action1Score = useExpert ? -actions.first.cost :
        classifier.weightFeatureProduct(actions.first.features);
      double action2Score = useExpert ? -actions.second.cost :
        classifier.weightFeatureProduct(actions.second.features);
      currentState.doAction(action1Score >= action2Score);
    }

    return examples;
  }

  private static class GlobalFeatures {
    public boolean anaphorSeen;
    public int currentIndex;
    public int size;
    public double docSize;
  }

  private static class State {
    private static int sHits;
    private static int sMisses;
    private static int ffHits;
    private static int ffMisses;

    private final Map hashedScores;
    private final Map hashedCosts;

    private final ClustererDoc doc;
    private final List clusters;
    private final Map mentionToCluster;
    private final List> mentionPairs;
    private final List globalFeatures;

    private int currentIndex;
    private Cluster c1;
    private Cluster c2;
    private long hash;

    public State(ClustererDoc doc) {
      currentDocId = doc.id;
      this.doc = doc;
      this.hashedScores = new HashMap<>();
      this.hashedCosts = new HashMap<>();
      this.clusters = new ArrayList<>();
      this.hash = 0;

      mentionToCluster = new HashMap<>();
      for (int m : doc.mentions) {
        Cluster c = new Cluster(m);
        clusters.add(c);
        mentionToCluster.put(m, c);
        hash ^= c.hash * 7;
      }

      List> allPairs = new ArrayList<>(doc.classificationScores.keySet());

      Counter> scores =
          USE_RANKING ? doc.rankingScores : doc.classificationScores;
      Collections.sort(allPairs, (p1, p2) -> {
        double diff = scores.getCount(p2) - scores.getCount(p1);
        return diff == 0 ? 0 : (int) Math.signum(diff);
      });

      int i = 0;
      for (i = 0; i < allPairs.size(); i++) {
        double score = scores.getCount(allPairs.get(i));
        if (score < MIN_PAIRWISE_SCORE && i > MIN_PAIRS) {
          break;
        }
        if (i >= EARLY_STOP_THRESHOLD && i / score > EARLY_STOP_VAL) {
          break;
        }
      }
      mentionPairs = allPairs.subList(0, i);
      if (LEFT_TO_RIGHT) {
        Collections.sort(mentionPairs, (p1, p2) -> {
          if (p1.second.equals(p2.second)) {
            double diff = scores.getCount(p2) - scores.getCount(p1);
            return diff == 0 ? 0 : (int) Math.signum(diff);
          }
          return doc.mentionIndices.get(p1.second)
              < doc.mentionIndices.get(p2.second) ? -1 : 1;
        });
        for (int j = 0; j < mentionPairs.size(); j++) {
          Pair p1 = mentionPairs.get(j);
          for (int k = j + 1; k < mentionPairs.size(); k++) {
            Pair p2 = mentionPairs.get(k);
            assert(doc.mentionIndices.get(p1.second)
                <= doc.mentionIndices.get(p2.second));
          }
        }
      }

      Counter seenAnaphors = new ClassicCounter<>();
      Counter seenAntecedents = new ClassicCounter<>();
      globalFeatures = new ArrayList<>();
      for (int j = 0; j < allPairs.size(); j++) {
        Pair mentionPair = allPairs.get(j);
        GlobalFeatures gf = new GlobalFeatures();
        gf.currentIndex = j;
        gf.anaphorSeen = seenAnaphors.containsKey(mentionPair.second);
        gf.size = mentionPairs.size();
        gf.docSize = doc.mentions.size() / 300.0;
        globalFeatures.add(gf);

        seenAnaphors.incrementCount(mentionPair.second);
        seenAntecedents.incrementCount(mentionPair.first);
      }

      currentIndex = 0;
      setClusters();
    }

    public State(State state) {
      this.hashedScores = state.hashedScores;
      this.hashedCosts = state.hashedCosts;

      this.doc = state.doc;
      this.hash = state.hash;
      this.mentionPairs = state.mentionPairs;
      this.currentIndex = state.currentIndex;
      this.globalFeatures = state.globalFeatures;

      this.clusters = new ArrayList<>();
      this.mentionToCluster = new HashMap<>();
      for (Cluster c : state.clusters) {
        Cluster copy = new Cluster(c);
        clusters.add(copy);
        for (int m : copy.mentions) {
          mentionToCluster.put(m, copy);
        }
      }

      setClusters();
    }

    public void setClusters() {
      Pair currentPair = mentionPairs.get(currentIndex);
      c1 = mentionToCluster.get(currentPair.first);
      c2 = mentionToCluster.get(currentPair.second);
    }

    public void doAction(boolean isMerge) {
      if (isMerge) {
        if (c2.size() > c1.size()) {
          Cluster tmp = c1;
          c1 = c2;
          c2 = tmp;
        }

        hash ^= 7 * c1.hash;
        hash ^= 7 * c2.hash;

        c1.merge(c2);
        for (int m : c2.mentions) {
          mentionToCluster.put(m, c1);
        }
        clusters.remove(c2);

        hash ^= 7 * c1.hash;
      }
      currentIndex++;
      if (!isComplete()) {
        setClusters();
      }
      while (c1 == c2) {
        currentIndex++;
        if (isComplete()) {
          break;
        }
        setClusters();
      }
    }

    public boolean doBestAction(ClustererClassifier classifier) {
      Boolean doMerge = hashedScores.get(new MergeKey(c1, c2, currentIndex));
      if (doMerge == null) {
        Counter features = getFeatures(doc, c1, c2,
            globalFeatures.get(currentIndex));
        doMerge = classifier.weightFeatureProduct(features) > 0;
        hashedScores.put(new MergeKey(c1, c2, currentIndex), doMerge);
        sMisses += isTraining;
      } else {
        sHits += isTraining;
      }

      doAction(doMerge);
      return doMerge;
    }

    public boolean isComplete() {
      return currentIndex >= mentionPairs.size();
    }

    public double getFinalCost(ClustererClassifier classifier) {
      while(EXACT_LOSS && !isComplete()) {
        if (hashedCosts.containsKey(hash)) {
          ffHits += isTraining;;
          return hashedCosts.get(hash);
        }
        doBestAction(classifier);
      }
      ffMisses += isTraining;

      double cost = EvalUtils.getCombinedF1(MUC_WEIGHT, doc.goldClusters, clusters,
          doc.mentionToGold, mentionToCluster);
      hashedCosts.put(hash, cost);
      return cost;
    }

    public void updateEvaluator(Evaluator evaluator) {
      evaluator.update(doc.goldClusters, clusters, doc.mentionToGold, mentionToCluster);
    }

    public Pair getActions(ClustererClassifier classifier) {
      Counter mergeFeatures = getFeatures(doc, c1, c2,
          globalFeatures.get(currentIndex));
      double mergeScore = Math.exp(classifier.weightFeatureProduct(mergeFeatures));
      hashedScores.put(new MergeKey(c1, c2, currentIndex), mergeScore > 0.5);

      State merge = new State(this);
      merge.doAction(true);
      double mergeB3 = merge.getFinalCost(classifier);

      State noMerge = new State(this);
      noMerge.doAction(false);
      double noMergeB3 = noMerge.getFinalCost(classifier);

      double weight = doc.mentions.size() / 100.0;
      double maxB3 = Math.max(mergeB3, noMergeB3);
      return new Pair<>(
              new CandidateAction(mergeFeatures, weight * (maxB3 - mergeB3)),
              new CandidateAction(new ClassicCounter<>(), weight * (maxB3 - noMergeB3)));
    }
  }

  private static class MergeKey {
    private final int hash;

    public MergeKey(Cluster c1, Cluster c2, int ind) {
      hash = (int)(c1.hash ^ c2.hash) + (2003 * ind) + currentDocId;
    }

    @Override
    public int hashCode() {
      return hash;
    }

    @Override
    public boolean equals(Object o) {
      return ((MergeKey) o).hash == hash;
    }
  }

  public static class Cluster {
    private static final Map, Long> MENTION_HASHES = new HashMap<>();
    private static final Random RANDOM = new Random(0);

    public final List mentions;
    public long hash;

    public Cluster(int m) {
      mentions = new ArrayList<>();
      mentions.add(m);
      hash = getMentionHash(m);
    }

    public Cluster(Cluster c) {
      mentions = new ArrayList<>(c.mentions);
      hash = c.hash;
    }

    public void merge(Cluster c) {
      mentions.addAll(c.mentions);
      hash ^= c.hash;
    }

    public int size() {
      return mentions.size();
    }

    public long getHash() {
      return hash;
    }

    private static long getMentionHash(int m) {
      Pair pair = new Pair<>(m, currentDocId);
      Long hash = MENTION_HASHES.get(pair);
      if (hash == null) {
        hash = RANDOM.nextLong();
        MENTION_HASHES.put(pair, hash);
      }
      return hash;
    }
  }

  private static int featuresCacheHits;
  private static int featuresCacheMisses;
  private static Map featuresCache = new HashMap<>();
  private static Compressor compressor = new Compressor<>();

  private static Counter getFeatures(ClustererDoc doc, Pair mentionPair,
      Counter> scores) {
    Counter features = new ClassicCounter<>();
    if (!scores.containsKey(mentionPair)) {
      mentionPair = new Pair<>(mentionPair.second, mentionPair.first);
    }
    double score = scores.getCount(mentionPair);
    features.incrementCount("max", score);
    return features;
  }

  private static Counter getFeatures(ClustererDoc doc,
      List> mentionPairs, Counter> scores) {
    Counter features = new ClassicCounter<>();

    double maxScore = 0;
    double minScore = 1;
    Counter totals = new ClassicCounter<>();
    Counter totalsLog = new ClassicCounter<>();
    Counter counts = new ClassicCounter<>();
    for (Pair mentionPair : mentionPairs) {
      if (!scores.containsKey(mentionPair)) {
        mentionPair = new Pair<>(mentionPair.second, mentionPair.first);
      }
      double score = scores.getCount(mentionPair);
      double logScore = cappedLog(score);

      String mt1 = doc.mentionTypes.get(mentionPair.first);
      String mt2 = doc.mentionTypes.get(mentionPair.second);
      mt1 = mt1.equals("PRONOMINAL") ? "PRONOMINAL" : "NON_PRONOMINAL";
      mt2 = mt2.equals("PRONOMINAL") ? "PRONOMINAL" : "NON_PRONOMINAL";
      String conj = "_" + mt1 + "_" + mt2;

      maxScore = Math.max(maxScore, score);
      minScore = Math.min(minScore, score);

      totals.incrementCount("", score);
      totalsLog.incrementCount("", logScore);
      counts.incrementCount("");

      totals.incrementCount(conj, score);
      totalsLog.incrementCount(conj, logScore);
      counts.incrementCount(conj);
    }

    features.incrementCount("max", maxScore);
    features.incrementCount("min", minScore);
    for (String key : counts.keySet()) {
      features.incrementCount("avg" + key, totals.getCount(key) / mentionPairs.size());
      features.incrementCount("avgLog" + key, totalsLog.getCount(key) / mentionPairs.size());
    }

    return features;
  }

  private static int earliestMention(Cluster c, ClustererDoc doc) {
    int earliest = -1;
    for (int m : c.mentions) {
      int pos = doc.mentionIndices.get(m);
      if (earliest == -1 || pos < doc.mentionIndices.get(earliest)) {
        earliest = m;
      }
    }
    return earliest;
  }

  private static Counter getFeatures(ClustererDoc doc, Cluster c1, Cluster c2, GlobalFeatures gf) {
    MergeKey key = new MergeKey(c1, c2, gf.currentIndex);
    CompressedFeatureVector cfv = featuresCache.get(key);
    Counter features = cfv == null ? null : compressor.uncompress(cfv);
    if (features != null) {
      featuresCacheHits += isTraining;
      return features;
    }
    featuresCacheMisses += isTraining;

    features = new ClassicCounter<>();
    if (gf.anaphorSeen) {
      features.incrementCount("anaphorSeen");
    }
    features.incrementCount("docSize", gf.docSize);
    features.incrementCount("percentComplete", gf.currentIndex / (double) gf.size);
    features.incrementCount("bias", 1.0);

    int earliest1 = earliestMention(c1, doc);
    int earliest2 = earliestMention(c2, doc);
    if (doc.mentionIndices.get(earliest1) > doc.mentionIndices.get(earliest2)) {
      int tmp = earliest1;
      earliest1 = earliest2;
      earliest2 = tmp;
    }
    features.incrementCount("anaphoricity", doc.anaphoricityScores.getCount(earliest2));

    if (c1.mentions.size() == 1 && c2.mentions.size() == 1) {
      Pair mentionPair = new Pair<>(c1.mentions.get(0),
          c2.mentions.get(0));

      if (USE_CLASSIFICATION) {
        features.addAll(addSuffix(getFeatures(doc, mentionPair, doc.classificationScores),
            "-classification"));
      }
      if (USE_RANKING) {
        features.addAll(addSuffix(getFeatures(doc, mentionPair, doc.rankingScores),
            "-ranking"));
      }

      features = addSuffix(features, "-single");
    } else {
      List> between = new ArrayList<>();
      for (int m1 : c1.mentions) {
        for (int m2 : c2.mentions) {
          between.add(new Pair<>(m1, m2));
        }
      }

      if (USE_CLASSIFICATION) {
        features.addAll(addSuffix(getFeatures(doc, between, doc.classificationScores),
            "-classification"));
      }
      if (USE_RANKING) {
        features.addAll(addSuffix(getFeatures(doc, between, doc.rankingScores),
            "-ranking"));
      }
    }

    featuresCache.put(key, compressor.compress(features));
    return features;
  }

  private static Counter addSuffix(Counter features, String suffix) {
    Counter withSuffix = new ClassicCounter<>();
    for (Map.Entry e : features.entrySet()) {
      withSuffix.incrementCount(e.getKey() + suffix, e.getValue());
    }
    return withSuffix;
  }

  private static double cappedLog(double x) {
    return Math.log(Math.max(x, 1e-8));
  }

  private static class ClustererClassifier extends SimpleLinearClassifier {
    public ClustererClassifier(double learningRate) {
      super(SimpleLinearClassifier.risk(),
          SimpleLinearClassifier.constant(learningRate),
          0);
    }

    public ClustererClassifier(String modelFile, double learningRate) {
      super(SimpleLinearClassifier.risk(),
          SimpleLinearClassifier.constant(learningRate),
          0,
          modelFile);
    }

    public CandidateAction bestAction(Pair actions) {
      return weightFeatureProduct(actions.first.features) >
          weightFeatureProduct(actions.second.features) ? actions.first : actions.second;
    }

    public void learn(Pair actions) {
      CandidateAction goodAction = actions.first;
      CandidateAction badAction = actions.second;
      if (badAction.cost == 0) {
        CandidateAction tmp = goodAction;
        goodAction = badAction;
        badAction = tmp;
      }
      Counter features = new ClassicCounter<>(goodAction.features);
      for (Map.Entry e : badAction.features.entrySet()) {
        features.decrementCount(e.getKey(), e.getValue());
      }
      learn(features, 0, badAction.cost);
    }
  }

  private static class CandidateAction {
    public final Counter features;
    public final double cost;

    public CandidateAction(Counter features, double cost) {
      this.features = features;
      this.cost = cost;
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy