All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.coref.statistical.PairwiseModelTrainer Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.coref.statistical;

import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;

import edu.stanford.nlp.coref.statistical.MaxMarginMentionRanker.ErrorType;

import edu.stanford.nlp.coref.data.Dictionaries.MentionType;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;

/**
 * Class for training coreference models
 * @author Kevin Clark
 */
public class PairwiseModelTrainer {
  public static void trainRanking(PairwiseModel model) throws Exception {
    Redwood.log("scoref-train", "Reading compression...");
    Compressor compressor = IOUtils.readObjectFromFile(
        StatisticalCorefTrainer.compressorFile);

    Redwood.log("scoref-train", "Reading train data...");
    List trainDocuments = IOUtils.readObjectFromFile(
        StatisticalCorefTrainer.extractedFeaturesFile);

    Redwood.log("scoref-train", "Training...");
    for (int i = 0; i < model.getNumEpochs(); i++) {
      Collections.shuffle(trainDocuments);
      int j = 0;
      for (DocumentExamples doc : trainDocuments) {
        j++;
        Redwood.log("scoref-train", "On epoch: " + i + " / " + model.getNumEpochs()
            + ", document: " + j + " / " + trainDocuments.size());
        Map> mentionToPotentialAntecedents = new HashMap<>();
        for (Example e : doc.examples) {
          int mention = e.mentionId2;
          List potentialAntecedents = mentionToPotentialAntecedents.get(mention);
          if (potentialAntecedents == null) {
            potentialAntecedents = new ArrayList<>();
            mentionToPotentialAntecedents.put(mention, potentialAntecedents);
          }
          potentialAntecedents.add(e);
        }

        List> examples = new ArrayList<>(
            mentionToPotentialAntecedents.values());
        Collections.shuffle(examples);
        for (List es : examples) {
          if (es.size() == 0) {
            continue;
          }

          if (model instanceof MaxMarginMentionRanker) {
            MaxMarginMentionRanker ranker = (MaxMarginMentionRanker) model;
            boolean noAntecedent = es.stream().allMatch(e -> e.label == 0);
            es.add(new Example(es.get(0), noAntecedent));

            double maxPositiveScore = -Double.MAX_VALUE;
            Example maxScoringPositive = null;
            for (Example e : es) {
              double score = model.predict(e, doc.mentionFeatures, compressor);
              if (e.label == 1) {
                assert(!noAntecedent ^ e.isNewLink());
                if (score > maxPositiveScore) {
                  maxPositiveScore = score;
                  maxScoringPositive = e;
                }
              }
            }
            assert(maxScoringPositive != null);

            double maxNegativeScore = -Double.MAX_VALUE;
            Example maxScoringNegative = null;
            ErrorType maxScoringEt = null;
            for (Example e : es) {
              double score = model.predict(e, doc.mentionFeatures, compressor);
              if (e.label != 1) {
                assert(!(noAntecedent && e.isNewLink()));
                ErrorType et = ErrorType.WL;
                if (noAntecedent && !e.isNewLink()) {
                  et = ErrorType.FL;
                } else if (!noAntecedent && e.isNewLink()) {
                  if (e.mentionType2 == MentionType.PRONOMINAL) {
                    et = ErrorType.FN_PRON;
                  } else {
                    et = ErrorType.FN;
                  }
                }

                if (ranker.multiplicativeCost) {
                  score = ranker.costs[et.id] * (1 - maxPositiveScore + score);
                } else {
                  score += ranker.costs[et.id];
                }
                if (score > maxNegativeScore) {
                  maxNegativeScore = score;
                  maxScoringNegative = e;
                  maxScoringEt = et;
                }
              }
            }
            assert(maxScoringNegative != null);

            ranker.learn(maxScoringPositive, maxScoringNegative,
                doc.mentionFeatures, compressor, maxScoringEt);
          } else {
            double maxPositiveScore = -Double.MAX_VALUE;
            double maxNegativeScore = -Double.MAX_VALUE;
            Example maxScoringPositive = null;
            Example maxScoringNegative = null;
            for (Example e : es) {
              double score = model.predict(e, doc.mentionFeatures, compressor);
              if (e.label == 1) {
                if (score > maxPositiveScore) {
                  maxPositiveScore = score;
                  maxScoringPositive = e;
                }
              } else {
                if (score > maxNegativeScore) {
                  maxNegativeScore = score;
                  maxScoringNegative = e;
                }
              }
            }
            model.learn(maxScoringPositive, maxScoringNegative,
                doc.mentionFeatures, compressor, 1);
          }
        }
      }
    }

    Redwood.log("scoref-train", "Writing models...");
    model.writeModel();
  }

  public static List>>
      getAnaphoricityExamples(List documents) {
    int p = 0;
    int t = 0;

    List>> examples = new ArrayList<>();
    while (!documents.isEmpty()) {
      DocumentExamples doc = documents.remove(documents.size() - 1);
      Map areAnaphoric = new HashMap<>();
      for (Example e : doc.examples) {
        Boolean isAnaphoric = areAnaphoric.get(e.mentionId2);
        if (isAnaphoric == null) {
          areAnaphoric.put(e.mentionId2, false);
        }
        if (e.label == 1) {
          areAnaphoric.put(e.mentionId2, true);
        }
      }

      for (Map.Entry e : areAnaphoric.entrySet()) {
        if (e.getValue()) {
          p++;
        }
        t++;
      }

      for (Example e : doc.examples) {
        Boolean isAnaphoric = areAnaphoric.get(e.mentionId2);
        if (isAnaphoric != null) {
          areAnaphoric.remove(e.mentionId2);
          examples.add(new Pair<>(new Example(e, isAnaphoric), doc.mentionFeatures));
        }
      }
    }

    Redwood.log("scoref-train", "Num anaphoricity examples " + p + " positive, " + t + " total");

    return examples;
  }

  public static List>> getExamples(
      List documents) {
    List>> examples = new ArrayList<>();
    while (!documents.isEmpty()) {
      DocumentExamples doc = documents.remove(documents.size() - 1);
      Map mentionFeatures = doc.mentionFeatures;
      for (Example e : doc.examples) {
        examples.add(new Pair<>(e, mentionFeatures));
      }
    }
    return examples;
  }

  public static void trainClassification(PairwiseModel model, boolean anaphoricityModel)
      throws Exception {
    int numTrainingExamples = model.getNumTrainingExamples();

    Redwood.log("scoref-train", "Reading compression...");
    Compressor compressor = IOUtils.readObjectFromFile(
        StatisticalCorefTrainer.compressorFile);

    Redwood.log("scoref-train", "Reading train data...");
    List trainDocuments = IOUtils.readObjectFromFile(
        StatisticalCorefTrainer.extractedFeaturesFile);

    Redwood.log("scoref-train", "Building train set...");
    List>> allExamples = anaphoricityModel
        ? getAnaphoricityExamples(trainDocuments) : getExamples(trainDocuments);

    Redwood.log("scoref-train", "Training...");
    Random random = new Random(0);
    int i = 0;
    boolean stopTraining = false;
    while (!stopTraining) {
      Collections.shuffle(allExamples, random);
      for (Pair> pair : allExamples) {
        if (i++ > numTrainingExamples) {
          stopTraining = true;
          break;
        }
        if (i % 10000 == 0) {
          Redwood.log("scoref-train", String.format("On train example %d/%d = %.2f%%",
              i, numTrainingExamples, 100.0 * i / numTrainingExamples));
        }
        model.learn(pair.first, pair.second, compressor);
      }
    }

    Redwood.log("scoref-train", "Writing models...");
    model.writeModel();
  }

  public static void test(PairwiseModel model, String predictionsName,
      boolean anaphoricityModel) throws Exception {
    Redwood.log("scoref-train", "Reading compression...");
    Compressor compressor = IOUtils.readObjectFromFile(
        StatisticalCorefTrainer.compressorFile);

    Redwood.log("scoref-train", "Reading test data...");
    List testDocuments = IOUtils.readObjectFromFile(
        StatisticalCorefTrainer.extractedFeaturesFile);

    Redwood.log("scoref-train", "Building test set...");
    List>> allExamples = anaphoricityModel
        ? getAnaphoricityExamples(testDocuments) : getExamples(testDocuments);

    Redwood.log("scoref-train", "Testing...");
    PrintWriter writer = new PrintWriter(model.getDefaultOutputPath() + predictionsName);
    Map>> scores = new HashMap<>();
    writeScores(allExamples, compressor, model, writer, scores);
    if (model instanceof MaxMarginMentionRanker) {
      writer.close();
      writer = new PrintWriter(model.getDefaultOutputPath() + predictionsName + "_anaphoricity");
      testDocuments = IOUtils.readObjectFromFile(
          StatisticalCorefTrainer.extractedFeaturesFile);
      allExamples = getAnaphoricityExamples(testDocuments);
      writeScores(allExamples, compressor, model, writer, scores);
    }
    IOUtils.writeObjectToFile(scores, model.getDefaultOutputPath() + predictionsName + ".ser");
    writer.close();
  }

  public static void writeScores(List>> examples,
      Compressor compressor, PairwiseModel model, PrintWriter writer,
      Map>> scores) {
    int i  = 0;
    for (Pair> pair : examples) {
      if (i++ % 10000 == 0) {
        Redwood.log("scoref-train", String.format("On test example %d/%d = %.2f%%",
            i, examples.size(), 100.0 * i / examples.size()));
      }
      Example example = pair.first;
      Map mentionFeatures = pair.second;
      double p = model.predict(example, mentionFeatures, compressor);
      writer.println(example.docId + " " + example.mentionId1 + ","
          + example.mentionId2 + " "  + p + " " + example.label);
      Counter> docScores = scores.get(example.docId);
      if (docScores == null) {
        docScores = new ClassicCounter<>();
        scores.put(example.docId, docScores);
      }
      docScores.incrementCount(new Pair<>(example.mentionId1, example.mentionId2), p);
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy