All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.ie.crf.CRFCliqueTree Maven / Gradle / Ivy

package edu.stanford.nlp.ie.crf;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.sequences.SequenceListener;
import edu.stanford.nlp.sequences.SequenceModel;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.GeneralizedCounter;
import edu.stanford.nlp.util.Index;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Builds a CliqueTree (an array of FactorTable) and does message passing
 * inference along it.
 *
 * @param  The type of the label (usually String in our uses)
 * @author Jenny Finkel
 */
public class CRFCliqueTree implements SequenceModel, SequenceListener {

  protected final FactorTable[] factorTables;
  protected final double z; // norm constant
  protected final Index classIndex;
  private final E backgroundSymbol;
  private final int backgroundIndex;
  // the window size, which is also the clique size
  protected final int windowSize;
  // the number of possible classes for each label
  private final int numClasses;
  private final int[] possibleValues;

  /** Initialize a clique tree */
  public CRFCliqueTree(FactorTable[] factorTables, Index classIndex, E backgroundSymbol) {
    this(factorTables, classIndex, backgroundSymbol, factorTables[0].totalMass());
  }

  /** This extra constructor was added to support the CRFCliqueTreeForPartialLabels */
  CRFCliqueTree(FactorTable[] factorTables, Index classIndex, E backgroundSymbol, double z) {
    this.factorTables = factorTables;
    this.z = z;
    this.classIndex = classIndex;
    this.backgroundSymbol = backgroundSymbol;
    backgroundIndex = classIndex.indexOf(backgroundSymbol);
    windowSize = factorTables[0].windowSize();
    numClasses = classIndex.size();
    possibleValues = new int[numClasses];
    for (int i = 0; i < numClasses; i++) {
      possibleValues[i] = i;
    }

    // Debug only
    // System.out.println("CRFCliqueTree constructed::numClasses: " +
    // numClasses);
  }

  public FactorTable[] getFactorTables() {
    return this.factorTables;
  }

  public Index classIndex() {
    return classIndex;
  }

  // SEQUENCE MODEL METHODS

  public int length() {
    return factorTables.length;
  }

  public int leftWindow() {
    return windowSize;
  }

  public int rightWindow() {
    return 0;
  }

  public int[] getPossibleValues(int position) {
    return possibleValues;
  }

  public double scoreOf(int[] sequence, int pos) {
    return scoresOf(sequence, pos)[sequence[pos]];
  }

  /**
   * Computes the unnormalized log conditional distribution over values of the
   * element at position pos in the sequence, conditioned on the values of the
   * elements in all other positions of the provided sequence.
   *
   * @param sequence
   *          the sequence containing the rest of the values to condition on
   * @param position
   *          the position of the element to give a distribution for
   * @return an array of type double, representing a probability distribution;
   *         sums to 1.0
   */
  public double[] scoresOf(int[] sequence, int position) {
    if (position >= factorTables.length) throw new RuntimeException("Index out of bounds: " + position);
    // DecimalFormat nf = new DecimalFormat("#0.000");
    // if (position>0 && position= length()) {
      nextLength = length() - position - 1;
    }
    FactorTable nextFactorTable = factorTables[position + nextLength];
    if (nextLength != windowSize - 1) {
      for (int j = 0; j < windowSize - 1 - nextLength; j++) {
        nextFactorTable = nextFactorTable.sumOutFront();
      }
    }
    if (nextLength == 0) { // we are asking about the prob of no sequence
      Arrays.fill(probNextGivenThis, 1.0);
    } else {
      int[] next = new int[nextLength];
      System.arraycopy(sequence, position + 1, next, 0, nextLength);
      for (int label = 0; label < numClasses; label++) {
        // ask the factor table such that pos is the first position in the
        // window
        // probNextGivenThis[label] =
        // factorTables[position+nextLength].conditionalLogProbGivenFirst(label,
        // next);
        // probNextGivenThis[label] =
        // nextFactorTable.conditionalLogProbGivenFirst(label, next);
        probNextGivenThis[label] = nextFactorTable.unnormalizedConditionalLogProbGivenFirst(label, next);
      }
    }

    // pointwise multiply
    return ArrayMath.pairwiseAdd(probThisGivenPrev, probNextGivenThis);
  }

  /**
   * Returns the log probability of this sequence given the CRF. Does so by
   * computing the marginal of the first windowSize tags, and then computing the
   * conditional probability for the rest of them, conditioned on the previous
   * tags.
   *
   * @param sequence
   *          the sequence to compute a score for
   * @return the score for the sequence
   */
  public double scoreOf(int[] sequence) {

    int[] given = new int[window() - 1];
    Arrays.fill(given, classIndex.indexOf(backgroundSymbol));
    double logProb = 0;
    for (int i = 0; i < length(); i++) {
      int label = sequence[i];
      logProb += condLogProbGivenPrevious(i, label, given);
      System.arraycopy(given, 1, given, 0, given.length - 1);
      given[given.length - 1] = label;
    }
    return logProb;
  }

  // OTHER

  public int window() {
    return windowSize;
  }

  public int getNumClasses() {
    return numClasses;
  }

  public double totalMass() {
    return z;
  }

  public int backgroundIndex() {
    return backgroundIndex;
  }

  public E backgroundSymbol() {
    return backgroundSymbol;
  }

  //
  // MARGINAL PROB OF TAG AT SINGLE POSITION
  //

  public double logProb(int position, int label) {
    double u = factorTables[position].unnormalizedLogProbEnd(label);
    return u - z;
  }

  public double prob(int position, int label) {
    return Math.exp(logProb(position, label));
  }

  public double logProb(int position, E label) {
    return logProb(position, classIndex.indexOf(label));
  }

  public double prob(int position, E label) {
    return Math.exp(logProb(position, label));
  }

  public Counter probs(int position) {
    Counter c = new ClassicCounter();
    for (int i = 0; i < classIndex.size(); i++) {
      E label = classIndex.get(i);
      c.incrementCount(label, prob(position, i));
    }
    return c;
  }

  public Counter logProbs(int position) {
    Counter c = new ClassicCounter();
    for (int i = 0; i < classIndex.size(); i++) {
      E label = classIndex.get(i);
      c.incrementCount(label, logProb(position, i));
    }
    return c;
  }

  //
  // MARGINAL PROBS OF TAGS AT MULTIPLE POSITIONS
  //

  /**
   * returns the log probability for the given labels (indexed using
   * classIndex), where the last label corresponds to the label at the specified
   * position. For instance if you called logProb(5, {1,2,3}) it will return the
   * marginal log prob that the label at position 3 is 1, the label at position
   * 4 is 2 and the label at position 5 is 3.
   */
  public double logProb(int position, int[] labels) {
    if (labels.length < windowSize) {
      return factorTables[position].unnormalizedLogProbEnd(labels) - z;
    } else if (labels.length == windowSize) {
      return factorTables[position].unnormalizedLogProb(labels) - z;
    } else {
      int[] l = new int[windowSize];
      System.arraycopy(labels, 0, l, 0, l.length);
      int position1 = position - labels.length + windowSize;
      double p = factorTables[position1].unnormalizedLogProb(l) - z;
      l = new int[windowSize - 1];
      System.arraycopy(labels, 1, l, 0, l.length);
      position1++;
      for (int i = windowSize; i < labels.length; i++) {
        p += condLogProbGivenPrevious(position1++, labels[i], l);
        System.arraycopy(l, 1, l, 0, l.length - 1);
        l[windowSize - 2] = labels[i];
      }
      return p;
    }
  }

  /**
   * Returns the probability for the given labels (indexed using classIndex),
   * where the last label corresponds to the label at the specified position.
   * For instance if you called prob(5, {1,2,3}) it will return the marginal
   * prob that the label at position 3 is 1, the label at position 4 is 2 and
   * the label at position 5 is 3.
   */
  public double prob(int position, int[] labels) {
    return Math.exp(logProb(position, labels));
  }

  /**
   * returns the log probability for the given labels, where the last label
   * corresponds to the label at the specified position. For instance if you
   * called logProb(5, {"O", "PER", "ORG"}) it will return the marginal log prob
   * that the label at position 3 is "O", the label at position 4 is "PER" and
   * the label at position 5 is "ORG".
   */
  public double logProb(int position, E[] labels) {
    return logProb(position, objectArrayToIntArray(labels));
  }

  /**
   * returns the probability for the given labels, where the last label
   * corresponds to the label at the specified position. For instance if you
   * called logProb(5, {"O", "PER", "ORG"}) it will return the marginal prob
   * that the label at position 3 is "O", the label at position 4 is "PER" and
   * the label at position 5 is "ORG".
   */
  public double prob(int position, E[] labels) {
    return Math.exp(logProb(position, labels));
  }

  public GeneralizedCounter logProbs(int position, int window) {
    GeneralizedCounter gc = new GeneralizedCounter(window);
    int[] labels = new int[window];
    // cdm july 2005: below array initialization isn't necessary: JLS (3rd ed.)
    // 4.12.5
    // Arrays.fill(labels, 0);

    OUTER: while (true) {
      List labelsList = intArrayToListE(labels);
      gc.incrementCount(labelsList, logProb(position, labels));
      for (int i = 0; i < labels.length; i++) {
        labels[i]++;
        if (labels[i] < numClasses) {
          break;
        }
        if (i == labels.length - 1) {
          break OUTER;
        }
        labels[i] = 0;
      }
    }
    return gc;
  }

  public GeneralizedCounter probs(int position, int window) {
    GeneralizedCounter gc = new GeneralizedCounter(window);
    int[] labels = new int[window];
    // cdm july 2005: below array initialization isn't necessary: JLS (3rd ed.)
    // 4.12.5
    // Arrays.fill(labels, 0);

    OUTER: while (true) {
      List labelsList = intArrayToListE(labels);
      gc.incrementCount(labelsList, prob(position, labels));
      for (int i = 0; i < labels.length; i++) {
        labels[i]++;
        if (labels[i] < numClasses) {
          break;
        }
        if (i == labels.length - 1) {
          break OUTER;
        }
        labels[i] = 0;
      }
    }
    return gc;
  }

  //
  // HELPER METHODS
  //

  private int[] objectArrayToIntArray(E[] os) {
    int[] is = new int[os.length];
    for (int i = 0; i < os.length; i++) {
      is[i] = classIndex.indexOf(os[i]);
    }
    return is;
  }

  private List intArrayToListE(int[] is) {
    List os = new ArrayList(is.length);
    for (int i = 0; i < is.length; i++) {
      os.add(classIndex.get(is[i]));
    }
    return os;
  }

  /**
   * Gives the probability of a tag at a single position conditioned on a
   * sequence of previous labels.
   *
   * @param position
   *          Index in sequence
   * @param label
   *          Label of item at index
   * @param prevLabels
   * @return conditional log probability
   */
  public double condLogProbGivenPrevious(int position, int label, int[] prevLabels) {
    if (prevLabels.length + 1 == windowSize) {
      return factorTables[position].conditionalLogProbGivenPrevious(prevLabels, label);
    } else if (prevLabels.length + 1 < windowSize) {
      FactorTable ft = factorTables[position].sumOutFront();
      while (ft.windowSize() > prevLabels.length + 1) {
        ft = ft.sumOutFront();
      }
      return ft.conditionalLogProbGivenPrevious(prevLabels, label);
    } else {
      int[] p = new int[windowSize - 1];
      System.arraycopy(prevLabels, prevLabels.length - p.length, p, 0, p.length);
      return factorTables[position].conditionalLogProbGivenPrevious(p, label);
    }
  }

  public double condLogProbGivenPrevious(int position, E label, E[] prevLabels) {
    return condLogProbGivenPrevious(position, classIndex.indexOf(label), objectArrayToIntArray(prevLabels));
  }

  public double condProbGivenPrevious(int position, int label, int[] prevLabels) {
    return Math.exp(condLogProbGivenPrevious(position, label, prevLabels));
  }

  public double condProbGivenPrevious(int position, E label, E[] prevLabels) {
    return Math.exp(condLogProbGivenPrevious(position, label, prevLabels));
  }

  public Counter condLogProbsGivenPrevious(int position, int[] prevlabels) {
    Counter c = new ClassicCounter();
    for (int i = 0; i < classIndex.size(); i++) {
      E label = classIndex.get(i);
      c.incrementCount(label, condLogProbGivenPrevious(position, i, prevlabels));
    }
    return c;
  }

  public Counter condLogProbsGivenPrevious(int position, E[] prevlabels) {
    Counter c = new ClassicCounter();
    for (int i = 0; i < classIndex.size(); i++) {
      E label = classIndex.get(i);
      c.incrementCount(label, condLogProbGivenPrevious(position, label, prevlabels));
    }
    return c;
  }

  //
  // PROB OF TAG AT SINGLE POSITION CONDITIONED ON FOLLOWING SEQUENCE OF LABELS
  //

  public double condLogProbGivenNext(int position, int label, int[] nextLabels) {
    position = position + nextLabels.length;
    if (nextLabels.length + 1 == windowSize) {
      return factorTables[position].conditionalLogProbGivenNext(nextLabels, label);
    } else if (nextLabels.length + 1 < windowSize) {
      FactorTable ft = factorTables[position].sumOutFront();
      while (ft.windowSize() > nextLabels.length + 1) {
        ft = ft.sumOutFront();
      }
      return ft.conditionalLogProbGivenPrevious(nextLabels, label);
    } else {
      int[] p = new int[windowSize - 1];
      System.arraycopy(nextLabels, 0, p, 0, p.length);
      return factorTables[position].conditionalLogProbGivenPrevious(p, label);
    }
  }

  public double condLogProbGivenNext(int position, E label, E[] nextLabels) {
    return condLogProbGivenNext(position, classIndex.indexOf(label), objectArrayToIntArray(nextLabels));
  }

  public double condProbGivenNext(int position, int label, int[] nextLabels) {
    return Math.exp(condLogProbGivenNext(position, label, nextLabels));
  }

  public double condProbGivenNext(int position, E label, E[] nextLabels) {
    return Math.exp(condLogProbGivenNext(position, label, nextLabels));
  }

  public Counter condLogProbsGivenNext(int position, int[] nextlabels) {
    Counter c = new ClassicCounter();
    for (int i = 0; i < classIndex.size(); i++) {
      E label = classIndex.get(i);
      c.incrementCount(label, condLogProbGivenNext(position, i, nextlabels));
    }
    return c;
  }

  public Counter condLogProbsGivenNext(int position, E[] nextlabels) {
    Counter c = new ClassicCounter();
    for (int i = 0; i < classIndex.size(); i++) {
      E label = classIndex.get(i);
      c.incrementCount(label, condLogProbGivenNext(position, label, nextlabels));
    }
    return c;
  }

  //
  // PROB OF TAG AT SINGLE POSITION CONDITIONED ON PREVIOUS AND FOLLOWING
  // SEQUENCE OF LABELS
  //

  // public double condProbGivenPreviousAndNext(int position, int label, int[]
  // prevLabels, int[] nextLabels) {

  // }



  //
  // JOINT CONDITIONAL PROBS
  //

  /**
   * @return a new CRFCliqueTree for the weights on the data
   */
  public static  CRFCliqueTree getCalibratedCliqueTree(double[][] weights, int[][][] data, Index[] labelIndices,
      int numClasses, Index classIndex, E backgroundSymbol) {

    FactorTable[] factorTables = new FactorTable[data.length];
    FactorTable[] messages = new FactorTable[data.length - 1];

    for (int i = 0; i < data.length; i++) {
      factorTables[i] = getFactorTable(weights, data[i], labelIndices, numClasses);

      if (i > 0) {
        messages[i - 1] = factorTables[i - 1].sumOutFront();
        factorTables[i].multiplyInFront(messages[i - 1]);
      }
    }

    for (int i = factorTables.length - 2; i >= 0; i--) {
      FactorTable summedOut = factorTables[i + 1].sumOutEnd();
      summedOut.divideBy(messages[i]);
      factorTables[i].multiplyInEnd(summedOut);
    }

    return new CRFCliqueTree(factorTables, classIndex, backgroundSymbol);
  }

  /**
   * @return a new CRFCliqueTree for the weights on the data
   */
  public static  CRFCliqueTree getCalibratedCliqueTree(double[] weights, double wscale, int[][] weightIndices,
      int[][][] data, Index[] labelIndices, int numClasses, Index classIndex, E backgroundSymbol) {

    FactorTable[] factorTables = new FactorTable[data.length];
    FactorTable[] messages = new FactorTable[data.length - 1];

    for (int i = 0; i < data.length; i++) {

      factorTables[i] = getFactorTable(weights, wscale, weightIndices, data[i], labelIndices, numClasses);

      if (i > 0) {
        messages[i - 1] = factorTables[i - 1].sumOutFront();
        factorTables[i].multiplyInFront(messages[i - 1]);
      }
    }

    for (int i = factorTables.length - 2; i >= 0; i--) {

      FactorTable summedOut = factorTables[i + 1].sumOutEnd();
      summedOut.divideBy(messages[i]);
      factorTables[i].multiplyInEnd(summedOut);
    }

    return new CRFCliqueTree(factorTables, classIndex, backgroundSymbol);
  }

  private static FactorTable getFactorTable(double[] weights, double wscale, int[][] weightIndices, int[][] data,
      Index[] labelIndices, int numClasses) {

    FactorTable factorTable = null;

    for (int j = 0; j < labelIndices.length; j++) {
      Index labelIndex = labelIndices[j];
      FactorTable ft = new FactorTable(numClasses, j + 1);

      // ... and each possible labeling for that clique
      for (int k = 0, liSize = labelIndex.size(); k < liSize; k++) {
        int[] label = ((CRFLabel) labelIndex.get(k)).getLabel();
        double weight = 0.0;
        for (int m = 0; m < data[j].length; m++) {
          int wi = weightIndices[data[j][m]][k];
          weight += wscale * weights[wi];
        }
        // try{
        ft.setValue(label, weight);
        // } catch (Exception e) {
        // System.out.println("CRFCliqueTree::getFactorTable");
        // System.out.println("NumClasses: " + numClasses + " j+1: " + (j+1));
        // System.out.println("k: " + k+" label: " +label+" labelIndexSize: " +
        // labelIndex.size());
        // throw new RunTimeException(e.toString());
        // }

      }
      if (j > 0) {
        ft.multiplyInEnd(factorTable);
      }
      factorTable = ft;

    }

    return factorTable;
  }

  public static FactorTable getFactorTable(double[][] weights, int[][] data, Index[] labelIndices, int numClasses) {

    FactorTable factorTable = null;

    for (int j = 0; j < labelIndices.length; j++) {
      Index labelIndex = labelIndices[j];
      FactorTable ft = new FactorTable(numClasses, j + 1);

      // ... and each possible labeling for that clique
      for (int k = 0, liSize = labelIndex.size(); k < liSize; k++) {
        int[] label = ((CRFLabel) labelIndex.get(k)).getLabel();
        double weight = 0.0;
        for (int m = 0; m < data[j].length; m++) {
          weight += weights[data[j][m]][k];
        }
        // try{
        ft.setValue(label, weight);
        // } catch (Exception e) {
        // System.out.println("CRFCliqueTree::getFactorTable");
        // System.out.println("NumClasses: " + numClasses + " j+1: " + (j+1));
        // System.out.println("k: " + k+" label: " +label+" labelIndexSize: " +
        // labelIndex.size());
        // throw new RunTimeException(e.toString());
        // }

      }
      if (j > 0) {
        ft.multiplyInEnd(factorTable);
      }
      factorTable = ft;

    }

    return factorTable;
  }

  // SEQUENCE MODEL METHODS

  /**
   * Computes the distribution over values of the element at position pos in the
   * sequence, conditioned on the values of the elements in all other positions
   * of the provided sequence.
   *
   * @param sequence
   *          the sequence containing the rest of the values to condition on
   * @param position
   *          the position of the element to give a distribution for
   * @return an array of type double, representing a probability distribution;
   *         sums to 1.0
   */
  public double[] getConditionalDistribution(int[] sequence, int position) {
    double[] result = scoresOf(sequence, position);
    ArrayMath.logNormalize(result);
    // System.out.println("marginal:          " + ArrayMath.toString(marginal,
    // nf));
    // System.out.println("conditional:       " + ArrayMath.toString(result,
    // nf));
    result = ArrayMath.exp(result);
    // System.out.println("conditional:       " + ArrayMath.toString(result,
    // nf));
    return result;
  }

  /**
   * Informs this sequence model that the value of the element at position pos
   * has changed. This allows this sequence model to update its internal model
   * if desired.
   *
   */
  public void updateSequenceElement(int[] sequence, int pos, int oldVal) {
    // do nothing; we don't change this model
  }

  /**
   * Informs this sequence model that the value of the whole sequence is
   * initialized to sequence
   *
   */
  public void setInitialSequence(int[] sequence) {
    // do nothing
  }

  /**
   * @return the number of possible values for each element; it is assumed to be
   *         the same for the element at each position
   */
  public int getNumValues() {
    return numClasses;
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy