All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.sequences.ExactBestSequenceFinder Maven / Gradle / Ivy

package edu.stanford.nlp.sequences;

import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.RuntimeInterruptedException;

import java.util.Arrays;


/**
 * A class capable of computing the best sequence given a SequenceModel.
 * Uses the Viterbi algorithm.
 *
 * @author Dan Klein
 * @author Teg Grenager ([email protected])
 */
public class ExactBestSequenceFinder implements BestSequenceFinder {

  private static final boolean DEBUG = false;

  public static Pair bestSequenceWithLinearConstraints(SequenceModel ts, double[][] linearConstraints) {
    return bestSequence(ts, linearConstraints);
  }

  /**
   * Runs the Viterbi algorithm on the sequence model given by the TagScorer
   * in order to find the best sequence.
   *
   * @param ts The SequenceModel to be used for scoring
   * @return An array containing the int tags of the best sequence
   */
  @Override
  public int[] bestSequence(SequenceModel ts) {
    return bestSequence(ts, null).first();
  }

  private static Pair bestSequence(SequenceModel ts, double[][] linearConstraints) {
    // Set up tag options
    int length = ts.length();
    int leftWindow = ts.leftWindow();
    int rightWindow = ts.rightWindow();
    int padLength = length + leftWindow + rightWindow;
    if (linearConstraints != null && linearConstraints.length != padLength)
      throw new RuntimeException("linearConstraints.length (" +  linearConstraints.length + ") does not match padLength (" + padLength + ") of SequenceModel" + ", length=="+length+", leftW="+leftWindow+", rightW="+rightWindow);
    int[][] tags = new int[padLength][];
    int[] tagNum = new int[padLength];
    if (DEBUG) { System.err.println("Doing bestSequence length " + length + "; leftWin " + leftWindow + "; rightWin " + rightWindow + "; padLength " + padLength); }
    for (int pos = 0; pos < padLength; pos++) {
      if (Thread.interrupted()) {  // Allow interrupting
        throw new RuntimeInterruptedException();
      }
      tags[pos] = ts.getPossibleValues(pos);
      tagNum[pos] = tags[pos].length;
      if (DEBUG) { System.err.println("There are " + tagNum[pos] + " values at position " + pos + ": " + Arrays.toString(tags[pos])); }
    }

    int[] tempTags = new int[padLength];

    // Set up product space sizes
    int[] productSizes = new int[padLength];

    int curProduct = 1;
    for (int i = 0; i < leftWindow + rightWindow; i++) {
      curProduct *= tagNum[i];
    }
    for (int pos = leftWindow + rightWindow; pos < padLength; pos++) {
      if (Thread.interrupted()) {  // Allow interrupting
        throw new RuntimeInterruptedException();
      }
      if (pos > leftWindow + rightWindow) {
        curProduct /= tagNum[pos - leftWindow - rightWindow - 1]; // shift off
      }
      curProduct *= tagNum[pos]; // shift on
      productSizes[pos - rightWindow] = curProduct;
    }

    // Score all of each window's options
    double[][] windowScore = new double[padLength][];
    for (int pos = leftWindow; pos < leftWindow + length; pos++) {
      if (Thread.interrupted()) {  // Allow interrupting
        throw new RuntimeInterruptedException();
      }
      if (DEBUG) { System.err.println("scoring word " + pos + " / " + (leftWindow + length) + ", productSizes =  " + productSizes[pos] + ", tagNum = " + tagNum[pos] + "..."); }
      windowScore[pos] = new double[productSizes[pos]];
      Arrays.fill(tempTags, tags[0][0]);
      if (DEBUG) { System.err.println("windowScore[" + pos + "] has size (productSizes[pos]) " + windowScore[pos].length); }

      for (int product = 0; product < productSizes[pos]; product++) {
        if (Thread.interrupted()) {  // Allow interrupting
          throw new RuntimeInterruptedException();
        }
        int p = product;
        int shift = 1;
        for (int curPos = pos + rightWindow; curPos >= pos - leftWindow; curPos--) {
          tempTags[curPos] = tags[curPos][p % tagNum[curPos]];
          p /= tagNum[curPos];
          if (curPos > pos) {
            shift *= tagNum[curPos];
          }
        }

        // Here now you get ts.scoresOf() for all classifications at a position at once, whereas the old code called ts.scoreOf() on each item.
        // CDM May 2007: The way this is done gives incorrect results if there are repeated values in the values of ts.getPossibleValues(pos) -- in particular if the first value of the array is repeated later.  I tried replacing it with the modulo version, but that only worked for left-to-right, not bidirectional inference, but I still think that if you sorted things out, you should be able to do it with modulos and the result would be conceptually simpler and robust to repeated values.  But in the meantime, I fixed the POS tagger to not give repeated values (which was a bug in the tagger).
        if (tempTags[pos] == tags[pos][0]) {
          // get all tags at once
          double[] scores = ts.scoresOf(tempTags, pos);
          if (DEBUG) { System.err.println("Matched at array index [product] " + product + "; tempTags[pos] == tags[pos][0] == " + tempTags[pos]); }
          if (DEBUG) { System.err.println("For pos " + pos + " scores.length is " + scores.length + "; tagNum[pos] = " + tagNum[pos] + "; windowScore[pos].length = " + windowScore[pos].length); }
          if (DEBUG) { System.err.println("scores: " + Arrays.toString(scores)); }
          // fill in the relevant windowScores
          for (int t = 0; t < tagNum[pos]; t++) {
            if (DEBUG) { System.err.println("Setting value of windowScore[" + pos + "][" + product + "+" + t + "*" + shift + "] = " + scores[t]); }
            windowScore[pos][product + t * shift] = scores[t];
          }
        }
      }
    }

    // Set up score and backtrace arrays
    double[][] score = new double[padLength][];
    int[][] trace = new int[padLength][];
    for (int pos = 0; pos < padLength; pos++) {
      if (Thread.interrupted()) {  // Allow interrupting
        throw new RuntimeInterruptedException();
      }
      score[pos] = new double[productSizes[pos]];
      trace[pos] = new int[productSizes[pos]];
    }

    // Do forward Viterbi algorithm

    // loop over the classification spot
    //System.err.println();
    for (int pos = leftWindow; pos < length + leftWindow; pos++) {
      //System.err.print(".");
      // loop over window product types
      for (int product = 0; product < productSizes[pos]; product++) {
        if (Thread.interrupted()) {  // Allow interrupting
          throw new RuntimeInterruptedException();
        }
        // check for initial spot
        if (pos == leftWindow) {
          // no predecessor type
          score[pos][product] = windowScore[pos][product];
          if (linearConstraints != null) {
            if (DEBUG) {
              if (linearConstraints[pos][product % tagNum[pos]] != 0) {
                System.err.println("Applying linear constraints=" + linearConstraints[pos][product % tagNum[pos]] + " to preScore="+ windowScore[pos][product] + " at pos="+pos+" for tag="+(product % tagNum[pos]));
              }
            }
            score[pos][product] += linearConstraints[pos][product % tagNum[pos]];
          }
          trace[pos][product] = -1;
        } else {
          // loop over possible predecessor types
          score[pos][product] = Double.NEGATIVE_INFINITY;
          trace[pos][product] = -1;
          int sharedProduct = product / tagNum[pos + rightWindow];
          int factor = productSizes[pos] / tagNum[pos + rightWindow];
          for (int newTagNum = 0; newTagNum < tagNum[pos - leftWindow - 1]; newTagNum++) {
            int predProduct = newTagNum * factor + sharedProduct;
            double predScore = score[pos - 1][predProduct] + windowScore[pos][product];

            if (linearConstraints != null) {
              if (DEBUG) {
                if (pos == 2 && linearConstraints[pos][product % tagNum[pos]] != 0) {
                  System.err.println("Applying linear constraints=" + linearConstraints[pos][product % tagNum[pos]] + " to preScore="+ predScore + " at pos="+pos+" for tag="+(product % tagNum[pos]));
                  System.err.println("predScore:" + predScore + " = score["+(pos - 1)+"]["+predProduct+"]:" + score[pos - 1][predProduct] + " + windowScore["+pos+"]["+product+"]:" + windowScore[pos][product]);
                }
              }
              predScore += linearConstraints[pos][product % tagNum[pos]];
            }

            if (predScore > score[pos][product]) {
              score[pos][product] = predScore;
              trace[pos][product] = predProduct;
            }
          }
        }
      }
    }

    // Project the actual tag sequence
    double bestFinalScore = Double.NEGATIVE_INFINITY;
    int bestCurrentProduct = -1;
    for (int product = 0; product < productSizes[leftWindow + length - 1]; product++) {
      if (score[leftWindow + length - 1][product] > bestFinalScore) {
        bestCurrentProduct = product;
        bestFinalScore = score[leftWindow + length - 1][product];
      }
    }
    int lastProduct = bestCurrentProduct;
    for (int last = padLength - 1; last >= length - 1 && last >= 0; last--) {
      tempTags[last] = tags[last][lastProduct % tagNum[last]];
      lastProduct /= tagNum[last];
    }
    for (int pos = leftWindow + length - 2; pos >= leftWindow; pos--) {
      int bestNextProduct = bestCurrentProduct;
      bestCurrentProduct = trace[pos + 1][bestNextProduct];
      tempTags[pos - leftWindow] = tags[pos - leftWindow][bestCurrentProduct / (productSizes[pos] / tagNum[pos - leftWindow])];
    }
    return new Pair<>(tempTags, bestFinalScore);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy