All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.ie.crf.FloatFactorTable Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.ie.crf; 
import edu.stanford.nlp.util.logging.Redwood;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.util.Index;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;


/** Stores a factor table as a one dimensional array of floats.
 *
 *  @author Jenny Finkel
 */

public class FloatFactorTable  {

  /** A logger for this class */
  private static Redwood.RedwoodChannels log = Redwood.channels(FloatFactorTable.class);

  private final int numClasses;
  private final int windowSize;

  private final float[] table;

  public FloatFactorTable(int numClasses, int windowSize) {
    this.numClasses = numClasses;
    this.windowSize = windowSize;

    table = new float[SloppyMath.intPow(numClasses, windowSize)];
    Arrays.fill(table, Float.NEGATIVE_INFINITY);
  }

  public boolean hasNaN() {
    return ArrayMath.hasNaN(table);
  }

  public String toProbString() {
    StringBuilder sb = new StringBuilder("{\n");
    for (int i = 0; i < table.length; i++) {
      sb.append(Arrays.toString(toArray(i)));
      sb.append(": ");
      sb.append(prob(toArray(i)));
      sb.append("\n");
    }
    sb.append("}");
    return sb.toString();
  }

  public String toString(Index classIndex) {
    StringBuilder sb = new StringBuilder("{\n");
    for (int i = 0; i < table.length; i++) {
      sb.append(toString(toArray(i), classIndex));
      sb.append(": ");
      sb.append(getValue(i));
      sb.append("\n");
    }
    sb.append("}");
    return sb.toString();
  }

  @Override
  public String toString() {
    StringBuilder sb = new StringBuilder("{\n");
    for (int i = 0; i < table.length; i++) {
      sb.append(Arrays.toString(toArray(i)));
      sb.append(": ");
      sb.append(getValue(i));
      sb.append("\n");
    }
    sb.append("}");
    return sb.toString();
  }

  private String toString(int[] array, Index classIndex) {
    List l = new ArrayList();
    for (int anArray : array) {
      l.add(classIndex.get(anArray));
    }
    return l.toString();
  }

  private int[] toArray(int index) {
    int[] indices = new int[windowSize];
    for (int i = indices.length - 1; i >= 0; i--) {
      indices[i] = index % numClasses;
      index /= numClasses;
    }
    return indices;
  }

  private int indexOf(int[] entry) {
    int index = 0;
    for (int anEntry : entry) {
      index *= numClasses;
      index += anEntry;
    }
    return index;
  }

  private int indexOf(int[] front, int end) {
    int index = 0;
    for (int aFront : front) {
      index *= numClasses;
      index += aFront;
    }
    index *= numClasses;
    index += end;
    return index;
  }

  private int[] indicesEnd(int[] entries) {
    int[] indices = new int[SloppyMath.intPow(numClasses, windowSize - entries.length)];
    int offset = SloppyMath.intPow(numClasses, entries.length);
    int index = 0;
    for (int entry : entries) {
      index *= numClasses;
      index += entry;
    }
    for (int i = 0; i < indices.length; i++) {
      indices[i] = index;
      index += offset;
    }
    return indices;
  }

  private int[] indicesFront(int[] entries) {
    int[] indices = new int[SloppyMath.intPow(numClasses, windowSize - entries.length)];
    int offset = SloppyMath.intPow(numClasses, windowSize - entries.length);
    int start = 0;
    for (int entry : entries) {
      start *= numClasses;
      start += entry;
    }
    start *= offset;
    int end = 0;
    for (int i = 0; i < entries.length; i++) {
      end *= numClasses;
      end += entries[i];
      if (i == entries.length - 1) {
        end += 1;
      }
    }
    end *= offset;
    for (int i = start; i < end; i++) {
      indices[i - start] = i;
    }
    return indices;
  }

  public int windowSize() {
    return windowSize;
  }

  public int numClasses() {
    return numClasses;
  }

  private int size() {
    return table.length;
  }

  public float totalMass() {
    return ArrayMath.logSum(table);
  }

  public float unnormalizedLogProb(int[] label) {
    return getValue(label);
  }

  public float logProb(int[] label) {
    return unnormalizedLogProb(label) - totalMass();
  }


  public float prob(int[] label) {
    return (float) Math.exp(unnormalizedLogProb(label) - totalMass());
  }

  // given is at the begining, of is at the end
  public float conditionalLogProb(int[] given, int of) {
    if (given.length != windowSize - 1) {
      log.info("error computing conditional log prob");
      System.exit(0);
    }
    int[] label = indicesFront(given);
    float[] masses = new float[label.length];
    for (int i = 0; i < masses.length; i++) {
      masses[i] = table[label[i]];
    }
    float z = ArrayMath.logSum(masses);

    return table[indexOf(given, of)] - z;
  }

  public float unnormalizedLogProbFront(int[] label) {
    label = indicesFront(label);
    float[] masses = new float[label.length];
    for (int i = 0; i < masses.length; i++) {
      masses[i] = table[label[i]];
    }
    return ArrayMath.logSum(masses);
  }

  public float logProbFront(int[] label) {
    return unnormalizedLogProbFront(label) - totalMass();
  }

  public float unnormalizedLogProbEnd(int[] label) {
    label = indicesEnd(label);
    float[] masses = new float[label.length];
    for (int i = 0; i < masses.length; i++) {
      masses[i] = table[label[i]];
    }
    return ArrayMath.logSum(masses);
  }

  public float logProbEnd(int[] label) {
    return unnormalizedLogProbEnd(label) - totalMass();
  }

  public float unnormalizedLogProbEnd(int label) {
    int[] l = {label};
    l = indicesEnd(l);
    float[] masses = new float[l.length];
    for (int i = 0; i < masses.length; i++) {
      masses[i] = table[l[i]];
    }
    return ArrayMath.logSum(masses);
  }

  public float logProbEnd(int label) {
    return unnormalizedLogProbEnd(label) - totalMass();
  }

  private float getValue(int index) {
    return table[index];
  }

  public float getValue(int[] label) {
    return table[indexOf(label)];
  }

  private void setValue(int index, float value) {
    table[index] = value;
  }

  public void setValue(int[] label, float value) {
    table[indexOf(label)] = value;
  }

  public void incrementValue(int[] label, float value) {
    table[indexOf(label)] += value;
  }

  private void logIncrementValue(int index, float value) {
    table[index] = SloppyMath.logAdd(table[index], value);
  }

  public void logIncrementValue(int[] label, float value) {
    int index = indexOf(label);
    table[index] = SloppyMath.logAdd(table[index], value);
  }

  public void multiplyInFront(FloatFactorTable other) {
    int divisor = SloppyMath.intPow(numClasses, windowSize - other.windowSize());
    for (int i = 0; i < table.length; i++) {
      table[i] += other.getValue(i / divisor);
    }
  }

  public void multiplyInEnd(FloatFactorTable other) {
    int divisor = SloppyMath.intPow(numClasses, other.windowSize());
    for (int i = 0; i < table.length; i++) {
      table[i] += other.getValue(i % divisor);
    }
  }

  public FloatFactorTable sumOutEnd() {
    FloatFactorTable ft = new FloatFactorTable(numClasses, windowSize - 1);
    for (int i = 0; i < table.length; i++) {
      ft.logIncrementValue(i / numClasses, table[i]);
    }
    return ft;
  }

  public FloatFactorTable sumOutFront() {
    FloatFactorTable ft = new FloatFactorTable(numClasses, windowSize - 1);
    int mod = SloppyMath.intPow(numClasses, windowSize - 1);
    for (int i = 0; i < table.length; i++) {
      ft.logIncrementValue(i % mod, table[i]);
    }
    return ft;
  }

  public void divideBy(FloatFactorTable other) {
    for (int i = 0; i < table.length; i++) {
      if (table[i] != Float.NEGATIVE_INFINITY || other.table[i] != Float.NEGATIVE_INFINITY) {
        table[i] -= other.table[i];
      }
    }
  }

  public static void main(String[] args) {
    FloatFactorTable ft = new FloatFactorTable(6, 3);

    /**
     for (int i = 0; i < 2; i++) {
     for (int j = 0; j < 2; j++) {
     for (int k = 0; k < 2; k++) {
     int[] a = new int[]{i, j, k};
     System.out.print(ft.toString(a)+": "+ft.indexOf(a));
     }
     }
     }
     for (int i = 0; i < 2; i++) {
     int[] b = new int[]{i};
     System.out.print(ft.toString(b)+": "+ft.toString(ft.indicesFront(b)));
     }
     for (int i = 0; i < 2; i++) {
     for (int j = 0; j < 2; j++) {
     int[] b = new int[]{i, j};
     System.out.print(ft.toString(b)+": "+ft.toString(ft.indicesFront(b)));
     }
     }
     for (int i = 0; i < 2; i++) {
     int[] b = new int[]{i};
     System.out.print(ft.toString(b)+": "+ft.toString(ft.indicesBack(b)));
     }	for (int i = 0; i < 2; i++) {
     for (int j = 0; j < 2; j++) {
     int[] b = new int[]{i, j};
     ft2.setValue(b, (i*2)+j);
     }
     }
     for (int i = 0; i < 2; i++) {
     for (int j = 0; j < 2; j++) {
     int[] b = new int[]{i, j};
     System.out.print(ft.toString(b)+": "+ft.toString(ft.indicesBack(b)));
     }
     }

     System.out.println("##########################################");

     **/

    for (int i = 0; i < 6; i++) {
      for (int j = 0; j < 6; j++) {
        for (int k = 0; k < 6; k++) {
          int[] b = new int[]{i, j, k};
          ft.setValue(b, (i * 4) + (j * 2) + k);
        }
      }
    }

    //System.out.println(ft);
    //System.out.println(ft.sumOutFront());

    FloatFactorTable ft2 = new FloatFactorTable(6, 2);
    for (int i = 0; i < 6; i++) {
      for (int j = 0; j < 6; j++) {
        int[] b = new int[]{i, j};
        ft2.setValue(b, i * 6 + j);
      }
    }

    System.out.println(ft);
    //FloatFactorTable ft3 = ft2.sumOutFront();
    //System.out.println(ft3);

    for (int i = 0; i < 6; i++) {
      for (int j = 0; j < 6; j++) {
        int[] b = new int[]{i, j};
        float t = 0;
        for (int k = 0; k < 6; k++) {
          t += Math.exp(ft.conditionalLogProb(b, k));
          log.info(k + "|" + i + "," + j + " : " + Math.exp(ft.conditionalLogProb(b, k)));
        }
        System.out.println(t);
      }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy