All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.ie.machinereading.BasicRelationExtractor Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.ie.machinereading;

import java.io.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;

import edu.stanford.nlp.classify.*;
import edu.stanford.nlp.ie.machinereading.structure.*;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ArgumentParser.Option;

public class BasicRelationExtractor implements Extractor {

  private static final long serialVersionUID = 2606577772115897869L;

  private static final Logger logger = Logger.getLogger(BasicRelationExtractor.class.getName());

  protected LinearClassifier classifier;

  @Option(name="featureCountThreshold", gloss="feature count threshold to apply to dataset")
  public int featureCountThreshold = 2;

  @Option(name="featureFactory", gloss="Feature factory for the relation extractor")
  public RelationFeatureFactory featureFactory;
  /**
   * strength of the prior on the linear classifier (passed to LinearClassifierFactory) or the C constant if relationExtractorClassifierType=svm
   */
  @Option(name="sigma", gloss="strength of the prior on the linear classifier (passed to LinearClassifierFactory) or the C constant if relationExtractorClassifierType=svm")
  public double sigma = 1.0;

  /**
   * which classifier to use (can be 'linear' or 'svm')
   */
  public String relationExtractorClassifierType = "linear";
  

  /**
   * If true, it creates automatically negative examples by generating all combinations between EntityMentions in a sentence
   * This is the common behavior, but for some domain (i.e., KBP) it must disabled. In these domains, the negative relation examples are created in the reader
   */
  protected boolean createUnrelatedRelations;

  /** Verifies that predicted labels are compatible with the relation arguments */
  private LabelValidator validator;

  protected RelationMentionFactory relationMentionFactory;

  public void setValidator(LabelValidator lv) { validator = lv; }
  public void setRelationExtractorClassifierType(String s) { relationExtractorClassifierType = s; }
  public void setFeatureCountThreshold(int i) {featureCountThreshold = i; }
  public void setSigma(double d) { sigma = d; }

  public BasicRelationExtractor(RelationFeatureFactory featureFac, Boolean createUnrelatedRelations, RelationMentionFactory factory) {
    featureFactory = featureFac;
    this.createUnrelatedRelations = createUnrelatedRelations;
    this.relationMentionFactory = factory;

    logger.setLevel(Level.INFO);
  }

  public void setCreateUnrelatedRelations(boolean b) {
    createUnrelatedRelations = b;
  }

  public static BasicRelationExtractor load(String modelPath) throws IOException, ClassNotFoundException {
    return IOUtils.readObjectFromURLOrClasspathOrFileSystem(modelPath);
  }

  @Override
  public void save(String modelpath) throws IOException {
    // make sure modelpath directory exists
    int lastSlash = modelpath.lastIndexOf(File.separator);
    if(lastSlash > 0){
      String path = modelpath.substring(0, lastSlash);
      File f = new File(path);
      if (! f.exists()) {
        f.mkdirs();
      }
    }

    FileOutputStream fos = new FileOutputStream(modelpath);
    ObjectOutputStream out = new ObjectOutputStream(fos);
    out.writeObject(this);
    out.close();
  }

  /**
   * Train on a list of ExtractionSentence containing labeled RelationMention objects
   */
  @Override
  public void train(Annotation sentences) {
    // Train a single multi-class classifier
    GeneralDataset trainSet = createDataset(sentences);
    trainMulticlass(trainSet);
  }

  public void trainMulticlass(GeneralDataset trainSet) {
    if (relationExtractorClassifierType.equalsIgnoreCase("linear")) {
      LinearClassifierFactory lcFactory = new LinearClassifierFactory<>(1e-4, false, sigma);
      lcFactory.setVerbose(false);
      // use in-place SGD instead of QN. this is faster but much worse!
      // lcFactory.useInPlaceStochasticGradientDescent(-1, -1, 1.0);
      // use a hybrid minimizer: start with in-place SGD, continue with QN
      // lcFactory.useHybridMinimizerWithInPlaceSGD(50, -1, sigma);
      classifier = lcFactory.trainClassifier(trainSet);
    } else if (relationExtractorClassifierType.equalsIgnoreCase("svm")) {
      SVMLightClassifierFactory svmFactory = new SVMLightClassifierFactory<>();
      svmFactory.setC(sigma);
      classifier = svmFactory.trainClassifier(trainSet);
    } else {
      throw new RuntimeException("Invalid classifier type: " + relationExtractorClassifierType);
    }
    if (logger.isLoggable(Level.FINE)) {
      reportWeights(classifier, null);
    }
  }

  protected static void reportWeights(LinearClassifier classifier, String classLabel) {
    if (classLabel != null) logger.fine("CLASSIFIER WEIGHTS FOR LABEL " + classLabel);
    Map> labelsToFeatureWeights = classifier.weightsAsMapOfCounters();
    List labels = new ArrayList<>(labelsToFeatureWeights.keySet());
    Collections.sort(labels);
    for (String label: labels) {
      Counter featWeights = labelsToFeatureWeights.get(label);
      List> sorted = Counters.toSortedListWithCounts(featWeights);
      StringBuilder bos = new StringBuilder();
      bos.append("WEIGHTS FOR LABEL ").append(label).append(':');
      for (Pair feat: sorted) {
        bos.append(' ').append(feat.first()).append(':').append(feat.second()+"\n");
      }
      logger.fine(bos.toString());
    }
  }

  protected String classOf(Datum datum, ExtractionObject rel) {
    Counter probs = classifier.probabilityOf(datum);
    List> sortedProbs = Counters.toDescendingMagnitudeSortedListWithCounts(probs);
    double nrProb = probs.getCount(RelationMention.UNRELATED);
    for(Pair choice: sortedProbs){
      if(choice.first.equals(RelationMention.UNRELATED)) return choice.first;
      if(nrProb >= choice.second) return RelationMention.UNRELATED; // no prediction, all probs have the same value
      if(compatibleLabel(choice.first, rel)) return choice.first;
    }
    return RelationMention.UNRELATED;
  }

  private boolean compatibleLabel(String label, ExtractionObject rel) {
    if(rel == null) return true;
    if(validator != null) return validator.validLabel(label, rel);
    return true;
  }

  protected Counter probabilityOf(Datum testDatum) {
    return classifier.probabilityOf(testDatum);
  }

  protected void justificationOf(Datum testDatum, PrintWriter pw, String label) {
    classifier.justificationOf(testDatum, pw);
  }

  /**
   * Predict a relation for each pair of entities in the sentence; including relations of type unrelated.
   * This creates new RelationMention objects!
   */
  protected List extractAllRelations(CoreMap sentence) {
    List extractions = new ArrayList<>();

    List cands = null;
    if(createUnrelatedRelations){
      // creates all possible relations between all entities in the sentence
      cands = AnnotationUtils.getAllUnrelatedRelations(relationMentionFactory, sentence, false);
    } else {
      // just take the candidates produced by the reader (in KBP)
      cands = sentence.get(MachineReadingAnnotations.RelationMentionsAnnotation.class);
      if(cands == null){
        cands = new ArrayList<>();
      }
    }

    // the actual classification takes place here!
    for (RelationMention rel : cands) {
      Datum testDatum = createDatum(rel);
      String label = classOf(testDatum, rel);
      Counter probs = probabilityOf(testDatum);
      double prob = probs.getCount(label);
      StringWriter sw = new StringWriter();
      PrintWriter pw = new PrintWriter(sw);
      if (logger.isLoggable(Level.INFO)) {
        justificationOf(testDatum, pw, label);
      }
      logger.info("Current sentence: " + AnnotationUtils.tokensAndNELabelsToString(rel.getArg(0).getSentence()) + "\n"
              + "Classifying relation: " + rel + "\n"
              + "JUSTIFICATION for label GOLD:" + rel.getType() + " SYS:" + label + " (prob:" + prob + "):\n"
              + sw.toString());
      logger.info("Justification done.");
      RelationMention relation = relationMentionFactory.constructRelationMention(
              rel.getObjectId(),
              sentence,
              rel.getExtent(),
              label,
              null,
              rel.getArgs(),
              probs);
      extractions.add(relation);

      if(! relation.getType().equals(rel.getType())){
        logger.info("Classification: found different type " + relation.getType() + " for relation: " + rel);
        logger.info("The predicted relation is: " + relation);
        logger.info("Current sentence: " + AnnotationUtils.tokensAndNELabelsToString(rel.getArg(0).getSentence()));
      } else{
        logger.info("Classification: found similar type " + relation.getType() + " for relation: " + rel);
        logger.info("The predicted relation is: " + relation);
        logger.info("Current sentence: " + AnnotationUtils.tokensAndNELabelsToString(rel.getArg(0).getSentence()));
      }
    }
    return extractions;
  }

  public List annotateMulticlass(List> testDatums) {
    List predictedLabels = new ArrayList<>();

    for (Datum testDatum: testDatums) {
      String label = classOf(testDatum, null);
      Counter probs = probabilityOf(testDatum);
      double prob = probs.getCount(label);

      StringWriter sw = new StringWriter();
      PrintWriter pw = new PrintWriter(sw);
      if (logger.isLoggable(Level.FINE)) {
        justificationOf(testDatum, pw, label);
      }
      logger.fine("JUSTIFICATION for label GOLD:" + testDatum.label() + " SYS:" + label + " (prob:" + prob + "):\n"
              + sw.toString() + "\nJustification done.");
      predictedLabels.add(label);

      if(! testDatum.label().equals(label)){
        logger.info("Classification: found different type " + label + " for relation: " + testDatum);
      } else{
        logger.info("Classification: found similar type " + label + " for relation: " + testDatum);
      }
    }

    return predictedLabels;
  }

  public void annotateSentence(CoreMap sentence) {
    // this stores all relation mentions generated by this extractor
    List relations = new ArrayList<>();

    // extractAllRelations creates new objects for every predicted relation
    for (RelationMention rel : extractAllRelations(sentence)) {
      // add all relations. potentially useful for a joint model
      // if (! RelationMention.isUnrelatedLabel(rel.getType()))
      relations.add(rel);

    }

    // caution: this removes the old list of relation mentions!
    for (RelationMention r: relations) {
      if (! r.getType().equals(RelationMention.UNRELATED)) {
        logger.fine("Found positive relation in annotateSentence: " + r);
      }
    }
    sentence.set(MachineReadingAnnotations.RelationMentionsAnnotation.class, relations);
  }

  @Override
  public void annotate(Annotation dataset) {
    for (CoreMap sentence : dataset.get(CoreAnnotations.SentencesAnnotation.class)){
      annotateSentence(sentence);
    }
  }

  protected GeneralDataset createDataset(Annotation corpus) {
    GeneralDataset dataset = new RVFDataset<>();

    for (CoreMap sentence : corpus.get(CoreAnnotations.SentencesAnnotation.class)) {
      for (RelationMention rel : AnnotationUtils.getAllRelations(relationMentionFactory, sentence, createUnrelatedRelations)) {
        dataset.add(createDatum(rel));
      }
    }

    dataset.applyFeatureCountThreshold(featureCountThreshold);
    return dataset;
  }

  protected Datum createDatum(RelationMention rel) {
    assert(featureFactory != null);
    return featureFactory.createDatum(rel);
  }

  protected Datum createDatum(RelationMention rel, String label) {
    assert(featureFactory != null);
    Datum datum = featureFactory.createDatum(rel, label);
    return datum;
  }

  @Override
  public void setLoggerLevel(Level level) {
    logger.setLevel(level);
  }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy