All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.naturalli.ClauseSplitter Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.naturalli;

import edu.stanford.nlp.classify.*;
import edu.stanford.nlp.ie.machinereading.structure.Span;
import edu.stanford.nlp.ie.util.RelationTriple;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.semgraph.SemanticGraph;
import edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.*;

import java.io.*;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.stream.Stream;
import java.util.zip.GZIPOutputStream;

import edu.stanford.nlp.naturalli.ClauseSplitterSearchProblem.*;
import edu.stanford.nlp.util.logging.Redwood;

import static edu.stanford.nlp.util.logging.Redwood.Util.*;

/**
 * Just a convenience alias for a clause splitting search problem factory.
 * Mostly here to form a nice parallel with {@link edu.stanford.nlp.naturalli.ForwardEntailer}.
 *
 * @author Gabor Angeli
 */
public interface ClauseSplitter extends BiFunction  {

  /** A logger for this class */
  Redwood.RedwoodChannels log = Redwood.channels(ClauseSplitter.class);

  enum ClauseClassifierLabel {
    CLAUSE_SPLIT(2),
    CLAUSE_INTERM(1),
    NOT_A_CLAUSE(0);
    public final byte index;
    ClauseClassifierLabel(int val) {
      this.index = (byte) val;
    }
    /** Seriously, why would Java not have this by default? */
    @Override
    public String toString() {
      return this.name();
    }
    @SuppressWarnings("unused")
    public static ClauseClassifierLabel fromIndex(int index) {
      switch (index) {
        case 0:
          return NOT_A_CLAUSE;
        case 1:
          return CLAUSE_INTERM;
        case 2:
          return CLAUSE_SPLIT;
        default:
          throw new IllegalArgumentException("Not a valid index: " + index);
      }
    }
  }


  /**
   * Train a clause searcher factory. That is, train a classifier for which arcs should be
   * new clauses.
   *
   * @param trainingData The training data. This is a stream of triples of:
   *                     
    *
  1. The sentence containing a known extraction.
  2. *
  3. The span of the subject in the sentence, as a token span.
  4. *
  5. The span of the object in the sentence, as a token span.
  6. *
* @param modelPath The path to save the model to. This is useful for {@link ClauseSplitter#load(String)}. * @param trainingDataDump The path to save the training data, as a set of labeled featurized datums. * @param featurizer The featurizer to use for this classifier. * * @return A factory for creating searchers from a given dependency tree. */ static ClauseSplitter train( Stream>>> trainingData, Optional modelPath, Optional trainingDataDump, Featurizer featurizer) { // Parse options LinearClassifierFactory factory = new LinearClassifierFactory<>(); // Generally useful objects OpenIE openie = new OpenIE(PropertiesUtils.asProperties( "splitter.nomodel", "true", "optimizefor", "GENERAL" )); WeightedDataset dataset = new WeightedDataset<>(); AtomicInteger numExamplesProcessed = new AtomicInteger(0); final Optional datasetDumpWriter = trainingDataDump.map(file -> { try { return new PrintWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(trainingDataDump.get())))); } catch (IOException e) { throw new RuntimeIOException(e); } }); // Step 1: Loop over data forceTrack("Training inference"); trainingData.forEach(rawExample -> { // Parse training datum CoreMap sentence = rawExample.first; Collection> spans = rawExample.second; List tokens = sentence.get(CoreAnnotations.TokensAnnotation.class); SemanticGraph tree = sentence.get(SemanticGraphCoreAnnotations.EnhancedDependenciesAnnotation.class); // Create raw clause searcher (no classifier) ClauseSplitterSearchProblem problem = new ClauseSplitterSearchProblem(tree, true); // Run search problem.search(fragmentAndScore -> { // Parse the search callback List> features = fragmentAndScore.second; SentenceFragment fragment = fragmentAndScore.third.get(); // Search for extractions Set extractions = new HashSet<>(openie.relationsInFragments(openie.entailmentsFromClause(fragment))); Trilean correct = Trilean.FALSE; RELATION_TRIPLE_LOOP: for (RelationTriple extraction : extractions) { // Clean up the guesses Span subjectGuess = Span.fromValues(extraction.subject.get(0).index() - 1, extraction.subject.get(extraction.subject.size() - 1).index()); Span objectGuess = Span.fromValues(extraction.object.get(0).index() - 1, extraction.object.get(extraction.object.size() - 1).index()); for (Pair candidateGold : spans) { Span subjectSpan = candidateGold.first; Span objectSpan = candidateGold.second; // Check if it matches if ((subjectGuess.equals(subjectSpan) && objectGuess.equals(objectSpan)) || (subjectGuess.equals(objectSpan) && objectGuess.equals(subjectSpan)) ) { correct = Trilean.TRUE; break RELATION_TRIPLE_LOOP; } else if (Util.nerOverlap(tokens, subjectSpan, subjectGuess) && Util.nerOverlap(tokens, objectSpan, objectGuess) || Util.nerOverlap(tokens, subjectSpan, objectGuess) && Util.nerOverlap(tokens, objectSpan, subjectGuess)) { if (!correct.isTrue()) { correct = Trilean.TRUE; break RELATION_TRIPLE_LOOP; } } else { if (!correct.isTrue()) { correct = Trilean.UNKNOWN; break RELATION_TRIPLE_LOOP; } } } } // Process the datum if (!features.isEmpty()) { // Convert the path to datums List, ClauseClassifierLabel>> decisionsToAddAsDatums = new ArrayList<>(); if (correct.isTrue()) { // If this is a "true" path, add the k-1 decisions as INTERM and the last decision as a SPLIT for (int i = 0; i < features.size(); ++i) { if (i == features.size() - 1) { decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT)); } else { decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM)); } } } else if (correct.isFalse()) { // If this is a "false" path, then we know at least the last decision was bad. decisionsToAddAsDatums.add(Pair.makePair(features.get(features.size() - 1), ClauseClassifierLabel.NOT_A_CLAUSE)); } else if (correct.isUnknown()) { // If this is an "unknown" path, only add it if it was the result of vanilla splits // (check if it is a sequence of simple splits) boolean isSimpleSplit = false; for (Counter feats : features) { if (featurizer.isSimpleSplit(feats)) { isSimpleSplit = true; break; } } // (if so, add it as if it were a True example) if (isSimpleSplit) { for (int i = 0; i < features.size(); ++i) { if (i == features.size() - 1) { decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT)); } else { decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM)); } } } } // Add the datums for (Pair, ClauseClassifierLabel> decision : decisionsToAddAsDatums) { // (create datum) RVFDatum datum = new RVFDatum<>(decision.first); datum.setLabel(decision.second); // (dump datum to debug log) if (datasetDumpWriter.isPresent()) { datasetDumpWriter.get().println(decision.second + "\t" + StringUtils.join(decision.first.entrySet().stream().map(entry -> entry.getKey() + "->" + entry.getValue()), ";")); } // (add datum to dataset) dataset.add(datum); } } return true; }, new LinearClassifier<>(new ClassicCounter<>()), Collections.emptyMap(), featurizer, 10000); // Debug info if (numExamplesProcessed.incrementAndGet() % 100 == 0) { log("processed " + numExamplesProcessed + " training sentences: " + dataset.size() + " datums"); } }); endTrack("Training inference"); // Close the file if (datasetDumpWriter.isPresent()) { datasetDumpWriter.get().close(); } // Step 2: Train classifier forceTrack("Training"); Classifier fullClassifier = factory.trainClassifier(dataset); endTrack("Training"); if (modelPath.isPresent()) { Pair, Featurizer> toSave = Pair.makePair(fullClassifier, featurizer); try { IOUtils.writeObjectToFile(toSave, modelPath.get()); log("SUCCESS: wrote model to " + modelPath.get().getPath()); } catch (IOException e) { log("ERROR: failed to save model to path: " + modelPath.get().getPath()); err(e); } } // Step 3: Check accuracy of classifier forceTrack("Training accuracy"); dataset.randomize(42L); Util.dumpAccuracy(fullClassifier, dataset); endTrack("Training accuracy"); int numFolds = 5; forceTrack(numFolds + " fold cross-validation"); for (int fold = 0; fold < numFolds; ++fold) { forceTrack("Fold " + (fold + 1)); forceTrack("Training"); Pair, GeneralDataset> foldData = dataset.splitOutFold(fold, numFolds); Classifier classifier = factory.trainClassifier(foldData.first); endTrack("Training"); forceTrack("Test"); Util.dumpAccuracy(classifier, foldData.second); endTrack("Test"); endTrack("Fold " + (fold + 1)); } endTrack(numFolds + " fold cross-validation"); // Step 5: return factory return (tree, truth) -> new ClauseSplitterSearchProblem(tree, truth, Optional.of(fullClassifier), Optional.of(featurizer)); } static ClauseSplitter train( Stream>>> trainingData, File modelPath, File trainingDataDump) { return train(trainingData, Optional.of(modelPath), Optional.of(trainingDataDump), ClauseSplitterSearchProblem.DEFAULT_FEATURIZER); } /** * Load a factory model from a given path. This can be trained with * {@link ClauseSplitter#train(Stream, Optional, Optional, Featurizer)}. * * @return A function taking a dependency tree, and returning a clause searcher. */ static ClauseSplitter load(String serializedModel) throws IOException { try { long start = System.currentTimeMillis(); Pair, Featurizer> data = IOUtils.readObjectFromURLOrClasspathOrFileSystem(serializedModel); ClauseSplitter rtn = (tree, truth) -> new ClauseSplitterSearchProblem(tree, truth, Optional.of(data.first), Optional.of(data.second)); log.info("Loading clause splitter from " + serializedModel + " ... done [" + Redwood.formatTimeDifference(System.currentTimeMillis() - start) + "]"); return rtn; } catch (ClassNotFoundException e) { throw new IllegalStateException("Invalid model at path: " + serializedModel, e); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy