ref: 3eac4fd0e131b25af66ac9e0952cafab6b0b965f
dir: /libstd/bigint.myr/
use "alloc.use"
use "cmp.use"
use "die.use"
use "extremum.use"
use "fmt.use"
use "hasprefix.use"
use "chartype.use"
use "option.use"
use "slcp.use"
use "sldup.use"
use "slfill.use"
use "slpush.use"
use "types.use"
use "utf.use"
pkg std =
type bigint = struct
dig : uint32[:] /* little endian, no leading zeros. */
sign : int /* -1 for -ve, 0 for zero, 1 for +ve. */
;;
/* administrivia */
generic mkbigint : (v : @a::(numeric,integral) -> bigint#)
const bigfree : (a : bigint# -> void)
const bigdup : (a : bigint# -> bigint#)
const bigassign : (d : bigint#, s : bigint# -> bigint#)
const bigparse : (s : byte[:] -> option(bigint#))
const bigfmt : (b : byte[:], a : bigint# -> size)
/* some useful predicates */
const bigiszero : (a : bigint# -> bool)
const bigcmp : (a : bigint#, b : bigint# -> order)
/* bigint*bigint -> bigint ops */
const bigadd : (a : bigint#, b : bigint# -> bigint#)
const bigsub : (a : bigint#, b : bigint# -> bigint#)
const bigmul : (a : bigint#, b : bigint# -> bigint#)
const bigdiv : (a : bigint#, b : bigint# -> bigint#)
const bigmod : (a : bigint#, b : bigint# -> bigint#)
const bigdivmod : (a : bigint#, b : bigint# -> (bigint#, bigint#))
const bigshl : (a : bigint#, b : bigint# -> bigint#)
const bigshr : (a : bigint#, b : bigint# -> bigint#)
/* bigint*int -> bigint ops */
const bigaddi : (a : bigint#, b : int64 -> bigint#)
const bigsubi : (a : bigint#, b : int64 -> bigint#)
const bigmuli : (a : bigint#, b : int64 -> bigint#)
const bigdivi : (a : bigint#, b : int64 -> bigint#)
const bigshli : (a : bigint#, b : uint64 -> bigint#)
const bigshri : (a : bigint#, b : uint64 -> bigint#)
;;
const Base = 0x100000000ul
generic mkbigint = {v : @a::(integral,numeric)
var a
var i
a = zalloc()
if v < 0
a.sign = -1
v = -v
elif v > 0
a.sign = 1
;;
i = v castto(uint64)
while i > 0
a.dig = slpush(a.dig, (i castto(uint32)))
i /= Base
;;
-> trim(a)
}
const bigfree = {a
slfree(a.dig)
free(a)
}
const bigdup = {a
-> bigassign(zalloc(), a)
}
const bigassign = {d, s
slfree(d.dig)
d.dig = sldup(s.dig)
d.sign = s.sign
-> d
}
/* for now, just dump out something for debugging... */
const bigfmt = {buf, val
const digitchars = ['0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f']
var v
var n, i
var tmp
var rem
n = 0
if val.sign == 0
n += encode(buf, '0')
elif val.sign == -1
n += encode(buf, '-')
;;
val = bigdup(val)
/* generate the digits in reverse order */
while !bigiszero(val)
(v, rem) = bigdivmod(val, mkbigint(10))
if rem.dig.len > 0
n += bfmt(buf[n:], "%c", digitchars[rem.dig[0]])
else
n += bfmt(buf[n:], "0")
;;
bigfree(val)
bigfree(rem)
val = v
;;
bigfree(val)
/* we only generated ascii digits, so this works. ugh. */
for i = 0; i < n/2; i++
tmp = buf[i]
buf[i] = buf[n - i - 1]
buf[n - i - 1] = tmp
;;
-> n
}
const bigparse = {str
var c, val, base
var v, b
var a
if hasprefix(str, "0x") || hasprefix(str, "0X")
base = 16
elif hasprefix(str, "0o") || hasprefix(str, "0O")
base = 8
elif hasprefix(str, "0b") || hasprefix(str, "0B")
base = 2
else
base = 10
;;
a = mkbigint(0)
b = mkbigint(base castto(int32))
/*
efficiency hack: to save allocations,
just mutate v[0]. The value will always
fit in one digit.
*/
v = mkbigint(1)
while str.len != 0
(c, str) = striter(str)
if c == '_'
continue
;;
val = charval(c, base)
if val < 0
bigfree(a)
-> `None
;;
v.dig[0] = val
if val == 0
v.sign = 0
else
v.sign = 1
;;
bigmul(a, b)
bigadd(a, v)
;;
-> `Some a
}
const bigiszero = {v
-> v.dig.len == 0
}
const bigcmp = {a, b
var i
if a.sign < b.sign
-> `Before
elif a.sign > b.sign
-> `After
else
/* the one with more digits has greater magnitude */
if a.dig.len > b.dig.len
-> signedorder(a.sign)
;;
/* otherwise, the one with the first larger digit is bigger */
for i = a.dig.len; i > 0; i--
if a.dig[i - 1] > b.dig[i - 1]
-> signedorder(a.sign)
elif b.dig[i - 1] > a.dig[i - 1]
-> signedorder(a.sign)
;;
;;
;;
-> `Equal
}
const signedorder = {sign
if sign < 0
-> `Before
else
-> `After
;;
}
/* a += b */
const bigadd = {a, b
if a.sign == b.sign || a.sign == 0
a.sign = b.sign
-> uadd(a, b)
elif b.sign == 0
-> a
else
match bigcmp(a, b)
| `Before: /* a is negative */
a.sign = b.sign
-> usub(b, a)
| `After: /* b is negative */
-> usub(a, b)
| `Equal:
die("Impossible. Equal vals with different sign.")
;;
;;
}
/* adds two unsigned values together. */
const uadd = {a, b
var v, i
var carry
var n
carry = 0
n = max(a.dig.len, b.dig.len)
/* guaranteed to carry no more than one value */
a.dig = slzgrow(a.dig, n + 1)
for i = 0; i < n; i++
v = (a.dig[i] castto(uint64)) + (b.dig[i] castto(uint64)) + carry;
if v >= Base
carry = 1
else
carry = 0
;;
a.dig[i] = v castto(uint32)
;;
a.dig[i] += carry castto(uint32)
-> trim(a)
}
/* a -= b */
const bigsub = {a, b
/* 0 - x = -x */
if a.sign == 0
bigassign(a, b)
a.sign = -b.sign
-> a
/* x - 0 = x */
elif b.sign == 0
-> a
elif a.sign != b.sign
-> uadd(a, b)
else
match bigcmp(a, b)
| `Before: /* a is negative */
a.sign = b.sign
-> usub(b, a)
| `After: /* b is negative */
-> usub(a, b)
| `Equal:
die("Impossible. Equal vals with different sign.")
;;
;;
-> a
}
/* subtracts two unsigned values, where 'a' is strictly greater than 'b' */
const usub = {a, b
var carry
var v, i
carry = 0
for i = 0; i < a.dig.len; i++
v = (a.dig[i] castto(int64)) - (b.dig[i] castto(int64)) - carry
if v < 0
carry = 1
else
carry = 0
;;
a.dig[i] = v castto(uint32)
;;
-> trim(a)
}
/* a *= b */
const bigmul = {a, b
var i, j
var ai, bj, wij
var carry, t
var w
if a.sign == 0 || b.sign == 0
a.sign = 0
slfree(a.dig)
a.dig = [][:]
-> a
elif a.sign != b.sign
a.sign = -1
else
a.sign = 1
;;
w = slzalloc(a.dig.len + b.dig.len)
for j = 0; j < b.dig.len; j++
if a.dig[j] == 0
w[j] = 0
continue
;;
carry = 0
for i = 0; i < a.dig.len; i++
ai = a.dig[i] castto(uint64)
bj = b.dig[j] castto(uint64)
wij = w[i+j] castto(uint64)
t = ai * bj + wij + carry
w[i + j] = (t castto(uint32))
carry = t >> 32
;;
w[i+j] = carry castto(uint32)
;;
slfree(a.dig)
a.dig = w
-> trim(a)
}
const bigdiv = {a : bigint#, b : bigint# -> bigint#
var q, r
(q, r) = bigdivmod(a, b)
bigfree(r)
slfree(a.dig)
a.dig = q.dig
free(q)
-> a
}
const bigmod = {a : bigint#, b : bigint# -> bigint#
var q, r
(q, r) = bigdivmod(a, b)
bigfree(q)
slfree(a.dig)
a.dig = r.dig
free(r)
-> a
}
/* a /= b */
const bigdivmod = {a : bigint#, b : bigint# -> (bigint#, bigint#)
/*
Implements bigint division using Algorithm D from
Knuth: Seminumerical algorithms, Section 4.3.1.
*/
var qhat, rhat, carry, shift
var x, y, z, w, p, t /* temporaries */
var b0, aj
var u, v
var m : int64, n : int64
var i, j : int64
var q
if bigiszero(b)
die("divide by zero\n")
;;
/* if b > a, we trucate to 0, with remainder 'a' */
if a.dig.len < b.dig.len
-> (mkbigint(0), bigdup(a))
;;
q = zalloc()
q.dig = slzalloc(max(a.dig.len, b.dig.len) + 1)
if a.sign != b.sign
q.sign = -1
else
q.sign = 1
;;
/* handle single digit divisor separately: the knuth algorithm needs at least 2 digits. */
if b.dig.len == 1
carry = 0
b0 = (b.dig[0] castto(uint64))
for j = a.dig.len; j > 0; j--
aj = (a.dig[j - 1] castto(uint64))
q.dig[j - 1] = (((carry << 32) + aj)/b0) castto(uint32)
carry = (carry << 32) + aj - (q.dig[j-1] castto(uint64))*b0
;;
-> (trim(q), trim(mkbigint(carry castto(int32))))
;;
u = bigdup(a)
v = bigdup(b)
m = u.dig.len
n = v.dig.len
shift = nlz(v.dig[n - 1])
bigshli(u, shift)
bigshli(v, shift)
for j = m - n; j >= 0; j--
/* load a few temps */
x = u.dig[j + n] castto(uint64)
y = u.dig[j + n - 1] castto(uint64)
z = v.dig[n - 1] castto(uint64)
w = v.dig[n - 2] castto(uint64)
t = u.dig[j + n - 2] castto(uint64)
/* estimate qhat */
qhat = (x*Base + y)/z
rhat = (x*Base + y) - (qhat * z)
:divagain
if qhat >= Base || (qhat * w) > (rhat*Base + t)
qhat--
rhat += z
if rhat < Base
goto divagain
;;
;;
/* multiply and subtract */
carry = 0
for i = 0; i < n; i++
p = qhat * (v.dig[i] castto(uint64))
t = (u.dig[i+j] castto(uint64)) - carry - (p % Base)
u.dig[i+j] = t castto(uint32)
carry = (((p castto(int64)) >> 32) - ((t castto(int64)) >> 32)) castto(uint64);
;;
x = u.dig[j + n] castto(uint64)
t = x - carry
u.dig[j + n] = t castto(uint32)
q.dig[j] = qhat castto(uint32)
/* adjust */
if x < carry
q.dig[j]--
carry = 0
for i = 0; i < n; i++
t = (u.dig[i+j] castto(uint64)) + (v.dig[i] castto(uint64)) + carry
u.dig[i+j] = t castto(uint32)
carry = t >> 32
;;
u.dig[j+n] = u.dig[j+n] + (carry castto(uint32));
;;
;;
/* undo the biasing for remainder */
bigshri(u, shift)
-> (trim(q), trim(u))
}
/* returns the number of leading zeros */
const nlz = {a : uint32
var n
if a == 0
-> 32
;;
n = 0
if a <= 0x0000ffff
n += 16
a <<= 16
;;
if a <= 0x00ffffff
n += 8
a <<= 8
;;
if a <= 0x0fffffff
n += 4
a <<= 4
;;
if a <= 0x3fffffff
n += 2
a <<= 2
;;
if a <= 0x7fffffff
n += 1
a <<= 1
;;
-> n
}
/* a <<= b */
const bigshl = {a, b
match b.dig.len
| 0: -> a
| 1: -> bigshli(a, b.dig[0] castto(uint64))
| n: die("shift by way too much\n")
;;
}
/* a >>= b, unsigned */
const bigshr = {a, b
match b.dig.len
| 0: -> a
| 1: -> bigshri(a, b.dig[0] castto(uint64))
| n: die("shift by way too much\n")
;;
}
/* a + b, b is integer.
FIXME: acually make this a performace improvement
*/
const bigaddi = {a, b
var bigb
bigb = mkbigint(b)
bigadd(a, bigb)
bigfree(bigb)
-> a
}
const bigsubi = {a, b
var bigb
bigb = mkbigint(b)
bigsub(a, bigb)
bigfree(bigb)
-> a
}
const bigmuli = {a, b
var bigb
bigb = mkbigint(b)
bigmul(a, bigb)
bigfree(bigb)
-> a
}
const bigdivi = {a, b
var bigb
bigb = mkbigint(b)
bigdiv(a, bigb)
bigfree(bigb)
-> a
}
/*
a << s, with integer arg.
logical left shift. any other type would be illogical.
*/
const bigshli = {a, s
var off, shift
var t, carry
var i
off = s/32
shift = s % 32
a.dig = slzgrow(a.dig, 1 + a.dig.len + off castto(size))
/* blit over the base values */
for i = a.dig.len; i > off; i--
a.dig[i - 1] = a.dig[i - 1 - off]
;;
for i = 0; i < off; i++
a.dig[i] = 0
;;
/* and shift over by the remainder */
carry = 0
for i = 0; i < a.dig.len; i++
t = (a.dig[i] castto(uint64)) << shift
a.dig[i] = (t | carry) castto(uint32)
carry = t >> 32
;;
-> trim(a)
}
/* logical shift right, zero fills. sign remains untouched. */
const bigshri = {a, s
var off, shift
var t, carry
var i
off = s/32
shift = s % 32
/* blit over the base values */
for i = 0; i < a.dig.len - off; i++
a.dig[i] = a.dig[i + off]
;;
for i = a.dig.len; i < a.dig.len + off; i++
a.dig[i] = 0
;;
/* and shift over by the remainder */
carry = 0
for i = a.dig.len; i > 0; i--
t = (a.dig[i - 1] castto(uint64))
a.dig[i - 1] = (carry | (t >> shift)) castto(uint32)
carry = t << (32 - shift)
;;
-> trim(a)
}
/* trims leading zeros */
const trim = {a
var i
for i = a.dig.len; i > 0; i--
if a.dig[i - 1] != 0
break
;;
;;
a.dig = slgrow(a.dig, i)
if i == 0
a.sign = 0
elif a.sign == 0
a.sign = 1
;;
-> a
}