All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
breeze.inference.ExpectationPropagation.scala Maven / Gradle / Ivy
package breeze.inference
import breeze.linalg._
import breeze.numerics._
import breeze.stats.distributions.{Dirichlet, Bernoulli, Gaussian}
/**
*
* @author dlwh
*/
class ExpectationPropagation[F,Q](project: (Q,F)=>(Q,Double), criterion: Double = 1E-4)(implicit qFactor: Q <:
val State(f_~, q, _, partitions) = state
val fi = f(i)
val fi_~ = f_~(i)
val q_\ = q / fi_~
val (new_q, new_partition) = project(q_\ , fi)
val newF_~ = f_~.updated(i,new_q / q_\)
State(newF_~, new_q, prior, partitions.updated(i, new_partition))
}
val hasNext = (cur.q eq lastQ) || !next.q.isConvergedTo(cur.q, criterion)
consumed = !hasNext
cur = next
hasNext
}
def next() = {
if(consumed) hasNext
consumed = true
cur
}
}
it
}
}
object ExpectationPropagation extends App {
val prop = 0.9
val mean = 2
val gen = for {
a <- new Bernoulli(prop)
x <- Gaussian(I(a) * mean,3)
} yield x
val data = gen.sample(5000)
case class ApproxTerm(s: Double = 0.0, b: DenseVector[Double] = DenseVector.zeros(2)) extends Factor[ApproxTerm] { f1 =>
def logPartition = s + breeze.numerics.lbeta(b)
def *(f: Double) = copy(s = s + f)
def *(f2: ApproxTerm) = {
ApproxTerm(f1.s + f2.s, f1.b + f2.b)
}
def /(f2: ApproxTerm) = {
ApproxTerm(f1.s - f2.s, f1.b - f2.b)
}
def apply(a: Double) = {
0.0 // TODO
}
def isConvergedTo(f: ApproxTerm, diff: Double) = {
(b - f.b).norm(2) <= diff
}
}
def likelihood(x: Double):DenseVector[Double] = {
DenseVector(Gaussian(0,3).pdf(x), Gaussian(mean,3).pdf(x))
}
def solve(old: DenseVector[Double], target: DenseVector[Double]) = {
val guess = copy(old)
for(i <- 0 until 20) {
val t2 = target + digamma(guess.sum)
for(i <- 0 until 5) {
guess -= ((digamma(guess) - t2) :/ (( digamma(guess + 1E-4) - digamma(guess))/1E-4))
}
}
guess
}
def project(q: ApproxTerm, x: Double): (ApproxTerm, Double) = {
val like = likelihood(x)
val target = digamma(q.b) - digamma(q.b.sum) + (like / (like dot q.b)) - 1/q.b.sum
val normalizer:Double = likelihood(x) dot normalize(q.b, 1)
val mle = solve(q.b, target)
assert(!normalizer.isNaN,(mle,q.b,like,normalize(q.b,1)))
ApproxTerm(-lbeta(mle+1.0), mle) -> math.log(normalizer)
}
val ep = new ExpectationPropagation({project _}, 1E-8)
for( state <- ep.inference(ApproxTerm(0.0,DenseVector.ones(2)), data, Array.fill(data.length)(ApproxTerm())) take 20) {
println(state.logPartition, state.q)
assert(!state.logPartition.isNaN, state.q.s + " " + state.q.b)
}
}