
cc.mallet.grmm.inference.AbstractBeliefPropagation Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
package cc.mallet.grmm.inference;
import java.util.logging.Logger;
import java.util.logging.Level;
import java.util.List;
import java.util.Iterator;
import java.io.*;
import cc.mallet.grmm.types.*;
import cc.mallet.util.MalletLogger;
/**
* Abstract base class for umplementations of belief propagation for general factor graphs.
* This class manages arrays of messages, computing beliefs from messages, and convergence
* thresholds.
*
* How to send individual messages (e.g., sum-product, max-product, etc) are mananged
* by istances of the interface @link{#MessageStrategy}. Concrete subclasses decide
* which order to send messages in.
*
* @author Charles Sutton
* @version $Id: AbstractBeliefPropagation.java,v 1.1 2007/10/22 21:37:49 mccallum Exp $
*/
public abstract class AbstractBeliefPropagation extends AbstractInferencer {
protected static Logger logger = MalletLogger.getLogger (AbstractBeliefPropagation.class.getName ());
private static final boolean diagnoseConvergence = false;
protected boolean normalizeBeliefs = true;
static private int totalMessagesSent = 0;
transient private int myMessagesSent = 0;
transient private int messagesSentAtStart = 0;
private double threshold = 0.00001;
protected boolean useCaching = false;
private MessageStrategy messager;
protected transient int iterUsed;
protected AbstractBeliefPropagation ()
{
this (new SumProductMessageStrategy ());
}
protected AbstractBeliefPropagation (MessageStrategy messager)
{
this.messager = messager;
}
public MessageStrategy getMessager ()
{
return messager;
}
public AbstractBeliefPropagation setMessager (MessageStrategy messager)
{
this.messager = messager;
return this;
}
/**
* Returns the total number of messages all BP inferencers have sent in the current Java image.
*/
public static int getTotalMessagesSent ()
{ return totalMessagesSent; }
/** Returns the total number of messages this inferencer has sent since its creation. */
public int getMessagesSent () { return myMessagesSent; }
/**
* Returns the number of messages sent during the last call to
* computeMarginals.
*/
public int getMessagesUsedLastTime ()
{
return myMessagesSent - messagesSentAtStart;
}
protected void resetMessagesSentAtStart ()
{
messagesSentAtStart = myMessagesSent;
}
/**
* Array that maps (to, from) to the lambda message sent from node
* from to node to.
*/
transient private MessageArray messages;
transient private MessageArray oldMessages; // messages from variable --> factor
transient private Factor[] bel;
protected transient FactorGraph mdlCurrent;
private void retrieveCachedMessages (FactorGraph m)
{
messages = (MessageArray) m.getInferenceCache (getClass ());
}
private void cacheMessages (FactorGraph m)
{
m.setInferenceCache (getClass (), messages);
}
private void clearOldMessages ()
{
oldMessages = null;
}
final protected void copyOldMessages ()
{
clearOldMessages ();
oldMessages = messages.duplicate ();
}
final protected boolean hasConverged ()
{
return hasConverged (this.threshold);
}
final protected boolean hasConverged (double threshold)
{
double maxDiff = Double.NEGATIVE_INFINITY;
Factor bestOldMsg = null, bestNewMsg = null;
for (MessageArray.Iterator msgIt = oldMessages.iterator (); msgIt.hasNext ();) {
Factor oldMsg = (Factor) msgIt.next ();
Object from = msgIt.from ();
Object to = msgIt.to ();
Factor newMsg = messages.get (from, to);
if (oldMsg != null) {
assert newMsg != null
: "Message went from nonnull to null " + from + " --> " + to;
for (java.util.Iterator it = oldMsg.assignmentIterator (); it.hasNext ();) {
Assignment assn = (Assignment) it.next ();
double val1 = oldMsg.value (assn);
double val2 = newMsg.value (assn);
double diff = Math.abs (val1 - val2);
if (diff > threshold) {
if (diagnoseConvergence) {
System.err.println ("*** Not converged: Difference of : " + diff + " from " + oldMsg + " --> " + newMsg);
}
return false;
} else if (diff > maxDiff) {
maxDiff = diff;
bestOldMsg = oldMsg;
bestNewMsg = newMsg;
}
}
}
}
if (diagnoseConvergence) {
System.err.println (
"*** CONVERGED: Max absolute difference : " + maxDiff + " from " + bestOldMsg + " --> " + bestNewMsg);
}
return true;
}
private void initOldMessages (FactorGraph fg)
{
oldMessages = new MessageArray (fg);
if (useCaching && fg.getInferenceCache (getClass ()) != null) {
logger.info ("AsyncLoopyBP: Reusing previous marginals");
retrieveCachedMessages (fg);
copyOldMessages ();
} else {
for (java.util.Iterator it = fg.factorsIterator (); it.hasNext ();) {
Factor factor = (Factor) it.next ();
VarSet varset = factor.varSet ();
for (java.util.Iterator vit = varset.iterator (); vit.hasNext ();) {
Variable var = (Variable) vit.next ();
oldMessages.put (var, factor, new TableFactor (var));
oldMessages.put (factor, var, new TableFactor (var));
}
}
}
}
transient protected int assignedVertexPtls[];
protected void initForGraph (FactorGraph mdl)
{
mdlCurrent = mdl;
int numV = mdl.numVariables ();
bel = new Factor [numV];
Object cache = mdl.getInferenceCache (getClass ());
if (useCaching && (cache != null)) {
messages = (MessageArray) cache;
} else {
messages = new MessageArray (mdl);
/*
// setup self-messages for vertex potentials
for (Iterator it = mdl.getVerticesIterator (); it.hasNext ();) {
Variable var = (Variable) it.next ();
Factor ptl = mdl.factorOfVar (var);
if (ptl != null) {
if (inLogSpace) {
logger.finer ("BeliefPropagation: Using log space.");
setMessage (i, i, new LogTableFactor ((AbstractTableFactor) ptl));
} else {
setMessage (i, i, ptl);
}
}
}
*/
}
initOldMessages (mdl);
messager.setMessageArray (messages, oldMessages);
}
protected void sendMessage (FactorGraph mdl, Variable from, Factor to)
{
totalMessagesSent++; myMessagesSent++;
// System.err.println (GeneralUtils.classShortName (this)+" send message "+from+" --> "+to);
messager.sendMessage (mdl, from, to);
}
protected void sendMessage (FactorGraph mdl, Factor from, Variable to)
{
totalMessagesSent++; myMessagesSent++;
// System.err.println (GeneralUtils.classShortName (this)+" send message "+from+" --> "+to);
messager.sendMessage (mdl, from, to);
}
protected void doneWithGraph (FactorGraph mdl)
{
clearOldMessages (); // free up memory
if (useCaching) cacheMessages (mdl);
}
public int iterationsUsed () { return iterUsed; }
public interface MessageStrategy {
void setMessageArray (MessageArray msgs, MessageArray oldMsgs);
void sendMessage (FactorGraph mdl, Factor from, Variable to);
void sendMessage (FactorGraph mdl, Variable from, Factor to);
Factor msgProduct (Factor product, int idx, int excludeMsgFrom);
}
public abstract static class AbstractMessageStrategy implements MessageStrategy {
protected MessageArray messages;
protected MessageArray oldMessages;
public void setMessageArray (MessageArray msgs, MessageArray oldMsgs)
{
messages = msgs;
oldMessages = oldMsgs;
}
public Factor msgProduct (Factor product, int idx, int excludeMsgFrom)
{
if (product == null) {
product = createEmptyFactorForVar (idx);
}
for (MessageArray.ToMsgsIterator it = messages.toMessagesIterator (idx); it.hasNext ();) {
it.next ();
int j = it.currentFromIdx ();
Factor msg = it.currentMessage ();
if (j != excludeMsgFrom) {
product.multiplyBy (msg);
// assert product.varSet ().size () <= 2;
}
}
return product;
}
private Factor createEmptyFactorForVar (int idx)
{
Factor product;
if (messages.isInLogSpace ()) {
product = new LogTableFactor ((Variable) messages.idx2obj (idx));
} else {
product = new TableFactor ((Variable) messages.idx2obj (idx));
}
return product;
}
}
public static class SumProductMessageStrategy extends AbstractMessageStrategy implements Serializable {
private double damping = 1.0;
public SumProductMessageStrategy ()
{
}
public SumProductMessageStrategy (double damping)
{
this.damping = damping;
}
public void sendMessage (FactorGraph mdl, Factor from, Variable to)
{
int fromIdx = messages.getIndex (from);
int toIdx = messages.getIndex (to);
Factor product = from.duplicate ();
msgProduct (product, fromIdx, toIdx);
Factor msg = product.marginalize (to);
msg.normalize ();
if (logger.isLoggable (Level.FINEST)) {
logger.info ("MSG "+from+" --> "+to);
logger.info ("FACTOR: "+from.dumpToString());
logger.info ("MSG: "+msg.dumpToString ());
logger.info ("END MSG "+from+" --> "+to);
}
assert msg.varSet ().size () == 1;
assert msg.varSet ().contains (to);
makeDampedUpdate (fromIdx, toIdx, msg);
}
public void sendMessage (FactorGraph mdl, Variable from, Factor to)
{
// System.err.println ("...sum-prod message");
int fromIdx = messages.getIndex (from);
int toIdx = messages.getIndex (to);
Factor msg = msgProduct (null, fromIdx, toIdx);
msg.normalize ();
assert msg.varSet ().size () == 1;
assert msg.varSet ().contains (from);
messages.put (fromIdx, toIdx, msg);
}
private void makeDampedUpdate (int fromIdx, int toIdx, Factor msg)
{
if (damping < 1.0) {
// there's damping
Factor oldMsg = oldMessages.get (fromIdx, toIdx);
// Factor oldMsg = messages.get (fromIdx, toIdx);
if (oldMsg != null) {
AbstractTableFactor oldTbl = (AbstractTableFactor) oldMsg.duplicate ();
oldTbl.normalize ();
oldTbl.timesEquals (1 - damping);
AbstractTableFactor tbl = (AbstractTableFactor) msg;
tbl.timesEquals (damping);
tbl.plusEquals (oldTbl);
msg = tbl;
}
}
messages.put (fromIdx, toIdx, msg);
}
// Serialization
private static final long serialVersionUID = 1;
private static final int CUURENT_SERIAL_VERSION = 2;
private void writeObject (ObjectOutputStream out) throws IOException
{
out.defaultWriteObject ();
out.writeInt (CUURENT_SERIAL_VERSION);
out.writeDouble (damping);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.defaultReadObject ();
int version = in.readInt (); // version
if (2 <= version) {
damping = in.readDouble ();
}
}
}
public static class MaxProductMessageStrategy extends AbstractMessageStrategy implements Serializable {
public void sendMessage (FactorGraph mdl, Factor from, Variable to)
{
// System.err.println ("...max-prod message");
int fromIdx = messages.getIndex (from);
int toIdx = messages.getIndex (to);
Factor product = from.duplicate ();
msgProduct (product, fromIdx, toIdx);
Factor msg = product.extractMax (to);
msg.normalize ();
assert msg.varSet ().size () == 1;
assert msg.varSet ().contains (to);
messages.put (fromIdx, toIdx, msg);
}
public void sendMessage (FactorGraph mdl, Variable from, Factor to)
{
// System.err.println ("...max-prod message");
int fromIdx = messages.getIndex (from);
int toIdx = messages.getIndex (to);
Factor msg = msgProduct (null, fromIdx, toIdx);
msg.normalize ();
assert msg.varSet ().size () == 1;
assert msg.varSet ().contains (from);
messages.put (fromIdx, toIdx, msg);
}
// Serialization
private static final long serialVersionUID = 1;
private static final int CUURENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException
{
out.defaultWriteObject ();
out.writeInt (CUURENT_SERIAL_VERSION);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.defaultReadObject ();
in.readInt (); // version
}
}
public Factor lookupMarginal (Variable var)
{
int idx = mdlCurrent.getIndex (var);
if ((idx < 0) || (idx > bel.length)) {
throw new IllegalArgumentException ("Cannot find variable "+var+" in factor graph "+mdlCurrent);
}
if (bel[idx] == null) {
Factor marg = messager.msgProduct (null, idx, Integer.MIN_VALUE);
if (normalizeBeliefs) {
marg.normalize ();
}
assert marg.varSet ().size () == 1
:"Invalid marginal for var " + var + ": " + marg;
assert marg.varSet ().contains (var)
:"Invalid marginal for var " + var + ": " + marg;
bel[idx] = marg;
}
return bel[idx];
}
public void dump ()
{
messages.dump ();
}
public void reportTime ()
{
System.err.println ("AbstractBeliefPropagation: Total messages sent = "+totalMessagesSent);
}
public void dump (PrintWriter writer)
{
messages.dump (writer);
}
// }}}
// Serialization
private static final long serialVersionUID = 1;
// If seralization-incompatible changes are made to these classes,
// then smarts can be added to these methods for backward compatibility.
private void writeObject (ObjectOutputStream out) throws IOException
{
out.defaultWriteObject ();
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.defaultReadObject ();
}
public Factor lookupMarginal (VarSet c)
{
if (c.size () == 1) {
return lookupMarginal (c.get (0));
} else {
List factors = mdlCurrent.allFactorsOf (c);
if (factors.isEmpty ()) {
throw new UnsupportedOperationException
("Cannot compute marginal of " + c + ": Must be either a single variable or a factor in the graph.");
}
return lookupMarginal (c, factors);
}
}
private Factor lookupMarginal (VarSet vs, List factors)
{
Factor marginal = Factors.multiplyAll (factors);
for (Iterator fit = factors.iterator (); fit.hasNext ();) {
Factor factor = (Factor) fit.next ();
for (java.util.Iterator it = vs.iterator (); it.hasNext ();) {
Variable var = (Variable) it.next ();
Factor msg = messages.get (var, factor);
if (msg != null) { // if the inferencer was stopped early, there may be no message
marginal.multiplyBy (msg);
}
}
}
marginal.normalize ();
return marginal;
}
public double lookupLogJoint (Assignment assn)
{
double accum = 0.0;
// Compute using BP-factorization
// prod_s (p(x_s))^-(deg(s)-1) * ...
for (java.util.Iterator it = mdlCurrent.variablesIterator (); it.hasNext ();) {
Variable var = (Variable) it.next ();
Factor ptl = lookupMarginal (var);
int deg = mdlCurrent.getDegree (var);
if (deg != 1)
// Note that below works correctly for degree 0!
{
accum -= (deg - 1) * ptl.logValue (assn);
}
}
// ... * prod_{c} p(x_C)
for (java.util.Iterator it = mdlCurrent.varSetIterator (); it.hasNext ();) {
VarSet varSet = (VarSet) it.next ();
Factor p12 = lookupMarginal (varSet);
double logphi = p12.logValue (assn);
accum += logphi;
}
return accum;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy