shithub: mc

ref: debd13efe12eaca072229e9859aaa6df8708a09f
dir: /lib/crypto/ctbig.myr/

View raw version
use std

use "ct"

pkg crypto =
	type ctbig = struct
		nbit	: std.size
		dig	: uint32[:] 	/* little endian, no leading zeros. */
	;;

	generic mkctbign 	: (v : @a, nbit : std.size -> ctbig#) :: numeric,integral @a
	const mkctbigle	: (v : byte[:], nbit : std.size -> ctbig#)
	//const mkctbigbe	: (v : byte[:], nbit : std.size -> ctbig#)

	const ctfree	: (v : ctbig# -> void)
	const ctbigdup	: (v : ctbig# -> ctbig#)
	const ctlike	: (v : ctbig# -> ctbig#)
	const ct2big	: (v : ctbig# -> std.bigint#)
	const big2ct	: (v : std.bigint#, ndig : std.size -> ctbig#)

	const ctadd	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)
	const ctsub	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)
	const ctmul	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)
	//const ctdivmod	: (r : ctbig#, m : ctbig#, a : ctbig#, b : ctbig# -> void)
	//const ctmodpow	: (r : ctbig#, a : ctbig#, b : ctbig# -> void)

	const ctiszero	: (v : ctbig# -> bool)
	const cteq	: (a : ctbig#, b : ctbig# -> bool)
	const ctne	: (a : ctbig#, b : ctbig# -> bool)
	const ctgt	: (a : ctbig#, b : ctbig# -> bool)
	const ctge	: (a : ctbig#, b : ctbig# -> bool)
	const ctlt	: (a : ctbig#, b : ctbig# -> bool)
	const ctle	: (a : ctbig#, b : ctbig# -> bool)
;;

const Base = 0x100000000ul

generic mkctbign = {v : @a, nbit : std.size :: integral,numeric @a
	var a
	var val

	a = std.zalloc()

	val = (v : uint64)
	a.nbit = nbit
	a.dig = std.slalloc(ndig(nbit))
	if nbit > 0
		a.dig[0] = (val : uint32)
	;;
	if nbit > 32
		a.dig[1] = (val >> 32 : uint32)
	;;
	-> a
}

const ct2big = {ct
	-> std.mk([
		.sign=1,
		.dig=std.sldup(ct.dig)
	])
}

const big2ct = {ct, nbit
	var v, n, l

	n = ndig(nbit)
	l = std.min(n, ct.dig.len)
	v = std.slzalloc(n)
	std.slcp(v, ct.dig[:l])
	-> std.mk([
		.nbit=nbit,
		.dig=v,
	])
}

const mkctbigle = {v, nbit
	var a, last, i, o, off

	/*
	  It's ok to depend on the length of v here: we can leak the
	  size of the numbers.
	 */
	o = 0
	a = std.slzalloc(ndig(nbit))
	for i = 0; i + 4 <= v.len; i += 4
		a[o++] = \
			(v[i + 0] <<  0 : uint32) | \
			(v[i + 1] <<  8 : uint32) | \
			(v[i + 2] << 16 : uint32) | \
			(v[i + 3] << 24 : uint32)
	;;

	last = 0
	for i; i < v.len; i++
		off = i & 0x3
		last |= (v[off] : uint32) << (8 *off)
	;;
	a[o++] = last
	-> std.mk([.nbit=nbit, .dig=a])
}

const ctlike = {v
	-> std.mk([
		.nbit = v.nbit,
		.dig=std.slzalloc(v.dig.len),
	])
}

const ctbigdup = {v
	-> std.mk([
		.nbit=v.nbit,
		.dig=std.sldup(v.dig),
	])
}

const ctfree = {v
	std.slfree(v.dig)
	std.free(v)
}

const ctadd = {r, a, b
	var v, i, carry, n

	checksz(a, b)
	checksz(a, r)

	carry = 0
	n = max(a.dig.len, b.dig.len)
	for i = 0; i < n; i++
		v = (a.dig[i] : uint64) + (b.dig[i] : uint64) + carry;
		r.dig[i] = (v  : uint32)
		carry >>= 32
	;;
}

const ctsub = {r, a, b
	var borrow, v, i

	checksz(a, b)
	checksz(a, r)

	borrow = 0
	for i = 0; i < a.dig.len; i++
		v = (a.dig[i] : uint64) - (b.dig[i] : uint64) - borrow
		borrow = (v & (1<<63)) >> 63
		v = mux(borrow, v + Base, v)
		r.dig[i] = (v  : uint32)
	;;
}

const ctmul = {r, a, b
	var i, j
	var ai, bj, wij
	var carry, t
	var w

	checksz(a, b)
	checksz(a, r)

	w  = std.slzalloc(a.dig.len + b.dig.len)
	for j = 0; j < b.dig.len; j++
		carry = 0
		for i = 0; i < a.dig.len; i++
			ai = (a.dig[i]  : uint64)
			bj = (b.dig[j]  : uint64)
			wij = (w[i+j]  : uint64)
			t = ai * bj + wij + carry
			w[i+j] = (t  : uint32)
			carry = t >> 32
		;;
		w[i + j] = (carry  : uint32)
	;;
	/* safe to leak that a == r; not data dependent */
	std.slgrow(&w, a.dig.len)
	if a == r
		std.slfree(a.dig)
	;;
	r.dig = w[:a.dig.len]
}

//const ctmodpow = {res, a, b
//	/* find rinv, mprime */
//	
//	/* convert to monty space */
//
//	/* do the modpow */
//
//	/* and come back */
//}

const ctiszero = {a
	var z, zz

	z = 1
	for var i = 0; i < a.dig.len; i++
		zz = mux(a.dig[i], 0, 1)
		z = mux(zz, z, 0)
	;;
	-> (z : bool)
}

const cteq = {a, b
	var z, d, e

	checksz(a, b)

	e = 1
	for var i = 0; i < a.dig.len; i++
		z = a.dig[i] - b.dig[i]
		d = mux(z, 1, 0)
		e = mux(e, d, 0)
	;;
	-> (e : bool)
}

const ctne = {a, b
	var v

	v = (cteq(a, b) : byte)
	-> (not(v) : bool)
}

const ctgt = {a, b
	var e, d, g

	checksz(a, b)

	g = 0
	for var i = 0; i < a.dig.len; i++
		e = not(a.dig[i] - b.dig[i])
		d = gt(a.dig[i], b.dig[i])
		g = mux(e, g, d) 
	;;
	-> (g : bool)
}

const ctge = {a, b
	var v

	v = (ctlt(a, b) : byte)
	-> (not(v) : bool)
}

const ctlt = {a, b
	var e, d, l

	checksz(a, b)

	l = 0
	for var i = 0; i < a.dig.len; i++
		e = not(a.dig[i] - b.dig[i])
		d = gt(a.dig[i], b.dig[i])
		l = mux(e, l, d) 
	;;
	-> (l : bool)
}

const ctle = {a, b
	var v

	v = (ctgt(a, b) : byte)
	-> (not(v) : bool)
}

const ndig = {nbit
	-> (nbit + 8*sizeof(uint32) - 1)/sizeof(uint32)
}

const checksz = {a, b
	std.assert(a.nbit == b.nbit, "mismatched bit sizes")
	std.assert(a.dig.len == b.dig.len, "mismatched backing sizes")
}