All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dk.bayes.infer.epnaivebayes.EPNaiveBayesFactorGraph.scala Maven / Gradle / Ivy

The newest version!
package dk.bayes.infer.epnaivebayes

import com.typesafe.scalalogging.slf4j.LazyLogging

import scala.annotation.tailrec
import dk.bayes.dsl.factor.DoubleFactor
import dk.bayes.dsl.factor.SingleFactor
import dk.bayes.math.numericops._
import dk.bayes.math.gaussian.canonical.DenseCanonicalGaussian
import dk.bayes.math.gaussian.canonical.SparseCanonicalGaussian
import dk.bayes.math.gaussian.canonical.SparseCanonicalGaussian
import dk.bayes.math.gaussian.Gaussian

/**
 * Computes posterior of X for a naive bayes net. Variables: X, Y1|X, Y2|X,...Yn|X
 *
 * It run Expectation Propagation algorithm. http://en.wikipedia.org/wiki/Expectation_propagation
 *
 * @param bn
 * @param paralllelMessagePassing If true then messages between X variable and Y variables are sent in parallel
 *
 * @author Daniel Korzekwa
 */
case class EPNaiveBayesFactorGraph[X](prior: SingleFactor[X], likelihoods: Seq[DoubleFactor[X, _]], paralllelMessagePassing: Boolean = false)(implicit val multOp: multOp[X], val divideOp: divideOp[X], val isIdentical: isIdentical[X]) extends LazyLogging {

  private var msgsUp: Seq[X] = likelihoods.map(l => l.initFactorMsgUp)
  private var posterior = multOp(prior.factorMsgDown, multOp(msgsUp: _*))

  def getPosterior(): X = posterior

  def calibrate(maxIter: Int = 100, threshold: Double = 1e-6) {

    @tailrec
    def calibrateIter(currPosterior: X, iterNum: Int) {
      if (iterNum >= maxIter) {      
        logger.warn(s"Factor graph did not converge in less than ${maxIter} iterations. Prior=%s, Posterior=%s".format(prior, posterior))
        return
      }
      if (paralllelMessagePassing) sendMsgsParallel() else sendMsgsSerial()

      if (isIdentical(posterior, currPosterior, threshold)) return
      else calibrateIter(posterior, iterNum + 1)
    }

    calibrateIter(posterior, 1)
  }

  private def sendMsgsParallel() {

    msgsUp = msgsUp.zip(likelihoods).map {
      case (currMsgUp, llh) =>

        val newMsgUp = llh.calcYFactorMsgUp(posterior, currMsgUp) match {
          case Some(msg) => msg
          case None      => currMsgUp
        }

        newMsgUp
    }

    posterior = multOp(prior.factorMsgDown, multOp(msgsUp: _*))
  }

  private def sendMsgsSerial() {

    msgsUp = msgsUp.zip(likelihoods).map {
      case (currMsgUp, llh) =>

        val newMsgUp = llh.calcYFactorMsgUp(posterior, currMsgUp) match {
          case Some(msg) => {
            val cavity = divideOp(posterior, currMsgUp)
            posterior = multOp(cavity, msg)
            msg
          }
          case None => currMsgUp
        }

        newMsgUp
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy