edu.stanford.nlp.naturalli.ClauseSplitter Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of stanford-corenlp Show documentation
Show all versions of stanford-corenlp Show documentation
Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.
package edu.stanford.nlp.naturalli;
import edu.stanford.nlp.classify.*;
import edu.stanford.nlp.ie.machinereading.structure.Span;
import edu.stanford.nlp.ie.util.RelationTriple;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.semgraph.SemanticGraph;
import edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.*;
import java.io.*;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.stream.Stream;
import java.util.zip.GZIPOutputStream;
import edu.stanford.nlp.naturalli.ClauseSplitterSearchProblem.*;
import edu.stanford.nlp.util.logging.Redwood;
import static edu.stanford.nlp.util.logging.Redwood.Util.*;
/**
* Just a convenience alias for a clause splitting search problem factory.
* Mostly here to form a nice parallel with {@link edu.stanford.nlp.naturalli.ForwardEntailer}.
*
* @author Gabor Angeli
*/
public interface ClauseSplitter extends BiFunction {
/** A logger for this class */
Redwood.RedwoodChannels log = Redwood.channels(ClauseSplitter.class);
enum ClauseClassifierLabel {
CLAUSE_SPLIT(2),
CLAUSE_INTERM(1),
NOT_A_CLAUSE(0);
public final byte index;
ClauseClassifierLabel(int val) {
this.index = (byte) val;
}
/** Seriously, why would Java not have this by default? */
@Override
public String toString() {
return this.name();
}
@SuppressWarnings("unused")
public static ClauseClassifierLabel fromIndex(int index) {
switch (index) {
case 0:
return NOT_A_CLAUSE;
case 1:
return CLAUSE_INTERM;
case 2:
return CLAUSE_SPLIT;
default:
throw new IllegalArgumentException("Not a valid index: " + index);
}
}
}
/**
* Train a clause searcher factory. That is, train a classifier for which arcs should be
* new clauses.
*
* @param trainingData The training data. This is a stream of triples of:
*
* - The sentence containing a known extraction.
* - The span of the subject in the sentence, as a token span.
* - The span of the object in the sentence, as a token span.
*
* @param modelPath The path to save the model to. This is useful for {@link ClauseSplitter#load(String)}.
* @param trainingDataDump The path to save the training data, as a set of labeled featurized datums.
* @param featurizer The featurizer to use for this classifier.
*
* @return A factory for creating searchers from a given dependency tree.
*/
static ClauseSplitter train(
Stream>>> trainingData,
Optional modelPath,
Optional trainingDataDump,
Featurizer featurizer) {
// Parse options
LinearClassifierFactory factory = new LinearClassifierFactory<>();
// Generally useful objects
OpenIE openie = new OpenIE(PropertiesUtils.asProperties(
"splitter.nomodel", "true",
"optimizefor", "GENERAL"
));
WeightedDataset dataset = new WeightedDataset<>();
AtomicInteger numExamplesProcessed = new AtomicInteger(0);
final Optional datasetDumpWriter = trainingDataDump.map(file -> {
try {
return new PrintWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(trainingDataDump.get()))));
} catch (IOException e) {
throw new RuntimeIOException(e);
}
});
// Step 1: Loop over data
forceTrack("Training inference");
trainingData.forEach(rawExample -> {
// Parse training datum
CoreMap sentence = rawExample.first;
Collection> spans = rawExample.second;
List tokens = sentence.get(CoreAnnotations.TokensAnnotation.class);
SemanticGraph tree = sentence.get(SemanticGraphCoreAnnotations.EnhancedDependenciesAnnotation.class);
// Create raw clause searcher (no classifier)
ClauseSplitterSearchProblem problem = new ClauseSplitterSearchProblem(tree, true);
// Run search
problem.search(fragmentAndScore -> {
// Parse the search callback
List> features = fragmentAndScore.second;
SentenceFragment fragment = fragmentAndScore.third.get();
// Search for extractions
Set extractions = new HashSet<>(openie.relationsInFragments(openie.entailmentsFromClause(fragment)));
Trilean correct = Trilean.FALSE;
RELATION_TRIPLE_LOOP: for (RelationTriple extraction : extractions) {
// Clean up the guesses
Span subjectGuess = Span.fromValues(extraction.subject.get(0).index() - 1, extraction.subject.get(extraction.subject.size() - 1).index());
Span objectGuess = Span.fromValues(extraction.object.get(0).index() - 1, extraction.object.get(extraction.object.size() - 1).index());
for (Pair candidateGold : spans) {
Span subjectSpan = candidateGold.first;
Span objectSpan = candidateGold.second;
// Check if it matches
if ((subjectGuess.equals(subjectSpan) && objectGuess.equals(objectSpan)) ||
(subjectGuess.equals(objectSpan) && objectGuess.equals(subjectSpan))
) {
correct = Trilean.TRUE;
break RELATION_TRIPLE_LOOP;
} else if (Util.nerOverlap(tokens, subjectSpan, subjectGuess) && Util.nerOverlap(tokens, objectSpan, objectGuess) ||
Util.nerOverlap(tokens, subjectSpan, objectGuess) && Util.nerOverlap(tokens, objectSpan, subjectGuess)) {
if (!correct.isTrue()) {
correct = Trilean.TRUE;
break RELATION_TRIPLE_LOOP;
}
} else {
if (!correct.isTrue()) {
correct = Trilean.UNKNOWN;
break RELATION_TRIPLE_LOOP;
}
}
}
}
// Process the datum
if (!features.isEmpty()) {
// Convert the path to datums
List, ClauseClassifierLabel>> decisionsToAddAsDatums = new ArrayList<>();
if (correct.isTrue()) {
// If this is a "true" path, add the k-1 decisions as INTERM and the last decision as a SPLIT
for (int i = 0; i < features.size(); ++i) {
if (i == features.size() - 1) {
decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
} else {
decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
}
}
} else if (correct.isFalse()) {
// If this is a "false" path, then we know at least the last decision was bad.
decisionsToAddAsDatums.add(Pair.makePair(features.get(features.size() - 1), ClauseClassifierLabel.NOT_A_CLAUSE));
} else if (correct.isUnknown()) {
// If this is an "unknown" path, only add it if it was the result of vanilla splits
// (check if it is a sequence of simple splits)
boolean isSimpleSplit = false;
for (Counter feats : features) {
if (featurizer.isSimpleSplit(feats)) {
isSimpleSplit = true;
break;
}
}
// (if so, add it as if it were a True example)
if (isSimpleSplit) {
for (int i = 0; i < features.size(); ++i) {
if (i == features.size() - 1) {
decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
} else {
decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
}
}
}
}
// Add the datums
for (Pair, ClauseClassifierLabel> decision : decisionsToAddAsDatums) {
// (create datum)
RVFDatum datum = new RVFDatum<>(decision.first);
datum.setLabel(decision.second);
// (dump datum to debug log)
if (datasetDumpWriter.isPresent()) {
datasetDumpWriter.get().println(decision.second + "\t" +
StringUtils.join(decision.first.entrySet().stream().map(entry -> entry.getKey() + "->" + entry.getValue()), ";"));
}
// (add datum to dataset)
dataset.add(datum);
}
}
return true;
}, new LinearClassifier<>(new ClassicCounter<>()), Collections.emptyMap(), featurizer, 10000);
// Debug info
if (numExamplesProcessed.incrementAndGet() % 100 == 0) {
log("processed " + numExamplesProcessed + " training sentences: " + dataset.size() + " datums");
}
});
endTrack("Training inference");
// Close the file
if (datasetDumpWriter.isPresent()) {
datasetDumpWriter.get().close();
}
// Step 2: Train classifier
forceTrack("Training");
Classifier fullClassifier = factory.trainClassifier(dataset);
endTrack("Training");
if (modelPath.isPresent()) {
Pair, Featurizer> toSave = Pair.makePair(fullClassifier, featurizer);
try {
IOUtils.writeObjectToFile(toSave, modelPath.get());
log("SUCCESS: wrote model to " + modelPath.get().getPath());
} catch (IOException e) {
log("ERROR: failed to save model to path: " + modelPath.get().getPath());
err(e);
}
}
// Step 3: Check accuracy of classifier
forceTrack("Training accuracy");
dataset.randomize(42L);
Util.dumpAccuracy(fullClassifier, dataset);
endTrack("Training accuracy");
int numFolds = 5;
forceTrack(numFolds + " fold cross-validation");
for (int fold = 0; fold < numFolds; ++fold) {
forceTrack("Fold " + (fold + 1));
forceTrack("Training");
Pair, GeneralDataset> foldData = dataset.splitOutFold(fold, numFolds);
Classifier classifier = factory.trainClassifier(foldData.first);
endTrack("Training");
forceTrack("Test");
Util.dumpAccuracy(classifier, foldData.second);
endTrack("Test");
endTrack("Fold " + (fold + 1));
}
endTrack(numFolds + " fold cross-validation");
// Step 5: return factory
return (tree, truth) -> new ClauseSplitterSearchProblem(tree, truth, Optional.of(fullClassifier), Optional.of(featurizer));
}
static ClauseSplitter train(
Stream>>> trainingData,
File modelPath,
File trainingDataDump) {
return train(trainingData, Optional.of(modelPath), Optional.of(trainingDataDump), ClauseSplitterSearchProblem.DEFAULT_FEATURIZER);
}
/**
* Load a factory model from a given path. This can be trained with
* {@link ClauseSplitter#train(Stream, Optional, Optional, Featurizer)}.
*
* @return A function taking a dependency tree, and returning a clause searcher.
*/
static ClauseSplitter load(String serializedModel) throws IOException {
try {
long start = System.currentTimeMillis();
Pair, Featurizer> data = IOUtils.readObjectFromURLOrClasspathOrFileSystem(serializedModel);
ClauseSplitter rtn = (tree, truth) -> new ClauseSplitterSearchProblem(tree, truth, Optional.of(data.first), Optional.of(data.second));
log.info("Loading clause splitter from " + serializedModel + " ... done [" +
Redwood.formatTimeDifference(System.currentTimeMillis() - start) + "]");
return rtn;
} catch (ClassNotFoundException e) {
throw new IllegalStateException("Invalid model at path: " + serializedModel, e);
}
}
}