All Downloads are FREE. Search and download functionalities are using the official Maven repository.

schrodinger.kernel.testkit.RandomEq.scala Maven / Gradle / Ivy

/*
 * Copyright 2021 Arman Bilge
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package schrodinger.kernel
package testkit

import cats.Eq
import cats.Monad
import cats.laws.discipline.ExhaustiveCheck
import cats.syntax.all.*
import schrodinger.math.LogDouble
import schrodinger.math.special.gamma
import scala.collection.mutable

final case class Confidence(replicates: Int, eqvThreshold: Double, neqvThreshold: Double)

object PseudoRandomEq:
  def apply[F[_]: Monad, G[_], S, A: Eq](using
      pseudo: PseudoRandom.Aux[F, G, S],
      seeds: ExhaustiveCheck[S],
      confidence: Confidence,
      eq: Eq[G[SimulationResult[A]]],
  ): Eq[F[A]] =
    import cats.laws.discipline.eq.catsLawsEqForFn1Exhaustive
    Eq.by[F[A], S => G[SimulationResult[A]]](rv =>
      s => rv.replicateA(confidence.replicates).map(SimulationResult(_)).simulate(s),
    )

final case class SimulationResult[A](samples: List[A])

object SimulationResult:
  given [A: Eq](using confidence: Confidence): Eq[SimulationResult[A]] =
    case (SimulationResult(xs), SimulationResult(ys)) =>
      import confidence.*

      val allValues = mutable.ArrayBuffer[A]()
      xs.foreach(a => if allValues.forall(_ =!= a) then allValues += a)
      ys.foreach(a => if allValues.forall(_ =!= a) then allValues += a)

      if allValues.size == 1 then true
      else
        val xcounts = new Array[Int](allValues.size)
        val ycounts = new Array[Int](allValues.size)
        allValues.zipWithIndex.foreach { (a, i) =>
          xcounts(i) = xs.count(_ === a)
          ycounts(i) = ys.count(_ === a)
        }

        val p = equidistributedBelief(xcounts, ycounts, Array.fill(allValues.size)(1.0))

        if p > eqvThreshold then true
        else if (1 - p) > neqvThreshold then false
        else throw new EqUndecidableException

  private def equidistributedBelief(
      trial1: Array[Int],
      trial2: Array[Int],
      dirichletPrior: Array[Double],
  ): Double =
    val marginal1 = dirichletMultinomialLogPmf(trial1, trial2, dirichletPrior)
    val marginal2 = dirichletMultinomialLogPmf(trial1, dirichletPrior) *
      dirichletMultinomialLogPmf(trial2, dirichletPrior)
    (marginal1 / (marginal1 + marginal2)).real

  private def dirichletMultinomialLogPmf(x: Array[Int], alpha: Array[Double]): LogDouble =
    val A = sum(alpha)
    val n = sum(x)
    var pmf = gamma(A) * gamma(n + 1.0) / gamma(n + A)
    var k = 0
    while k < x.length do
      pmf *= gamma(x(k) + alpha(k)) / gamma(alpha(k)) / gamma(x(k) + 1.0)
      k += 1
    pmf

  private def dirichletMultinomialLogPmf(
      x1: Array[Int],
      x2: Array[Int],
      alpha: Array[Double],
  ): LogDouble =
    val A = sum(alpha)
    val n1 = sum(x1)
    val n2 = sum(x2)
    val n = n1 + n2
    var pmf = gamma(A) * gamma(n1 + 1.0) * gamma(n2 + 1.0) / gamma(n + A)
    var k = 0
    while k < x1.length do
      pmf *= gamma(x1(k) + x2(k) + alpha(k))
        / gamma(alpha(k)) / gamma(x1(k) + 1.0) / gamma(x2(k) + 1.0)
      k += 1
    pmf

  private def sum(x: Array[Int]): Int =
    var i = 0
    var sum = 0
    while i < x.length do
      sum += x(i)
      i += 1
    sum

  private def sum(x: Array[Double]): Double =
    var i = 0
    var sum = 0.0
    while i < x.length do
      sum += x(i)
      i += 1
    sum




© 2015 - 2025 Weber Informatics LLC | Privacy Policy