shithub: mc

Download patch

ref: 473ce973528106273540303dc11f4726ee870c3c
parent: 731dd8851a19c11f91d01bdc8c469c77d653a4ea
author: Ori Bernstein <ori@eigenstate.org>
date: Sat Mar 24 21:53:12 EDT 2018

Add constant time bigint ops.

--- a/lib/crypto/ctbig.myr
+++ b/lib/crypto/ctbig.myr
@@ -20,11 +20,12 @@
 	const ct2big	: (v : ctbig# -> std.bigint#)
 	const big2ct	: (v : std.bigint#, nbit : std.size -> ctbig#)
 
+	/* arithmetic */
 	const ctadd	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)
 	const ctsub	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)
 	const ctmul	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)
-	//const ctdivmod	: (r : ctbig#, m : ctbig#, a : ctbig#, b : ctbig# -> void)
-	const ctmodpow	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)
+	//const ctdivmod	: (q : ctbig#, u : ctbig#, a : ctbig#, b : ctbig# -> void)
+	//const ctmodpow	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)
 
 	const ctiszero	: (v : ctbig# -> bool)
 	const cteq	: (a : ctbig#, b : ctbig# -> bool)
@@ -33,11 +34,36 @@
 	const ctge	: (a : ctbig#, b : ctbig# -> bool)
 	const ctlt	: (a : ctbig#, b : ctbig# -> bool)
 	const ctle	: (a : ctbig#, b : ctbig# -> bool)
+
+	impl std.equatable ctbig#
 ;;
 
 const Bits = 32
 const Base = 0x100000000ul
 
+impl std.equatable ctbig# =
+	eq = {a, b
+		-> cteq(a, b)
+	}
+;;
+
+const __init__ = {
+	var ct : ctbig#
+
+	ct = ctzero(0)
+	std.fmtinstall(std.typeof(ct), ctfmt)
+	ctfree(ct)
+}
+
+const ctfmt = {sb, ap, opts
+	var ct : ctbig#
+
+	ct = std.vanext(ap)
+	for d : ct.dig
+		std.sbfmt(sb, "{w=8,p=0,x}", d)
+	;;
+}
+
 generic mkctbign = {v : @a, nbit : std.size :: integral,numeric @a
 	var a
 	var val
@@ -59,7 +85,7 @@
 const ctzero = {nbit
 	-> std.mk([
 		.nbit=nbit,
-		.dig=std.slalloc(ndig(nbit)),
+		.dig=std.slzalloc(ndig(nbit)),
 	])
 }
 
@@ -70,13 +96,13 @@
 	])
 }
 
-const big2ct = {ct, nbit
+const big2ct = {big, nbit
 	var v, n, l
 
 	n = ndig(nbit)
-	l = std.min(n, ct.dig.len)
+	l = std.min(n, big.dig.len)
 	v = std.slzalloc(n)
-	std.slcp(v, ct.dig[:l])
+	std.slcp(v[:l], big.dig[:l])
 	-> clip(std.mk([
 		.nbit=nbit,
 		.dig=v,
@@ -129,17 +155,16 @@
 }
 
 const ctadd = {r, a, b
-	var v, i, carry, n
+	var v, i, carry
 
 	checksz(a, b)
 	checksz(a, r)
 
 	carry = 0
-	n = max(a.dig.len, b.dig.len)
-	for i = 0; i < n; i++
+	for i = 0; i < a.dig.len; i++
 		v = (a.dig[i] : uint64) + (b.dig[i] : uint64) + carry;
 		r.dig[i] = (v  : uint32)
-		carry >>= 32
+		carry = v >> 32
 	;;
 }
 
@@ -190,16 +215,6 @@
 	clip(r)
 }
 
-const ctmodpow = {res, a, b
-	/* find rinv, mprime */
-	
-	/* convert to monty space */
-
-	/* do the modpow */
-
-	/* and come back */
-}
-
 const ctiszero = {a
 	var z, zz
 
@@ -219,7 +234,8 @@
 	e = 1
 	for var i = 0; i < a.dig.len; i++
 		z = a.dig[i] - b.dig[i]
-		d = mux(z, 1, 0)
+		/* z != 0 ? 0 : 1 */
+		d = mux(z, 0, 1)
 		e = mux(e, d, 0)
 	;;
 	-> (e : bool)
@@ -275,7 +291,7 @@
 }
 
 const ndig = {nbit
-	-> (nbit + 8*sizeof(uint32) - 1)/sizeof(uint32)
+	-> (nbit + 8*sizeof(uint32) - 1)/(8*sizeof(uint32))
 }
 
 const checksz = {a, b
@@ -284,11 +300,41 @@
 }
 
 const clip = {v
-	var mask, edge
-	
-
-	edge = v.nbit & (Bits - 1)
-	mask = (1 << edge) - 1
-	v.dig[v.dig.len - 1] &= (mask : uint32)
+//	var mask, edge
+//
+//	edge = v.nbit & (Bits - 1)
+//	mask = (1 << edge) - 1
+//	v.dig[v.dig.len - 1] &= (mask : uint32)
 	-> v
 }
+
+const nlz = {a : uint32
+	var n
+
+	if a == 0
+		-> 32
+	;;
+	n = 0
+	if a <= 0x0000ffff
+		n += 16
+		a <<= 16
+	;;
+	if a <= 0x00ffffff
+		n += 8
+		a <<= 8
+	;;
+	if a <= 0x0fffffff
+		n += 4
+		a <<= 4
+	;;
+	if a <= 0x3fffffff
+		n += 2
+		a <<= 2
+	;;
+	if a <= 0x7fffffff
+		n += 1
+		a <<= 1
+	;;
+	-> n
+}
+
--- /dev/null
+++ b/lib/crypto/test/ctbig.myr
@@ -1,0 +1,88 @@
+use std
+use crypto
+use testr
+
+const Nbit = 128
+
+const main = {
+	testr.run([
+		[.name="add", .fn={ctx
+			do(ctx, crypto.ctadd,
+				"5192296858610368357189246603769160",
+				"5192296858534810493479828944327220", 
+				"75557863709417659441940")
+		}],
+		[.name="sub", .fn={ctx
+			do(ctx, crypto.ctsub,
+				"5192296858459252629770411284885280",
+				"5192296858534810493479828944327220", 
+				"75557863709417659441940")
+		}],
+		[.name="mul", .fn={ctx
+			do(ctx, crypto.ctmul,
+				"392318858376010676506814412592879878824393346033951606800",
+				"5192296858534810493479828944327220", 
+				"75557863709417659441940")
+		}],
+		[.name="div", .fn={ctx
+			do(ctx, div,
+				"75557863709417659441940",
+				"392318858376010676506814412592879878824393346033951606800",
+				"5192296858534810493479828944327220")
+		}],
+		[.name="mod", .fn={ctx
+			do(ctx, mod,
+				"75557863709417659441940",
+				"392318858376010676506814412592879878824393346033951606800",
+				"5192296858534810493479828944327220")
+		}],
+		//[.name="modpow", .fn={ctx
+		//	r = do(ctx, crypto.ctsub,
+		//		"5192296858459252629770411284885280"
+		//		"5192296858534810493479828944327220", 
+		//		"75557863709417659441940")
+		//}],
+
+	][:])
+}
+
+const div = {r, a, b
+	var z
+
+	z = crypto.ctzero(a.nbit)
+	crypto.ctdivmod(r, z, a, b)
+}
+
+const mod = {r, a, b
+	var z
+
+	z = crypto.ctzero(a.nbit)
+	crypto.ctdivmod(z, r, a, b)
+}
+
+const do = {ctx, op, estr, astr, bstr
+	var r, a, ai, b, bi, e, ei
+
+	r = crypto.ctzero(Nbit)
+	ei = std.get(std.bigparse(estr))
+	ai = std.get(std.bigparse(astr))
+	bi = std.get(std.bigparse(bstr))
+	e = crypto.big2ct(ei, Nbit)
+	a = crypto.big2ct(ai, Nbit)
+	b = crypto.big2ct(bi, Nbit)
+
+	std.bigfree(ei)
+	std.bigfree(ai)
+	std.bigfree(bi)
+
+	op(r, a, b)
+
+	testr.eq(ctx, r, e)
+
+	crypto.ctfree(r)
+	crypto.ctfree(e)
+	crypto.ctfree(a)
+	crypto.ctfree(b)
+}
+
+
--- a/lib/std/bigint.myr
+++ b/lib/std/bigint.myr
@@ -741,7 +741,6 @@
 	;;
 	/* undo the biasing for remainder */
 	bigshri(u, shift)
-	trim(q)
 	bigfree(v)
 	-> (trim(q), trim(u))
 }