shithub: mc

Download patch

ref: f9f93d1e447873ca3e5fa6c542eb34e8dd8d4b71
parent: 185f780a03fbfbb4655b7c07b3ac147980cede2d
author: Ori Bernstein <ori@eigenstate.org>
date: Sat Apr 7 20:59:24 EDT 2018

Constant time modpow.

--- a/lib/crypto/ct.myr
+++ b/lib/crypto/ct.myr
@@ -53,7 +53,7 @@
 generic ne = {a, b
 	const nshift = 8*sizeof(@t) - 1
 	var q = a ^ b
-	-> ((q | -q) >> nshift)^1
+	-> (q | -q) >> nshift
 }
 
 generic mux = {c, a, b
--- a/lib/crypto/ctbig.myr
+++ b/lib/crypto/ctbig.myr
@@ -1,4 +1,5 @@
 use std
+use iter
 
 use "ct"
 
@@ -25,7 +26,7 @@
 	const ctsub	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)
 	const ctmul	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)
 	//const ctdivmod	: (q : ctbig#, u : ctbig#, a : ctbig#, b : ctbig# -> void)
-	//const ctmodpow	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)
+	const ctmodpow	: (r : ctbig#, a : ctbig#, b : ctbig#, m : ctbig# -> void)
 
 	const ctiszero	: (v : ctbig# -> bool)
 	const cteq	: (a : ctbig#, b : ctbig# -> bool)
@@ -35,6 +36,9 @@
 	const ctlt	: (a : ctbig#, b : ctbig# -> bool)
 	const ctle	: (a : ctbig#, b : ctbig# -> bool)
 
+	/* for testing */
+	const growmod	: (r : ctbig#, a : ctbig#, k : uint32, m : ctbig# -> void)
+
 	impl std.equatable ctbig#
 ;;
 
@@ -59,8 +63,8 @@
 	var ct : ctbig#
 
 	ct = std.vanext(ap)
-	for d : ct.dig
-		std.sbfmt(sb, "{w=8,p=0,x}", d)
+	for d : iter.byreverse(ct.dig)
+		std.sbfmt(sb, "{w=8,p=0,x}.", d)
 	;;
 }
 
@@ -89,6 +93,13 @@
 	])
 }
 
+const ctdup = {v
+	-> std.mk([
+		.nbit=v.nbit,
+		.dig=std.sldup(v.dig)
+	])
+}
+
 const ct2big = {ct
 	-> std.mk([
 		.sign=1,
@@ -155,6 +166,10 @@
 }
 
 const ctadd = {r, a, b
+	ctaddcc(r, a, b, 1)
+}
+
+const ctaddcc = {r, a, b, ctl
 	var v, i, carry
 
 	checksz(a, b)
@@ -163,12 +178,16 @@
 	carry = 0
 	for i = 0; i < a.dig.len; i++
 		v = (a.dig[i] : uint64) + (b.dig[i] : uint64) + carry;
-		r.dig[i] = (v  : uint32)
+		r.dig[i] = mux(ctl, (v  : uint32), r.dig[i])
 		carry = v >> 32
 	;;
 }
 
 const ctsub = {r, a, b
+	ctsubcc(r, a, b, 1)
+}
+
+const ctsubcc = {r, a, b, ctl
 	var borrow, v, i
 
 	checksz(a, b)
@@ -178,10 +197,10 @@
 	for i = 0; i < a.dig.len; i++
 		v = (a.dig[i] : uint64) - (b.dig[i] : uint64) - borrow
 		borrow = (v & (1<<63)) >> 63
-		v = mux(borrow, v + Base, v)
-		r.dig[i] = (v  : uint32)
+		r.dig[i] = mux(ctl, (v  : uint32), r.dig[i])
 	;;
 	clip(r)
+	-> borrow
 }
 
 const ctmul = {r, a, b
@@ -215,6 +234,186 @@
 	clip(r)
 }
 
+/*
+ * Returns the top digit in the number that has
+ * a bit set. This is useful for finding our division.
+ */
+ const topfull = {n : ctbig#
+	var top
+
+	top = 0
+	for var i = 0; i < n.dig.len; i++
+		top = mux(n.dig[i], i, top)
+	;;
+	-> 0
+}
+
+/*
+ * Multiplies by 2**32 mod m
+ */
+const growmod = {r, a, k, m
+	var a0, a1, b0, hi, g, q, tb, e
+	var chf, clow, under, over
+	var cc : uint64
+
+	checksz(a, m)
+	std.assert(a.dig.len > 1, "bad modulus")
+	std.assert(a.nbit % 32 == 0, "ragged sizes not yet supported")
+	//std.assert(a.dig[a.dig.len - 1] & (1 << 31) != 0, "top of mod not set")
+
+	a0 = (a.dig[m.dig.len - 1] : uint64) << 32
+	a1 = (a.dig[m.dig.len - 2] : uint64) << 0
+	b0 = (m.dig[m.dig.len - 1] : uint64)
+	
+	/* 
+	 * We hold the top digit here, so 
+	 * this keeps the number of digits the same, and
+	 * as a result, keeps checksz() happy.
+	 */
+	hi = a.dig[a.dig.len - 1]
+
+	/* Do the multiplication of x by 2**32 */
+	std.slcp(r.dig[1:], a.dig[:a.dig.len-1])
+	r.dig[0] = k
+	g = ((a0 + a1) / b0 : uint32)
+	e = eq(a0, b0)
+	q = mux((e : uint32), 0xffffffff, mux(eq(g, 0), 0, g - 1));
+
+	cc = 0;
+	tb = 1;
+	for var u = 0; u < r.dig.len; u++
+		var mw, zw, xw, nxw
+		var zl : uint64
+
+		mw = m.dig[u];
+		zl = (mw : uint64) * (q : uint64) + cc
+		cc = zl >> 32
+		zw = (zl : uint32)
+		xw = r.dig[u]
+		nxw = xw - zw;
+		cc += (gt(nxw, xw) : uint64)
+		r.dig[u] = nxw;
+		tb = mux(eq(nxw, mw), tb, gt(nxw, mw));
+	;;
+
+	/*
+	 * We can either underestimate or overestimate q, 
+	 *  - If we overestimated, either cc < hi, or cc == hi && tb != 0.
+	 *  - If we overestimated, cc > hi.
+	 *  - Otherwise, we got it exactly right.
+	 * 
+	 * If we overestimated, we need to subtract 'm' once. If we
+	 * underestimated, we need to add it once.
+	 */
+	chf = (cc >> 32 : uint32)
+	clow = (cc >> 0 : uint32)
+	over = chf | gt(clow, hi);
+	under = ~over & (tb | (~chf & lt(clow, hi)));
+	ctaddcc(r, r, m, over);
+	ctsubcc(r, r, m, under);
+
+}
+
+const tomonty = {r, x, m
+	checksz(x, r)
+	checksz(x, m)
+
+	std.slcp(r.dig, x.dig)
+	for var i = 0; i < m.dig.len; i++
+		growmod(r, r, 0, m)
+	;;
+}
+
+const ccopy = {r, v, ctl
+	checksz(r, v)
+	for var i = 0; i < r.dig.len; i++
+		r.dig[i] = mux(ctl, v.dig[i], r.dig[i])
+	;;
+}
+
+const muladd = {a, b, k
+	-> (a : uint64) * (b : uint64) + (k : uint64)
+}
+
+const montymul = {r : ctbig#, x : ctbig#, y : ctbig#, m : ctbig#, m0i : uint32
+	var dh : uint64
+	var s
+
+	checksz(x, y)
+	checksz(x, m)
+	checksz(x, r)
+
+	std.slfill(r.dig, 0)
+	dh = 0
+	for var u = 0; u < x.dig.len; u++
+		var f : uint32, xu : uint32
+		var r1 : uint64, r2 : uint64, zh : uint64
+
+		xu = x.dig[u]
+		f = (r.dig[0] + x.dig[u] * y.dig[0]) * m0i;
+		r1 = 0;
+		r2 = 0;
+		for var v = 0; v < y.dig.len; v++
+			var z : uint64
+			var t : uint32
+
+			z = muladd(xu, y.dig[v], r.dig[v]) + r1
+			r1 = z >> 32
+			t = (z : uint32)
+			z = muladd(f, m.dig[v], t) + r2
+			r2 = z >> 32
+			if v != 0
+				r.dig[v - 1] = (z : uint32)
+			;;
+		;;
+		zh = dh + r1 + r2;
+		r.dig[r.dig.len - 1] = (zh : uint32)
+		dh = zh >> 32;
+	;;
+
+	/*
+	 * r may still be greater than m at that point; notably, the
+	 * 'dh' word may be non-zero.
+	 */
+	s = ne(dh, 0) | (ctge(r, m) : uint64)
+	ctsubcc(r, r, m, (s : uint32))
+}
+
+const ninv32 = {x
+	var y
+
+	y = 2 - x
+	y *= 2 - y * x
+	y *= 2 - y * x
+	y *= 2 - y * x
+	y *= 2 - y * x
+	-> mux(x & 1, -y, 0)
+}
+
+const ctmodpow = {r, a, e, m
+	var t1, t2, m0i, ctl, k, d
+	var n = 0
+
+	t1 = ctdup(a)
+	t2 = ctzero(a.nbit)
+	m0i = ninv32(m.dig[0])
+
+	tomonty(t1, a, m);
+	std.slfill(r.dig, 0);
+	r.dig[0] = 1;
+	for var i = 0; i < e.nbit; i++
+		k = (i : uint32)
+		d = e.dig[e.dig.len - (k>>5) - 1]
+		ctl = (d >> (k & 0x1f)) & 1
+		montymul(t2, r, t1, m, m0i)
+		ccopy(r, t2, ctl);
+		montymul(t2, t1, t1, m, m0i);
+		std.slcp(t1.dig, t2.dig);
+	;;
+	ctfree(t1)
+	ctfree(t2)
+}
+
 const ctiszero = {a
 	var z, zz
 
@@ -227,18 +426,14 @@
 }
 
 const cteq = {a, b
-	var z, d, e
+	var ne
 
 	checksz(a, b)
-
-	e = 1
+	ne = 0
 	for var i = 0; i < a.dig.len; i++
-		z = a.dig[i] - b.dig[i]
-		/* z != 0 ? 0 : 1 */
-		d = mux(z, 0, 1)
-		e = mux(e, d, 0)
+		ne = ne | a.dig[i] - b.dig[i]
 	;;
-	-> (e : bool)
+	-> (not(ne) : bool)
 }
 
 const ctne = {a, b
@@ -249,17 +444,7 @@
 }
 
 const ctgt = {a, b
-	var e, d, g
-
-	checksz(a, b)
-
-	g = 0
-	for var i = 0; i < a.dig.len; i++
-		e = not(a.dig[i] - b.dig[i])
-		d = gt(a.dig[i], b.dig[i])
-		g = mux(e, g, d) 
-	;;
-	-> (g : bool)
+	-> (ctsubcc(b, b, a, 0) : bool)
 }
 
 const ctge = {a, b
@@ -270,17 +455,7 @@
 }
 
 const ctlt = {a, b
-	var e, d, l
-
-	checksz(a, b)
-
-	l = 0
-	for var i = 0; i < a.dig.len; i++
-		e = not(a.dig[i] - b.dig[i])
-		d = gt(a.dig[i], b.dig[i])
-		l = mux(e, l, d) 
-	;;
-	-> (l : bool)
+	-> (ctsubcc(a, a, b, 0) : bool)
 }
 
 const ctle = {a, b
--- a/lib/crypto/test/ctbig.myr
+++ b/lib/crypto/test/ctbig.myr
@@ -9,60 +9,118 @@
 	testr.run([
 		/* normal */
 		[.name="add", .fn={ctx
-			do(ctx, crypto.ctadd, Nbit,
+			do2(ctx, crypto.ctadd, Nbit,
 				"5192296858610368357189246603769160",
 				"5192296858534810493479828944327220", 
 				"75557863709417659441940")
 		}],
 		[.name="sub", .fn={ctx
-			do(ctx, crypto.ctsub, Nbit,
+			do2(ctx, crypto.ctsub, Nbit,
 				"5192296858459252629770411284885280",
 				"5192296858534810493479828944327220", 
 				"75557863709417659441940")
 		}],
 		[.name="mul", .fn={ctx
-			do(ctx, crypto.ctmul, Nbit,
+			do2(ctx, crypto.ctmul, Nbit,
 				"392318858376010676506814412592879878824393346033951606800",
 				"5192296858534810493479828944327220", 
 				"75557863709417659441940")
 		}],
-		
+		[.name="growmod", .fn={ctx
+			do2(ctx, growmod0, Nbit,
+				"259016584597313952181375284077740334036",
+				"137304361882109849168381018424069802644",
+				"279268927326277818181333274586733399084")
+			}
+		],
+		/* comparisons */
+		[.name="lt-less", .fn={ctx
+			dobool(ctx, crypto.ctlt, Nbit,
+				true,
+				"137304361882109849168381018424069802644",
+				"279268927326277818181333274586733399084")
+			}
+		],
+		[.name="lt-equal", .fn={ctx
+			dobool(ctx, crypto.ctlt, Nbit,
+				false,
+				"137304361882109849168381018424069802644",
+				"137304361882109849168381018424069802644")
+			}
+		],
+		[.name="lt-greater", .fn={ctx
+			dobool(ctx, crypto.ctlt, Nbit,
+				false,
+				"279268927326277818181333274586733399084",
+				"137304361882109849168381018424069802644")
+			}
+		],
+		[.name="gt-less", .fn={ctx
+			dobool(ctx, crypto.ctgt, Nbit,
+				false,
+				"137304361882109849168381018424069802644",
+				"279268927326277818181333274586733399084")
+			}
+		],
+		[.name="gt-equal", .fn={ctx
+			dobool(ctx, crypto.ctgt, Nbit,
+				false,
+				"137304361882109849168381018424069802644",
+				"137304361882109849168381018424069802644")
+			}
+		],
+		[.name="gt-greater", .fn={ctx
+			dobool(ctx, crypto.ctgt, Nbit,
+				true,
+				"279268927326277818181333274586733399084",
+				"137304361882109849168381018424069802644")
+			}
+		],
+
+		[.name="growmodsmall", .fn={ctx
+			do2(ctx, growmod0, Nbit,
+				"30064771072",
+				"7",
+				"279268927326277818181333274586733399084")
+			}
+		],
 		[.name="addfunky", .fn={ctx
-			do(ctx, crypto.ctadd, Nfunky,
+			do2(ctx, crypto.ctadd, Nfunky,
 				"75540728658750274549064",
 				"5192296858534810493479828944327220", 
 				"75557863709417659441940")
 		}],
 		[.name="subfunky", .fn={ctx
-			do(ctx, crypto.ctsub, Nfunky,
+			do2(ctx, crypto.ctsub, Nfunky,
 				"528887911047229543018272",
 				"5192296858534810493479828944327220", 
 				"75557863709417659441940")
 		}],
 		[.name="mulfunky", .fn={ctx
-			do(ctx, crypto.ctmul, Nfunky,
+			do2(ctx, crypto.ctmul, Nfunky,
 				"434472066238453871708176",
 				"5192296858534810493479828944327220", 
 				"75557863709417659441940")
 		}],
 		//[.name="div", .fn={ctx
-		//	do(ctx, div,
+		//	do2(ctx, div,
 		//		"75557863709417659441940",
 		//		"392318858376010676506814412592879878824393346033951606800",
 		//		"5192296858534810493479828944327220")
 		//}],
 		//[.name="mod", .fn={ctx
-		//	do(ctx, mod,
+		//	do2(ctx, mod,
 		//		"75557863709417659441940",
 		//		"392318858376010676506814412592879878824393346033951606800",
 		//		"5192296858534810493479828944327220")
 		//}],
-		//[.name="modpow", .fn={ctx
-		//	r = do(ctx, crypto.ctsub,
-		//		"5192296858459252629770411284885280"
-		//		"5192296858534810493479828944327220", 
-		//		"75557863709417659441940")
-		//}],
+		[.name="modpow", .fn={ctx
+			do3(ctx, crypto.ctmodpow, Nbit,
+				"1231231254019581241243091223098123",
+				"1231231254019581241243091223098123",
+				"1",
+				"238513807008428752753137056878245001837")
+		}],
 
 	][:])
 }
@@ -80,8 +138,29 @@
 //	z = crypto.ctzero(a.nbit)
 //	crypto.ctdivmod(z, r, a, b)
 //}
-//
-const do = {ctx, op, nbit, estr, astr, bstr
+
+const growmod0 = {r, a, b
+	crypto.growmod(r, a, 0, b)
+}
+
+const dobool : (ctx : testr.ctx#, op : (a : crypto.ctbig#, b : crypto.ctbig# -> bool), nbit : std.size, e : bool, astr : byte[:], bstr : byte[:] -> void) = {ctx, op, nbit, e, astr, bstr
+	var r, a, ai, b, bi
+
+	r = crypto.ctzero(nbit)
+	ai = std.get(std.bigparse(astr))
+	bi = std.get(std.bigparse(bstr))
+	a = crypto.big2ct(ai, nbit)
+	b = crypto.big2ct(bi, nbit)
+
+	std.bigfree(ai)
+	std.bigfree(bi)
+	testr.eq(ctx, op(a, b), e)
+
+	crypto.ctfree(a)
+	crypto.ctfree(b)
+}
+
+const do2 = {ctx, op, nbit, estr, astr, bstr
 	var r, a, ai, b, bi, e, ei
 
 	r = crypto.ctzero(nbit)
@@ -104,6 +183,35 @@
 	crypto.ctfree(e)
 	crypto.ctfree(a)
 	crypto.ctfree(b)
+}
+
+
+const do3 = {ctx, op, nbit, estr, astr, bstr, cstr
+	var r, a, ai, b, bi, c, ci, e, ei
+
+	r = crypto.ctzero(nbit)
+	ei = std.get(std.bigparse(estr))
+	ai = std.get(std.bigparse(astr))
+	bi = std.get(std.bigparse(bstr))
+	ci = std.get(std.bigparse(cstr))
+	e = crypto.big2ct(ei, nbit)
+	a = crypto.big2ct(ai, nbit)
+	b = crypto.big2ct(bi, nbit)
+	c = crypto.big2ct(ci, nbit)
+
+	std.bigfree(ei)
+	std.bigfree(ai)
+	std.bigfree(bi)
+
+	op(r, a, b, c)
+
+	testr.eq(ctx, r, e)
+
+	crypto.ctfree(r)
+	crypto.ctfree(e)
+	crypto.ctfree(a)
+	crypto.ctfree(b)
+	crypto.ctfree(c)
 }
 
 
--- a/lib/std/hashfuncs.myr
+++ b/lib/std/hashfuncs.myr
@@ -18,6 +18,12 @@
 		}
 	;;
 
+	impl equatable bool =
+		eq = {a, b
+			-> a == b
+		}
+	;;
+
 	impl equatable @a :: integral,numeric @a =
 		eq = {a, b
 			-> a == b