
cc.mallet.grmm.inference.TRP Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mallet Show documentation
Show all versions of mallet Show documentation
MALLET is a Java-based package for statistical natural language processing,
document classification, clustering, topic modeling, information extraction,
and other machine learning applications to text.
The newest version!
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.inference;
import gnu.trove.THashSet;
import gnu.trove.THashMap;
import gnu.trove.TIntObjectHashMap;
import java.util.logging.Logger;
import java.util.logging.Level;
import java.util.*;
import java.io.*;
import org._3pq.jgrapht.UndirectedGraph;
import org._3pq.jgrapht.Graph;
import org._3pq.jgrapht.Edge;
import org._3pq.jgrapht.traverse.BreadthFirstIterator;
import org._3pq.jgrapht.graph.SimpleGraph;
import org.jdom.Document;
import org.jdom.JDOMException;
import org.jdom.Element;
import org.jdom.input.SAXBuilder;
import cc.mallet.grmm.types.*;
import cc.mallet.util.MalletLogger;
/**
* Implementation of Wainwright's TRP schedule for loopy BP
* in general graphical models.
*
* @author Charles Sutton
* @version $Id: TRP.java,v 1.1 2007/10/22 21:37:49 mccallum Exp $
*/
public class TRP extends AbstractBeliefPropagation {
private static Logger logger = MalletLogger.getLogger (TRP.class.getName ());
private static final boolean reportSpanningTrees = false;
private TreeFactory factory;
private TerminationCondition terminator;
private Random random = new Random ();
/* Make sure that we've included all edges before we terminate. */
transient private TIntObjectHashMap factorTouched;
transient private boolean hasConverged;
transient private File verboseOutputDirectory = null;
public TRP ()
{
this (null, null);
}
public TRP (TreeFactory f)
{
this (f, null);
}
public TRP (TerminationCondition cond)
{
this (null, cond);
}
public TRP (TreeFactory f, TerminationCondition cond)
{
factory = f;
terminator = cond;
}
public static TRP createForMaxProduct ()
{
TRP trp = new TRP ();
trp.setMessager (new MaxProductMessageStrategy ());
return trp;
}
// Accessors
public TRP setTerminator (TerminationCondition cond)
{
terminator = cond;
return this;
}
public TRP setFactory (TreeFactory factory)
{
this.factory = factory;
return this;
}
// xxx should this be static?
public void setRandomSeed (long seed) { random = new Random (seed); }
public void setVerboseOutputDirectory (File verboseOutputDirectory)
{
this.verboseOutputDirectory = verboseOutputDirectory;
}
public boolean isConverged () { return hasConverged; }
protected void initForGraph (FactorGraph m)
{
super.initForGraph (m);
int numNodes = m.numVariables ();
factorTouched = new TIntObjectHashMap (numNodes);
hasConverged = false;
if (factory == null) {
factory = new AlmostRandomTreeFactory ();
}
if (terminator == null) {
terminator = new DefaultConvergenceTerminator ();
} else {
terminator.reset ();
}
}
private static cc.mallet.grmm.types.Tree graphToTree (Graph g) throws Exception
{
// Perhaps handle gracefully?? -cas
if (g.vertexSet ().size () <= 0) {
throw new RuntimeException ("Empty graph.");
}
Tree tree = new cc.mallet.grmm.types.Tree ();
Object root = g.vertexSet ().iterator ().next ();
tree.add (root);
for (Iterator it1 = new BreadthFirstIterator (g, root); it1.hasNext();) {
Object v1 = it1.next ();
for (Iterator it2 = g.edgesOf (v1).iterator (); it2.hasNext ();) {
Edge edge = (Edge) it2.next ();
Object v2 = edge.oppositeVertex (v1);
if (tree.getParent (v1) != v2) {
tree.addNode (v1, v2);
assert tree.getParent (v2) == v1;
}
}
}
return tree;
}
/**
* Interface for tree-generation strategies for TRP.
*
* TRP works by repeatedly doing exact inference over spanning tree
* of the original graph. But the trees can be chosen arbitrarily.
* In fact, they don't need to be spanning trees; any acyclic
* substructure will do. Users of TRP can tell it which strategy
* to use by passing in an implementation of TreeFactory.
*/
public interface TreeFactory extends Serializable {
public cc.mallet.grmm.types.Tree nextTree (FactorGraph mdl);
}
// This works around what appears to be a bug in OpenJGraph
// connected sets.
private static class SimpleUnionFind {
private Map obj2set = new THashMap ();
private Set findSet (Object obj)
{
Set container = (Set) obj2set.get (obj);
if (container != null) {
return container;
} else {
Set newSet = new THashSet ();
newSet.add (obj);
obj2set.put (obj, newSet);
return newSet;
}
}
private void union (Object obj1, Object obj2)
{
Set set1 = findSet (obj1);
Set set2 = findSet (obj2);
set1.addAll (set2);
for (Iterator it = set2.iterator (); it.hasNext ();) {
Object obj = it.next ();
obj2set.put (obj, set1);
}
}
public boolean noPairConnected (VarSet varSet)
{
for (int i = 0; i < varSet.size (); i++) {
for (int j = i + 1; j < varSet.size (); j++) {
Variable v1 = varSet.get (i);
Variable v2 = varSet.get (j);
if (findSet (v1) == findSet (v2)) {
return false;
}
}
}
return true;
}
public void unionAll (Factor factor)
{
VarSet varSet = factor.varSet ();
for (int i = 0; i < varSet.size (); i++) {
Variable var = varSet.get (i);
union (var, factor);
}
}
}
/**
* Always adds edges that have not been touched, after that
* adds random edges.
*/
public class AlmostRandomTreeFactory implements TreeFactory {
public Tree nextTree (FactorGraph fullGraph)
{
SimpleUnionFind unionFind = new SimpleUnionFind ();
ArrayList edges = new ArrayList (fullGraph.factors ());
ArrayList goodEdges = new ArrayList (fullGraph.numVariables ());
Collections.shuffle (edges, random);
// First add all edges that haven't been used so far
try {
for (Iterator it = edges.iterator (); it.hasNext ();) {
Factor factor = (Factor) it.next ();
VarSet varSet = factor.varSet ();
if (!isFactorTouched (factor) && unionFind.noPairConnected (varSet)) {
goodEdges.add (factor);
unionFind.unionAll (factor);
it.remove ();
}
}
// Now add as many other edges as possible
for (Iterator it = edges.iterator (); it.hasNext ();) {
Factor factor = (Factor) it.next ();
VarSet varSet = factor.varSet ();
if (unionFind.noPairConnected (varSet)) {
goodEdges.add (factor);
unionFind.unionAll (factor);
}
}
for (Iterator it = goodEdges.iterator (); it.hasNext ();) {
Factor factor = (Factor) it.next ();
touchFactor (factor);
}
UndirectedGraph g = new SimpleGraph ();
for (Iterator it = fullGraph.variablesIterator (); it.hasNext ();) {
Variable var = (Variable) it.next ();
g.addVertex (var);
}
for (Iterator it = goodEdges.iterator (); it.hasNext ();) {
Factor factor = (Factor) it.next ();
g.addVertex (factor);
for (Iterator vit = factor.varSet ().iterator (); vit.hasNext ();) {
Variable var = (Variable) vit.next ();
g.addEdge (factor, var);
}
}
Tree tree = graphToTree (g);
if (reportSpanningTrees) {
System.out.println ("********* SPANNING TREE *************");
System.out.println (tree.dumpToString ());
System.out.println ("********* END TREE *************");
}
return tree;
} catch (Exception e) {
e.printStackTrace ();
throw new RuntimeException (e);
}
}
private static final long serialVersionUID = -7461763414516915264L;
}
/**
* Generates spanning trees cyclically from a predefined collection.
*/
static public class TreeListFactory implements TreeFactory {
private List lst;
private Iterator it;
public TreeListFactory (List l)
{
lst = l;
it = lst.iterator ();
}
public TreeListFactory (cc.mallet.grmm.types.Tree[] arr)
{
lst = new ArrayList (java.util.Arrays.asList (arr));
it = lst.iterator ();
}
public static TreeListFactory makeFromReaders (FactorGraph fg, List readerList)
{
List treeList = new ArrayList ();
for (Iterator it = readerList.iterator (); it.hasNext ();) {
try {
Reader reader = (Reader) it.next ();
Document doc = new SAXBuilder ().build (reader);
Element treeElt = doc.getRootElement ();
Element rootElt = (Element) treeElt.getChildren ().get (0);
Tree tree = readTreeRec (fg, rootElt);
System.out.println (tree.dumpToString ());
treeList.add (tree);
} catch (JDOMException e) {
throw new RuntimeException (e);
} catch (IOException e) {
throw new RuntimeException (e);
}
}
return new TreeListFactory (treeList);
}
/** @param fileList List of File objects. Each file should be an XML document describing a tree. */
public static TreeListFactory readFromFiles (FactorGraph fg, List fileList)
{
List treeList = new ArrayList ();
for (Iterator it = fileList.iterator (); it.hasNext ();) {
try {
File treeFile = (File) it.next ();
Document doc = new SAXBuilder ().build (treeFile);
Element treeElt = doc.getRootElement ();
Element rootElt = (Element) treeElt.getChildren ().get (0);
treeList. add (readTreeRec (fg, rootElt));
} catch (JDOMException e) {
throw new RuntimeException (e);
} catch (IOException e) {
throw new RuntimeException (e);
}
}
return new TreeListFactory (treeList);
}
private static Tree readTreeRec (FactorGraph fg, Element elt)
{
List subtrees = new ArrayList ();
for (Iterator it = elt.getChildren ().iterator (); it.hasNext ();) {
Element child = (Element) it.next ();
Tree subtree = readTreeRec (fg, child);
subtrees.add (subtree);
}
Object parent = objFromElt (fg, elt);
return Tree.makeFromSubtree (parent, subtrees);
}
private static Object objFromElt (FactorGraph fg, Element elt)
{
String type = elt.getName ();
if (type.equals ("VAR")) {
String vname = elt.getAttributeValue ("NAME");
return fg.findVariable (vname);
} else if (type.equals("FACTOR")) {
String varSetStr = elt.getAttributeValue ("VARS");
String[] vnames = varSetStr.split ("\\s+");
Variable[] vars = new Variable [vnames.length];
for (int i = 0; i < vnames.length; i++) {
vars[i] = fg.findVariable (vnames[i]);
}
return fg.factorOf (new HashVarSet (vars));
} else {
throw new RuntimeException ("Can't figure out element "+elt);
}
}
public cc.mallet.grmm.types.Tree nextTree (FactorGraph mdl)
{
// If no more trees, rewind.
if (!it.hasNext ()) {
it = lst.iterator ();
}
return (cc.mallet.grmm.types.Tree) it.next ();
}
}
// Termination conditions
// will this need to be subclassed from outside? Will such
// subclasses need access to the private state of TRP?
static public interface TerminationCondition extends Cloneable, Serializable {
// This takes the instances of trp as a parameter so that if a
// TRP instance is cloned, and the terminator copied over, it
// will still work.
public boolean shouldContinue (TRP trp);
public void reset ();
// boy do I hate Java cloning
public Object clone () throws CloneNotSupportedException;
}
static public class IterationTerminator implements TerminationCondition {
int current;
int max;
public void reset () { current = 0; }
public IterationTerminator (int m)
{
max = m;
reset ();
}
public boolean shouldContinue (TRP trp)
{
current++;
if (current >= max) {
logger.finest ("***TRP quitting: Iteration " + current + " >= " + max);
}
return current <= max;
}
public Object clone () throws CloneNotSupportedException
{
return super.clone ();
}
}
//xxx Delta is currently ignored.
public static class ConvergenceTerminator implements TerminationCondition {
double delta = 0.01;
public ConvergenceTerminator () {}
public ConvergenceTerminator (double delta) { this.delta = delta; }
public void reset ()
{
}
public boolean shouldContinue (TRP trp)
{
/*
if (oldMessages != null)
retval = !checkForConvergence (trp);
copyMessages(trp);
return retval;
*/
boolean retval = !trp.hasConverged (delta);
trp.copyOldMessages ();
return retval;
}
public Object clone () throws CloneNotSupportedException
{
return super.clone ();
}
}
// Runs until convergence, but doesn't stop until all edges have
// been used at least once, and always stops after 1000 iterations.
public static class DefaultConvergenceTerminator implements TerminationCondition {
ConvergenceTerminator cterminator;
IterationTerminator iterminator;
String msg;
public DefaultConvergenceTerminator () { this (0.001, 1000); }
public DefaultConvergenceTerminator (double delta, int maxIter)
{
cterminator = new ConvergenceTerminator (delta);
iterminator = new IterationTerminator (maxIter);
msg = "***TRP quitting: over " + maxIter + " iterations";
}
public void reset ()
{
iterminator.reset ();
cterminator.reset ();
}
// Terminate if converged or at insanely high # of iterations
public boolean shouldContinue (TRP trp)
{
boolean notAllTouched = !trp.allEdgesTouched ();
if (!iterminator.shouldContinue (trp)) {
logger.warning (msg);
if (notAllTouched) {
logger.warning ("***TRP warning: Not all edges used!");
}
return false;
}
if (notAllTouched) {
return true;
} else {
return cterminator.shouldContinue (trp);
}
}
public Object clone () throws CloneNotSupportedException
{
DefaultConvergenceTerminator dup = (DefaultConvergenceTerminator)
super.clone ();
dup.iterminator = (IterationTerminator) iterminator.clone ();
dup.cterminator = (ConvergenceTerminator) cterminator.clone ();
return dup;
}
}
// And now, the heart of TRP:
public void computeMarginals (FactorGraph m)
{
resetMessagesSentAtStart ();
initForGraph (m);
int iter = 0;
while (terminator.shouldContinue (this)) {
logger.finer ("TRP iteration " + (iter++));
cc.mallet.grmm.types.Tree tree = factory.nextTree (m);
propagate (tree);
dumpForIter (iter, tree);
}
iterUsed = iter;
logger.info ("TRP used " + iter + " iterations.");
doneWithGraph (m);
}
private void dumpForIter (int iter, Tree tree)
{
if (verboseOutputDirectory != null) {
try {
// output messages
FileWriter writer = new FileWriter (new File (verboseOutputDirectory, "iter" + iter + ".txt"));
dump (new PrintWriter (writer, true));
writer.close ();
FileWriter bfWriter = new FileWriter (new File (verboseOutputDirectory, "beliefs" + iter + ".txt"));
dumpBeliefs (new PrintWriter (bfWriter, true));
bfWriter.close ();
// output spanning tree
FileWriter treeWriter = new FileWriter (new File (verboseOutputDirectory, "tree" + iter + ".txt"));
treeWriter.write (tree.toString ());
treeWriter.write ("\n");
treeWriter.close ();
} catch (IOException e) {
e.printStackTrace ();
}
}
}
private void dumpBeliefs (PrintWriter writer)
{
for (int vi = 0; vi < mdlCurrent.numVariables (); vi++) {
Variable var = mdlCurrent.get (vi);
Factor mrg = lookupMarginal (var);
writer.println (mrg.dumpToString ());
writer.println ();
}
}
private void propagate (cc.mallet.grmm.types.Tree tree)
{
Object root = tree.getRoot ();
lambdaPropagation (tree, root);
piPropagation (tree, root);
}
/** Sends BP messages starting from children to parents. This version uses constant stack space. */
private void lambdaPropagation (cc.mallet.grmm.types.Tree tree, Object root)
{
LinkedList openList = new LinkedList ();
LinkedList closedList = new LinkedList ();
openList.addAll (tree.getChildren (root));
while (!openList.isEmpty ()) {
Object var = openList.removeFirst ();
openList.addAll (tree.getChildren (var));
closedList.addFirst (var);
}
// Now open list contains all of the nodes (except the root) in reverse topological order. Send the messages.
for (Iterator it = closedList.iterator (); it.hasNext ();) {
Object child = it.next ();
Object parent = tree.getParent (child);
sendMessage (mdlCurrent, child, parent);
}
}
/** Sends BP messages starting from parents to children. This version uses constant stack space. */
private void piPropagation (cc.mallet.grmm.types.Tree tree, Object root)
{
LinkedList openList = new LinkedList ();
openList.add (root);
while (!openList.isEmpty ()) {
Object current = openList.removeFirst ();
List children = tree.getChildren (current);
for (Iterator it = children.iterator (); it.hasNext ();) {
Object child = it.next ();
sendMessage (mdlCurrent, current, child);
openList.add (child);
}
}
}
private void sendMessage (FactorGraph fg, Object parent, Object child)
{
if (logger.isLoggable (Level.FINER)) logger.finer ("Sending message: "+parent+" --> "+child);
if (parent instanceof Factor) {
sendMessage (fg, (Factor) parent, (Variable) child);
} else if (parent instanceof Variable) {
sendMessage (fg, (Variable) parent, (Factor) child);
}
}
private boolean allEdgesTouched ()
{
Iterator it = mdlCurrent.factorsIterator ();
while (it.hasNext ()) {
Factor factor = (Factor) it.next ();
int idx = mdlCurrent.getIndex (factor);
int numTouches = getNumTouches (idx);
if (numTouches == 0) {
logger.finest ("***TRP continuing: factor " + idx
+ " not touched.");
return false;
}
}
return true;
}
private void touchFactor (Factor factor)
{
int idx = mdlCurrent.getIndex (factor);
incrementTouches (idx);
}
private boolean isFactorTouched (Factor factor)
{
int idx1 = mdlCurrent.getIndex (factor);
return (getNumTouches (idx1) > 0);
}
private int getNumTouches (int idx1)
{
Integer integer = (Integer) factorTouched.get (idx1);
return (integer == null) ? 0 : integer.intValue ();
}
private void incrementTouches (int idx1)
{
int nt = getNumTouches (idx1);
factorTouched.put (idx1, new Integer (nt + 1));
}
public Factor query (DirectedModel m, Variable var)
{
throw new UnsupportedOperationException
("GRMM doesn't yet do directed models.");
}
//xxx could get moved up to AbstractInferencer, if mdlCurrent did.
public Assignment bestAssignment ()
{
int[] outcomes = new int [mdlCurrent.numVariables ()];
for (int i = 0; i < outcomes.length; i++) {
Variable var = mdlCurrent.get (i);
TableFactor ptl = (TableFactor) lookupMarginal (var);
outcomes[i] = ptl.argmax ();
}
return new Assignment (mdlCurrent, outcomes);
}
// Deep copy termination condition
public Object clone ()
{
try {
TRP dup = (TRP) super.clone ();
if (terminator != null) {
dup.terminator = (TerminationCondition) terminator.clone ();
}
return dup;
} catch (CloneNotSupportedException e) {
// should never happen
throw new RuntimeException (e);
}
}
// Serialization
private static final long serialVersionUID = 1;
// If seralization-incompatible changes are made to these classes,
// then smarts can be added to these methods for backward compatibility.
private void writeObject (ObjectOutputStream out) throws IOException
{
out.defaultWriteObject ();
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.defaultReadObject ();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy