All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dlm.model.Wishart.scala Maven / Gradle / Ivy

The newest version!
package dlm.core.model

import breeze.stats.distributions._
import breeze.linalg._
import breeze.numerics._

case class Wishart(n: Double, scale: DenseMatrix[Double])(
    implicit rand: RandBasis = Rand)
    extends ContinuousDistr[DenseMatrix[Double]]
    with Moments[DenseMatrix[Double], DenseMatrix[Double]] {

  val d: Int = scale.cols

  def lgmultivariategamma(x: Double): Double = {
    log(math.Pi) * d * (d - 1) * 0.25 + ((1 to d) map (i =>
      lgamma(x + (1 - i) * 0.5))).sum
  }

  def logNormalizer: Double = {
    -logdet(scale)._2 * 0.5 * n - n * d * 0.5 * log(2) - lgmultivariategamma(
      n * 0.5)
  }

  def unnormalizedLogPdf(x: DenseMatrix[Double]): Double = {
    logdet(x)._2 * (n - d - 1) * 0.5 - 0.5 * trace(scale \ x)
  }

  private val l = cholesky(scale)

  def draw(): DenseMatrix[Double] = drawBartlett()

  // Bartlett Decomposition
  // https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition
  def bartlettDecomp() = {
    DenseMatrix.tabulate(d, d) {
      case (i, j) =>
        (for {
          c <- ChiSquared(n - i)
          n <- Gaussian(0, 1)
          x = if (i == j) sqrt(c) else if (i > j) n else 0
        } yield x).draw
    }
  }

  def drawBartlett() = {
    val a = bartlettDecomp()

    l * a * a.t * l.t
  }

  /**
    *  Draw from the wishart using the definition of the Wishart distribution
    */
  def drawNaive(): DenseMatrix[Double] = {
    val xi = Vector.fill(n.toInt)(
      l * DenseVector.rand(scale.cols, rand.gaussian(0, 1)))
    xi.map(x => x * x.t).reduce(_ + _)
  }

  def entropy: Double = ???
  def mean: breeze.linalg.DenseMatrix[Double] = n * scale
  def mode: breeze.linalg.DenseMatrix[Double] = (n - d - 1) * scale
  def variance: breeze.linalg.DenseMatrix[Double] = ???
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy