All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.stripe.rainier.sampler.DualAvg.scala Maven / Gradle / Ivy

The newest version!
package com.stripe.rainier.sampler

class DualAvgTuner(delta: Double) extends StepSizeTuner {
  var da: DualAvg = _

  def initialize(params: Array[Double], lf: LeapFrog): Double = {
    val stepSize0 = findReasonableStepSize(params, lf, IdentityMassMatrix)
    da = DualAvg(delta, stepSize0)
    stepSize0
  }

  def update(logAcceptanceProb: Double): Double = {
    da.update(logAcceptanceProb)
    da.stepSize
  }

  def reset(): Double = {
    val ss = stepSize
    da = DualAvg(delta, ss)
    ss
  }

  def stepSize: Double = {
    da.finalStepSize
  }

  private def findReasonableStepSize(params: Array[Double],
                                     lf: LeapFrog,
                                     mass: MassMatrix): Double = {
    var stepSize = 1.0
    var logAcceptanceProb =
      lf.tryStepping(params, stepSize, mass)
    val exponent = if (logAcceptanceProb > Math.log(0.5)) { 1.0 } else { -1.0 }
    val doubleOrHalf = Math.pow(2, exponent)
    while (stepSize != 0.0 && (exponent * logAcceptanceProb > -exponent * Math
             .log(2))) {
      stepSize *= doubleOrHalf
      logAcceptanceProb = lf.tryStepping(params, stepSize, mass)
    }
    stepSize
  }
}

final class DualAvg(
    delta: Double,
    var logStepSize: Double,
    var logStepSizeBar: Double,
    var avgError: Double,
    var iteration: Int,
    shrinkageTarget: Double,
    stepSizeUpdateDenom: Double = 0.05,
    acceptanceProbUpdateDenom: Int = 10,
    decayRate: Double = 0.75
) {
  def stepSize: Double = Math.exp(logStepSize)
  def finalStepSize: Double = Math.exp(logStepSizeBar)

  def update(logAcceptanceProb: Double): Unit = {
    val newAcceptanceProb = Math.exp(logAcceptanceProb)
    iteration = iteration + 1
    val avgErrorMultiplier =
      1.0 / (iteration.toDouble + acceptanceProbUpdateDenom)
    val stepSizeMultiplier = Math.pow(iteration.toDouble, -decayRate)

    avgError = (
      (1.0 - avgErrorMultiplier) * avgError
        + (avgErrorMultiplier * (delta - newAcceptanceProb))
    )

    logStepSize = (
      shrinkageTarget
        - (avgError * Math.sqrt(iteration.toDouble) / stepSizeUpdateDenom)
    )

    logStepSizeBar = (stepSizeMultiplier * logStepSize
      + (1.0 - stepSizeMultiplier) * logStepSizeBar)
  }
}

private object DualAvg {
  def apply(delta: Double, stepSize: Double): DualAvg =
    new DualAvg(
      delta = delta,
      logStepSize = Math.log(stepSize),
      logStepSizeBar = 0.0,
      avgError = 0.0,
      iteration = 0,
      shrinkageTarget = Math.log(10 * stepSize)
    )
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy