All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.mallet.grmm.inference.ResidualBP Maven / Gradle / Ivy

Go to download

MALLET is a Java-based package for statistical natural language processing, document classification, clustering, topic modeling, information extraction, and other machine learning applications to text.

The newest version!
/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.inference;


import java.util.Random;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.io.ObjectOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;

import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.Variable;

/**
 * A dynamic BP schedule where 
 * The loopy belief propagation algorithm for approximate inference in
 *  general graphical models.
 *
 * Created: Wed Nov  5 19:30:15 2003
 *
 * @author Charles Sutton
 * @version $Id: ResidualBP.java,v 1.1 2007/10/22 21:37:49 mccallum Exp $
 */
public class ResidualBP extends AbstractBeliefPropagation {

  public static final int DEFAULT_MAX_ITER = 1000;

  transient private int iterUsed;

  private int maxIter;

  private Random rand = new Random ();

  public int iterationsUsed () { return iterUsed; }

  public void setUseCaching (boolean useCaching) { this.useCaching = useCaching; }

  // Note this does not have the sophisticated Terminator interface
  // that we've got in TRP.
  public ResidualBP () { this (new SumProductMessageStrategy (), ResidualBP.DEFAULT_MAX_ITER); }

  public ResidualBP (int maxIter) { this (new SumProductMessageStrategy (), maxIter); }

  public ResidualBP (MessageStrategy messager, int maxIter)
  {
    super (messager);
    this.maxIter = maxIter;
  }

  public static Inferencer createForMaxProduct ()
  {
    return new ResidualBP (new MaxProductMessageStrategy (), ResidualBP.DEFAULT_MAX_ITER);
  }

  public ResidualBP setRand (Random rand)
  {
    this.rand = rand;
    return this;
  }

  public void computeMarginals (FactorGraph mdl)
  {
    super.initForGraph (mdl);

    int iter;
    for (iter = 0; iter < maxIter; iter++) {
      logger.finer ("***AsyncLoopyBP iteration "+iter);
      propagate (mdl);
      if (hasConverged ()) break;
      copyOldMessages ();
    }
    iterUsed = iter;
    if (iter >= maxIter) {
      logger.info ("***Loopy BP quitting: not converged after "+maxIter+" iterations.");
    } else {
      iterUsed++;  // there's an off-by-one b/c of location of above break
      logger.info ("***AsyncLoopyBP converged: "+iterUsed+" iterations");
    }

    doneWithGraph (mdl);
  }

  private void propagate (FactorGraph mdl)
  {
    // Send all messages in random order.
    ArrayList factors = new ArrayList (mdl.factors());
    Collections.shuffle (factors, rand);
    for (Iterator it = factors.iterator(); it.hasNext();) {
      Factor factor = (Factor) it.next();
      for (Iterator vit = factor.varSet ().iterator (); vit.hasNext ();) {
        Variable from = (Variable) vit.next ();
        sendMessage (mdl, from, factor);
      }
    }

    for (Iterator it = factors.iterator(); it.hasNext();) {
      Factor factor = (Factor) it.next();
      for (Iterator vit = factor.varSet ().iterator (); vit.hasNext ();) {
        Variable to = (Variable) vit.next ();
        sendMessage (mdl, factor, to);
      }
    }
  }

  // Serialization
  private static final long serialVersionUID = 1;

  // If seralization-incompatible changes are made to these classes,
  //  then smarts can be added to these methods for backward compatibility.
  private void writeObject (ObjectOutputStream out) throws IOException {
     out.defaultWriteObject ();
   }

  private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
     in.defaultReadObject ();
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy