ref: ba9f79ecd79a5d24b74f714cc8c70204d5bc26c5
dir: /libstd/bigint.myr/
use "alloc.use"
use "chartype.use"
use "cmp.use"
use "die.use"
use "extremum.use"
use "hasprefix.use"
use "option.use"
use "slcp.use"
use "sldup.use"
use "slfill.use"
use "slpush.use"
use "types.use"
use "utf.use"
pkg std =
type bigint = struct
dig : uint32[:] /* little endian, no leading zeros. */
sign : int /* -1 for -ve, 0 for zero, 1 for +ve. */
;;
/* administrivia */
generic mkbigint : (v : @a::(numeric,integral) -> bigint#)
const bigfree : (a : bigint# -> void)
const bigdup : (a : bigint# -> bigint#)
const bigassign : (d : bigint#, s : bigint# -> bigint#)
const bigparse : (s : byte[:] -> option(bigint#))
const bigbfmt : (b : byte[:], a : bigint#, base : int -> size)
const bigfmt : (a : bigint#, base : int -> byte[:])
/*
const bigtoint : (a : bigint# -> @a::(numeric,integral))
*/
/* some useful predicates */
const bigiszero : (a : bigint# -> bool)
const bigeq : (a : bigint#, b : bigint# -> bool)
generic bigeqi : (a : bigint#, b : @a::(numeric,integral) -> bool)
const bigcmp : (a : bigint#, b : bigint# -> order)
/* bigint*bigint -> bigint ops */
const bigadd : (a : bigint#, b : bigint# -> bigint#)
const bigsub : (a : bigint#, b : bigint# -> bigint#)
const bigmul : (a : bigint#, b : bigint# -> bigint#)
const bigdiv : (a : bigint#, b : bigint# -> bigint#)
const bigmod : (a : bigint#, b : bigint# -> bigint#)
const bigdivmod : (a : bigint#, b : bigint# -> (bigint#, bigint#))
const bigshl : (a : bigint#, b : bigint# -> bigint#)
const bigshr : (a : bigint#, b : bigint# -> bigint#)
/*
const bigpow : (a : bigint#, b : bigint# -> bigint#)
*/
/* bigint*int -> bigint ops */
generic bigaddi : (a : bigint#, b : @a::(integral,numeric) -> bigint#)
generic bigsubi : (a : bigint#, b : @a::(integral,numeric) -> bigint#)
generic bigmuli : (a : bigint#, b : @a::(integral,numeric) -> bigint#)
generic bigdivi : (a : bigint#, b : @a::(integral,numeric) -> bigint#)
generic bigshli : (a : bigint#, b : @a::(integral,numeric) -> bigint#)
generic bigshri : (a : bigint#, b : @a::(integral,numeric) -> bigint#)
/*
const bigpowi : (a : bigint#, b : uint64 -> bigint#)
*/
;;
const Base = 0x100000000ul
generic mkbigint = {v : @a::(integral,numeric)
var a
var val
a = zalloc()
if v < 0
a.sign = -1
v = -v
elif v > 0
a.sign = 1
;;
val = v castto(uint64)
a.dig = slpush([][:], val castto(uint32))
if val > Base
a.dig = slpush(a.dig, (val/Base) castto(uint32))
;;
-> trim(a)
}
const bigfree = {a
slfree(a.dig)
free(a)
}
const bigdup = {a
-> bigassign(zalloc(), a)
}
const bigassign = {d, s
slfree(d.dig)
d.dig = sldup(s.dig)
d.sign = s.sign
-> d
}
const bigfmt = {a, base
var buf
var n
/*
allocate a buffer guaranteed to be big enough.
that's
2 + floor(nbits/(log_2(10)))
or
2 + a.dig.len * 32/3.32...
or
2 + a.dig.len * 10
plus one for the - sign.
*/
buf = slalloc(3 + a.dig.len * 10)
n = bigbfmt(buf, a, base)
-> buf[:n]
}
/* for now, just dump out something for debugging... */
const bigbfmt = {buf, x, base
const digitchars = [
'0','1','2','3','4','5','6','7','8','9',
'a','b','c','d','e','f', 'g', 'h', 'i', 'j',
'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's',
't', 'u', 'v', 'w', 'x', 'y', 'z']
var v, val
var n, i
var tmp, rem, b
if base < 0 || base > 36
die("invalid base in bigbfmt\n")
elif base == 0
b = mkbigint(10)
else
b = mkbigint(base)
;;
n = 0
val = bigdup(x)
/* generate the digits in reverse order */
while !bigiszero(val)
(v, rem) = bigdivmod(val, b)
if rem.dig.len > 0
n += encode(buf[n:], digitchars[rem.dig[0]])
else
n += encode(buf[n:], '0')
;;
bigfree(val)
bigfree(rem)
val = v
;;
bigfree(val)
bigfree(b)
/* this is done last, so we get things right when we reverse the string */
if x.sign == 0
n += encode(buf[n:], '0')
elif x.sign == -1
n += encode(buf[n:], '-')
;;
/* we only generated ascii digits, so this works for reversing. */
for i = 0; i < n/2; i++
tmp = buf[i]
buf[i] = buf[n - i - 1]
buf[n - i - 1] = tmp
;;
-> n
}
const bigparse = {str
var c, val : int, base
var v, b
var a
if hasprefix(str, "0x") || hasprefix(str, "0X")
base = 16
elif hasprefix(str, "0o") || hasprefix(str, "0O")
base = 8
elif hasprefix(str, "0b") || hasprefix(str, "0B")
base = 2
else
base = 10
;;
if base != 10
str = str[2:]
;;
a = mkbigint(0)
b = mkbigint(base)
/*
efficiency hack: to save allocations,
just mutate v[0]. The value will always
fit in one digit.
*/
v = mkbigint(1)
while str.len != 0
(c, str) = striter(str)
if c == '_'
continue
;;
val = charval(c, base)
if val < 0 || val > base
bigfree(a)
bigfree(b)
bigfree(v)
-> `None
;;
v.dig[0] = val castto(uint32)
if val == 0
v.sign = 0
else
v.sign = 1
;;
bigmul(a, b)
bigadd(a, v)
;;
-> `Some a
}
const bigiszero = {v
-> v.dig.len == 0
}
const bigeq = {a, b
var i
if a.sign != b.sign || a.dig.len != b.dig.len
-> false
;;
for i = 0; i < a.dig.len; i++
if a.dig[i] != b.dig[i]
-> false
;;
;;
-> true
}
generic bigeqi = {a, b
var v
var dig : uint32[2]
bigdigit(&v, b < 0, b castto(uint64), dig[:])
-> bigeq(a, &v)
}
const bigcmp = {a, b
var i
var da, db, sa, sb
sa = a.sign castto(int64)
sb = b.sign castto(int64)
if sa < sb
-> `Before
elif sa > sb
-> `After
elif a.dig.len < b.dig.len
-> signedorder(-sa)
elif a.dig.len > b.dig.len
-> signedorder(sa)
else
/* otherwise, the one with the first larger digit is bigger */
for i = a.dig.len; i > 0; i--
da = a.dig[i - 1] castto(int64)
db = b.dig[i - 1] castto(int64)
-> signedorder(sa * (da - db))
;;
;;
-> `Equal
}
const signedorder = {sign
if sign < 0
-> `Before
elif sign == 0
-> `Equal
else
-> `After
;;
}
/* a += b */
const bigadd = {a, b
if a.sign == b.sign || a.sign == 0
a.sign = b.sign
-> uadd(a, b)
elif b.sign == 0
-> a
else
match bigcmp(a, b)
| `Before: /* a is negative */
a.sign = b.sign
-> usub(b, a)
| `After: /* b is negative */
-> usub(a, b)
| `Equal:
die("Impossible. Equal vals with different sign.")
;;
;;
}
/* adds two unsigned values together. */
const uadd = {a, b
var v, i
var carry
var n
carry = 0
n = max(a.dig.len, b.dig.len)
/* guaranteed to carry no more than one value */
a.dig = slzgrow(a.dig, n + 1)
for i = 0; i < n; i++
v = (a.dig[i] castto(uint64)) + carry;
if i < b.dig.len
v += (b.dig[i] castto(uint64))
;;
if v >= Base
carry = 1
else
carry = 0
;;
a.dig[i] = v castto(uint32)
;;
a.dig[i] += carry castto(uint32)
-> trim(a)
}
/* a -= b */
const bigsub = {a, b
/* 0 - x = -x */
if a.sign == 0
bigassign(a, b)
a.sign = -b.sign
-> a
/* x - 0 = x */
elif b.sign == 0
-> a
elif a.sign != b.sign
-> uadd(a, b)
else
match bigcmp(a, b)
| `Before: /* a is negative */
a.sign = b.sign
-> usub(b, a)
| `After: /* b is negative */
-> usub(a, b)
| `Equal:
die("Impossible. Equal vals with different sign.")
;;
;;
-> a
}
/* subtracts two unsigned values, where 'a' is strictly greater than 'b' */
const usub = {a, b
var carry
var v, i
carry = 0
for i = 0; i < a.dig.len; i++
v = (a.dig[i] castto(int64)) - carry
if i < b.dig.len
v -= (b.dig[i] castto(int64))
;;
if v < 0
carry = 1
else
carry = 0
;;
a.dig[i] = v castto(uint32)
;;
-> trim(a)
}
/* a *= b */
const bigmul = {a, b
var i, j
var ai, bj, wij
var carry, t
var w
if a.sign == 0 || b.sign == 0
a.sign = 0
slfree(a.dig)
a.dig = [][:]
-> a
elif a.sign != b.sign
a.sign = -1
else
a.sign = 1
;;
w = slzalloc(a.dig.len + b.dig.len)
for j = 0; j < b.dig.len; j++
carry = 0
for i = 0; i < a.dig.len; i++
ai = a.dig[i] castto(uint64)
bj = b.dig[j] castto(uint64)
wij = w[i+j] castto(uint64)
t = ai * bj + wij + carry
w[i + j] = (t castto(uint32))
carry = t >> 32
;;
w[i+j] = carry castto(uint32)
;;
slfree(a.dig)
a.dig = w
-> trim(a)
}
const bigdiv = {a : bigint#, b : bigint# -> bigint#
var q, r
(q, r) = bigdivmod(a, b)
bigfree(r)
slfree(a.dig)
a.dig = q.dig
a.sign = q.sign
free(q)
-> a
}
const bigmod = {a : bigint#, b : bigint# -> bigint#
var q, r
(q, r) = bigdivmod(a, b)
bigfree(q)
slfree(a.dig)
a.dig = r.dig
a.sign = r.sign
free(r)
-> a
}
/* a /= b */
const bigdivmod = {a : bigint#, b : bigint# -> (bigint#, bigint#)
/*
Implements bigint division using Algorithm D from
Knuth: Seminumerical algorithms, Section 4.3.1.
*/
var qhat, rhat, carry, shift
var x, y, z, w, p, t /* temporaries */
var b0, aj
var u, v
var m : int64, n : int64
var i, j : int64
var q
if bigiszero(b)
die("divide by zero\n")
;;
/* if b > a, we trucate to 0, with remainder 'a' */
if a.dig.len < b.dig.len
-> (mkbigint(0), bigdup(a))
;;
q = zalloc()
q.dig = slzalloc(max(a.dig.len, b.dig.len) + 1)
if a.sign != b.sign
q.sign = -1
else
q.sign = 1
;;
/* handle single digit divisor separately: the knuth algorithm needs at least 2 digits. */
if b.dig.len == 1
carry = 0
b0 = (b.dig[0] castto(uint64))
for j = a.dig.len; j > 0; j--
aj = (a.dig[j - 1] castto(uint64))
q.dig[j - 1] = (((carry << 32) + aj)/b0) castto(uint32)
carry = (carry << 32) + aj - (q.dig[j-1] castto(uint64))*b0
;;
q = trim(q)
-> (q, trim(mkbigint(carry castto(int32))))
;;
u = bigdup(a)
v = bigdup(b)
m = u.dig.len
n = v.dig.len
shift = nlz(v.dig[n - 1])
bigshli(u, shift)
bigshli(v, shift)
u.dig = slzgrow(u.dig, u.dig.len + 1)
for j = m - n; j >= 0; j--
/* load a few temps */
x = u.dig[j + n] castto(uint64)
y = u.dig[j + n - 1] castto(uint64)
z = v.dig[n - 1] castto(uint64)
w = v.dig[n - 2] castto(uint64)
t = u.dig[j + n - 2] castto(uint64)
/* estimate qhat */
qhat = (x*Base + y)/z
rhat = (x*Base + y) - (qhat * z)
:divagain
if qhat >= Base || (qhat * w) > (rhat*Base + t)
qhat--
rhat += z
if rhat < Base
goto divagain
;;
;;
/* multiply and subtract */
carry = 0
for i = 0; i < n; i++
p = qhat * (v.dig[i] castto(uint64))
t = (u.dig[i+j] castto(uint64)) - carry - (p % Base)
u.dig[i+j] = t castto(uint32)
carry = (((p castto(int64)) >> 32) - ((t castto(int64)) >> 32)) castto(uint64);
;;
x = u.dig[j + n] castto(uint64)
t = x - carry
u.dig[j + n] = t castto(uint32)
q.dig[j] = qhat castto(uint32)
/* adjust */
if x < carry
q.dig[j]--
carry = 0
for i = 0; i < n; i++
t = (u.dig[i+j] castto(uint64)) + (v.dig[i] castto(uint64)) + carry
u.dig[i+j] = t castto(uint32)
carry = t >> 32
;;
u.dig[j+n] = u.dig[j+n] + (carry castto(uint32));
;;
;;
/* undo the biasing for remainder */
bigshri(u, shift)
-> (trim(q), trim(u))
}
/* returns the number of leading zeros */
const nlz = {a : uint32
var n
if a == 0
-> 32
;;
n = 0
if a <= 0x0000ffff
n += 16
a <<= 16
;;
if a <= 0x00ffffff
n += 8
a <<= 8
;;
if a <= 0x0fffffff
n += 4
a <<= 4
;;
if a <= 0x3fffffff
n += 2
a <<= 2
;;
if a <= 0x7fffffff
n += 1
a <<= 1
;;
-> n
}
/* a <<= b */
const bigshl = {a, b
match b.dig.len
| 0: -> a
| 1: -> bigshli(a, b.dig[0] castto(uint64))
| n: die("shift by way too much\n")
;;
}
/* a >>= b, unsigned */
const bigshr = {a, b
match b.dig.len
| 0: -> a
| 1: -> bigshri(a, b.dig[0] castto(uint64))
| n: die("shift by way too much\n")
;;
}
/* a + b, b is integer.
FIXME: acually make this a performace improvement
*/
generic bigaddi = {a, b
var bigb : bigint
var dig : uint32[2]
bigdigit(&bigb, b < 0, b castto(uint64), dig[:])
bigadd(a, &bigb)
-> a
}
generic bigsubi = {a, b : @a::(numeric,integral)
var bigb : bigint
var dig : uint32[2]
bigdigit(&bigb, b < 0, b castto(uint64), dig[:])
bigsub(a, &bigb)
-> a
}
generic bigmuli = {a, b
var bigb : bigint
var dig : uint32[2]
bigdigit(&bigb, b < 0, b castto(uint64), dig[:])
bigmul(a, &bigb)
-> a
}
generic bigdivi = {a, b
var bigb : bigint
var dig : uint32[2]
bigdigit(&bigb, b < 0, b castto(uint64), dig[:])
bigdiv(a, &bigb)
-> a
}
/*
a << s, with integer arg.
logical left shift. any other type would be illogical.
*/
generic bigshli = {a, s : @a::(numeric,integral)
var off, shift
var t, carry
var i
assert(s >= 0, "shift amount must be positive")
off = (s castto(uint64)) / 32
shift = (s castto(uint64)) % 32
/* zero shifted by anything is zero */
if a.sign == 0
-> a
;;
a.dig = slzgrow(a.dig, 1 + a.dig.len + off castto(size))
/* blit over the base values */
for i = a.dig.len; i > off; i--
a.dig[i - 1] = a.dig[i - 1 - off]
;;
for i = 0; i < off; i++
a.dig[i] = 0
;;
/* and shift over by the remainder */
carry = 0
for i = 0; i < a.dig.len; i++
t = (a.dig[i] castto(uint64)) << shift
a.dig[i] = (t | carry) castto(uint32)
carry = t >> 32
;;
-> trim(a)
}
/* logical shift right, zero fills. sign remains untouched. */
generic bigshri = {a, s
var off, shift
var t, carry
var i
assert(s >= 0, "shift amount must be positive")
off = (s castto(uint64)) / 32
shift = (s castto(uint64)) % 32
/* blit over the base values */
for i = 0; i < a.dig.len - off; i++
a.dig[i] = a.dig[i + off]
;;
for i = a.dig.len - off; i < a.dig.len; i++
a.dig[i] = 0
;;
/* and shift over by the remainder */
carry = 0
for i = a.dig.len; i > 0; i--
t = (a.dig[i - 1] castto(uint64))
a.dig[i - 1] = (carry | (t >> shift)) castto(uint32)
carry = t << (32 - shift)
;;
-> trim(a)
}
/* creates a bigint on the stack; should not be modified. */
const bigdigit = {v, isneg : bool, val : uint64, dig
v.sign = 1
if isneg
val = -val
v.sign = -1
;;
if val == 0
v.sign = 0
v.dig = [][:]
elif val < Base
v.dig = dig[:1]
v.dig[0] = val castto(uint32)
else
v.dig = dig
v.dig[0] = val castto(uint32)
v.dig[1] = (val >> 32) castto(uint32)
;;
}
/* trims leading zeros */
const trim = {a
var i
for i = a.dig.len; i > 0; i--
if a.dig[i - 1] != 0
break
;;
;;
a.dig = slgrow(a.dig, i)
if i == 0
a.sign = 0
elif a.sign == 0
a.sign = 1
;;
-> a
}