All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hedgehog.random.MersenneTwister64.scala Maven / Gradle / Ivy

The newest version!
/**
 * This is copied from a scalaprops modules here:
 *
 * https://github.com/scalaprops/scalaprops/blob/f5b9f2edf5987676194f269cedbbfb405bc9d9b0/gen/src/main/scala/scalaprops/MersenneTwister64.scala
 *
 * Copyright 2015 scalaprops contributors
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to
 * deal in the Software without restriction, including without limitation the
 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
 * sell copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 **/
package hedgehog.random

import java.nio.ByteBuffer
import java.util.Arrays

final class MersenneTwister64 private(private val mt0: Array[Long], private val mti0: Int = 313) { // N + 1 = 313

  import MersenneTwister64._

  override def equals(other: Any): Boolean =
    other match {
      case that: MersenneTwister64 => this === that
      case _ => false
    }

  override def hashCode = mti0

  def ===(that: MersenneTwister64): Boolean =
    (this.mti0 == that.mti0) && Arrays.equals(this.mt0, that.mt0)

  def next: MersenneTwister64 =
    nextLong._1

  def getSeedBytes(): Array[Byte] = {
    val bytes = new Array[Byte](BYTES)
    val bb = ByteBuffer.wrap(bytes)

    var i = 0
    while(i < N){
      bb.putLong(mt0(i))
      i += 1
    }
    bb.putInt(mti0)
    bytes
  }

  def setSeedBytes(bytes: Array[Byte]): MersenneTwister64 = {
    val mt = mt0.clone()
    val bs = if (bytes.length < BYTES) Arrays.copyOf(bytes, BYTES) else bytes
    val bb = ByteBuffer.wrap(bs)
    var i = 0
    while(i < N){
      mt(i) = bb.getLong()
      i += 1
    }
    val mti = bb.getInt
    new MersenneTwister64(mt, mti)
  }

  // TODO improve
  def reseed(n: Long) = next

  def nextLong: (MersenneTwister64, Long) = {
    var mti = mti0
    var x = 0L

    val mt1 = if (mti >= N) {
      val mt = mt0.clone()
      var kk = 0

      while (kk < N_M) {
        x = (mt(kk) & UpperMask) | (mt(kk + 1) & LowerMask)
        mt(kk) = mt(kk + M) ^ (x >>> 1) ^ mag01(x)
        kk += 1
      }

      while (kk < N_1) {
        x = (mt(kk) & UpperMask) | (mt(kk + 1) & LowerMask)
        mt(kk) = mt(kk + M_N) ^ (x >>> 1) ^ mag01(x)
        kk += 1
      }

      x = (mt(N_1) & UpperMask) | (mt(0) & LowerMask)
      mt(N_1) = mt(M_1) ^ (x >>> 1) ^ mag01(x)

      mti = 0
      mt
    } else {
      mt0
    }

    x = mt1(mti)
    mti += 1

    // Tempering
    x ^= (x >>> 29) & 0x5555555555555555L
    x ^= (x  << 17) & 0x71D67FFFEDA60000L
    x ^= (x  << 37) & 0xFFF7EEE000000000L
    x ^= (x >>> 43)

    (new MersenneTwister64(mt1, mti), x)
  }

  def nextInt: (MersenneTwister64, Int) = {
    val (r, n) = nextLong
    (r, (n >>> 32).toInt)
  }

  // Generates a random Double in the interval [0, 1)
  def nextDouble: (MersenneTwister64, Double) = {
    val x = nextInt
    val a: Long = (x._2.toLong & 0xffffffffL) >>> 5
    val y = x._1.nextInt
    val b: Long = (y._2.toLong & 0xffffffffL) >>> 6
    val r = (a * 67108864.0 + b) / 9007199254740992.0
    (y._1, r)
  }
  override def toString: String = {
    mt0.mkString("MersenneTwister64(Array(", ",", s"), ${mti0})")
  }
}

object MersenneTwister64 {

  private final val UpperMask = 0xFFFFFFFF80000000L // = 0xFFFFFFFFFFFFFFFFL ^ Int.MinValue
  private final val LowerMask = 0x7FFFFFFFL         // = Int.MinValue

  private final val N = 312
  private final val M = 156

  private final val N_M = N - M
  private final val N_1 = N - 1

  private final val M_N = M - N
  private final val M_1 = M - 1

  private final val BYTES = N * 8 + 4

  @inline private def mag01(x: Long) =
    if ((x & 1) == 0) 0L else 0xB5026F5AA96619EL

  def standard(seed: Long): MersenneTwister64 =
    new MersenneTwister64(seedFromLong(N, seed))

  def fromSeed(seed: Long): MersenneTwister64 =
    fromSeedArray(seedFromLong(N, seed), N + 1)

  def fromArray(arr: Array[Long]): MersenneTwister64 =
    fromSeedArray(seedFromArray(N, arr), N + 1)

  def fromSeedArray(mt: Array[Long], mti: Int): MersenneTwister64 = {
    assert(mt.length == N)
    new MersenneTwister64(mt, mti)
  }

  def fromBytes(bytes: Array[Byte]): MersenneTwister64 =
    fromArray(longsFromBytes(bytes, bytes.length / 8))

  def seedFromInt(length: Int, seed: Int): Array[Int] = {
    val a = new Array[Int](length)
    a(0) = seed

    var i = 1
    while(i < length){
      val x = a(i - 1)
      a(i) = 1812433253 * (x ^ (x >>> 30)) + i
      i += 1
    }

    a
  }

  private def seedFromLong(length: Int, seed: Long): Array[Long] = {
    val a = new Array[Long](length)
    a(0) = seed

    var i = 1
    while(i < length){
      val x = a(i - 1)
      a(i) = 6364136223846793005L * (x ^ (x >>> 62)) + i
      i += 1
    }

    a
  }

  private def seedFromArray(length: Int, seed: Array[Long]): Array[Long] = {
    val a = seedFromLong(length, 19650218)
    val length_1 = length - 1

    var i = 1
    var j = 0
    var k = java.lang.Math.max(length, seed.length)

    while (k != 0) {
      val x = a(i - 1)
      a(i) = a(i) ^ ((x ^ (x >>> 62)) * 3935559000370003845L) + seed(j) + j
      i += 1
      j += 1

      if (i >= length) {
        a(0) = a(length_1)
        i = 1
      }

      if (j >= seed.length) {
        j = 0
      }
      k -= 1
    }

    k = length - 1
    while (k != 0) {
      val x = a(i - 1)
      a(i) = a(i) ^ ((x ^ (x >>> 62)) * 2862933555777941757L) - i
      i += 1

      if (i >= length) {
        a(0) = a(length_1)
        i = 1
      }

      k -= 1
    }

    a(0) = 1L << 63
    a
  }


  private[this] def longFromByteBuffer(bb: ByteBuffer): Long =
    if (bb.remaining >= 8) {
      bb.getLong()
    } else {
      var n = 0L
      while (bb.remaining > 0) n = (n << 8) | bb.get
      n
    }

  private def longsFromBytes(bytes: Array[Byte], n: Int): Array[Long] =
    longsFromByteBuffer(ByteBuffer.wrap(bytes), n)

  private[this] def longsFromByteBuffer(bb: ByteBuffer, n: Int): Array[Long] = {
    val out = new Array[Long](n)
    var i = 0
    while (i < n && bb.remaining >= 8) {
      out(i) = bb.getLong()
      i += 1
    }
    if (i < n && bb.remaining > 0) out(i) = longFromByteBuffer(bb)
    out
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy