edu.stanford.nlp.ie.EntityCachingAbstractSequencePriorBIO Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of stanford-corenlp Show documentation
Show all versions of stanford-corenlp Show documentation
Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.
package edu.stanford.nlp.ie;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.sequences.ListeningSequenceModel;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.ling.CoreAnnotations;
import java.util.List;
import java.util.ArrayList;
import java.util.Arrays;
/**
* This class keeps track of all labeled entities and updates the
* its list whenever the label at a point gets changed. This allows
* you to not have to regenerate the list every time, which can be quite
* inefficient.
*
* @author Mengqiu Wang
**/
public abstract class EntityCachingAbstractSequencePriorBIO implements ListeningSequenceModel {
/** A logger for this class */
private static Redwood.RedwoodChannels log = Redwood.channels(EntityCachingAbstractSequencePriorBIO.class);
protected int[] sequence;
protected final int backgroundSymbol;
protected final int numClasses;
protected final int[] possibleValues;
protected final Index classIndex;
protected final Index tagIndex;
private final List wordDoc;
public EntityCachingAbstractSequencePriorBIO(String backgroundSymbol, Index classIndex, Index tagIndex, List doc) {
this.classIndex = classIndex;
this.tagIndex = tagIndex;
this.backgroundSymbol = classIndex.indexOf(backgroundSymbol);
this.numClasses = classIndex.size();
this.possibleValues = new int[numClasses];
for (int i=0; i(doc.size());
for (IN w: doc) {
wordDoc.add(w.get(CoreAnnotations.TextAnnotation.class));
}
}
private boolean VERBOSE = false;
EntityBIO[] entities;
@Override
public int leftWindow() {
return Integer.MAX_VALUE; // not Markovian!
}
@Override
public int rightWindow() {
return Integer.MAX_VALUE; // not Markovian!
}
@Override
public int[] getPossibleValues(int position) {
return possibleValues;
}
@Override
public double scoreOf(int[] sequence, int pos) {
return scoresOf(sequence, pos)[sequence[pos]];
}
/**
* @return the length of the sequence
*/
@Override
public int length() {
return wordDoc.size();
}
/**
* get the number of classes in the sequence model.
*/
public int getNumClasses() {
return classIndex.size();
}
public double[] getConditionalDistribution (int[] sequence, int position) {
double[] probs = scoresOf(sequence, position);
ArrayMath.logNormalize(probs);
probs = ArrayMath.exp(probs);
//System.out.println(this);
return probs;
}
@Override
public double[] scoresOf (int[] sequence, int position) {
double[] probs = new double[numClasses];
int origClass = sequence[position];
int oldVal = origClass;
// if (BisequenceEmpiricalNERPrior.debugIndices.indexOf(position) != -1)
// EmpiricalNERPriorBIO.DEBUG = true;
for (int label = 0; label < numClasses; label++) {
if (label != origClass) {
sequence[position] = label;
updateSequenceElement(sequence, position, oldVal);
probs[label] = scoreOf(sequence);
oldVal = label;
// if (BisequenceEmpiricalNERPrior.debugIndices.indexOf(position) != -1)
// System.out.println(this);
}
}
sequence[position] = origClass;
updateSequenceElement(sequence, position, oldVal);
probs[origClass] = scoreOf(sequence);
// EmpiricalNERPriorBIO.DEBUG = false;
return probs;
}
@Override
public void setInitialSequence(int[] initialSequence) {
this.sequence = initialSequence;
entities = new EntityBIO[initialSequence.length];
// Arrays.fill(entities, null); // not needed; Java arrays zero initialized
for (int i = 0; i < initialSequence.length; i++) {
if (initialSequence[i] != backgroundSymbol) {
String rawTag = classIndex.get(sequence[i]);
String[] parts = rawTag.split("-");
//TODO(mengqiu) this needs to be updated, so that initial can be I as well
if (parts[0].equals("B")) { // B-
EntityBIO entity = extractEntity(initialSequence, i, parts[1]);
addEntityToEntitiesArray(entity);
i += entity.words.size() - 1;
}
}
}
}
private void addEntityToEntitiesArray(EntityBIO entity) {
for (int j = entity.startPosition; j < entity.startPosition + entity.words.size(); j++) {
entities[j] = entity;
}
}
/**
* extracts the entity starting at the given position
* and adds it to the entity list. returns the index
* of the last element in the entity (not index+1)
**/
public EntityBIO extractEntity(int[] sequence, int position, String tag) {
EntityBIO entity = new EntityBIO();
entity.type = tagIndex.indexOf(tag);
entity.startPosition = position;
entity.words = new ArrayList<>();
entity.words.add(wordDoc.get(position));
int pos = position + 1;
for ( ; pos < sequence.length; pos++) {
String rawTag = classIndex.get(sequence[pos]);
String[] parts = rawTag.split("-");
if (parts[0].equals("I") && parts[1].equals(tag)) {
String word = wordDoc.get(pos);
entity.words.add(word);
} else {
break;
}
}
entity.otherOccurrences = otherOccurrences(entity);
return entity;
}
/**
* finds other locations in the sequence where the sequence of
* words in this entity occurs.
*/
public int[] otherOccurrences(EntityBIO entity){
List other = new ArrayList<>();
for (int i = 0; i < wordDoc.size(); i++) {
if (i == entity.startPosition) { continue; }
if (matches(entity, i)) {
other.add(Integer.valueOf(i));
}
}
return toArray(other);
}
public static int[] toArray(List list) {
int[] arr = new int[list.size()];
for (int i = 0; i < arr.length; i++) {
arr[i] = list.get(i);
}
return arr;
}
public boolean matches(EntityBIO entity, int position) {
String word = wordDoc.get(position);
if (word.equalsIgnoreCase(entity.words.get(0))) {
for (int j = 1; j < entity.words.size(); j++) {
if (position + j >= wordDoc.size()) {
return false;
}
String nextWord = wordDoc.get(position+j);
if (!nextWord.equalsIgnoreCase(entity.words.get(j))) {
return false;
}
}
return true;
}
return false;
}
@Override
public void updateSequenceElement(int[] sequence, int position, int oldVal) {
this.sequence = sequence;
if (sequence[position] == oldVal)
return;
if (VERBOSE) log.info("changing position "+position+" from " +classIndex.get(oldVal)+" to "+classIndex.get(sequence[position]));
if (sequence[position] == backgroundSymbol) { // new tag is O
String oldRawTag = classIndex.get(oldVal);
String[] oldParts = oldRawTag.split("-");
if (oldParts[0].equals("B")) { // old tag was a B, current entity definitely affected, also check next one
EntityBIO entity = entities[position];
if (entity == null)
throw new RuntimeException("oldTag starts with B, entity at position should not be null");
// remove entities for all words affected by this entity
for (int i=0; i < entity.words.size(); i++) {
entities[position+i] = null;
}
} else { // old tag was a I, check previous one
if (entities[position] != null) { // this was part of an entity, shortened
if (VERBOSE) log.info("splitting off prev entity");
EntityBIO oldEntity = entities[position];
int oldLen = oldEntity.words.size();
int offset = position - oldEntity.startPosition;
List newWords = new ArrayList<>();
for (int i=0; i 0)
log.info("position:" + position +", entities[position-1] = " + entities[position-1].toString(tagIndex));
} // otherwise, non-entity part I-xxx -> O, no enitty affected
}
} else {
String rawTag = classIndex.get(sequence[position]);
String[] parts = rawTag.split("-");
if (parts[0].equals("B")) { // new tag is B
if (oldVal == backgroundSymbol) { // start a new entity, may merge with the next word
EntityBIO entity = extractEntity(sequence, position, parts[1]);
addEntityToEntitiesArray(entity);
} else {
String oldRawTag = classIndex.get(oldVal);
String[] oldParts = oldRawTag.split("-");
if (oldParts[0].equals("B")) { // was a different B-xxx
EntityBIO oldEntity = entities[position];
if (oldEntity.words.size() > 1) { // remove all old entity, add new singleton
for (int i=0; i< oldEntity.words.size(); i++)
entities[position+i] = null;
EntityBIO entity = extractEntity(sequence, position, parts[1]);
addEntityToEntitiesArray(entity);
} else { // extract entity
EntityBIO entity = extractEntity(sequence, position, parts[1]);
addEntityToEntitiesArray(entity);
}
} else { // was I
EntityBIO oldEntity = entities[position];
if (oldEntity != null) {// break old entity
int oldLen = oldEntity.words.size();
int offset = position - oldEntity.startPosition;
List newWords = new ArrayList<>();
for (int i=0; i 0) {
if (entities[position-1] != null) {
String oldTag = tagIndex.get(entities[position-1].type);
EntityBIO entity = extractEntity(sequence, position-1-entities[position-1].words.size()+1, oldTag);
addEntityToEntitiesArray(entity);
}
}
} else {
String oldRawTag = classIndex.get(oldVal);
String[] oldParts = oldRawTag.split("-");
if (oldParts[0].equals("B")) { // was a B, clean the B entity first, then check if previous is an entity
EntityBIO oldEntity = entities[position];
for (int i=0; i 0) {
if (entities[position-1] != null) {
String oldTag = tagIndex.get(entities[position-1].type);
if (VERBOSE)
log.info("position:" + position +", entities[position-1] = " + entities[position-1].toString(tagIndex));
EntityBIO entity = extractEntity(sequence, position-1-entities[position-1].words.size()+1, oldTag);
addEntityToEntitiesArray(entity);
}
}
} else { // was a differnt I-xxx,
if (entities[position] != null) { // shorten the previous one, remove any additional parts
EntityBIO oldEntity = entities[position];
int oldLen = oldEntity.words.size();
int offset = position - oldEntity.startPosition;
List newWords = new ArrayList<>();
for (int i=0; i 0) {
if (entities[position-1] != null) {
String oldTag = tagIndex.get(entities[position-1].type);
EntityBIO entity = extractEntity(sequence, position-1-entities[position-1].words.size()+1, oldTag);
addEntityToEntitiesArray(entity);
}
}
}
}
}
}
}
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < entities.length; i++) {
sb.append(i);
sb.append('\t');
String word = wordDoc.get(i);
sb.append(word);
sb.append('\t');
sb.append(classIndex.get(sequence[i]));
if (entities[i] != null) {
sb.append('\t');
sb.append(entities[i].toString(tagIndex));
}
sb.append('\n');
}
return sb.toString();
}
public String toString(int pos) {
StringBuilder sb = new StringBuilder();
for (int i = Math.max(0, pos - 3); i < Math.min(entities.length, pos + 3); i++) {
sb.append(i);
sb.append('\t');
String word = wordDoc.get(i);
sb.append(word);
sb.append('\t');
sb.append(classIndex.get(sequence[i]));
if (entities[i] != null) {
sb.append('\t');
sb.append(entities[i].toString(tagIndex));
}
sb.append('\n');
}
return sb.toString();
}
}
class EntityBIO {
public int startPosition;
public List words;
public int type;
/**
* the beginning index of other locations where this sequence of
* words appears.
*/
public int[] otherOccurrences;
public String toString(Index tagIndex) {
StringBuilder sb = new StringBuilder();
sb.append('"');
sb.append(StringUtils.join(words, " "));
sb.append("\" start: ");
sb.append(startPosition);
sb.append(" type: ");
sb.append(tagIndex.get(type));
sb.append(" other_occurrences: ");
sb.append(Arrays.toString(otherOccurrences));
return sb.toString();
}
}