com.github.keenon.loglinear.inference.TableFactor Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of loglinear Show documentation
Show all versions of loglinear Show documentation
A PGM library that's easy to use, and runs fast
package com.github.keenon.loglinear.inference;
import com.github.keenon.loglinear.model.NDArray;
import com.github.keenon.loglinear.model.ConcatVector;
import com.github.keenon.loglinear.model.GraphicalModel;
import com.github.keenon.loglinear.model.NDArrayDoubles;
import org.apache.commons.math3.util.FastMath;
import java.io.IOException;
import java.util.*;
import java.util.function.BiFunction;
/**
* Created by keenon on 8/11/15.
*
* Holds a factor populated by doubles that knows how to do all the important operations for PGM inference. Internally,
* these are just different flavors of two basic data-flow operations:
*
* - Factor product
* - Factor marginalization
*
* The output here is different ways to grow and shrink factors that turn out to be useful for downstream uses in PGMs.
* Basically, we care about message passing, as that will be the primary operation.
*
* Everything is represented as log-linear, because the primary use for TableFactor is in CliqueTree, and that is
* intended for use with log-linear models.
*/
public class TableFactor extends NDArrayDoubles {
public int[] neighborIndices;
/**
* Construct a TableFactor for inference within a model. This just copies the important bits from the model factor,
* and replaces the ConcatVectorTable with an internal datastructure that has done all the dotproducts with the
* weights out, and so stores only doubles.
*
* Each element of the table is given by: t_i = exp(f_i*w)
*
* @param weights the vector to dot product with every element of the factor table
* @param factor the feature factor to be multiplied in
*/
public TableFactor(ConcatVector weights, GraphicalModel.Factor factor) {
super(factor.featuresTable.getDimensions());
this.neighborIndices = factor.neigborIndices;
// Calculate the factor residents by dot product with the weights
// OPTIMIZATION:
// Rather than use the standard iterator, which creates lots of int[] arrays on the heap, which need to be GC'd,
// we use the fast version that just mutates one array. Since this is read once for us here, this is ideal.
Iterator fastPassByReferenceIterator = factor.featuresTable.fastPassByReferenceIterator();
int[] assignment = fastPassByReferenceIterator.next();
while (true) {
setAssignmentLogValue(assignment, factor.featuresTable.getAssignmentValue(assignment).get().dotProduct(weights));
// This mutates the assignment[] array, rather than creating a new one
if (fastPassByReferenceIterator.hasNext()) fastPassByReferenceIterator.next();
else break;
}
}
/**
* Fast approximation of the exp() function
* @param val the value to be exponentiated
* @return the value we want
*/
public static double exp(double val) {
final long tmp = (long) (1512775 * val + 1072632447);
return Double.longBitsToDouble(tmp << 32);
}
public static final boolean USE_EXP_APPROX = false;
/**
* Construct a TableFactor for inference within a model. This is the same as the other constructor, except that the
* table is observed out before any unnecessary dot products are done out, so hopefully we dramatically reduce the
* number of computations required to calculate the resulting table.
*
* Each element of the table is given by: t_i = exp(f_i*w)
*
* @param weights the vector to dot product with every element of the factor table
* @param factor the feature factor to be multiplied in
*/
public TableFactor(ConcatVector weights, GraphicalModel.Factor factor, int[] observations) {
super();
assert(observations.length == factor.neigborIndices.length);
int size = 0;
for (int observation : observations) if (observation == -1) size++;
neighborIndices = new int[size];
dimensions = new int[size];
int[] forwardPointers = new int[size];
int[] factorAssignment = new int[factor.neigborIndices.length];
int cursor = 0;
for (int i = 0; i < factor.neigborIndices.length; i++) {
if (observations[i] == -1) {
neighborIndices[cursor] = factor.neigborIndices[i];
dimensions[cursor] = factor.featuresTable.getDimensions()[i];
forwardPointers[cursor] = i;
cursor++;
}
else factorAssignment[i] = observations[i];
}
assert(cursor == size);
values = new double[combinatorialNeighborStatesCount()];
for (int[] assn : this) {
for (int i = 0; i < assn.length; i++) {
factorAssignment[forwardPointers[i]] = assn[i];
}
setAssignmentLogValue(assn, factor.featuresTable.getAssignmentValue(factorAssignment).get().dotProduct(weights));
}
}
/**
* Remove a variable by observing it at a certain value, return a new factor without that variable.
*
* @param variable the variable to be observed
* @param value the value the variable takes when observed
* @return a new factor with 'variable' in it
*/
public TableFactor observe(int variable, final int value) {
return marginalize(variable, 0, (marginalizedVariableValue, assignment) -> {
if (marginalizedVariableValue == value) {
return (old,n) -> {
// This would mean that we're observing something with 0 probability, which will wonk up downstream
// stuff
// assert(n != 0);
return n;
};
}
else {
return (old,n) -> old;
}
});
}
/**
* Returns the summed marginals for each element in the factor. These are represented in log space, and are summed
* using the numerically stable variant, even though it's slightly slower.
*
* @return an array of doubles one-to-one with variable states for each variable
*/
public double[][] getSummedMarginals() {
double[][] results = new double[neighborIndices.length][];
for (int i = 0; i < neighborIndices.length; i++) {
results[i] = new double[getDimensions()[i]];
}
double[][] maxValues = new double[neighborIndices.length][];
for (int i = 0; i < neighborIndices.length; i++) {
maxValues[i] = new double[getDimensions()[i]];
for (int j = 0; j < maxValues[i].length; j++) maxValues[i][j] = Double.NEGATIVE_INFINITY;
}
// Get max values
// OPTIMIZATION:
// Rather than use the standard iterator, which creates lots of int[] arrays on the heap, which need to be GC'd,
// we use the fast version that just mutates one array. Since this is read once for us here, this is ideal.
Iterator fastPassByReferenceIterator = fastPassByReferenceIterator();
int[] assignment = fastPassByReferenceIterator.next();
while (true) {
double v = getAssignmentLogValue(assignment);
for (int i = 0; i < neighborIndices.length; i++) {
if (maxValues[i][assignment[i]] < v) maxValues[i][assignment[i]] = v;
}
// This mutates the resultAssignment[] array, rather than creating a new one
if (fastPassByReferenceIterator.hasNext()) {
fastPassByReferenceIterator.next();
}
else break;
}
// Do the summation
// OPTIMIZATION:
// Rather than use the standard iterator, which creates lots of int[] arrays on the heap, which need to be GC'd,
// we use the fast version that just mutates one array. Since this is read once for us here, this is ideal.
Iterator secondFastPassByReferenceIterator = fastPassByReferenceIterator();
assignment = secondFastPassByReferenceIterator.next();
while (true) {
double v = getAssignmentLogValue(assignment);
for (int i = 0; i < neighborIndices.length; i++) {
if (USE_EXP_APPROX) {
results[i][assignment[i]] += exp(v - maxValues[i][assignment[i]]);
}
else {
results[i][assignment[i]] += FastMath.exp(v - maxValues[i][assignment[i]]);
}
}
// This mutates the resultAssignment[] array, rather than creating a new one
if (secondFastPassByReferenceIterator.hasNext()) {
secondFastPassByReferenceIterator.next();
}
else break;
}
// normalize results, and move to linear space
for (int i = 0; i < neighborIndices.length; i++) {
double sum = 0.0;
for (int j = 0; j < results[i].length; j++) {
if (USE_EXP_APPROX) {
results[i][j] = exp(maxValues[i][j]) * results[i][j];
}
else {
results[i][j] = FastMath.exp(maxValues[i][j]) * results[i][j];
}
sum += results[i][j];
}
if (Double.isInfinite(sum)) {
for (int j = 0; j < results[i].length; j++) {
results[i][j] = 1.0 / results[i].length;
}
}
else {
for (int j = 0; j < results[i].length; j++) {
results[i][j] /= sum;
}
}
}
return results;
}
/**
* Convenience function to max out all but one variable, and return the marginal array.
*
* @return an array of doubles one-to-one with variable states for each variable
*/
public double[][] getMaxedMarginals() {
double[][] maxValues = new double[neighborIndices.length][];
for (int i = 0; i < neighborIndices.length; i++) {
maxValues[i] = new double[getDimensions()[i]];
for (int j = 0; j < maxValues[i].length; j++) maxValues[i][j] = Double.NEGATIVE_INFINITY;
}
// Get max values
// OPTIMIZATION:
// Rather than use the standard iterator, which creates lots of int[] arrays on the heap, which need to be GC'd,
// we use the fast version that just mutates one array. Since this is read once for us here, this is ideal.
Iterator fastPassByReferenceIterator = fastPassByReferenceIterator();
int[] assignment = fastPassByReferenceIterator.next();
while (true) {
double v = getAssignmentLogValue(assignment);
for (int i = 0; i < neighborIndices.length; i++) {
if (maxValues[i][assignment[i]] < v) maxValues[i][assignment[i]] = v;
}
// This mutates the resultAssignment[] array, rather than creating a new one
if (fastPassByReferenceIterator.hasNext()) {
fastPassByReferenceIterator.next();
}
else break;
}
for (int i = 0; i < neighborIndices.length; i++) {
normalizeLogArr(maxValues[i]);
}
return maxValues;
}
/**
* Marginalize out a variable by taking the max value.
*
* @param variable the variable to be maxed out.
* @return a table factor that will contain the largest value of the variable being marginalized out.
*/
public TableFactor maxOut(int variable) {
return marginalize(variable, Double.NEGATIVE_INFINITY, (marginalizedVariableValue, assignment) -> Math::max);
}
/**
* Marginalize out a variable by taking a sum.
*
* @param variable the variable to be summed out
* @return a factor with variable removed
*/
public TableFactor sumOut(int variable) {
// OPTIMIZATION: This is by far the most common case, for linear chain inference, and is worth making fast
// We can use closed loops, and not bother with using the basic iterator to loop through indices.
// If this special case doesn't trip, we fall back to the standard (but slower) algorithm for the general case
if (getDimensions().length == 2) {
if (neighborIndices[0] == variable) {
TableFactor marginalized = new TableFactor(new int[]{neighborIndices[1]}, new int[]{getDimensions()[1]});
for (int i = 0; i < marginalized.values.length; i++) marginalized.values[i] = 0;
// We use the stable log-sum-exp trick here, so first we calculate the max
double[] max = new double[getDimensions()[1]];
for (int j = 0; j < getDimensions()[1]; j++) {
max[j] = Double.NEGATIVE_INFINITY;
}
for (int i = 0; i < getDimensions()[0]; i++) {
int k = i * getDimensions()[1];
for (int j = 0; j < getDimensions()[1]; j++) {
int index = k + j;
if (values[index] > max[j]) {
max[j] = values[index];
}
}
}
// Then we take the sum, minus the max
for (int i = 0; i < getDimensions()[0]; i++) {
int k = i * getDimensions()[1];
for (int j = 0; j < getDimensions()[1]; j++) {
int index = k + j;
if (Double.isFinite(max[j])) {
if (USE_EXP_APPROX) {
marginalized.values[j] += exp(values[index] - max[j]);
}
else {
marginalized.values[j] += FastMath.exp(values[index] - max[j]);
}
}
}
}
// And now we exponentiate, and add back in the values
for (int j = 0; j < getDimensions()[1]; j++) {
if (Double.isFinite(max[j])) {
marginalized.values[j] = max[j] + FastMath.log(marginalized.values[j]);
}
else {
marginalized.values[j] = max[j];
}
}
return marginalized;
}
else {
assert(neighborIndices[1] == variable);
TableFactor marginalized = new TableFactor(new int[]{neighborIndices[0]}, new int[]{getDimensions()[0]});
for (int i = 0; i < marginalized.values.length; i++) marginalized.values[i] = 0;
// We use the stable log-sum-exp trick here, so first we calculate the max
double[] max = new double[getDimensions()[0]];
for (int i = 0; i < getDimensions()[0]; i++) {
max[i] = Double.NEGATIVE_INFINITY;
}
for (int i = 0; i < getDimensions()[0]; i++) {
int k = i * getDimensions()[1];
for (int j = 0; j < getDimensions()[1]; j++) {
int index = k + j;
if (values[index] > max[i]) {
max[i] = values[index];
}
}
}
// Then we take the sum, minus the max
for (int i = 0; i < getDimensions()[0]; i++) {
int k = i * getDimensions()[1];
for (int j = 0; j < getDimensions()[1]; j++) {
int index = k + j;
if (Double.isFinite(max[i])) {
if (USE_EXP_APPROX) {
marginalized.values[i] += exp(values[index] - max[i]);
}
else {
marginalized.values[i] += FastMath.exp(values[index] - max[i]);
}
}
}
}
// And now we exponentiate, and add back in the values
for (int i = 0; i < getDimensions()[0]; i++) {
if (Double.isFinite(max[i])) {
marginalized.values[i] = max[i] + FastMath.log(marginalized.values[i]);
}
else {
marginalized.values[i] = max[i];
}
}
return marginalized;
}
}
else {
// This is a little tricky because we need to use the stable log-sum-exp trick on top of our marginalize
// dataflow operation.
// First we calculate all the max values to use as pivots to prevent overflow
TableFactor maxValues = maxOut(variable);
// Then we do the sum against an offset from the pivots
TableFactor marginalized = marginalize(variable, 0, (marginalizedVariableValue, assignment) -> (a, b) -> a + (USE_EXP_APPROX ? exp(b - maxValues.getAssignmentLogValue(assignment)) : FastMath.exp(b - maxValues.getAssignmentLogValue(assignment))));
// Then we factor the max values back in, and
for (int[] assignment : marginalized) {
marginalized.setAssignmentLogValue(assignment, maxValues.getAssignmentLogValue(assignment) + FastMath.log(marginalized.getAssignmentLogValue(assignment)));
}
return marginalized;
}
}
/**
* Product two factors, taking the multiplication at the intersections.
* @param other the other factor to be multiplied
* @return a factor containing the union of both variable sets
*/
public TableFactor multiply(TableFactor other) {
// Calculate the result domain
List domain = new ArrayList<>();
List otherDomain = new ArrayList<>();
List resultDomain = new ArrayList<>();
for (int n : neighborIndices) {
domain.add(n);
resultDomain.add(n);
}
for (int n : other.neighborIndices) {
otherDomain.add(n);
if (!resultDomain.contains(n)) resultDomain.add(n);
}
// Create result TableFactor
int[] resultNeighborIndices = new int[resultDomain.size()];
int[] resultDimensions = new int[resultNeighborIndices.length];
for (int i = 0; i < resultDomain.size(); i++) {
int var = resultDomain.get(i);
resultNeighborIndices[i] = var;
// assert consistency about variable size, we can't have the same variable with two different sizes
assert((getVariableSize(var) == 0 && other.getVariableSize(var) > 0) ||
(getVariableSize(var) > 0 && other.getVariableSize(var) == 0) ||
(getVariableSize(var) == other.getVariableSize(var)));
resultDimensions[i] = Math.max(getVariableSize(resultDomain.get(i)), other.getVariableSize(resultDomain.get(i)));
}
TableFactor result = new TableFactor(resultNeighborIndices, resultDimensions);
// OPTIMIZATION:
// If we're a factor of size 2 receiving a message of size 1, then we can optimize that pretty heavily
// We could just use the general algorithm at the end of this set of special cases, but this is the fastest way
if (otherDomain.size() == 1 && (resultDomain.size() == domain.size()) && domain.size() == 2) {
int msgVar = otherDomain.get(0);
int msgIndex = resultDomain.indexOf(msgVar);
if (msgIndex == 0) {
for (int i = 0; i < resultDimensions[0]; i++) {
double d = other.values[i];
int k = i * resultDimensions[1];
for (int j = 0; j < resultDimensions[1]; j++) {
int index = k + j;
result.values[index] = values[index] + d;
}
}
}
else if (msgIndex == 1) {
for (int i = 0; i < resultDimensions[0]; i++) {
int k = i * resultDimensions[1];
for (int j = 0; j < resultDimensions[1]; j++) {
int index = k + j;
result.values[index] = values[index] + other.values[j];
}
}
}
}
// OPTIMIZATION:
// The special case where we're a message of size 1, and the other factor is receiving the message, and of size 2
else if (domain.size() == 1 && (resultDomain.size() == otherDomain.size()) && resultDomain.size() == 2) {
return other.multiply(this);
}
// Otherwise we follow the big comprehensive, slow general purpose algorithm
else {
// Calculate back-pointers from the result domain indices to original indices
int[] mapping = new int[result.neighborIndices.length];
int[] otherMapping = new int[result.neighborIndices.length];
for (int i = 0; i < result.neighborIndices.length; i++) {
mapping[i] = domain.indexOf(result.neighborIndices[i]);
otherMapping[i] = otherDomain.indexOf(result.neighborIndices[i]);
}
// Do the actual joining operation between the two tables, applying 'join' for each result element.
int[] assignment = new int[neighborIndices.length];
int[] otherAssignment = new int[other.neighborIndices.length];
// OPTIMIZATION:
// Rather than use the standard iterator, which creates lots of int[] arrays on the heap, which need to be GC'd,
// we use the fast version that just mutates one array. Since this is read once for us here, this is ideal.
Iterator fastPassByReferenceIterator = result.fastPassByReferenceIterator();
int[] resultAssignment = fastPassByReferenceIterator.next();
while (true) {
// Set the assignment arrays correctly
for (int i = 0; i < resultAssignment.length; i++) {
if (mapping[i] != -1) assignment[mapping[i]] = resultAssignment[i];
if (otherMapping[i] != -1) otherAssignment[otherMapping[i]] = resultAssignment[i];
}
result.setAssignmentLogValue(resultAssignment, getAssignmentLogValue(assignment) + other.getAssignmentLogValue(otherAssignment));
// This mutates the resultAssignment[] array, rather than creating a new one
if (fastPassByReferenceIterator.hasNext()) fastPassByReferenceIterator.next();
else break;
}
}
return result;
}
/**
* This is useful for calculating the partition function, and is exposed here because when implemented internally
* we can do a much more numerically stable summation.
*
* @return the sum of all values for all assignments to the TableFactor
*/
public double valueSum() {
// We want the exp(log-sum-exp), for stability
// This rearranges to exp(a)*(sum-exp)
double max = 0.0;
for (int[] assignment : this) {
double v = getAssignmentLogValue(assignment);
if (v > max) {
max = v;
}
}
double sumExp = 0.0;
for (int[] assignment : this) {
sumExp += FastMath.exp(getAssignmentLogValue(assignment) - max);
}
return sumExp * FastMath.exp(max);
}
/**
* Just a pass through to the NDArray version, plus a FastMath.exp to ensure that to the outside world the TableFactor
* doesn't look like it's in log-space
*
* @param assignment a list of variable settings, in the same order as the neighbors array of the factor
* @return the value of the assignment
*/
@Override
public double getAssignmentValue(int[] assignment) {
double d = super.getAssignmentValue(assignment);
// if (d == null) d = Double.NEGATIVE_INFINITY;
return FastMath.exp(d);
}
/**
* Just a pass through to the NDArray version, plus a FastMath.log to ensure that to the outside world the TableFactor
* doesn't look like it's in log-space
*
* @param assignment a list of variable settings, in the same order as the neighbors array of the factor
* @param value the value to put into the factor table
*/
@Override
public void setAssignmentValue(int[] assignment, double value) {
super.setAssignmentValue(assignment, FastMath.log(value));
}
////////////////////////////////////////////////////////////////////////////
// PRIVATE IMPLEMENTATION
////////////////////////////////////////////////////////////////////////////
private double getAssignmentLogValue(int[] assignment) {
return super.getAssignmentValue(assignment);
}
private void setAssignmentLogValue(int[] assignment, double value) {
super.setAssignmentValue(assignment, value);
}
/**
* Marginalizes out a variable by applying an associative join operation for each possible assignment to the
* marginalized variable.
*
* @param variable the variable (by 'name', not offset into neighborIndices)
* @param startingValue associativeJoin is basically a foldr over a table, and this is the initialization
* @param curriedFoldr the associative function to use when applying the join operation, taking first the
* assignment to the value being marginalized, and then a foldr operation
* @return a new TableFactor that doesn't contain 'variable', where values were gotten through associative
* marginalization.
*/
private TableFactor marginalize(int variable, double startingValue, BiFunction> curriedFoldr) {
// Can't marginalize the last variable
assert(getDimensions().length > 1);
// Calculate the result domain
List resultDomain = new ArrayList<>();
for (int n : neighborIndices) {
if (n != variable) {
resultDomain.add(n);
}
}
// Create result TableFactor
int[] resultNeighborIndices = new int[resultDomain.size()];
int[] resultDimensions = new int[resultNeighborIndices.length];
for (int i = 0; i < resultDomain.size(); i++) {
int var = resultDomain.get(i);
resultNeighborIndices[i] = var;
resultDimensions[i] = getVariableSize(var);
}
TableFactor result = new TableFactor(resultNeighborIndices, resultDimensions);
// Calculate forward-pointers from the old domain to new domain
int[] mapping = new int[neighborIndices.length];
for (int i = 0; i < neighborIndices.length; i++) {
mapping[i] = resultDomain.indexOf(neighborIndices[i]);
}
// Initialize
for (int[] assignment : result) {
result.setAssignmentLogValue(assignment, startingValue);
}
// Do the actual fold into the result
int[] resultAssignment = new int[result.neighborIndices.length];
int marginalizedVariableValue = 0;
// OPTIMIZATION:
// Rather than use the standard iterator, which creates lots of int[] arrays on the heap, which need to be GC'd,
// we use the fast version that just mutates one array. Since this is read once for us here, this is ideal.
Iterator fastPassByReferenceIterator = fastPassByReferenceIterator();
int[] assignment = fastPassByReferenceIterator.next();
while (true) {
// Set the assignment arrays correctly
for (int i = 0; i < assignment.length; i++) {
if (mapping[i] != -1) resultAssignment[mapping[i]] = assignment[i];
else marginalizedVariableValue = assignment[i];
}
result.setAssignmentLogValue(resultAssignment, curriedFoldr.apply(marginalizedVariableValue, resultAssignment)
.apply(result.getAssignmentLogValue(resultAssignment), getAssignmentLogValue(assignment)));
if (fastPassByReferenceIterator.hasNext()) fastPassByReferenceIterator.next();
else break;
}
return result;
}
/**
* Address a variable by index to get it's size. Basically just a convenience function.
*
* @param variable the name, not index into neighbors, of the variable in question
* @return the size of the factor along this dimension
*/
private int getVariableSize(int variable) {
for (int i = 0; i < neighborIndices.length; i++) {
if (neighborIndices[i] == variable) return getDimensions()[i];
}
return 0;
}
/**
* Super basic in-place array normalization
* @param arr the array to normalize
*/
private void normalizeLogArr(double[] arr) {
// Find the log-scale normalization value
double max = Double.NEGATIVE_INFINITY;
for (double d : arr) {
if (d > max) max = d;
}
double expSum = 0.0;
for (double d : arr) {
expSum += FastMath.exp(d-max);
}
double logSumExp = max + FastMath.log(expSum);
if (Double.isInfinite(logSumExp)) {
// Just put in uniform probabilities if we are normalizing all 0s
for (int i = 0; i < arr.length; i++) {
arr[i] = 1.0 / arr.length;
}
}
else {
// Normalize in log-scale before exponentiation, to help with stability
for (int i = 0; i < arr.length; i++) {
arr[i] = FastMath.exp(arr[i] - logSumExp);
}
}
}
/**
* FOR PRIVATE USE AND TESTING ONLY
*/
TableFactor(int[] neighborIndices, int[] dimensions) {
super(dimensions);
this.neighborIndices = neighborIndices;
for (int i = 0; i < values.length; i++) {
values[i] = Double.NEGATIVE_INFINITY;
}
}
@SuppressWarnings("*")
private boolean assertsEnabled() {
boolean assertsEnabled = false;
assert(assertsEnabled = true); // intentional side effect
return assertsEnabled;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy