All Downloads are FREE. Search and download functionalities are using the official Maven repository.

mgo.tools.RejectionSampler.scala Maven / Gradle / Ivy

package mgo.tools

/*
 * Copyright (C) 2021 Romain Reuillon
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see .
 */

object RejectionSampler {

  /**
   * Monte Carlo estimation of the success rate of the predicate.
   *
   * @param test
   * @param pass
   */
  case class State(test: Long = 0L, pass: Long = 0L) {
    def inverseProbability() = test.toDouble / pass
  }

  def success(state: State) = State(state.test + 1, state.pass + 1)
  def fail(state: State) = State(state.test + 1, state.pass)
  def noSuccess(state: State) = state.pass == 0

}

import RejectionSampler._
import scala.annotation.tailrec
/**
 * Rejection sampler with a predicate and a state.
 *
 * @param dist
 * @param patternFunction
 * @param accept
 */
class RejectionSampler(_sample: () => (Vector[Double], Lazy[Double]), val accept: Vector[Double] => Boolean) {

  def warmup(n: Int, state: State = State()): State =
    if (n > 0) {
      val (x, _) = _sample()
      if (!accept(x)) warmup(n - 1, fail(state))
      else warmup(n - 1, success(state))
    } else state

  def sample(state: State = State()): (State, (Vector[Double], Double)) = {
    @tailrec def sample0(state: State): (State, (Vector[Double], Double)) = {
      val (x, density) = _sample()
      if (!accept(x)) {
        // if the sample is rejected, resample and keep the failure in the state
        sample0(fail(state))
      } else {
        val newState = success(state)
        // if the sample is accepted, return the state, the sample pattern and the adjusted density
        (newState, (x, density.value / newState.inverseProbability()))
      }
    }

    sample0(state)
  }

  @tailrec final def sampleVector(n: Int, state: State = State(), res: List[(Vector[Double], Double)] = List()): (State, Vector[(Vector[Double], Double)]) = {
    if (n > 0) {
      val (newState, newSample) = sample(state)
      sampleVector(n - 1, newState, newSample :: res)
    } else (state, res.reverse.toVector)
  }

}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy