cc.mallet.grmm.inference.JunctionTreePropagation Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
The newest version!
/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://mallet.cs.umass.edu/
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.inference;
import java.util.Collection;
import java.util.Iterator;
import java.util.logging.Logger;
import java.util.logging.Level;
import java.io.Serializable;
import java.io.ObjectOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.util.MalletLogger;
/**
* An implementation of Hugin-style propagation for junction trees.
* This destructively modifies the junction tree so that its clique potentials
* are the true marginals of the underlying graph.
*
* End users will not usually need to use this class directly. Use
* JunctionTreeInferencer instead.
*
* This class is not an instance of Inferencer because it destructively
* modifies the junction tree, which the Inferencer methods do not do to
* factor graphs.
*
* Created: Feb 1, 2006
*
* @author " + child);
for (Iterator it = jt.getChildren (child).iterator (); it.hasNext ();) {
VarSet gchild = (VarSet) it.next ();
collectEvidence (jt, child, gchild);
}
if (parent != null) {
totalMessagesSent++;
strategy.sendMessage (jt, child, parent);
}
}
// top-down pass
private void distributeEvidence (JunctionTree jt, VarSet parent)
{
for (Iterator it = jt.getChildren (parent).iterator (); it.hasNext ();) {
VarSet child = (VarSet) it.next ();
totalMessagesSent++;
strategy.sendMessage (jt, parent, child);
distributeEvidence (jt, child);
}
}
private void propagate (JunctionTree jt)
{
VarSet root = (VarSet) jt.getRoot ();
collectEvidence (jt, null, root);
distributeEvidence (jt, root);
}
public Factor lookupMarginal (JunctionTree jt, VarSet varSet)
{
if (jt == null) { throw new IllegalStateException ("Call computeMarginals() first."); }
VarSet parent = jt.findParentCluster (varSet);
if (parent == null) {
throw new UnsupportedOperationException
("No parent cluster in " + jt + " for clique " + varSet);
}
Factor cpf = jt.getCPF (parent);
if (logger.isLoggable (Level.FINER)) {
logger.finer ("Lookup jt marginal: clique " + varSet + " cluster " + parent);
logger.finest (" cpf " + cpf);
}
Factor marginal = strategy.extractBelief (cpf, varSet);
marginal.normalize ();
return marginal;
}
public Factor lookupMarginal (JunctionTree jt, Variable var)
{
if (jt == null) { throw new IllegalStateException ("Call computeMarginals() first."); }
VarSet parent = jt.findParentCluster (var);
Factor cpf = jt.getCPF (parent);
if (logger.isLoggable (Level.FINER)) {
logger.finer ("Lookup jt marginal: var " + var + " cluster " + parent);
logger.finest (" cpf " + cpf);
}
Factor marginal = strategy.extractBelief (cpf, new HashVarSet (new Variable[] { var }));
marginal.normalize ();
return marginal;
}
///////////////////////////////////////////////////////////////////////////
// MEESAGE STRATEGIES
///////////////////////////////////////////////////////////////////////////
/**
* Implements a strategy pattern for message sending. This allows sum-product
* and max-product messages, e.g., to be different implementations of this strategy.
*/
public interface MessageStrategy {
/**
* Sends a message from the clique FROM to TO in a junction tree.
*/
public void sendMessage (JunctionTree jt, VarSet from, VarSet to);
public Factor extractBelief (Factor cpf, VarSet varSet);
}
public static class SumProductMessageStrategy implements MessageStrategy, Serializable {
/**
* This sends a sum-product message, normalized to avoid
* underflow.
*/
public void sendMessage (JunctionTree jt, VarSet from, VarSet to)
{
Collection sepset = jt.getSepset (from, to);
Factor fromCpf = jt.getCPF (from);
Factor toCpf = jt.getCPF (to);
Factor oldSepsetPot = jt.getSepsetPot (from, to);
Factor lambda = fromCpf.marginalize (sepset);
lambda.normalize ();
jt.setSepsetPot (lambda, from, to);
toCpf = toCpf.multiply (lambda);
toCpf.divideBy (oldSepsetPot);
toCpf.normalize ();
jt.setCPF (to, toCpf);
}
public Factor extractBelief (Factor cpf, VarSet varSet)
{
return cpf.marginalize (varSet);
}
// Serialization
private static final long serialVersionUID = 1;
private static final int CUURENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException
{
out.defaultWriteObject ();
out.writeInt (CUURENT_SERIAL_VERSION);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.defaultReadObject ();
in.readInt (); // version
}
}
public static class MaxProductMessageStrategy implements MessageStrategy, Serializable {
/**
* This sends a max-product message.
*/
public void sendMessage (JunctionTree jt, VarSet from, VarSet to)
{
// System.err.println ("Send message "+from+" --> "+to);
Collection sepset = jt.getSepset (from, to);
Factor fromCpf = jt.getCPF (from);
Factor toCpf = jt.getCPF (to);
Factor oldSepsetPot = jt.getSepsetPot (from, to);
Factor lambda = fromCpf.extractMax (sepset);
lambda.normalize ();
jt.setSepsetPot (lambda, from, to);
toCpf = toCpf.multiply (lambda);
toCpf.divideBy (oldSepsetPot);
toCpf.normalize ();
jt.setCPF (to, toCpf);
}
public Factor extractBelief (Factor cpf, VarSet varSet)
{
return cpf.extractMax (varSet);
}
// Serialization
private static final long serialVersionUID = 1;
private static final int CUURENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException
{
out.defaultWriteObject ();
out.writeInt (CUURENT_SERIAL_VERSION);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.defaultReadObject ();
in.readInt (); // version
}
}
// Serialization
private static final long serialVersionUID = 1;
private static final int CUURENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException
{
out.defaultWriteObject ();
out.writeInt (CUURENT_SERIAL_VERSION);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.defaultReadObject ();
in.readInt (); // version
}
}