All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.stripe.rainier.sampler.HMC.scala Maven / Gradle / Ivy

The newest version!
package com.stripe.rainier.sampler

class HMCSampler(nSteps: Int) extends Sampler {
  def initialize(params: Array[Double], lf: LeapFrog)(implicit rng: RNG) = ()

  def warmup(params: Array[Double],
             lf: LeapFrog,
             stepSize: Double,
             mass: MassMatrix)(implicit rng: RNG): Double = {
    lf.startIteration(params, mass)
    lf.takeSteps(nSteps, stepSize, mass)
    lf.finishIteration(params, mass)
  }

  def run(params: Array[Double],
          lf: LeapFrog,
          stepSize: Double,
          mass: MassMatrix)(implicit rng: RNG): Unit = {
    lf.startIteration(params, mass)
    lf.takeSteps(nSteps, stepSize, mass)
    lf.finishIteration(params, mass)
    ()
  }
}

object HMC {
  def apply(warmIt: Int, it: Int, nSteps: Int): SamplerConfig =
    new DefaultConfig {
      override val warmupIterations = warmIt
      override val iterations = it
      override def sampler() = new HMCSampler(nSteps)
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy