ref: cbf8e144771b88703180db26d74f1c65e45d84fe
dir: /lib/crypto/ctbig.myr/
use std use iter use "ct" pkg crypto = type ctbig = struct nbit : std.size dig : uint32[:] /* little endian, no leading zeros. */ ;; generic mkctbign : (v : @a, nbit : std.size -> ctbig#) :: numeric,integral @a const ctzero : (nbit : std.size -> ctbig#) const ctbytesle : (v : ctbig# -> byte[:]) const ctbytesbe : (v : ctbig# -> byte[:]) const mkctbigle : (v : byte[:], nbit : std.size -> ctbig#) const mkctbigbe : (v : byte[:], nbit : std.size -> ctbig#) const ctfree : (v : ctbig# -> void) const ctbigdup : (v : ctbig# -> ctbig#) pkglocal const ct2big : (v : ctbig# -> std.bigint#) pkglocal const big2ct : (v : std.bigint#, nbit : std.size -> ctbig#) /* arithmetic */ pkglocal const ctadd : (r : ctbig#, a : ctbig#, b : ctbig# -> void) pkglocal const ctsub : (r : ctbig#, a : ctbig#, b : ctbig# -> void) pkglocal const ctmul : (r : ctbig#, a : ctbig#, b : ctbig# -> void) pkglocal const ctmodpow : (r : ctbig#, a : ctbig#, b : ctbig#, m : ctbig# -> void) pkglocal const ctiszero : (v : ctbig# -> bool) pkglocal const cteq : (a : ctbig#, b : ctbig# -> bool) pkglocal const ctne : (a : ctbig#, b : ctbig# -> bool) pkglocal const ctgt : (a : ctbig#, b : ctbig# -> bool) pkglocal const ctge : (a : ctbig#, b : ctbig# -> bool) pkglocal const ctlt : (a : ctbig#, b : ctbig# -> bool) pkglocal const ctle : (a : ctbig#, b : ctbig# -> bool) /* for testing */ pkglocal const growmod : (r : ctbig#, a : ctbig#, k : uint32, m : ctbig# -> void) pkglocal const clip : (v : ctbig# -> ctbig#) impl std.equatable ctbig# ;; const Bits = 32 const Base = 0x100000000ul impl std.equatable ctbig# = eq = {a, b -> cteq(a, b) } ;; const __init__ = { var ct : ctbig# ct = ctzero(0) std.fmtinstall(std.typeof(ct), ctfmt) ctfree(ct) } const ctfmt = {sb, ap, opts var ct : ctbig# ct = std.vanext(ap) for d : iter.byreverse(ct.dig) std.sbfmt(sb, "{w=8,p=0,x}.", d) ;; } generic mkctbign = {v : @a, nbit : std.size :: integral,numeric @a var a var val a = std.zalloc() val = (v : uint64) a.nbit = nbit a.dig = std.slalloc(ndig(nbit)) if nbit > 0 a.dig[0] = (val : uint32) ;; if nbit > 32 a.dig[1] = (val >> 32 : uint32) ;; -> clip(a) } const ctzero = {nbit -> std.mk([ .nbit=nbit, .dig=std.slzalloc(ndig(nbit)), ]) } const ctdup = {v -> std.mk([ .nbit=v.nbit, .dig=std.sldup(v.dig) ]) } const ct2big = {ct -> std.mk([ .sign=1, .dig=std.sldup(ct.dig) ]) } const big2ct = {big, nbit var v, n, l n = ndig(nbit) l = std.min(n, big.dig.len) v = std.slzalloc(n) std.slcp(v[:l], big.dig[:l]) -> clip(std.mk([ .nbit=nbit, .dig=v, ])) } const mkctbigle = {v, nbit var a, last, i, o, off /* It's ok to depend on the length of v here: we can leak the size of the numbers. */ o = 0 a = std.slzalloc(ndig(nbit)) for i = 0; i + 4 <= v.len; i += 4 a[o++] = \ ((v[i + 0] : uint32) << 0) | \ ((v[i + 1] : uint32) << 8) | \ ((v[i + 2] : uint32) << 16) | \ ((v[i + 3] : uint32) << 24) ;; if i != v.len last = 0 for i; i < v.len; i++ off = i & 0x3 last |= (v[i] : uint32) << (8 *off) ;; a[o++] = last ;; -> clip(std.mk([.nbit=nbit, .dig=a])) } const mkctbigbe = {v, nbit var a, i, o, tail : byte[4] /* It's ok to depend on the length of v here: we can leak the size of the numbers. */ o = 0 a = std.slzalloc(ndig(nbit)) for i = v.len ; i >= 4; i -= 4 a[o++] = std.getbe32(v[i-4:i]) ;; if i != 0 std.slfill(tail[:], 0) std.slcp(tail[4-i:], v[:i]) a[o++] = std.getbe32(tail[:]) ;; -> clip(std.mk([.nbit=nbit, .dig=a])) } const ctbytesle = {v var d, i, n, o, ret o = 0 n = (v.nbit + 7) / 8 ret = std.slalloc(n) for i = 0; i * 4 < n; i++ d = v.dig[i] ret[o++] = (d >> 0 : byte) ret[o++] = (d >> 8 : byte) ret[o++] = (d >> 16 : byte) ret[o++] = (d >> 24 : byte) ;; if i * 4 != n d = v.dig[i] for ; i < n; i++ ret[o++] = (d : byte) d >>= 8 ;; ;; -> ret } const ctbytesbe = {v : ctbig# var d : uint32, i, n, o, ret i = v.dig.len - 1 o = 0 n = (v.nbit + 7) / 8 ret = std.slalloc(n) if n & 0x3 != 0 d = v.dig[i--] for var j = n & 0x3 + 1; j > 0; j-- ret[o++] = (d >> 8*(j - 1 : uint32): byte) ;; ;; for ; i >= 0 ; i-- d = v.dig[i] ret[o++] = (d >> 24 : byte) ret[o++] = (d >> 16 : byte) ret[o++] = (d >> 8 : byte) ret[o++] = (d >> 0 : byte) ;; -> ret } const ctbigdup = {v -> std.mk([ .nbit=v.nbit, .dig=std.sldup(v.dig), ]) } const ctfree = {v std.slfree(v.dig) std.free(v) } const ctadd = {r, a, b ctaddcc(r, a, b, 1) } const ctaddcc = {r, a, b, ctl var v, i, carry checksz(a, b) checksz(a, r) carry = 0 for i = 0; i < a.dig.len; i++ v = (a.dig[i] : uint64) + (b.dig[i] : uint64) + carry; r.dig[i] = mux(ctl, (v : uint32), r.dig[i]) carry = v >> 32 ;; clip(r) } const ctsub = {r, a, b ctsubcc(r, a, b, 1) } const ctsubcc = {r, a, b, ctl var borrow, v, i checksz(a, b) checksz(a, r) borrow = 0 for i = 0; i < a.dig.len; i++ v = (a.dig[i] : uint64) - (b.dig[i] : uint64) - borrow borrow = (v & (1<<63)) >> 63 r.dig[i] = mux(ctl, (v : uint32), r.dig[i]) ;; clip(r) -> borrow } const ctmul = {r, a, b var i, j var ai, bj, wij var carry, t var w checksz(a, b) checksz(a, r) w = std.slzalloc(a.dig.len + b.dig.len) for j = 0; j < b.dig.len; j++ carry = 0 for i = 0; i < a.dig.len; i++ ai = (a.dig[i] : uint64) bj = (b.dig[j] : uint64) wij = (w[i+j] : uint64) t = ai * bj + wij + carry w[i+j] = (t : uint32) carry = t >> 32 ;; w[i + j] = (carry : uint32) ;; /* safe to leak that a == r; not data dependent */ std.slgrow(&w, a.dig.len) if a == r std.slfree(a.dig) ;; r.dig = w[:a.dig.len] clip(r) } /* * Returns the top digit in the number that has * a bit set. This is useful for finding our division. */ const topfull = {n : ctbig# var top top = 0 for var i = 0; i < n.dig.len; i++ top = mux(n.dig[i], i, top) ;; -> 0 } const unalignedword = {v, bit var lo, hi, s, i s = (bit & 0x1f : uint32) i = (bit >> 5 : uint32) lo = v.dig[i] if s == 0 hi = 0 else hi = v.dig[i + 1] ;; -> (lo >> s) | (hi << (32 - s)) } /* * Multiplies by 2**32 mod m */ const growmod = {r, a, k, m var a0, a1, b0, hi, g, q, tb, e var chf, clow, under, over var cc : uint64 checksz(a, m) std.assert(a.dig.len > 1, "bad modulus\n") std.assert(m.dig[m.dig.len - 1] & (1 << 31) != 0, "top of mod not set: m={}, nbit={}\n", m, m.nbit) std.assert(m.nbit % 32 == 0, "ragged sizes not yet supported: a.nbit=={}\n", a.nbit) a0 = (unalignedword(a, a.nbit - 32) : uint64) << 32 a1 = (unalignedword(a, a.nbit - 64) : uint64) << 0 b0 = (unalignedword(m, m.nbit - 32) : uint64) /* * We hold the top digit here, so * this keeps the number of digits the same, and * as a result, keeps checksz() happy. */ hi = a.dig[a.dig.len - 1] /* Do the multiplication of x by 2**32 */ std.slcp(r.dig[1:], a.dig[:a.dig.len-1]) r.dig[0] = k g = ((a0 + a1) / b0 : uint32) e = eq(a0, b0) q = mux((e : uint32), 0xffffffff, mux(eq(g, 0), 0, g - 1)); cc = 0; tb = 1; for var u = 0; u < r.dig.len; u++ var mw, zw, xw, nxw var zl : uint64 mw = m.dig[u]; zl = (mw : uint64) * (q : uint64) + cc cc = zl >> 32 zw = (zl : uint32) xw = r.dig[u] nxw = xw - zw; cc += (gt(nxw, xw) : uint64) r.dig[u] = nxw; tb = mux(eq(nxw, mw), tb, gt(nxw, mw)); ;; /* * We can either underestimate or overestimate q, * - If we overestimated, either cc < hi, or cc == hi && tb != 0. * - If we overestimated, cc > hi. * - Otherwise, we got it exactly right. * * If we overestimated, we need to subtract 'm' once. If we * underestimated, we need to add it once. */ chf = (cc >> 32 : uint32) clow = (cc >> 0 : uint32) over = chf | gt(clow, hi); under = ~over & (tb | (~chf & lt(clow, hi))); ctaddcc(r, r, m, over); ctsubcc(r, r, m, under); clip(r) } const tomonty = {r, x, m checksz(x, r) checksz(x, m) std.slcp(r.dig, x.dig) for var i = 0; i < m.dig.len; i++ growmod(r, r, 0, m) ;; } const ccopy = {r, v, ctl checksz(r, v) for var i = 0; i < r.dig.len; i++ r.dig[i] = mux(ctl, v.dig[i], r.dig[i]) ;; } const muladd = {a, b, k -> (a : uint64) * (b : uint64) + (k : uint64) } const montymul = {r : ctbig#, x : ctbig#, y : ctbig#, m : ctbig#, m0i : uint32 var dh : uint64 var s checksz(x, y) checksz(x, m) checksz(x, r) std.slfill(r.dig, 0) dh = 0 for var u = 0; u < x.dig.len; u++ var f : uint32, xu : uint32 var r1 : uint64, r2 : uint64, zh : uint64 xu = x.dig[u] f = (r.dig[0] + x.dig[u] * y.dig[0]) * m0i; r1 = 0; r2 = 0; for var v = 0; v < y.dig.len; v++ var z : uint64 var t : uint32 z = muladd(xu, y.dig[v], r.dig[v]) + r1 r1 = z >> 32 t = (z : uint32) z = muladd(f, m.dig[v], t) + r2 r2 = z >> 32 if v != 0 r.dig[v - 1] = (z : uint32) ;; ;; zh = dh + r1 + r2; r.dig[r.dig.len - 1] = (zh : uint32) dh = zh >> 32; ;; /* * r may still be greater than m at that point; notably, the * 'dh' word may be non-zero. */ s = ne(dh, 0) | (ctge(r, m) : uint64) ctsubcc(r, r, m, (s : uint32)) } const ninv32 = {x var y y = 2 - x y *= 2 - y * x y *= 2 - y * x y *= 2 - y * x y *= 2 - y * x -> mux(x & 1, -y, 0) } const ctmodpow = {r, a, e, m var t1, t2, m0i, ctl var n = 0 t1 = ctdup(a) t2 = ctzero(a.nbit) m0i = ninv32(m.dig[0]) tomonty(t1, a, m); std.slfill(r.dig, 0); r.dig[0] = 1; for var i = 0; i < e.nbit; i++ ctl = (e.dig[i>>5] >> (i & 0x1f : uint32)) & 1 montymul(t2, r, t1, m, m0i) ccopy(r, t2, ctl); montymul(t2, t1, t1, m, m0i); std.slcp(t1.dig, t2.dig); ;; ctfree(t1) ctfree(t2) } const ctiszero = {a var z, zz z = 1 for var i = 0; i < a.dig.len; i++ zz = mux(a.dig[i], 0, 1) z = mux(zz, z, 0) ;; -> (z : bool) } const cteq = {a, b var nz checksz(a, b) nz = 0 for var i = 0; i < a.dig.len; i++ nz = nz | a.dig[i] - b.dig[i] ;; -> (eq(nz, 0) : bool) } const ctne = {a, b var v v = (cteq(a, b) : byte) -> (not(v) : bool) } const ctgt = {a, b -> (ctsubcc(b, b, a, 0) : bool) } const ctge = {a, b var v v = (ctlt(a, b) : byte) -> (not(v) : bool) } const ctlt = {a, b -> (ctsubcc(a, a, b, 0) : bool) } const ctle = {a, b var v v = (ctgt(a, b) : byte) -> (not(v) : bool) } const ndig = {nbit -> (nbit + 8*sizeof(uint32) - 1)/(8*sizeof(uint32)) } const checksz = {a, b std.assert(a.nbit == b.nbit, "mismatched bit sizes") std.assert(a.dig.len == b.dig.len, "mismatched backing sizes") } const clip = {v var mask, edge : uint64 edge = (v.nbit : uint64) & (Bits - 1) mask = mux(edge, (1 << edge) - 1, ~0) v.dig[v.dig.len - 1] &= (mask : uint32) -> v }