
cc.mallet.grmm.inference.JunctionTreePropagation Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jcore-mallet-2.0.9 Show documentation
Show all versions of jcore-mallet-2.0.9 Show documentation
MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.
The newest version!
/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://mallet.cs.umass.edu/
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.inference;
import java.util.Collection;
import java.util.Iterator;
import java.util.logging.Logger;
import java.util.logging.Level;
import java.io.Serializable;
import java.io.ObjectOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.util.MalletLogger;
/**
* An implementation of Hugin-style propagation for junction trees.
* This destructively modifies the junction tree so that its clique potentials
* are the true marginals of the underlying graph.
*
* End users will not usually need to use this class directly. Use
* JunctionTreeInferencer instead.
*
* This class is not an instance of Inferencer because it destructively
* modifies the junction tree, which the Inferencer methods do not do to
* factor graphs.
*
* Created: Feb 1, 2006
*
* @author " + child);
for (Iterator it = jt.getChildren (child).iterator (); it.hasNext ();) {
VarSet gchild = (VarSet) it.next ();
collectEvidence (jt, child, gchild);
}
if (parent != null) {
totalMessagesSent++;
strategy.sendMessage (jt, child, parent);
}
}
// top-down pass
private void distributeEvidence (JunctionTree jt, VarSet parent)
{
for (Iterator it = jt.getChildren (parent).iterator (); it.hasNext ();) {
VarSet child = (VarSet) it.next ();
totalMessagesSent++;
strategy.sendMessage (jt, parent, child);
distributeEvidence (jt, child);
}
}
private void propagate (JunctionTree jt)
{
VarSet root = (VarSet) jt.getRoot ();
collectEvidence (jt, null, root);
distributeEvidence (jt, root);
}
public Factor lookupMarginal (JunctionTree jt, VarSet varSet)
{
if (jt == null) { throw new IllegalStateException ("Call computeMarginals() first."); }
VarSet parent = jt.findParentCluster (varSet);
if (parent == null) {
throw new UnsupportedOperationException
("No parent cluster in " + jt + " for clique " + varSet);
}
Factor cpf = jt.getCPF (parent);
if (logger.isLoggable (Level.FINER)) {
logger.finer ("Lookup jt marginal: clique " + varSet + " cluster " + parent);
logger.finest (" cpf " + cpf);
}
Factor marginal = strategy.extractBelief (cpf, varSet);
marginal.normalize ();
return marginal;
}
public Factor lookupMarginal (JunctionTree jt, Variable var)
{
if (jt == null) { throw new IllegalStateException ("Call computeMarginals() first."); }
VarSet parent = jt.findParentCluster (var);
Factor cpf = jt.getCPF (parent);
if (logger.isLoggable (Level.FINER)) {
logger.finer ("Lookup jt marginal: var " + var + " cluster " + parent);
logger.finest (" cpf " + cpf);
}
Factor marginal = strategy.extractBelief (cpf, new HashVarSet (new Variable[] { var }));
marginal.normalize ();
return marginal;
}
///////////////////////////////////////////////////////////////////////////
// MEESAGE STRATEGIES
///////////////////////////////////////////////////////////////////////////
/**
* Implements a strategy pattern for message sending. This allows sum-product
* and max-product messages, e.g., to be different implementations of this strategy.
*/
public interface MessageStrategy {
/**
* Sends a message from the clique FROM to TO in a junction tree.
*/
public void sendMessage (JunctionTree jt, VarSet from, VarSet to);
public Factor extractBelief (Factor cpf, VarSet varSet);
}
public static class SumProductMessageStrategy implements MessageStrategy, Serializable {
/**
* This sends a sum-product message, normalized to avoid
* underflow.
*/
public void sendMessage (JunctionTree jt, VarSet from, VarSet to)
{
Collection sepset = jt.getSepset (from, to);
Factor fromCpf = jt.getCPF (from);
Factor toCpf = jt.getCPF (to);
Factor oldSepsetPot = jt.getSepsetPot (from, to);
Factor lambda = fromCpf.marginalize (sepset);
lambda.normalize ();
jt.setSepsetPot (lambda, from, to);
toCpf = toCpf.multiply (lambda);
toCpf.divideBy (oldSepsetPot);
toCpf.normalize ();
jt.setCPF (to, toCpf);
}
public Factor extractBelief (Factor cpf, VarSet varSet)
{
return cpf.marginalize (varSet);
}
// Serialization
private static final long serialVersionUID = 1;
private static final int CUURENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException
{
out.defaultWriteObject ();
out.writeInt (CUURENT_SERIAL_VERSION);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.defaultReadObject ();
in.readInt (); // version
}
}
public static class MaxProductMessageStrategy implements MessageStrategy, Serializable {
/**
* This sends a max-product message.
*/
public void sendMessage (JunctionTree jt, VarSet from, VarSet to)
{
// System.err.println ("Send message "+from+" --> "+to);
Collection sepset = jt.getSepset (from, to);
Factor fromCpf = jt.getCPF (from);
Factor toCpf = jt.getCPF (to);
Factor oldSepsetPot = jt.getSepsetPot (from, to);
Factor lambda = fromCpf.extractMax (sepset);
lambda.normalize ();
jt.setSepsetPot (lambda, from, to);
toCpf = toCpf.multiply (lambda);
toCpf.divideBy (oldSepsetPot);
toCpf.normalize ();
jt.setCPF (to, toCpf);
}
public Factor extractBelief (Factor cpf, VarSet varSet)
{
return cpf.extractMax (varSet);
}
// Serialization
private static final long serialVersionUID = 1;
private static final int CUURENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException
{
out.defaultWriteObject ();
out.writeInt (CUURENT_SERIAL_VERSION);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.defaultReadObject ();
in.readInt (); // version
}
}
// Serialization
private static final long serialVersionUID = 1;
private static final int CUURENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException
{
out.defaultWriteObject ();
out.writeInt (CUURENT_SERIAL_VERSION);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.defaultReadObject ();
in.readInt (); // version
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy