shithub: mc

ref: 16b894430a5b71f4956ee305229bbbffd13605d0
dir: /lib/std/test/bigint.myr/

View raw version
use std
use testr

type cmd = union
	`Add (cmd#, cmd#)
	`Sub (cmd#, cmd#)
	`Mul (cmd#, cmd#)
	`Div (cmd#, cmd#)
	`Mod (cmd#, cmd#)
	`Shl (cmd#, cmd#)
	`Shr (cmd#, cmd#)
	`Modpow (cmd#, cmd#, cmd#)
	`Val byte[:]
;;

const main = {
	testr.run([
		[.name = "smoke-test", .fn = smoketest],
		[.name = "matches-small", .fn = matchsmall],
		[.name = "comparisons", .fn = comparisons],
		[.name = "format-zero", .fn = fmtzero],
		[.name = "division", .fn = smokediv],
		[.name = "modulo", .fn = smokemod],
		[.name = "add-negatives", .fn = addneg],
		[.name = "sub-negatives", .fn = subneg],
		[.name = "bigdivmod", .fn = somebigdivmod],
	][:])
}

const smoketest = {ct
	var a, b, c, d, e, f
	var buf : byte[64], n

	/* a few combined ops */
	a = std.mkbigint(1234)
	b = std.mkbigint(0x7fffffff)
	c = std.mkbigint(7919)
	d = std.mkbigint(113051)
	e = std.mkbigint(11)
	f = std.mkbigint((4294967296 : int64))

	std.bigmul(a, b)
	std.bigmul(a, b)
	std.bigadd(a, c)
	std.bigsub(a, d)
	std.bigdiv(a, e)

	std.bigfree(b)
	std.bigfree(c)
	std.bigfree(d)
	std.bigfree(e)

	n = std.bigbfmt(buf[:], a, 0)
	testr.check(ct, std.eq(buf[:n], "517347321949036993306"), "simple smoke test failed")

	n = std.bigbfmt(buf[:], f, 0)
	testr.check(ct, std.eq(buf[:n], "4294967296"), "smoke test failed for 2^32 case")
}

const matchsmall = {c
	var nums = [ -5, -3, -2, -1, 0, 1, 2, 4, 8, 10 ][:]
	for a : nums
		for b : nums
			var p = std.mkbigint(a)
			var q = std.mkbigint(b)

			var radd = std.mkbigint(a + b)
			var sadd = std.bigdup(p)
			std.bigadd(sadd, q)
			testr.check(c, std.bigeq(radd, sadd), "{} + {} != {} (was {})", a, b, a + b, sadd)

			var rsub = std.mkbigint(a - b)
			var ssub = std.bigdup(p)
			std.bigsub(ssub, q)
			testr.check(c, std.bigeq(rsub, ssub), "{} - {} != {} (was {})", a, b, a - b, ssub)

			var rmul = std.mkbigint(a * b)
			var smul = std.bigdup(p)
			std.bigmul(smul, q)
			testr.check(c, std.bigeq(rmul, smul), "{} * {} != {} (was {})", a, b, a * b, smul)


			if b != 0
				if a > 0 && b > 0
					var rmod = std.mkbigint(a % b)
					var smod = std.bigdup(p)
					std.bigmod(smod, q)
					testr.check(c, std.bigeq(rmod, smod), "{} % {} != {} (was {})", a, b, a % b, smod)
					std.bigfree(rmod)
				;;

				var rdiv = std.mkbigint(a / b)
				var sdiv = std.bigdup(p)
				std.bigdiv(sdiv, q)
				testr.check(c, std.bigeq(rdiv, sdiv), "{} / {} != {} (was {})", a, b, a / b, sdiv)

				std.bigfree(rdiv)
			;;

			std.bigfree(p)
			std.bigfree(q)
			std.bigfree(radd)
			std.bigfree(rsub)
			std.bigfree(rmul)
		;;
	;;
}

const comparisons = {c
	/* some comparison tests */
	var a = try(std.bigparse("1234_5678_1234_6789_6666_7777_8888"))
	var b = try(std.bigparse("2234_5678_1234_6789_6666_7777_8888"))
	testr.check(c, std.bigcmp(a, b) == `std.Before, "{} should be < {}", a, b)
	std.bigfree(a)
	std.bigfree(b)

	a = try(std.bigparse("36028797018963964"))
	b = try(std.bigparse("36028797018963958"))
	testr.check(c, std.bigcmp(a, b) == `std.After, "{} should be > {}", a, b)
	std.bigfree(a)
	std.bigfree(b)
}

const fmtzero = {c
	/* make sure we format '0' correctly */
	run(c, std.mk(`Val "0"), "0")
}

const smokediv = {c
	/* smoke test for division */
	run(c, std.mk(`Div (\
		std.mk(`Val "1234_5678_1234_6789_6666_7777_8888"), \
		std.mk(`Val "1234_5678_1234_6789_6666_7777"))), \
		"10000")
	run(c, std.mk(`Div (\
		std.mk(`Val "0xffff_1234_1234_1234_1234"), \
		std.mk(`Val "0xf010_1234_2314"))), \
		"4580035496")
	run(c, std.mk(`Div (\
		std.mk(`Val "5192296858534810493479828944327220"), \
		std.mk(`Val "75557863709417659441940"))), \
		"68719476751")
	run(c, std.mk(`Div (\
		std.mk(`Val "75557863709417659441940"), \
		std.mk(`Val "5192296858534810493479828944327220"))), \
		"0")
}

const smokemod = {c
	/* smoke test for mod */
	run(c, std.mk(`Mod (\
		std.mk(`Val "5192296858534810493479828944327220"), \
		std.mk(`Val "75557863709417659441940"))),\
		"257025710597479990280")

	run(c, std.mk(`Modpow (\
		std.mk(`Val "1"), \
		std.mk(`Val "3"), \
		std.mk(`Val "2"))), \
		"1")

	run(c, std.mk(`Modpow (\
		std.mk(`Val "5192296858534810493479828944327220"), \
		std.mk(`Val "75557863709417659441940"), \
		std.mk(`Val "755578"))), \
		"49054")
	run(c, std.mk(`Modpow (\
		std.mk(`Val "2393"), \
		std.mk(`Val "2"), \
		std.mk(`Val "6737"))), \
		"6736")
	run(c, std.mk(`Modpow (\
		std.mk(`Val "6193257528475266832463188301662235"), \
		std.mk(`Val "6157075615645799356061575607567581"), \
		std.mk(`Val "12314151231291598712123151215135163"))), \
		"1540381241336817586803754632242117")
	run(c, std.mk(`Modpow (\
		std.mk(`Val "7220"), \
		std.mk(`Val "755578"), \
		std.mk(`Val "75557863709417659441940"))), \
		"27076504425474791131220")
}

const addneg = {c
	run(c, std.mk(`Add (\
		std.mk(`Val "-1"), \
		std.mk(`Val "1"))), \
		"0")

	run(c, std.mk(`Add (\
		std.mk(`Val "4"), \
		std.mk(`Val "-2"))), \
		"2")

	run(c, std.mk(`Add (\
		std.mk(`Val "3"), \
		std.mk(`Val "-6"))), \
		"-3")

	run(c, std.mk(`Add (\
		std.mk(`Val "-4"), \
		std.mk(`Val "8"))), \
		"4")

	run(c, std.mk(`Add (\
		std.mk(`Val "-10"), \
		std.mk(`Val "5"))), \
		"-5")

	run(c, std.mk(`Add (\
		std.mk(`Val "-6"), \
		std.mk(`Val "-12"))), \
		"-18")

	run(c, std.mk(`Add (\
		std.mk(`Val "7"), \
		std.mk(`Val "-7"))), \
		"0")
}

const subneg = {c
	run(c, std.mk(`Sub (\
		std.mk(`Val "-1"), \
		std.mk(`Val "1"))), \
		"-2")

	run(c, std.mk(`Sub (\
		std.mk(`Val "1"), \
		std.mk(`Val "-1"))), \
		"2")

	run(c, std.mk(`Sub (\
		std.mk(`Val "4"), \
		std.mk(`Val "-2"))), \
		"6")

	run(c, std.mk(`Sub (\
		std.mk(`Val "3"), \
		std.mk(`Val "-6"))), \
		"9")

	run(c, std.mk(`Sub (\
		std.mk(`Val "-4"), \
		std.mk(`Val "8"))), \
		"-12")

	run(c, std.mk(`Sub (\
		std.mk(`Val "-10"), \
		std.mk(`Val "5"))), \
		"-15")

	run(c, std.mk(`Sub (\
		std.mk(`Val "-6"), \
		std.mk(`Val "-12"))), \
		"6")

	run(c, std.mk(`Sub (\
		std.mk(`Val "-14"), \
		std.mk(`Val "-7"))), \
		"-7")

	run(c, std.mk(`Sub (\
		std.mk(`Val "-8"), \
		std.mk(`Val "-8"))), \
		"0")
}

const run = {c : testr.ctx#, e : cmd#, res : byte[:]
	var buf : byte[4096]
	var v, n

	v = eval(e)
	n = std.bigbfmt(buf[:], v, 0)
	testr.check(c, std.eq(buf[:n], res), "{} != {}", buf[:n], res)
}

const eval = {e : cmd#
	var buf : byte[2048]
	var a, b, c	/* scratch vars */
	var n		/* buf len */

	match e#
	| `Add (x, y):	-> binop("+", std.bigadd, x, y)
	| `Sub (x, y):	-> binop("-", std.bigsub, x, y)
	| `Mul (x, y):	-> binop("*", std.bigmul, x, y)
	| `Div (x, y):	-> binop("/", std.bigdiv, x, y)
	| `Mod (x, y):	-> binop("%", std.bigmod, x, y)
	| `Shl (x, y):	-> binop("<<", std.bigshl, x, y)
	| `Shr (x, y):	-> binop(">>", std.bigshr, x, y)
	| `Val x:
		a = try(std.bigparse(x))
		n = std.bigbfmt(buf[:], a, 0)
		-> a
	| `Modpow (x, y, z):
		a = eval(x)
		b = eval(y)
		c = eval(z)
		-> std.bigmodpow(a, b, c)
	;;
}

const binop = {name, op, x, y
	var a, b

	a = eval(x)
	b = eval(y)
	op(a, b)
	std.bigfree(b)
	-> a
}

generic try = {x : std.option(@a)
	match x
	| `std.Some v:	-> v
	| `std.None:	std.die("failed to get val")
	;;
}

const somebigdivmod = {c
	var inputs : (byte[:], byte[:], byte[:], byte[:])[:] = [
		/* ( a,  b,  a/b,  a%b),  */
		("6185103187", 	"3519152458", 	"1", 	"2665950729"),
		("6980766832", 	"3087864937", 	"2", 	"805036958"),
		("4991855479", 	"6447381549", 	"0", 	"4991855479"),
		("6119892968", 	"8374603717", 	"0", 	"6119892968"),
		("7442542736", 	"5989044918", 	"1", 	"1453497818"),
		("4580939187", 	"4661174670", 	"0", 	"4580939187"),
		("4935237814", 	"3140079933", 	"1", 	"1795157881"),
		("1010425087", 	"446887574", 	"2", 	"116649939"),
		("1892124804", 	"4011235814", 	"0", 	"1892124804"),
		("2457899196", 	"1885658848", 	"1", 	"572240348"),
		("262030672", 	"1329267171", 	"0", 	"262030672"),
		("8430016289", 	"6215302855", 	"1", 	"2214713434"),
		("3347587073", 	"3855445301", 	"0", 	"3347587073"),
		("8014059310", 	"4622364950", 	"1", 	"3391694360"),
		("2934470708", 	"6439081239", 	"0", 	"2934470708"),
		("420917856", 	"5020252044", 	"0", 	"420917856"),
		("3862040257", 	"4601227994", 	"0", 	"3862040257"),
		("4439858911", 	"4387180962", 	"1", 	"52677949"),
		("165706876", 	"7452574618", 	"0", 	"165706876"),
		("268686047", 	"3047271875", 	"0", 	"268686047"),
		("83484791", 	"4899367309", 	"0", 	"83484791"),
		("1242378916", 	"4909326080", 	"0", 	"1242378916"),
		("980437482", 	"2305509838", 	"0", 	"980437482"),
		("118982655", 	"5051748984", 	"0", 	"118982655"),
		("8521162809", 	"2888108066", 	"2", 	"2744946677"),
		("4600797553", 	"7579789288", 	"0", 	"4600797553"),
		("5533774720", 	"3320264831", 	"1", 	"2213509889"),
	][:]

	for (aS, bS, qS, rS) : inputs
		var a     = try(std.bigparse(aS))
		var b     = try(std.bigparse(bS))
		var q_exp = try(std.bigparse(qS))
		var r_exp = try(std.bigparse(rS))

		var q_act, r_act
		(q_act, r_act) = std.bigdivmod(a, b)
		testr.check(c, std.bigeq(q_exp, q_act), "expected {} / {} to be {}, was {}", a, b, q_exp, q_act)
		testr.check(c, std.bigeq(r_exp, r_act), "expected {} % {} to be {}, was {}", a, b, r_exp, r_act)
	;;
}