All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dk.bayes.infer.ep.GenericEP.scala Maven / Gradle / Ivy

The newest version!
package dk.bayes.infer.ep

import org.slf4j.LoggerFactory

import com.typesafe.scalalogging.slf4j.Logger

import dk.bayes.model.factor.api.Factor
import dk.bayes.model.factorgraph.FactorGraph
import dk.bayes.model.factorgraph.FactorNode

/**
 * Default implementation of the Expectation Propagation Bayesian Inference algorithm.
 *
 * @author Daniel Korzekwa
 *
 * @param threshold Calibration criteria: the maximum absolute difference between old and new corresponding messages on a factor graph,
 */
case class GenericEP(factorGraph: FactorGraph, threshold: Double = 0.00001) extends EP {

  private val logger = Logger(LoggerFactory.getLogger(getClass()))

  def setEvidence(varId: Int, varValue: AnyVal) = {

    val nodes = factorGraph.getNodes()
    for (node <- nodes) {
      node match {
        case node: FactorNode if (node.getFactor.getVariableIds.contains(varId)) => {
          val newFactor = node.getFactor().withEvidence(varId, varValue)
          node.setFactor(newFactor)
        }
        case _ => //do nothing
      }
    }
  }

  def marginal(variableId: Int, variablesIds: Int*): Factor = {

    variablesIds match {
      case Nil => {
        val varNode = factorGraph.getVariableNode(variableId)

        val gates = varNode.getGates()

        val inMsgs = gates.map(g => g.getEndGate.getMessage())

        val variableMarginal = inMsgs.reduceLeft((msg1, msg2) => msg1 * msg2)

        variableMarginal
      }
      case _ => {
        val allVarIds = variableId :: variablesIds.toList
        val factorNode = factorGraph.getFactorNode(allVarIds)
        val factorMarginal = factorNode.factorMarginal()
        factorMarginal
      }
    }

  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy