All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dlm.model.ParticleGibbs.scala Maven / Gradle / Ivy

The newest version!
package dlm.core.model

import breeze.linalg.DenseVector
import breeze.stats.distributions.{Multinomial, Rand}
import cats.implicits._
import math.{exp, log}
import breeze.stats.mean
import ParticleFilter._

case class PgState(conditionedState: Map[Double, DenseVector[Double]],
                   states: Vector[Vector[(Double, DenseVector[Double])]],
                   weights: Vector[Double],
                   ll: Double)

/**
  * Particle Gibbs Sampler for A Dynamic Generalised Linear Dglm
  */
case class ParticleGibbs(n: Int) {

  def initialConditionedState(model: Dglm,
                              p: DlmParameters,
                              ys: Vector[Data]) = {
    val n0 = math.floor(n / 5).toInt
    val st = ParticleFilter(n, n0, multinomialResample).filter(model, ys, p)
    val ws = st.map(_.weights).last
    val states = st.map(d => d.state.map((d.time, _)))

    ParticleGibbs.sampleState(states, ws).draw.toMap
  }

  def initialiseState(model: Dglm, p: DlmParameters, ys: Vector[Data]) = {

    val t0 = ys.foldLeft(0.0)((t0, d) => math.min(t0, d.time))
    val x0 = MultivariateGaussianSvd(p.m0, p.c0)
      .sample(n - 1)
      .map(x => (t0, x))
      .toVector

    // sample the first conditioned state
    val conditionedState = initialConditionedState(model, p, ys)

    PgState(conditionedState, Vector(x0), Vector.fill(n - 1)(1.0 / n), 0.0)
  }

  def step(mod: Dglm, p: DlmParameters)(s: PgState, d: Data): PgState = {

    val y = KalmanFilter.flattenObs(d.observation)
    val dt = d.time - s.states.last.head._1

    // resample using the previous weights
    val resampledX = multinomialResample(s.states.head, s.weights)

    // advance n-1 states from time t, located at the head of the list
    val x1 =
      advanceState(dt, resampledX.map(_._2), mod, p).draw.map(x => (d.time, x))

    if (y.data.isEmpty) {
      PgState(s.conditionedState,
              s.states :+ x1,
              Vector.fill(x1.size - 1)(1.0 / x1.size),
              s.ll)
    } else {
      // concat conditioned state at current time and advanced state
      val cond: (Double, DenseVector[Double]) =
        (d.time, s.conditionedState.getOrElse(d.time, x1.head._2))
      val x = cond +: x1

      // calculate weights of all n states at time t
      val w = calcWeights(mod, d.time, x.map(_._2), d.observation, p)

      // calculate updated pseudo log-likelihood
      val max = w.max
      val w1 = w map (a => exp(a - max))
      val ll = s.ll + max + log(mean(w1))

      PgState(s.conditionedState, s.states :+ x1, w1.tail, ll)
    }
  }

  /**
    * Perform the PG filter
    * @param n the total number of particles in the filter
    * @param n0 if ESS < n0 then resample
    */
  def filter(model: Dglm, ys: Vector[Data], p: DlmParameters): PgState = {

    val init = initialiseState(model, p, ys)
    ys.foldLeft(init)(step(model, p))
  }
}

object ParticleGibbs {

  /**
    * Using the weights at time T (the end of all observations) sample a
    * path from the collection of paths
    * @param states a collection of paths with ancestory, the outer list is
    * of length T, theinner length N
    * @param weights particle weights at time T
    * @return a single path
    */
  def sampleState(
      states: Vector[Vector[(Double, DenseVector[Double])]],
      weights: Vector[Double]): Rand[Vector[(Double, DenseVector[Double])]] = {
    for {
      k <- Multinomial(DenseVector(weights.toArray))
      x = states.transpose
    } yield x(k)
  }

  /**
    * Sample the conditioned state from the Particle Gibbs Sampler
    * @param n the number of particles to use in the particle filter
    */
  def sample(n: Int, mod: Dglm, ys: Vector[Data], p: DlmParameters) = {
    val filtered = ParticleGibbs(n).filter(mod, ys, p)
    sampleState(filtered.states, filtered.weights) map ((filtered.ll, _))
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy