All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.ie.crf.CRFCliqueTree Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.ie.crf; 
import edu.stanford.nlp.util.logging.Redwood;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.sequences.ListeningSequenceModel;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.GeneralizedCounter;
import edu.stanford.nlp.util.Index;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Builds a CliqueTree (an array of FactorTable) and does message passing
 * inference along it.
 *
 * @param  The type of the label (usually String in our uses)
 * @author Jenny Finkel
 */
public class CRFCliqueTree implements ListeningSequenceModel  {

  /** A logger for this class */
  private static Redwood.RedwoodChannels log = Redwood.channels(CRFCliqueTree.class);

  private final FactorTable[] factorTables;
  private final double z; // norm constant
  private final Index classIndex;
  private final E backgroundSymbol;
  private final int backgroundIndex;
  // the window size, which is also the clique size
  private final int windowSize;
  // the number of possible classes for each label
  private final int numClasses;
  private final int[] possibleValues;

  /** Initialize a clique tree. */
  public CRFCliqueTree(FactorTable[] factorTables, Index classIndex, E backgroundSymbol) {
    this(factorTables, classIndex, backgroundSymbol, factorTables[0].totalMass());
  }

  /** This extra constructor was added to support the CRFCliqueTreeForPartialLabels. */
  CRFCliqueTree(FactorTable[] factorTables, Index classIndex, E backgroundSymbol, double z) {
    this.factorTables = factorTables;
    this.z = z;
    this.classIndex = classIndex;
    this.backgroundSymbol = backgroundSymbol;
    backgroundIndex = classIndex.indexOf(backgroundSymbol);
    windowSize = factorTables[0].windowSize();
    numClasses = classIndex.size();
    possibleValues = new int[numClasses];
    for (int i = 0; i < numClasses; i++) {
      possibleValues[i] = i;
    }

    // Debug only
    // System.out.println("CRFCliqueTree constructed::numClasses: " +
    // numClasses);
  }

  public FactorTable[] getFactorTables() {
    return this.factorTables;
  }

  public Index classIndex() {
    return classIndex;
  }

  // SEQUENCE MODEL METHODS

  @Override
  public int length() {
    return factorTables.length;
  }

  @Override
  public int leftWindow() {
    return windowSize;
  }

  @Override
  public int rightWindow() {
    return 0;
  }

  @Override
  public int[] getPossibleValues(int position) {
    return possibleValues;
  }

  @Override
  public double scoreOf(int[] sequence, int pos) {
    return scoresOf(sequence, pos)[sequence[pos]];
  }

  /**
   * Computes the unnormalized log conditional distribution over values of the
   * element at position pos in the sequence, conditioned on the values of the
   * elements in all other positions of the provided sequence.
   *
   * @param sequence
   *          the sequence containing the rest of the values to condition on
   * @param position
   *          the position of the element to give a distribution for
   * @return an array of type double, representing a probability distribution;
   *         sums to 1.0
   */
  @Override
  public double[] scoresOf(int[] sequence, int position) {
    if (position >= factorTables.length) throw new RuntimeException("Index out of bounds: " + position);
    // DecimalFormat nf = new DecimalFormat("#0.000");
    // if (position>0 && position= length()) {
      nextLength = length() - position - 1;
    }
    FactorTable nextFactorTable = factorTables[position + nextLength];
    if (nextLength != windowSize - 1) {
      for (int j = 0; j < windowSize - 1 - nextLength; j++) {
        nextFactorTable = nextFactorTable.sumOutFront();
      }
    }
    if (nextLength == 0) { // we are asking about the prob of no sequence
      Arrays.fill(probNextGivenThis, 1.0);
    } else {
      int[] next = new int[nextLength];
      System.arraycopy(sequence, position + 1, next, 0, nextLength);
      for (int label = 0; label < numClasses; label++) {
        // ask the factor table such that pos is the first position in the
        // window
        // probNextGivenThis[label] =
        // factorTables[position+nextLength].conditionalLogProbGivenFirst(label,
        // next);
        // probNextGivenThis[label] =
        // nextFactorTable.conditionalLogProbGivenFirst(label, next);
        probNextGivenThis[label] = nextFactorTable.unnormalizedConditionalLogProbGivenFirst(label, next);
      }
    }

    // pointwise multiply
    return ArrayMath.pairwiseAdd(probThisGivenPrev, probNextGivenThis);
  }

  /**
   * Returns the log probability of this sequence given the CRF. Does so by
   * computing the marginal of the first windowSize tags, and then computing the
   * conditional probability for the rest of them, conditioned on the previous
   * tags.
   *
   * @param sequence The sequence to compute a score for
   * @return the score for the sequence
   */
  @Override
  public double scoreOf(int[] sequence) {

    int[] given = new int[window() - 1];
    Arrays.fill(given, classIndex.indexOf(backgroundSymbol));
    double logProb = 0.0;
    for (int i = 0, length = length(); i < length; i++) {
      int label = sequence[i];
      logProb += condLogProbGivenPrevious(i, label, given);
      System.arraycopy(given, 1, given, 0, given.length - 1);
      given[given.length - 1] = label;
    }
    return logProb;
  }

  // OTHER

  public int window() {
    return windowSize;
  }

  public int getNumClasses() {
    return numClasses;
  }

  public double totalMass() {
    return z;
  }

  public int backgroundIndex() {
    return backgroundIndex;
  }

  public E backgroundSymbol() {
    return backgroundSymbol;
  }

  //
  // MARGINAL PROB OF TAG AT SINGLE POSITION
  //

  public double[][] logProbTable() {
    double[][] result = new double[length()][classIndex.size()];
    for (int i = 0; i < length(); i++) {
      result[i] = new double[classIndex.size()];
      for (int j = 0; j < classIndex.size(); j++) {
        result[i][j] = logProb(i, j);
      }
    }

    return result;
  }

  /*
  * TODO(mengqiu) this function is buggy, should make sure label converts properly into int[] in cases where it's not 0-order label
  */
  public double logProbStartPos() {
    double u = factorTables[0].unnormalizedLogProbFront(backgroundIndex);
    return u - z;
  }

  public double logProb(int position, int label) {
    double u = factorTables[position].unnormalizedLogProbEnd(label);
    return u - z;
  }

  public double prob(int position, int label) {
    return Math.exp(logProb(position, label));
  }

  public double logProb(int position, E label) {
    return logProb(position, classIndex.indexOf(label));
  }

  public double prob(int position, E label) {
    return Math.exp(logProb(position, label));
  }

  public double[] probsToDoubleArr(int position) {
    double[] probs = new double[classIndex.size()];
    for (int i = 0, sz = classIndex.size(); i < sz; i++) {
      probs[i] = prob(position, i);
    }
    return probs;
  }

  public double[] logProbsToDoubleArr(int position) {
    double[] probs = new double[classIndex.size()];
    for (int i = 0, sz = classIndex.size(); i < sz; i++) {
      probs[i] = logProb(position, i);
    }
    return probs;
  }

  public Counter probs(int position) {
    Counter c = new ClassicCounter<>();
    for (int i = 0, sz = classIndex.size(); i < sz; i++) {
      E label = classIndex.get(i);
      c.incrementCount(label, prob(position, i));
    }
    return c;
  }

  public Counter logProbs(int position) {
    Counter c = new ClassicCounter<>();
    for (int i = 0, sz = classIndex.size(); i < sz; i++) {
      E label = classIndex.get(i);
      c.incrementCount(label, logProb(position, i));
    }
    return c;
  }

  //
  // MARGINAL PROBS OF TAGS AT MULTIPLE POSITIONS
  //

  /**
   * returns the log probability for the given labels (indexed using
   * classIndex), where the last label corresponds to the label at the specified
   * position. For instance if you called logProb(5, {1,2,3}) it will return the
   * marginal log prob that the label at position 3 is 1, the label at position
   * 4 is 2 and the label at position 5 is 3.
   */
  public double logProb(int position, int[] labels) {
    if (labels.length < windowSize) {
      return factorTables[position].unnormalizedLogProbEnd(labels) - z;
    } else if (labels.length == windowSize) {
      return factorTables[position].unnormalizedLogProb(labels) - z;
    } else {
      int[] l = new int[windowSize];
      System.arraycopy(labels, 0, l, 0, l.length);
      int position1 = position - labels.length + windowSize;
      double p = factorTables[position1].unnormalizedLogProb(l) - z;
      l = new int[windowSize - 1];
      System.arraycopy(labels, 1, l, 0, l.length);
      position1++;
      for (int i = windowSize; i < labels.length; i++) {
        p += condLogProbGivenPrevious(position1++, labels[i], l);
        System.arraycopy(l, 1, l, 0, l.length - 1);
        l[windowSize - 2] = labels[i];
      }
      return p;
    }
  }

  /**
   * Returns the probability for the given labels (indexed using classIndex),
   * where the last label corresponds to the label at the specified position.
   * For instance if you called prob(5, {1,2,3}) it will return the marginal
   * prob that the label at position 3 is 1, the label at position 4 is 2 and
   * the label at position 5 is 3.
   */
  public double prob(int position, int[] labels) {
    return Math.exp(logProb(position, labels));
  }

  /**
   * returns the log probability for the given labels, where the last label
   * corresponds to the label at the specified position. For instance if you
   * called logProb(5, {"O", "PER", "ORG"}) it will return the marginal log prob
   * that the label at position 3 is "O", the label at position 4 is "PER" and
   * the label at position 5 is "ORG".
   */
  public double logProb(int position, E[] labels) {
    return logProb(position, objectArrayToIntArray(labels));
  }

  /**
   * returns the probability for the given labels, where the last label
   * corresponds to the label at the specified position. For instance if you
   * called logProb(5, {"O", "PER", "ORG"}) it will return the marginal prob
   * that the label at position 3 is "O", the label at position 4 is "PER" and
   * the label at position 5 is "ORG".
   */
  public double prob(int position, E[] labels) {
    return Math.exp(logProb(position, labels));
  }

  public GeneralizedCounter logProbs(int position, int window) {
    GeneralizedCounter gc = new GeneralizedCounter<>(window);
    int[] labels = new int[window];
    // cdm july 2005: below array initialization isn't necessary: JLS (3rd ed.)
    // 4.12.5
    // Arrays.fill(labels, 0);

    OUTER: while (true) {
      List labelsList = intArrayToListE(labels);
      gc.incrementCount(labelsList, logProb(position, labels));
      for (int i = 0; i < labels.length; i++) {
        labels[i]++;
        if (labels[i] < numClasses) {
          break;
        }
        if (i == labels.length - 1) {
          break OUTER;
        }
        labels[i] = 0;
      }
    }
    return gc;
  }

  public GeneralizedCounter probs(int position, int window) {
    GeneralizedCounter gc = new GeneralizedCounter<>(window);
    int[] labels = new int[window];
    // cdm july 2005: below array initialization isn't necessary: JLS (3rd ed.)
    // 4.12.5
    // Arrays.fill(labels, 0);

    OUTER: while (true) {
      List labelsList = intArrayToListE(labels);
      gc.incrementCount(labelsList, prob(position, labels));
      for (int i = 0; i < labels.length; i++) {
        labels[i]++;
        if (labels[i] < numClasses) {
          break;
        }
        if (i == labels.length - 1) {
          break OUTER;
        }
        labels[i] = 0;
      }
    }
    return gc;
  }

  //
  // HELPER METHODS
  //

  private int[] objectArrayToIntArray(E[] os) {
    int[] is = new int[os.length];
    for (int i = 0; i < os.length; i++) {
      is[i] = classIndex.indexOf(os[i]);
    }
    return is;
  }

  private List intArrayToListE(int[] is) {
    List os = new ArrayList<>(is.length);
    for (int i : is) {
      os.add(classIndex.get(i));
    }
    return os;
  }

  /**
   * Gives the probability of a tag at a single position conditioned on a
   * sequence of previous labels.
   *
   * @param position Index in sequence
   * @param label Label of item at index
   * @param prevLabels Indices of labels in previous positions
   * @return conditional log probability
   */
  public double condLogProbGivenPrevious(int position, int label, int[] prevLabels) {
    if (prevLabels.length + 1 == windowSize) {
      return factorTables[position].conditionalLogProbGivenPrevious(prevLabels, label);
    } else if (prevLabels.length + 1 < windowSize) {
      FactorTable ft = factorTables[position].sumOutFront();
      while (ft.windowSize() > prevLabels.length + 1) {
        ft = ft.sumOutFront();
      }
      return ft.conditionalLogProbGivenPrevious(prevLabels, label);
    } else {
      int[] p = new int[windowSize - 1];
      System.arraycopy(prevLabels, prevLabels.length - p.length, p, 0, p.length);
      return factorTables[position].conditionalLogProbGivenPrevious(p, label);
    }
  }

  public double condLogProbGivenPrevious(int position, E label, E[] prevLabels) {
    return condLogProbGivenPrevious(position, classIndex.indexOf(label), objectArrayToIntArray(prevLabels));
  }

  public double condProbGivenPrevious(int position, int label, int[] prevLabels) {
    return Math.exp(condLogProbGivenPrevious(position, label, prevLabels));
  }

  public double condProbGivenPrevious(int position, E label, E[] prevLabels) {
    return Math.exp(condLogProbGivenPrevious(position, label, prevLabels));
  }

  public Counter condLogProbsGivenPrevious(int position, int[] prevlabels) {
    Counter c = new ClassicCounter<>();
    for (int i = 0, sz = classIndex.size(); i < sz; i++) {
      E label = classIndex.get(i);
      c.incrementCount(label, condLogProbGivenPrevious(position, i, prevlabels));
    }
    return c;
  }

  public Counter condLogProbsGivenPrevious(int position, E[] prevlabels) {
    Counter c = new ClassicCounter<>();
    for (int i = 0, sz = classIndex.size(); i < sz; i++) {
      E label = classIndex.get(i);
      c.incrementCount(label, condLogProbGivenPrevious(position, label, prevlabels));
    }
    return c;
  }

  //
  // PROB OF TAG AT SINGLE POSITION CONDITIONED ON FOLLOWING SEQUENCE OF LABELS
  //

  public double condLogProbGivenNext(int position, int label, int[] nextLabels) {
    position = position + nextLabels.length;
    if (nextLabels.length + 1 == windowSize) {
      return factorTables[position].conditionalLogProbGivenNext(nextLabels, label);
    } else if (nextLabels.length + 1 < windowSize) {
      FactorTable ft = factorTables[position].sumOutFront();
      while (ft.windowSize() > nextLabels.length + 1) {
        ft = ft.sumOutFront();
      }
      return ft.conditionalLogProbGivenPrevious(nextLabels, label);
    } else {
      int[] p = new int[windowSize - 1];
      System.arraycopy(nextLabels, 0, p, 0, p.length);
      return factorTables[position].conditionalLogProbGivenPrevious(p, label);
    }
  }

  public double condLogProbGivenNext(int position, E label, E[] nextLabels) {
    return condLogProbGivenNext(position, classIndex.indexOf(label), objectArrayToIntArray(nextLabels));
  }

  public double condProbGivenNext(int position, int label, int[] nextLabels) {
    return Math.exp(condLogProbGivenNext(position, label, nextLabels));
  }

  public double condProbGivenNext(int position, E label, E[] nextLabels) {
    return Math.exp(condLogProbGivenNext(position, label, nextLabels));
  }

  public Counter condLogProbsGivenNext(int position, int[] nextlabels) {
    Counter c = new ClassicCounter<>();
    for (int i = 0, sz = classIndex.size(); i < sz; i++) {
      E label = classIndex.get(i);
      c.incrementCount(label, condLogProbGivenNext(position, i, nextlabels));
    }
    return c;
  }

  public Counter condLogProbsGivenNext(int position, E[] nextlabels) {
    Counter c = new ClassicCounter<>();
    for (int i = 0, sz = classIndex.size(); i < sz; i++) {
      E label = classIndex.get(i);
      c.incrementCount(label, condLogProbGivenNext(position, label, nextlabels));
    }
    return c;
  }

  //
  // PROB OF TAG AT SINGLE POSITION CONDITIONED ON PREVIOUS AND FOLLOWING
  // SEQUENCE OF LABELS
  //

  // public double condProbGivenPreviousAndNext(int position, int label, int[]
  // prevLabels, int[] nextLabels) {

  // }



  //
  // JOINT CONDITIONAL PROBS
  //
  /**
   * @return a new CRFCliqueTree for the weights on the data
   */
  public static  CRFCliqueTree getCalibratedCliqueTree(int[][][] data, List> labelIndices,
      int numClasses, Index classIndex, E backgroundSymbol, CliquePotentialFunction cliquePotentialFunc, double[][][] featureVals) {

    FactorTable[] factorTables = new FactorTable[data.length];
    FactorTable[] messages = new FactorTable[data.length - 1];

    for (int i = 0; i < data.length; i++) {
      double[][] featureValByCliqueSize = null;
      if (featureVals != null)
        featureValByCliqueSize = featureVals[i];
      factorTables[i] = getFactorTable(data[i], labelIndices, numClasses, cliquePotentialFunc, featureValByCliqueSize, i);

      // log.info("before calibration,FT["+i+"] = " + factorTables[i].toProbString());

      if (i > 0) {
        messages[i - 1] = factorTables[i - 1].sumOutFront();
        // log.info("forward message, message["+(i-1)+"] = " + messages[i-1].toProbString());
        factorTables[i].multiplyInFront(messages[i - 1]);
        // log.info("after forward calibration, FT["+i+"] = " + factorTables[i].toProbString());
      }
    }

    for (int i = factorTables.length - 2; i >= 0; i--) {
      FactorTable summedOut = factorTables[i + 1].sumOutEnd();
      summedOut.divideBy(messages[i]);
      // log.info("backward summedOut, summedOut= " + summedOut.toProbString());
      factorTables[i].multiplyInEnd(summedOut);
      // log.info("after backward calibration, FT["+i+"] = " + factorTables[i].toProbString());
    }

    return new CRFCliqueTree<>(factorTables, classIndex, backgroundSymbol);
  }

  /**
   * This function assumes a LinearCliquePotentialFunction is used for wrapping the weights
   * @return a new CRFCliqueTree for the weights on the data
   */
  public static  CRFCliqueTree getCalibratedCliqueTree(double[] weights, double wscale, int[][] weightIndices,
      int[][][] data, List> labelIndices, int numClasses, Index classIndex, E backgroundSymbol) {

    FactorTable[] factorTables = new FactorTable[data.length];
    FactorTable[] messages = new FactorTable[data.length - 1];

    for (int i = 0; i < data.length; i++) {

      factorTables[i] = getFactorTable(weights, wscale, weightIndices, data[i], labelIndices, numClasses);

      if (i > 0) {
        messages[i - 1] = factorTables[i - 1].sumOutFront();
        factorTables[i].multiplyInFront(messages[i - 1]);
      }
    }

    for (int i = factorTables.length - 2; i >= 0; i--) {

      FactorTable summedOut = factorTables[i + 1].sumOutEnd();
      summedOut.divideBy(messages[i]);
      factorTables[i].multiplyInEnd(summedOut);
    }

    return new CRFCliqueTree<>(factorTables, classIndex, backgroundSymbol);
  }

  private static FactorTable getFactorTable(double[] weights, double wScale, int[][] weightIndices, int[][] data,
      List> labelIndices, int numClasses) {

    FactorTable factorTable = null;

    for (int j = 0, sz = labelIndices.size(); j < sz; j++) {
      Index labelIndex = labelIndices.get(j);
      FactorTable ft = new FactorTable(numClasses, j + 1);

      // ... and each possible labeling for that clique
      for (int k = 0, liSize = labelIndex.size(); k < liSize; k++) {
        int[] label = labelIndex.get(k).getLabel();
        double weight = 0.0;
        for (int m = 0; m < data[j].length; m++) {
          int wi = weightIndices[data[j][m]][k];
          weight += wScale * weights[wi];
        }
        // try{
        ft.setValue(label, weight);
        // } catch (Exception e) {
        // System.out.println("CRFCliqueTree::getFactorTable");
        // System.out.println("NumClasses: " + numClasses + " j+1: " + (j+1));
        // System.out.println("k: " + k+" label: " +label+" labelIndexSize: " +
        // labelIndex.size());
        // throw new RunTimeException(e.toString());
        // }

      }
      if (j > 0) {
        ft.multiplyInEnd(factorTable);
      }
      factorTable = ft;

    }

    return factorTable;
  }

  // static FactorTable getFactorTable(double[][] weights, int[][] data, List> labelIndices, int numClasses, int posInSent) {
  //   CliquePotentialFunction cliquePotentialFunc = new LinearCliquePotentialFunction(weights);
  //   return getFactorTable(data, labelIndices, numClasses, cliquePotentialFunc, null, posInSent);
  // }

  static FactorTable getFactorTable(int[][] data, List> labelIndices, int numClasses,
      CliquePotentialFunction cliquePotentialFunc, double[][] featureValByCliqueSize, int posInSent) {
    FactorTable factorTable = null;

    for (int j = 0, sz = labelIndices.size(); j < sz; j++) {
      Index labelIndex = labelIndices.get(j);
      FactorTable ft = new FactorTable(numClasses, j + 1);
      double[] featureVal = null;
      if (featureValByCliqueSize != null)
        featureVal = featureValByCliqueSize[j];

      // ... and each possible labeling for that clique
      for (int k = 0, liSize = labelIndex.size(); k < liSize; k++) {
        int[] label = labelIndex.get(k).getLabel();
        double cliquePotential = cliquePotentialFunc.computeCliquePotential(j+1, k, data[j], featureVal, posInSent);
        // for (int m = 0; m < data[j].length; m++) {
        //   weight += weights[data[j][m]][k];
        // }
        // try{
        ft.setValue(label, cliquePotential);
        // } catch (Exception e) {
        // System.out.println("CRFCliqueTree::getFactorTable");
        // System.out.println("NumClasses: " + numClasses + " j+1: " + (j+1));
        // System.out.println("k: " + k+" label: " +label+" labelIndexSize: " +
        // labelIndex.size());
        // throw new RunTimeException(e.toString());
        // }

      }
      if (j > 0) {
        ft.multiplyInEnd(factorTable);
      }
      factorTable = ft;

    }

    return factorTable;
  }


  // SEQUENCE MODEL METHODS

  /**
   * Computes the distribution over values of the element at position pos in the
   * sequence, conditioned on the values of the elements in all other positions
   * of the provided sequence.
   *
   * @param sequence
   *          the sequence containing the rest of the values to condition on
   * @param position
   *          the position of the element to give a distribution for
   * @return an array of type double, representing a probability distribution;
   *         sums to 1.0
   */
  public double[] getConditionalDistribution(int[] sequence, int position) {
    double[] result = scoresOf(sequence, position);
    ArrayMath.logNormalize(result);
    // System.out.println("marginal:          " + ArrayMath.toString(marginal,
    // nf));
    // System.out.println("conditional:       " + ArrayMath.toString(result,
    // nf));
    result = ArrayMath.exp(result);
    // System.out.println("conditional:       " + ArrayMath.toString(result,
    // nf));
    return result;
  }

  /**
   * Informs this sequence model that the value of the element at position pos
   * has changed. This allows this sequence model to update its internal model
   * if desired.
   */
  @Override
  public void updateSequenceElement(int[] sequence, int pos, int oldVal) {
    // do nothing; we don't change this model
  }

  /**
   * Informs this sequence model that the value of the whole sequence is
   * initialized to sequence
   */
  @Override
  public void setInitialSequence(int[] sequence) {
    // do nothing
  }

  /**
   * @return the number of possible values for each element; it is assumed to be
   *         the same for the element at each position
   */
  public int getNumValues() {
    return numClasses;
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy