ref: 3ecd5cbac39da3cc30cc1c7abd514968fca069a9
dir: /lib/std/bigint.myr/
use "alloc" use "chartype" use "cmp" use "die" use "extremum" use "hasprefix" use "option" use "slcp" use "sldup" use "slfill" use "slpush" use "types" use "utf" use "errno" use "memops" pkg std = type bigint = struct dig : uint32[:] /* little endian, no leading zeros. */ sign : int /* -1 for -ve, 0 for zero, 1 for +ve. */ ;; /* administrivia */ generic mkbigint : (v : @a::(numeric,integral) -> bigint#) const bigfree : (a : bigint# -> void) const bigdup : (a : bigint# -> bigint#) const bigassign : (d : bigint#, s : bigint# -> bigint#) const bigmove : (d : bigint#, s : bigint# -> bigint#) const bigparse : (s : byte[:] -> option(bigint#)) const bigclear : (a : bigint# -> bigint#) const bigbfmt : (b : byte[:], a : bigint#, base : int -> size) /* const bigtoint : (a : bigint# -> @a::(numeric,integral)) */ /* some useful predicates */ const bigiszero : (a : bigint# -> bool) const bigeq : (a : bigint#, b : bigint# -> bool) generic bigeqi : (a : bigint#, b : @a::(numeric,integral) -> bool) const bigcmp : (a : bigint#, b : bigint# -> order) /* bigint*bigint -> bigint ops */ const bigadd : (a : bigint#, b : bigint# -> bigint#) const bigsub : (a : bigint#, b : bigint# -> bigint#) const bigmul : (a : bigint#, b : bigint# -> bigint#) const bigdiv : (a : bigint#, b : bigint# -> bigint#) const bigmod : (a : bigint#, b : bigint# -> bigint#) const bigdivmod : (a : bigint#, b : bigint# -> (bigint#, bigint#)) const bigshl : (a : bigint#, b : bigint# -> bigint#) const bigshr : (a : bigint#, b : bigint# -> bigint#) const bigmodpow : (b : bigint#, e : bigint#, m : bigint# -> bigint#) /* const bigpow : (a : bigint#, b : bigint# -> bigint#) */ /* bigint*int -> bigint ops */ generic bigaddi : (a : bigint#, b : @a::(integral,numeric) -> bigint#) generic bigsubi : (a : bigint#, b : @a::(integral,numeric) -> bigint#) generic bigmuli : (a : bigint#, b : @a::(integral,numeric) -> bigint#) generic bigdivi : (a : bigint#, b : @a::(integral,numeric) -> bigint#) generic bigshli : (a : bigint#, b : @a::(integral,numeric) -> bigint#) generic bigshri : (a : bigint#, b : @a::(integral,numeric) -> bigint#) /* const bigpowi : (a : bigint#, b : uint64 -> bigint#) */ ;; const Base = 0x100000000ul generic mkbigint = {v : @a::(integral,numeric) var a var val a = zalloc() if v < 0 a.sign = -1 v = -v elif v > 0 a.sign = 1 ;; val = v castto(uint64) slpush(&a.dig, val castto(uint32)) if val > Base slpush(&a.dig, (val/Base) castto(uint32)) ;; -> trim(a) } const bigfree = {a slfree(a.dig) free(a) } const bigdup = {a -> bigassign(zalloc(), a) } const bigassign = {d, s slfree(d.dig) d# = s# d.dig = sldup(s.dig) -> d } const bigmove = {d, s slfree(d.dig) d# = s# s.dig = [][:] s.sign = 0 -> d } const bigclear = {v std.slfree(v.dig) v.sign = 0 v.dig = [][:] -> v } /* for now, just dump out something for debugging... */ const bigbfmt = {buf, x, base const digitchars = [ '0','1','2','3','4','5','6','7','8','9', 'a','b','c','d','e','f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] var v, val var n, i var tmp, rem, b if base < 0 || base > 36 die("invalid base in bigbfmt\n") ;; if bigiszero(x) n ;; if base == 0 b = mkbigint(10) else b = mkbigint(base) ;; n = 0 val = bigdup(x) /* generate the digits in reverse order */ while !bigiszero(val) (v, rem) = bigdivmod(val, b) if rem.dig.len > 0 n += encode(buf[n:], digitchars[rem.dig[0]]) else n += encode(buf[n:], '0') ;; bigfree(val) bigfree(rem) val = v ;; bigfree(val) bigfree(b) /* this is done last, so we get things right when we reverse the string */ if x.sign == 0 n += encode(buf[n:], '0') elif x.sign == -1 n += encode(buf[n:], '-') ;; /* we only generated ascii digits, so this works for reversing. */ for i = 0; i < n/2; i++ tmp = buf[i] buf[i] = buf[n - i - 1] buf[n - i - 1] = tmp ;; -> n } const bigparse = {str var val : int, base var v, b var a if hasprefix(str, "0x") || hasprefix(str, "0X") base = 16 elif hasprefix(str, "0o") || hasprefix(str, "0O") base = 8 elif hasprefix(str, "0b") || hasprefix(str, "0B") base = 2 else base = 10 ;; if base != 10 str = str[2:] ;; a = mkbigint(0) b = mkbigint(base) /* efficiency hack: to save allocations, just mutate v[0]. The value will always fit in one digit. */ v = mkbigint(1) for c in std.bychar(str) if c == '_' continue ;; val = charval(c, base) if val < 0 || val > base bigfree(a) bigfree(b) bigfree(v) -> `None ;; v.dig[0] = val castto(uint32) if val == 0 v.sign = 0 else v.sign = 1 ;; bigmul(a, b) bigadd(a, v) ;; -> `Some a } const bigiszero = {v -> v.dig.len == 0 } const bigeq = {a, b if a.sign != b.sign || a.dig.len != b.dig.len -> false ;; for var i = 0; i < a.dig.len; i++ if a.dig[i] != b.dig[i] -> false ;; ;; -> true } generic bigeqi = {a, b var v var dig : uint32[2] bigdigit(&v, b < 0, b castto(uint64), dig[:]) -> bigeq(a, &v) } const bigcmp = {a, b var da, db, sa, sb sa = a.sign castto(int64) sb = b.sign castto(int64) if sa < sb -> `Before elif sa > sb -> `After elif a.dig.len < b.dig.len -> signedorder(-sa) elif a.dig.len > b.dig.len -> signedorder(sa) else /* otherwise, the one with the first larger digit is bigger */ for var i = a.dig.len; i > 0; i-- da = a.dig[i - 1] castto(int64) db = b.dig[i - 1] castto(int64) if da != db -> signedorder(sa * (da - db)) ;; ;; ;; -> `Equal } const signedorder = {sign if sign < 0 -> `Before elif sign == 0 -> `Equal else -> `After ;; } /* a += b */ const bigadd = {a, b if a.sign == b.sign || a.sign == 0 a.sign = b.sign -> uadd(a, b) elif b.sign == 0 -> a else match bigcmp(a, b) | `Before: /* a is negative */ a.sign = b.sign -> usub(b, a) | `After: /* b is negative */ -> usub(a, b) | `Equal: die("Impossible. Equal vals with different sign.") ;; ;; } /* adds two unsigned values together. */ const uadd = {a, b var v, i var carry var n carry = 0 n = max(a.dig.len, b.dig.len) /* guaranteed to carry no more than one value */ slzgrow(&a.dig, n + 1) for i = 0; i < n; i++ v = (a.dig[i] castto(uint64)) + carry; if i < b.dig.len v += (b.dig[i] castto(uint64)) ;; if v >= Base carry = 1 else carry = 0 ;; a.dig[i] = v castto(uint32) ;; a.dig[i] += carry castto(uint32) -> trim(a) } /* a -= b */ const bigsub = {a, b /* 0 - x = -x */ if a.sign == 0 bigassign(a, b) a.sign = -b.sign -> a /* x - 0 = x */ elif b.sign == 0 -> a elif a.sign != b.sign -> uadd(a, b) else match bigcmp(a, b) | `Before: /* a is negative */ a.sign = b.sign -> usub(b, a) | `After: /* b is negative */ -> usub(a, b) | `Equal: -> bigclear(a) ;; ;; -> a } /* subtracts two unsigned values, where 'a' is strictly greater than 'b' */ const usub = {a, b var carry var v, i carry = 0 for i = 0; i < a.dig.len; i++ v = (a.dig[i] castto(int64)) - carry if i < b.dig.len v -= (b.dig[i] castto(int64)) ;; if v < 0 carry = 1 else carry = 0 ;; a.dig[i] = v castto(uint32) ;; -> trim(a) } /* a *= b */ const bigmul = {a, b var i, j var ai, bj, wij var carry, t var w if a.sign == 0 || b.sign == 0 a.sign = 0 slfree(a.dig) a.dig = [][:] -> a elif a.sign != b.sign a.sign = -1 else a.sign = 1 ;; w = slzalloc(a.dig.len + b.dig.len) for j = 0; j < b.dig.len; j++ carry = 0 for i = 0; i < a.dig.len; i++ ai = a.dig[i] castto(uint64) bj = b.dig[j] castto(uint64) wij = w[i+j] castto(uint64) t = ai * bj + wij + carry w[i + j] = (t castto(uint32)) carry = t >> 32 ;; w[i+j] = carry castto(uint32) ;; slfree(a.dig) a.dig = w -> trim(a) } const bigdiv = {a : bigint#, b : bigint# -> bigint# var q, r (q, r) = bigdivmod(a, b) bigfree(r) -> bigmove(a, q) } const bigmod = {a : bigint#, b : bigint# -> bigint# var q, r (q, r) = bigdivmod(a, b) bigfree(q) -> bigmove(a, r) } /* a /= b */ const bigdivmod = {a : bigint#, b : bigint# -> (bigint#, bigint#) /* Implements bigint division using Algorithm D from Knuth: Seminumerical algorithms, Section 4.3.1. */ var m : int64, n : int64 var qhat, rhat, carry, shift var x, y, z, w, p, t /* temporaries */ var b0, aj var u, v var i, j : int64 var q if bigiszero(b) die("divide by zero\n") ;; /* if b > a, we trucate to 0, with remainder 'a' */ if a.dig.len < b.dig.len -> (mkbigint(0), bigdup(a)) ;; q = zalloc() q.dig = slzalloc(max(a.dig.len, b.dig.len) + 1) if a.sign != b.sign q.sign = -1 else q.sign = 1 ;; /* handle single digit divisor separately: the knuth algorithm needs at least 2 digits. */ if b.dig.len == 1 carry = 0 b0 = (b.dig[0] castto(uint64)) for j = a.dig.len; j > 0; j-- aj = (a.dig[j - 1] castto(uint64)) q.dig[j - 1] = (((carry << 32) + aj)/b0) castto(uint32) carry = (carry << 32) + aj - (q.dig[j-1] castto(uint64))*b0 ;; q = trim(q) -> (q, trim(mkbigint(carry castto(int32)))) ;; u = bigdup(a) v = bigdup(b) m = u.dig.len n = v.dig.len shift = nlz(v.dig[n - 1]) bigshli(u, shift) bigshli(v, shift) slzgrow(&u.dig, u.dig.len + 1) for j = m - n; j >= 0; j-- /* load a few temps */ x = u.dig[j + n] castto(uint64) y = u.dig[j + n - 1] castto(uint64) z = v.dig[n - 1] castto(uint64) w = v.dig[n - 2] castto(uint64) t = u.dig[j + n - 2] castto(uint64) /* estimate qhat */ qhat = (x*Base + y)/z rhat = (x*Base + y) - qhat*z :divagain if qhat >= Base || (qhat * w) > (rhat*Base + t) qhat-- rhat += z if rhat < Base goto divagain ;; ;; /* multiply and subtract */ carry = 0 for i = 0; i < n; i++ p = qhat * (v.dig[i] castto(uint64)) t = (u.dig[i+j] castto(uint64)) - carry - (p % Base) u.dig[i+j] = t castto(uint32) carry = (((p castto(int64)) >> 32) - ((t castto(int64)) >> 32)) castto(uint64); ;; t = (u.dig[j + n] castto(uint64)) - carry u.dig[j + n] = t castto(uint32) q.dig[j] = qhat castto(uint32) /* adjust */ if t castto(int64) < 0 q.dig[j]-- carry = 0 for i = 0; i < n; i++ t = (u.dig[i+j] castto(uint64)) + (v.dig[i] castto(uint64)) + carry u.dig[i+j] = t castto(uint32) carry = t >> 32 ;; u.dig[j+n] = u.dig[j+n] + (carry castto(uint32)); ;; ;; /* undo the biasing for remainder */ u = bigshri(u, shift) -> (trim(q), trim(u)) } /* computes b^e % m */ const bigmodpow = {base, exp, mod var r, n r = mkbigint(1) n = 0 while !bigiszero(exp) if (exp.dig[0] & 1) != 0 bigmul(r, base) bigmod(r, mod) ;; bigshri(exp, 1) bigmul(base, base) bigmod(base, mod) ;; -> bigmove(base, r) } /* returns the number of leading zeros */ const nlz = {a : uint32 var n if a == 0 -> 32 ;; n = 0 if a <= 0x0000ffff n += 16 a <<= 16 ;; if a <= 0x00ffffff n += 8 a <<= 8 ;; if a <= 0x0fffffff n += 4 a <<= 4 ;; if a <= 0x3fffffff n += 2 a <<= 2 ;; if a <= 0x7fffffff n += 1 a <<= 1 ;; -> n } /* a <<= b */ const bigshl = {a, b match b.dig.len | 0: -> a | 1: -> bigshli(a, b.dig[0] castto(uint64)) | n: die("shift by way too much\n") ;; } /* a >>= b, unsigned */ const bigshr = {a, b match b.dig.len | 0: -> a | 1: -> bigshri(a, b.dig[0] castto(uint64)) | n: die("shift by way too much\n") ;; } /* a + b, b is integer. FIXME: acually make this a performace improvement */ generic bigaddi = {a, b var bigb : bigint var dig : uint32[2] bigdigit(&bigb, b < 0, b castto(uint64), dig[:]) bigadd(a, &bigb) -> a } generic bigsubi = {a, b : @a::(numeric,integral) var bigb : bigint var dig : uint32[2] bigdigit(&bigb, b < 0, b castto(uint64), dig[:]) bigsub(a, &bigb) -> a } generic bigmuli = {a, b var bigb : bigint var dig : uint32[2] bigdigit(&bigb, b < 0, b castto(uint64), dig[:]) bigmul(a, &bigb) -> a } generic bigdivi = {a, b var bigb : bigint var dig : uint32[2] bigdigit(&bigb, b < 0, b castto(uint64), dig[:]) bigdiv(a, &bigb) -> a } /* a << s, with integer arg. logical left shift. any other type would be illogical. */ generic bigshli = {a, s : @a::(numeric,integral) var off, shift var t, carry iassert(s >= 0, "shift amount must be positive") off = (s castto(uint64)) / 32 shift = (s castto(uint64)) % 32 /* zero shifted by anything is zero */ if a.sign == 0 -> a ;; slzgrow(&a.dig, 1 + a.dig.len + off castto(size)) /* blit over the base values */ for var i = a.dig.len; i > off; i-- a.dig[i - 1] = a.dig[i - 1 - off] ;; for var i = 0; i < off; i++ a.dig[i] = 0 ;; /* and shift over by the remainder */ carry = 0 for var i = 0; i < a.dig.len; i++ t = (a.dig[i] castto(uint64)) << shift a.dig[i] = (t | carry) castto(uint32) carry = t >> 32 ;; -> trim(a) } /* logical shift right, zero fills. sign remains untouched. */ generic bigshri = {a, s var off, shift var t, carry iassert(s >= 0, "shift amount must be positive") off = (s castto(uint64)) / 32 shift = (s castto(uint64)) % 32 /* blit over the base values */ for var i = 0; i < a.dig.len - off; i++ a.dig[i] = a.dig[i + off] ;; for var i = a.dig.len - off; i < a.dig.len; i++ a.dig[i] = 0 ;; /* and shift over by the remainder */ carry = 0 for var i = a.dig.len; i > 0; i-- t = (a.dig[i - 1] castto(uint64)) a.dig[i - 1] = (carry | (t >> shift)) castto(uint32) carry = t << (32 - shift) ;; -> trim(a) } /* creates a bigint on the stack; should not be modified. */ const bigdigit = {v, isneg : bool, val : uint64, dig v.sign = 1 if isneg val = -val v.sign = -1 ;; if val == 0 v.sign = 0 v.dig = [][:] elif val < Base v.dig = dig[:1] v.dig[0] = val castto(uint32) else v.dig = dig v.dig[0] = val castto(uint32) v.dig[1] = (val >> 32) castto(uint32) ;; } /* trims leading zeros */ const trim = {a var i for i = a.dig.len; i > 0; i-- if a.dig[i - 1] != 0 break ;; ;; slgrow(&a.dig, i) if i == 0 a.sign = 0 elif a.sign == 0 a.sign = 1 ;; -> a }