All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.stripe.rainier.sampler.EHMC.scala Maven / Gradle / Ivy

The newest version!
package com.stripe.rainier.sampler

class EHMCSampler(maxSteps: Int,
                  minSteps: Int = 1,
                  bufSize: Int = 100,
                  pCount: Double = 0.1)
    extends Sampler {
  val steps = new RingBuffer(bufSize)
  var buf: Array[Double] = _

  def initialize(params: Array[Double], lf: LeapFrog)(implicit rng: RNG) = {
    buf = new Array[Double](lf.inputOutputSize)
  }

  def warmup(params: Array[Double],
             lf: LeapFrog,
             stepSize: Double,
             mass: MassMatrix)(implicit rng: RNG): Double = {
    lf.startIteration(params, mass)
    if (shouldCountSteps()) {
      countSteps(params, lf, stepSize, mass)
    } else {
      val n = steps.sample().toInt
      lf.takeSteps(n, stepSize, mass)
    }
    lf.finishIteration(params, mass)
  }

  private def shouldCountSteps()(implicit rng: RNG): Boolean =
    !steps.full || rng.standardUniform < pCount

  private def countSteps(params: Array[Double],
                         lf: LeapFrog,
                         stepSize: Double,
                         mass: MassMatrix): Unit = {
    var l = 0
    while (!lf.isUTurn(params) && l < maxSteps) {
      l += 1
      lf.takeSteps(1, stepSize, mass)
      if (l == minSteps)
        lf.snapshot(buf)
    }
    if (l < minSteps) {
      lf.takeSteps(minSteps - l, stepSize, mass)
    } else {
      lf.restore(buf)
    }

    steps.add(l.toDouble)
  }

  def run(params: Array[Double],
          lf: LeapFrog,
          stepSize: Double,
          mass: MassMatrix)(implicit rng: RNG): Unit = {
    lf.startIteration(params, mass)
    val n = steps.sample().toInt
    lf.takeSteps(n, stepSize, mass)
    lf.finishIteration(params, mass)
    ()
  }
}

object EHMC {
  def apply(warmIt: Int,
            it: Int,
            minSteps: Int = 1,
            numLengths: Int = 100): SamplerConfig =
    new DefaultConfig {
      override val warmupIterations = warmIt
      override val iterations = it
      override def sampler() = new EHMCSampler(1000, minSteps, numLengths, 0.1)
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy