shithub: mc

ref: b31626248ea64c23d2020846e2cac1c751348b61
dir: /lib/std/rand.myr/

View raw version
use "alloc"
use "assert"
use "die"
use "extremum"
use "mk"
use "now"
use "putint"
use "types"

pkg std =
	type rng

	const mksrng	: (seed : uint32 -> rng#)
	const freerng	: (rng : rng# -> void)

	generic rand	: (lo : @a, hi : @a -> @a) :: numeric,integral @a
	generic randnum	: (-> @a) :: numeric,integral @a
	const randbytes	: (buf : byte[:] -> void)

	generic rngrand	: (rng : rng#, lo : @a, hi : @a -> @a) ::numeric,integral @a
	generic rngrandnum	: (rng : rng# -> @a) :: numeric,integral @a
	const rngrandbytes	: (rng : rng#, buf : byte[:] -> void)
;;

type rng = struct
	s0 : uint64
	s1 : uint64
;;

var _rng

generic rand	= {lo, hi;	-> rngrand(&_rng, lo, hi)}
generic randnum	= {;	-> rngrandnum(&_rng)}
const randbytes	= {buf;	-> rngrandbytes(&_rng, buf)}

const __init__ = {
	_rng.s0 = (now() : uint64)
	_rng.s1 = (now() : uint64)
}

const mksrng = {seed : uint32 -> rng#
	-> std.mk([
		.s0=(seed & 0xffff : uint64),
		.s1=(seed&0xffff0000>>16 : uint64)
	])
}

const freerng	= {r
	std.free(r)
}

generic rngrand = {rng, lo, hi
	var span, lim, val, max
	
	span = abs(hi - lo)
	max = ~0
	/* if ~0 is negative, we have a signed value with a different max */
	if max < 0
	        max = (1 << (8*sizeof(@a)-1)) - 1
	;;
	
	lim = (max/span)*span
	val = (rngrandnum(rng) & max)
	while val > lim
	        val = (rngrandnum(rng) & max)
	;;
	-> val % span + lo
}

const rngrandbytes = {rng, buf
	var n, r : uint64

	n = 0
	for var i = 0; i + 8 < buf.len/8; i++
		r = rngrandnum(rng)
		putle64(buf[n:n+8], r)
		n += 8
	;;
	r = rngrandnum(rng)
	for ; n != buf.len; n++
		buf[n] = (r & 0xff : byte)
		r >>= 8
	;;
}


/*
Generate a number using the xoroshiro algorithm, as
designed by David Blackman and Sebastiano Vigna.

See http://xoroshiro.di.unimi.it/ for details.
*/
generic rngrandnum = {rng -> @a :: numeric,integral @a
	var s0, s1, r

	s0 = rng.s0
	s1 = rng.s1

	r = s0 + s1
	s1 ^= s0

	rng.s0 = (s0 << 55 | s0 >> 9) ^ s1 ^ (s1 << 14)
	rng.s1 = (s1 << 36 | s1 >> 28)
	-> (r : @a)
}