首页 > 代码库 > scala版本的梅森旋转随机数算法

scala版本的梅森旋转随机数算法

package xzxz

import scala.annotation.tailrec

class MersenneTwister(seed: Int) {

  private val SIZE: Int = 624

  private val PERIOD: Int = 397

  private val DIFF: Int = SIZE - PERIOD

  private def isOdd(n: Int): Int = n & 1

  private def m32(n: Int): Int = 0x80000000 & n

  private def l31(n: Int): Int = 0x7FFFFFFF & n

  private val mersenneTwister = new Array[Int](SIZE)

  private var index: Int = 0

  resetSeed(seed)

  private def generateNumbers(): Unit = {

    @tailrec
    def rollMersennerTwister(i: Int): Unit = {
      if (i < DIFF) {
        val y = m32(mersenneTwister(i)) | l31(mersenneTwister(i + 1))
        mersenneTwister(i) = mersenneTwister(i + PERIOD) ^ (y >>> 1) ^ (isOdd(y) * 0x9908b0df)
        rollMersennerTwister(i + 1)
      }
    }

    rollMersennerTwister(0)

    def unroll(i: Int): Unit = {
      val y = m32(mersenneTwister(i)) | l31(mersenneTwister(i + 1))
      mersenneTwister(i) = mersenneTwister(i - DIFF) ^ (y >>> 1) ^ (isOdd(y) * 0x9908b0df)
    }

    @tailrec
    def startUnroll(i: Int): Unit = {
      if (i < SIZE - 1) {
        @tailrec
        def doUnroll(j: Int): Unit = {
          if (j < 11) {
            unroll(i)
            doUnroll(j + 1)
          }
        }

        startUnroll(i + 1)
      }
    }

    startUnroll(DIFF + 1)

    val y = m32(mersenneTwister(SIZE - 1)) | l31(mersenneTwister(SIZE - 1));
    mersenneTwister(SIZE - 1) = mersenneTwister(PERIOD - 1) ^ (y >>> 1) ^ (isOdd(y) * 0x9908b0df);
  }

  final def resetSeed(seed: Int): Unit = {
    mersenneTwister(0) = seed
    index = 0;

    @tailrec
    def getmersenneTwister(i: Int): Unit = {
      if (i < SIZE) {
        mersenneTwister(i) = 0x6c078965 * (mersenneTwister(i - 1) ^ mersenneTwister(i - 1) >>> 30) + i;
        getmersenneTwister(i + 1)
      }
    }

    getmersenneTwister(1)
  }

  /**
   * 生成 [0.0, 1.0) 范围内的随机数
   */
  final def nextFloat(): Double = (nextInt() + 2147483648.0) / 4294967296.0

  final def nextInt(): Int = {
    if (index == 0)
      generateNumbers()

    var res: Int = mersenneTwister(index)

    res ^= res >>> 11
    res ^= res << 7 & 0x9d2c5680
    res ^= res << 15 & 0xefc60000
    res ^= res >>> 18

    index = index + 1
    if (index == SIZE)
      index = 0
    res
  }
}

scala版本的梅森旋转随机数算法