All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dlm.model.Storvik.scala Maven / Gradle / Ivy

The newest version!
package dlm.core.model

import breeze.linalg.{DenseVector, diag}
import breeze.stats.distributions._
import breeze.numerics.exp
import cats.implicits._
import cats.Traverse
import scala.language.higherKinds
import spire.syntax.cfor._

/**
  * State of the Storvik filter
  * @param time the time of the observation associated with this latent state
  * @param state the particle cloud representing the posterior state
  * @param params
  */
case class StorvikState(time: Double,
                        state: Vector[DenseVector[Double]],
                        params: Vector[DlmParameters],
                        statsV: Vector[Vector[InverseGamma]],
                        statsW: Vector[Vector[InverseGamma]],
                        ess: Int)

object StorvikFilter {

  def initialiseState[T[_]: Traverse](model: Dglm,
                                      p: DlmParameters,
                                      ys: T[Data],
                                      n: Int,
                                      priorV: InverseGamma,
                                      priorW: InverseGamma): StorvikState = {

    val x0 = MultivariateGaussian(p.m0, p.c0).sample(n).toVector
    val vStats = Vector.fill(n)(Vector.fill(p.w.cols)(priorV))
    val wStats = Vector.fill(n)(Vector.fill(p.w.cols)(priorW))
    val p0 = drawParams(vStats, wStats, p)

    StorvikState(0.0, x0, p0, vStats, wStats, 0)
  }

  /**
    * Update the value of the sufficient statistics for the system covariance matrix
    * @param dt the time difference between observations
    * @param stats the previous value of the sufficient statistics
    * @param x0 the previous value of the state particles
    * @param x1 the current value of the state particles
    * @return a vector representing the diagonal entries of the system covariance matrix
    */
  def updateStatsW(dt: Double,
                   mod: Dglm,
                   stats: Vector[InverseGamma],
                   x0: DenseVector[Double],
                   x1: DenseVector[Double]): Vector[InverseGamma] = {

    val diff = (x1 - mod.g(dt) * x0)
    val ss = (diff *:* diff) / dt
    val shapes = stats.map(_.shape + 0.5)
    val scales = (stats zip ss.data) map { case (s, si) => s.scale + 0.5 * si }

    (shapes zip scales) map {
      case (shape, scale) => InverseGamma(shape, scale)
    }
  }

  def drawParams(statsV: Vector[Vector[InverseGamma]],
                 statsW: Vector[Vector[InverseGamma]],
                 p: DlmParameters) = {

    (statsV, statsW).zipped.map {
      case (vs, ws) =>
        p.copy(
          v = diag(DenseVector(vs.map(g => g.draw).toArray)),
          w = diag(DenseVector(ws.map(g => g.draw).toArray))
        )
    }
  }

  /**
    * Update the value of the sufficient statistics for the observation covariance matrix
    */
  def updateStatsV(time: Double,
                   mod: Dglm,
                   stats: Vector[InverseGamma],
                   x: DenseVector[Double],
                   y: DenseVector[Double]): Vector[InverseGamma] = {

    val diff = (y - mod.f(time).t * x)
    val ss = (diff *:* diff)
    val shapes = stats.map(_.shape + 0.5)
    val scales = (stats zip ss.data) map { case (s, si) => s.scale + 0.5 * si }

    (shapes zip scales) map {
      case (shape, scale) => InverseGamma(shape, scale)
    }
  }

  def advanceState(mod: Dglm,
                   dt: Double,
                   ps: Vector[DlmParameters],
                   xs: Vector[DenseVector[Double]]) = {

    (xs zip ps).map { case (x, p) => Dglm.stepState(mod, p)(x, dt).draw }
  }

  def step(mod: Dglm, n0: Int)(s: StorvikState, d: Data): StorvikState = {

    val y = KalmanFilter.flattenObs(d.observation)
    val dt = d.time - s.time

    // propose parameters from a known distribution, using the vector of sufficient statistics
    val params = drawParams(s.statsV, s.statsW, s.params.head)

    // advance the state particles, according to p(x(t) | x(t-1), psi)
    val x1 = advanceState(mod, dt, params, s.state)

    // calculate the new weights, evaluating p(y(t) | x(t), psi)
    val weights = calcWeights(mod, d.time, x1, d.observation, params)

    // avoid underflow
    val maxWeight = weights.max
    val expWeight = weights map (a => exp(a - maxWeight))

    // normalise the weights to sum to one
    val normWeights = ParticleFilter.normaliseWeights(expWeight)

    // calculate the effective sample size
    val ess = ParticleFilter.effectiveSampleSize(normWeights)

    // resample if the effective sample size is less than n0
    if (ess < n0) {
      val indices = ParticleFilter.multinomialResample(
        expWeight.indices.toVector,
        expWeight)
      val resampledState = indices map (x1(_))
      val resampledParams = indices map (params(_))
      val resampledStatsW = indices map (s.statsW(_))
      val resampledStatsV = indices map (s.statsV(_))

      val statsW = (s.state, resampledState, resampledStatsW).zipped.map {
        case (x0, x, stats) => updateStatsW(dt, mod, stats, x0, x)
      }

      val statsV = (resampledState, resampledStatsV).zipped.map {
        case (x, stats) => updateStatsV(d.time, mod, stats, x, y)
      }

      StorvikState(d.time, resampledState, resampledParams, statsV, statsW, ess)
    } else {
      val statsW = (s.state, x1, s.statsW).zipped.map {
        case (x0, x, stats) => updateStatsW(dt, mod, stats, x0, x)
      }

      val statsV = (x1, s.statsV).zipped.map {
        case (x, stats) => updateStatsV(d.time, mod, stats, x, y)
      }

      StorvikState(d.time, x1, params, statsV, statsW, ess)
    }
  }

  def filterArrayTs(model: Dglm,
                    ys: Vector[Data],
                    p: DlmParameters,
                    priorV: InverseGamma,
                    priorW: InverseGamma,
                    n: Int,
                    n0: Int) = {

    val st = Array.ofDim[StorvikState](ys.length + 1)
    st(0) = initialiseState(model, p, ys, n, priorV, priorW)

    cfor(1)(_ < st.size, _ + 1) { i =>
      st(i) = step(model, n0)(st(i - 1), ys(i - 1))
    }

    st
  }

  def filterTs(model: Dglm,
               ys: Vector[Data],
               p: DlmParameters,
               priorV: InverseGamma,
               priorW: InverseGamma,
               n: Int,
               n0: Int): Vector[StorvikState] = {

    val init = initialiseState(model, p, ys, n, priorV, priorW)
    ys.scanLeft(init)(step(model, n0))
  }

  def calcWeights(mod: Dglm,
                  time: Double,
                  state: Vector[DenseVector[Double]],
                  y: DenseVector[Option[Double]],
                  ps: Vector[DlmParameters]) = {

    val fm = KalmanFilter.missingF(mod.f, time, y)

    (state zip ps).map {
      case (x, p) =>
        val vm = KalmanFilter.missingV(p.v, y)
        mod.conditionalLikelihood(vm)(fm.t * x, KalmanFilter.flattenObs(y))
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy