All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.stripe.rainier.sampler.Stats.scala Maven / Gradle / Ivy

The newest version!
package com.stripe.rainier.sampler

class Stats(n: Int) {
  var gradientEvaluations = 0L
  var iterations = 0
  var divergences = 0

  val gradientTimes = new RingBuffer(n)
  val iterationTimes = new RingBuffer(n)
  val stepSizes = new RingBuffer(n)
  val acceptanceRates = new RingBuffer(n)
  val gradsPerIteration = new RingBuffer(n)

  val energyVariance = new VarianceEstimator(1)
  var energyTransitions2 = 0.0
  def bfmi = energyTransitions2 / energyVariance.raw(0)
}

class RingBuffer(size: Int) {
  var full = false
  private var i = 0
  private val buf = new Array[Double](size)

  def add(value: Double): Unit = {
    i += 1
    if (i == size)
      full = true
    i = i % size
    buf(i) = value
  }

  def last: Double = buf(i)
  def toList: List[Double] = {
    val lastPart = buf.take(i).toList
    if (full)
      buf.drop(i).toList ++ lastPart
    else
      lastPart
  }
  def sample()(implicit rng: RNG) = {
    if (full)
      buf(rng.int(size))
    else
      buf(rng.int(i + 1))
  }

  def mean: Double = {
    var sum = 0.0
    var j = 0
    while (j < size) {
      sum += buf(j)
      j += 1
    }
    if (full)
      sum / size.toDouble
    else
      sum / i.toDouble
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy