All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.grmm.inference.JunctionTreePropagation Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://mallet.cs.umass.edu/
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.inference;


import java.util.Collection;
import java.util.Iterator;
import java.util.logging.Logger;
import java.util.logging.Level;
import java.io.Serializable;
import java.io.ObjectOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;

import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.util.MalletLogger;

/**
 * An implementation of Hugin-style propagation for junction trees.
 * This destructively modifies the junction tree so that its clique potentials
 * are the true marginals of the underlying graph.
 * 

* End users will not usually need to use this class directly. Use * JunctionTreeInferencer instead. *

* This class is not an instance of Inferencer because it destructively * modifies the junction tree, which the Inferencer methods do not do to * factor graphs. *

* Created: Feb 1, 2006 * * @author " + child); for (Iterator it = jt.getChildren (child).iterator (); it.hasNext ();) { VarSet gchild = (VarSet) it.next (); collectEvidence (jt, child, gchild); } if (parent != null) { totalMessagesSent++; strategy.sendMessage (jt, child, parent); } } // top-down pass private void distributeEvidence (JunctionTree jt, VarSet parent) { for (Iterator it = jt.getChildren (parent).iterator (); it.hasNext ();) { VarSet child = (VarSet) it.next (); totalMessagesSent++; strategy.sendMessage (jt, parent, child); distributeEvidence (jt, child); } } private void propagate (JunctionTree jt) { VarSet root = (VarSet) jt.getRoot (); collectEvidence (jt, null, root); distributeEvidence (jt, root); } public Factor lookupMarginal (JunctionTree jt, VarSet varSet) { if (jt == null) { throw new IllegalStateException ("Call computeMarginals() first."); } VarSet parent = jt.findParentCluster (varSet); if (parent == null) { throw new UnsupportedOperationException ("No parent cluster in " + jt + " for clique " + varSet); } Factor cpf = jt.getCPF (parent); if (logger.isLoggable (Level.FINER)) { logger.finer ("Lookup jt marginal: clique " + varSet + " cluster " + parent); logger.finest (" cpf " + cpf); } Factor marginal = strategy.extractBelief (cpf, varSet); marginal.normalize (); return marginal; } public Factor lookupMarginal (JunctionTree jt, Variable var) { if (jt == null) { throw new IllegalStateException ("Call computeMarginals() first."); } VarSet parent = jt.findParentCluster (var); Factor cpf = jt.getCPF (parent); if (logger.isLoggable (Level.FINER)) { logger.finer ("Lookup jt marginal: var " + var + " cluster " + parent); logger.finest (" cpf " + cpf); } Factor marginal = strategy.extractBelief (cpf, new HashVarSet (new Variable[] { var })); marginal.normalize (); return marginal; } /////////////////////////////////////////////////////////////////////////// // MEESAGE STRATEGIES /////////////////////////////////////////////////////////////////////////// /** * Implements a strategy pattern for message sending. This allows sum-product * and max-product messages, e.g., to be different implementations of this strategy. */ public interface MessageStrategy { /** * Sends a message from the clique FROM to TO in a junction tree. */ public void sendMessage (JunctionTree jt, VarSet from, VarSet to); public Factor extractBelief (Factor cpf, VarSet varSet); } public static class SumProductMessageStrategy implements MessageStrategy, Serializable { /** * This sends a sum-product message, normalized to avoid * underflow. */ public void sendMessage (JunctionTree jt, VarSet from, VarSet to) { Collection sepset = jt.getSepset (from, to); Factor fromCpf = jt.getCPF (from); Factor toCpf = jt.getCPF (to); Factor oldSepsetPot = jt.getSepsetPot (from, to); Factor lambda = fromCpf.marginalize (sepset); lambda.normalize (); jt.setSepsetPot (lambda, from, to); toCpf = toCpf.multiply (lambda); toCpf.divideBy (oldSepsetPot); toCpf.normalize (); jt.setCPF (to, toCpf); } public Factor extractBelief (Factor cpf, VarSet varSet) { return cpf.marginalize (varSet); } // Serialization private static final long serialVersionUID = 1; private static final int CUURENT_SERIAL_VERSION = 1; private void writeObject (ObjectOutputStream out) throws IOException { out.defaultWriteObject (); out.writeInt (CUURENT_SERIAL_VERSION); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { in.defaultReadObject (); in.readInt (); // version } } public static class MaxProductMessageStrategy implements MessageStrategy, Serializable { /** * This sends a max-product message. */ public void sendMessage (JunctionTree jt, VarSet from, VarSet to) { // System.err.println ("Send message "+from+" --> "+to); Collection sepset = jt.getSepset (from, to); Factor fromCpf = jt.getCPF (from); Factor toCpf = jt.getCPF (to); Factor oldSepsetPot = jt.getSepsetPot (from, to); Factor lambda = fromCpf.extractMax (sepset); lambda.normalize (); jt.setSepsetPot (lambda, from, to); toCpf = toCpf.multiply (lambda); toCpf.divideBy (oldSepsetPot); toCpf.normalize (); jt.setCPF (to, toCpf); } public Factor extractBelief (Factor cpf, VarSet varSet) { return cpf.extractMax (varSet); } // Serialization private static final long serialVersionUID = 1; private static final int CUURENT_SERIAL_VERSION = 1; private void writeObject (ObjectOutputStream out) throws IOException { out.defaultWriteObject (); out.writeInt (CUURENT_SERIAL_VERSION); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { in.defaultReadObject (); in.readInt (); // version } } // Serialization private static final long serialVersionUID = 1; private static final int CUURENT_SERIAL_VERSION = 1; private void writeObject (ObjectOutputStream out) throws IOException { out.defaultWriteObject (); out.writeInt (CUURENT_SERIAL_VERSION); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { in.defaultReadObject (); in.readInt (); // version } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy