All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.ie.EntityCachingAbstractSequencePriorBIO Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.ie;
import edu.stanford.nlp.util.logging.Redwood;

import edu.stanford.nlp.sequences.ListeningSequenceModel;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.ling.CoreAnnotations;

import java.util.List;
import java.util.ArrayList;
import java.util.Arrays;

/**
 * This class keeps track of all labeled entities and updates the
 * its list whenever the label at a point gets changed.  This allows
 * you to not have to regenerate the list every time, which can be quite
 * inefficient.
 *
 * @author Mengqiu Wang
 **/
public abstract class EntityCachingAbstractSequencePriorBIO  implements ListeningSequenceModel  {

  /** A logger for this class */
  private static Redwood.RedwoodChannels log = Redwood.channels(EntityCachingAbstractSequencePriorBIO.class);

  protected int[] sequence;
  protected final int backgroundSymbol;
  protected final int numClasses;
  protected final int[] possibleValues;
  protected final Index classIndex;
  protected final Index tagIndex;
  private final List wordDoc;

  public EntityCachingAbstractSequencePriorBIO(String backgroundSymbol, Index classIndex, Index tagIndex, List doc) {
    this.classIndex = classIndex;
    this.tagIndex = tagIndex;
    this.backgroundSymbol = classIndex.indexOf(backgroundSymbol);
    this.numClasses = classIndex.size();
    this.possibleValues = new int[numClasses];
    for (int i=0; i(doc.size());
    for (IN w: doc) {
      wordDoc.add(w.get(CoreAnnotations.TextAnnotation.class));
    }
  }

  private boolean VERBOSE = false;

  EntityBIO[] entities;

  @Override
  public int leftWindow() {
    return Integer.MAX_VALUE; // not Markovian!
  }

  @Override
  public int rightWindow() {
    return Integer.MAX_VALUE; // not Markovian!
  }

  @Override
  public int[] getPossibleValues(int position) {
    return possibleValues;
  }

  @Override
  public double scoreOf(int[] sequence, int pos) {
    return scoresOf(sequence, pos)[sequence[pos]];
  }

  /**
   * @return the length of the sequence
   */
  @Override
  public int length() {
    return wordDoc.size();
  }

  /**
   * get the number of classes in the sequence model.
   */
  public int getNumClasses() {
    return classIndex.size();
  }

  public  double[] getConditionalDistribution (int[] sequence, int position) {
    double[] probs = scoresOf(sequence, position);
    ArrayMath.logNormalize(probs);
    probs = ArrayMath.exp(probs);
    //System.out.println(this);
    return probs;
  }

  @Override
  public  double[] scoresOf (int[] sequence, int position) {
    double[] probs = new double[numClasses];
    int origClass = sequence[position];
    int oldVal = origClass;
    // if (BisequenceEmpiricalNERPrior.debugIndices.indexOf(position) != -1)
    //  EmpiricalNERPriorBIO.DEBUG = true;
    for (int label = 0; label < numClasses; label++) {
      if (label != origClass) {
        sequence[position] = label;
        updateSequenceElement(sequence, position, oldVal);
        probs[label] = scoreOf(sequence);
        oldVal = label;
        // if (BisequenceEmpiricalNERPrior.debugIndices.indexOf(position) != -1)
        //   System.out.println(this);
      }
    }
    sequence[position] = origClass;
    updateSequenceElement(sequence, position, oldVal);
    probs[origClass] = scoreOf(sequence);
    // EmpiricalNERPriorBIO.DEBUG = false;
    return probs;
  }

  @Override
  public void setInitialSequence(int[] initialSequence) {
    this.sequence = initialSequence;
    entities = new EntityBIO[initialSequence.length];
    // Arrays.fill(entities, null);  // not needed; Java arrays zero initialized
    for (int i = 0; i < initialSequence.length; i++) {
      if (initialSequence[i] != backgroundSymbol) {
        String rawTag = classIndex.get(sequence[i]);
        String[] parts = rawTag.split("-");
        //TODO(mengqiu) this needs to be updated, so that initial can be I as well
        if (parts[0].equals("B")) { // B-
          EntityBIO entity = extractEntity(initialSequence, i, parts[1]);
          addEntityToEntitiesArray(entity);
          i += entity.words.size() - 1;
        }
      }
    }
  }

  private void addEntityToEntitiesArray(EntityBIO entity) {
    for (int j = entity.startPosition; j < entity.startPosition + entity.words.size(); j++) {
      entities[j] = entity;
    }
  }

  /**
   * extracts the entity starting at the given position
   * and adds it to the entity list.  returns the index
   * of the last element in the entity (not index+1)
   **/
  public EntityBIO extractEntity(int[] sequence, int position, String tag) {
    EntityBIO entity = new EntityBIO();
    entity.type = tagIndex.indexOf(tag);
    entity.startPosition = position;
    entity.words = new ArrayList<>();
    entity.words.add(wordDoc.get(position));
    int pos = position + 1;
    for ( ; pos < sequence.length; pos++) {
      String rawTag = classIndex.get(sequence[pos]);
      String[] parts = rawTag.split("-");
      if (parts[0].equals("I") && parts[1].equals(tag)) {
      	String word = wordDoc.get(pos);
        entity.words.add(word);
      } else {
        break;
      }
    }
    entity.otherOccurrences = otherOccurrences(entity);
    return entity;
  }

  /**
   * finds other locations in the sequence where the sequence of
   * words in this entity occurs.
   */
  public int[] otherOccurrences(EntityBIO entity){
    List other = new ArrayList<>();
    for (int i = 0; i < wordDoc.size(); i++) {
      if (i == entity.startPosition) { continue; }
      if (matches(entity, i)) {
        other.add(Integer.valueOf(i));
      }
    }
    return toArray(other);
  }

  public static int[] toArray(List list) {
    int[] arr = new int[list.size()];
    for (int i = 0; i < arr.length; i++) {
      arr[i] = list.get(i);
    }
    return arr;
  }

  public boolean matches(EntityBIO entity, int position) {
  	String word = wordDoc.get(position);
    if (word.equalsIgnoreCase(entity.words.get(0))) {
      for (int j = 1; j < entity.words.size(); j++) {
        if (position + j >= wordDoc.size()) {
          return false;
        }
        String nextWord = wordDoc.get(position+j);
        if (!nextWord.equalsIgnoreCase(entity.words.get(j))) {
          return false;
        }
      }
      return true;
    }
    return false;
  }

  @Override
  public void updateSequenceElement(int[] sequence, int position, int oldVal) {
    this.sequence = sequence;

    if (sequence[position] == oldVal)
      return;

    if (VERBOSE) log.info("changing position "+position+" from " +classIndex.get(oldVal)+" to "+classIndex.get(sequence[position]));

    if (sequence[position] == backgroundSymbol) { // new tag is O
      String oldRawTag = classIndex.get(oldVal);
      String[] oldParts = oldRawTag.split("-");
      if (oldParts[0].equals("B")) { // old tag was a B, current entity definitely affected, also check next one
        EntityBIO entity = entities[position];
        if (entity == null)
          throw new RuntimeException("oldTag starts with B, entity at position should not be null");
        // remove entities for all words affected by this entity
        for (int i=0; i < entity.words.size(); i++) {
          entities[position+i] = null;
        }
      } else { // old tag was a I, check previous one
        if (entities[position] != null) { // this was part of an entity, shortened
          if (VERBOSE) log.info("splitting off prev entity");
          EntityBIO oldEntity = entities[position];
          int oldLen = oldEntity.words.size();
          int offset = position - oldEntity.startPosition;
          List newWords = new ArrayList<>();
          for (int i=0; i 0)
            log.info("position:" + position +", entities[position-1] = " + entities[position-1].toString(tagIndex));
        } // otherwise, non-entity part I-xxx -> O, no enitty affected
      }
    } else {
      String rawTag = classIndex.get(sequence[position]);
      String[] parts = rawTag.split("-");
      if (parts[0].equals("B")) { // new tag is B
        if (oldVal == backgroundSymbol) { // start a new entity, may merge with the next word
          EntityBIO entity = extractEntity(sequence, position, parts[1]);
          addEntityToEntitiesArray(entity);
        } else {
          String oldRawTag = classIndex.get(oldVal);
          String[] oldParts = oldRawTag.split("-");
          if (oldParts[0].equals("B")) { // was a different B-xxx
            EntityBIO oldEntity = entities[position];
            if (oldEntity.words.size() > 1) { // remove all old entity, add new singleton
              for (int i=0; i< oldEntity.words.size(); i++)
                entities[position+i] = null;
              EntityBIO entity = extractEntity(sequence, position, parts[1]);
              addEntityToEntitiesArray(entity);
            } else { // extract entity
              EntityBIO entity = extractEntity(sequence, position, parts[1]);
              addEntityToEntitiesArray(entity);
            }
          } else { // was I
            EntityBIO oldEntity = entities[position];
            if (oldEntity != null) {// break old entity
              int oldLen = oldEntity.words.size();
              int offset = position - oldEntity.startPosition;
              List newWords = new ArrayList<>();
              for (int i=0; i 0) {
            if (entities[position-1] != null) {
              String oldTag = tagIndex.get(entities[position-1].type);
              EntityBIO entity = extractEntity(sequence, position-1-entities[position-1].words.size()+1, oldTag);
              addEntityToEntitiesArray(entity);
            }
          }
        } else {
          String oldRawTag = classIndex.get(oldVal);
          String[] oldParts = oldRawTag.split("-");
          if (oldParts[0].equals("B")) { // was a B, clean the B entity first, then check if previous is an entity
            EntityBIO oldEntity = entities[position];
            for (int i=0; i 0) {
              if (entities[position-1] != null) {
                String oldTag = tagIndex.get(entities[position-1].type);
                if (VERBOSE)
                  log.info("position:" + position +", entities[position-1] = " + entities[position-1].toString(tagIndex));
                EntityBIO entity = extractEntity(sequence, position-1-entities[position-1].words.size()+1, oldTag);
                addEntityToEntitiesArray(entity);
              }
            }
          } else { // was a differnt I-xxx,
            if (entities[position] != null) { // shorten the previous one, remove any additional parts
              EntityBIO oldEntity = entities[position];
              int oldLen = oldEntity.words.size();
              int offset = position - oldEntity.startPosition;
              List newWords = new ArrayList<>();
              for (int i=0; i 0) {
                if (entities[position-1] != null) {
                  String oldTag = tagIndex.get(entities[position-1].type);
                  EntityBIO entity = extractEntity(sequence, position-1-entities[position-1].words.size()+1, oldTag);
                  addEntityToEntitiesArray(entity);
                }
              }
            }
          }
        }
      }
    }
  }

  @Override
  public String toString() {
    StringBuilder sb = new StringBuilder();
    for (int i = 0; i < entities.length; i++) {
      sb.append(i);
      sb.append('\t');
      String word = wordDoc.get(i);
      sb.append(word);
      sb.append('\t');
      sb.append(classIndex.get(sequence[i]));
      if (entities[i] != null) {
        sb.append('\t');
        sb.append(entities[i].toString(tagIndex));
      }
      sb.append('\n');
    }
    return sb.toString();
  }

  public String toString(int pos) {
    StringBuilder sb = new StringBuilder();
    for (int i = Math.max(0, pos - 3); i < Math.min(entities.length, pos + 3); i++) {
      sb.append(i);
      sb.append('\t');
      String word = wordDoc.get(i);
      sb.append(word);
      sb.append('\t');
      sb.append(classIndex.get(sequence[i]));
      if (entities[i] != null) {
        sb.append('\t');
        sb.append(entities[i].toString(tagIndex));
      }
      sb.append('\n');
    }
    return sb.toString();
  }
}

class EntityBIO {
  public int startPosition;
  public List words;
  public int type;

  /**
   * the beginning index of other locations where this sequence of
   * words appears.
   */
  public int[] otherOccurrences;

  public String toString(Index tagIndex) {
    StringBuilder sb = new StringBuilder();
    sb.append('"');
    sb.append(StringUtils.join(words, " "));
    sb.append("\" start: ");
    sb.append(startPosition);
    sb.append(" type: ");
    sb.append(tagIndex.get(type));
    sb.append(" other_occurrences: ");
    sb.append(Arrays.toString(otherOccurrences));
    return sb.toString();
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy