ref: 77d584ee227d829e842efdc938dffa322665b22c
parent: f3265a99202d98f6f7e5969e34faae934ce754b9
author: S. Gilles <sgilles@math.umd.edu>
date: Tue Mar 20 05:15:36 EDT 2018
Implement fma64
--- a/lib/math/bld.sub
+++ b/lib/math/bld.sub
@@ -7,5 +7,8 @@
# summation
fpmath-sum-impl.myr
+ # fused-multiply-add
+ fpmath-fma-impl.myr
+
lib ../std:std
;;
--- /dev/null
+++ b/lib/math/fpmath-fma-impl.myr
@@ -1,0 +1,400 @@
+use std
+
+pkg math =
+ pkglocal const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32)
+ pkglocal const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64)
+;;
+
+const exp_mask32 : uint32 = 0xff << 23
+const exp_mask64 : uint64 = 0x7ff << 52
+
+pkglocal const fma32 = {x : flt32, y : flt32, z : flt32
+ -> 0.0
+}
+
+pkglocal const fma64 = {x : flt64, y : flt64, z : flt64
+ var xn : bool, yn : bool, zn : bool
+ var xe : int64, ye : int64, ze : int64
+ var xs : uint64, ys : uint64, zs : uint64
+
+ var xb : uint64 = std.flt64bits(x)
+ var yb : uint64 = std.flt64bits(y)
+ var zb : uint64 = std.flt64bits(z)
+
+ /* check for both NaNs and infinities */
+ if xb & exp_mask64 == exp_mask64 || \
+ yb & exp_mask64 == exp_mask64
+ -> x * y + z
+ elif z == 0.0 || z == -0.0 || x * y == 0.0 || x * y == -0.0
+ -> x * y + z
+ elif zb & exp_mask64 == exp_mask64
+ -> z
+ ;;
+
+ (xn, xe, xs) = std.flt64explode(x)
+ (yn, ye, ys) = std.flt64explode(y)
+ (zn, ze, zs) = std.flt64explode(z)
+ if xe == -1023
+ xe = -1022
+ ;;
+ if ye == -1023
+ ye = -1022
+ ;;
+ if ze == -1023
+ ze = -1022
+ ;;
+
+ /* Keep product in high/low uint64s */
+ var xs_h : uint64 = xs >> 32
+ var ys_h : uint64 = ys >> 32
+ var xs_l : uint64 = xs & 0xffffffff
+ var ys_l : uint64 = ys & 0xffffffff
+
+ var t_l : uint64 = xs_l * ys_l
+ var t_m : uint64 = xs_l * ys_h + xs_h * ys_l
+ var t_h : uint64 = xs_h * ys_h
+
+ var prod_l : uint64 = t_l + (t_m << 32)
+ var prod_h : uint64 = t_h + (t_m >> 32)
+ if t_l > prod_l
+ prod_h++
+ ;;
+
+ var prod_n = xn != yn
+ var prod_lastbit_e = (xe - 52) + (ye - 52)
+ var prod_first1 = find_first1_64_hl(prod_h, prod_l, 105)
+ var prod_firstbit_e = prod_lastbit_e + prod_first1
+
+ var z_firstbit_e = ze
+ var z_lastbit_e = ze - 52
+ var z_first1 = 52
+
+ /* subnormals could throw firstbit_e calculations out of whack */
+ if (zb & exp_mask64 == 0)
+ z_first1 = find_first1_64(zs, z_first1)
+ z_firstbit_e = z_lastbit_e + z_first1
+ ;;
+
+ var res_n
+ var res_h = 0
+ var res_l = 0
+ var res_first1
+ var res_lastbit_e
+ var res_firstbit_e
+
+ if prod_n == zn
+ res_n = prod_n
+
+ /*
+ Align prod and z so that the top bit of the
+ result is either 53 or 54, then add.
+ */
+ if prod_firstbit_e >= z_firstbit_e
+ /*
+ [ prod_h ][ prod_l ]
+ [ z...
+ */
+ res_lastbit_e = prod_lastbit_e
+ (res_h, res_l) = (prod_h, prod_l)
+ (res_h, res_l) = add_shifted(res_h, res_l, zs, z_lastbit_e - prod_lastbit_e)
+ else
+ /*
+ [ prod_h ][ prod_l ]
+ [ z...
+ */
+ res_lastbit_e = z_lastbit_e - 64
+ res_h = zs
+ res_l = 0
+ if prod_lastbit_e >= res_lastbit_e + 64
+ /* In this situation, prod must be extremely subnormal */
+ res_h += shl(prod_l, prod_lastbit_e - res_lastbit_e - 64)
+ elif prod_lastbit_e >= res_lastbit_e
+ res_h += shl(prod_h, prod_lastbit_e - res_lastbit_e)
+ res_h += shr(prod_l, res_lastbit_e + 64 - prod_lastbit_e)
+ res_l += shl(prod_l, prod_lastbit_e - res_lastbit_e)
+ elif prod_lastbit_e + 64 >= res_lastbit_e
+ res_h += shr(prod_h, res_lastbit_e - prod_lastbit_e)
+ var l1 = shl(prod_h, prod_lastbit_e + 64 - res_lastbit_e)
+ var l2 = shr(prod_l, res_lastbit_e - prod_lastbit_e)
+ res_l = l1 + l2
+ if res_l < l1
+ res_h++
+ ;;
+ elif prod_lastbit_e + 128 >= res_lastbit_e
+ res_l += shr(prod_h, res_lastbit_e - prod_lastbit_e - 64)
+ ;;
+ ;;
+ else
+ match compare_hl_z(prod_h, prod_l, prod_firstbit_e, prod_lastbit_e, zs, z_firstbit_e, z_lastbit_e)
+ | `std.Equal: -> 0.0
+ | `std.Before:
+ /* prod > z */
+ res_n = prod_n
+ res_lastbit_e = prod_lastbit_e
+ (res_h, res_l) = sub_shifted(prod_h, prod_l, zs, z_lastbit_e - prod_lastbit_e)
+ | `std.After:
+ /* z > prod */
+ res_n = zn
+ res_lastbit_e = z_lastbit_e - 64
+ (res_h, res_l) = sub_shifted(zs, 0, prod_h, prod_lastbit_e + 64 - (z_lastbit_e - 64))
+ (res_h, res_l) = sub_shifted(res_h, res_l, prod_l, prod_lastbit_e - (z_lastbit_e - 64))
+ ;;
+ ;;
+
+ res_first1 = 64 + find_first1_64(res_h, 55)
+ if res_first1 == 63
+ res_first1 = find_first1_64(res_l, 63)
+ ;;
+ res_firstbit_e = res_first1 + res_lastbit_e
+
+ /*
+ Finally, res_h and res_l are the high and low bits of
+ the result. They now need to be assembled into a flt64.
+ Subnormals and infinities could be a problem.
+ */
+ var res_s = 0
+ if res_firstbit_e <= -1023
+ /* Subnormal case */
+ if res_lastbit_e + 128 < -1022
+ res_s = shr(res_h, 12 - 1022 - (res_lastbit_e + 128))
+ res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64))
+ elif res_lastbit_e + 64 < -1022
+ res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022))
+ res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64))
+ else
+ res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022))
+ res_s |= shl(res_l, -12 + (res_lastbit_e + 64) - (-1022))
+ ;;
+
+ if need_round_away(res_h, res_l, res_first1 + (-1074 - res_firstbit_e))
+ res_s++
+ ;;
+
+ /* No need for exponents, they are all zero */
+ var res = res_s
+ if res_n
+ res |= (1 << 63)
+ ;;
+ -> std.flt64frombits(res)
+ ;;
+
+ if res_firstbit_e >= 1024
+ /* Infinity case */
+ if res_n
+ -> std.flt64frombits(0xfff0000000000000)
+ else
+ -> std.flt64frombits(0x7ff0000000000000)
+ ;;
+ ;;
+
+ if res_first1 - 52 >= 64
+ res_s = shr(res_h, (res_first1 : int64) - 64 - 52)
+ if need_round_away(res_h, res_l, res_first1 - 52)
+ res_s++
+ ;;
+ elif res_first1 - 52 >= 0
+ res_s = shl(res_h, 64 - (res_first1 - 52))
+ res_s |= shr(res_l, res_first1 - 52)
+ if need_round_away(res_h, res_l, res_first1 - 52)
+ res_s++
+ ;;
+ else
+ res_s = shl(res_h, res_first1 - 52)
+ ;;
+
+ /* The res_s++s might have messed everything up */
+ if res_s & (1 << 53) != 0
+ res_s >= 1
+ res_firstbit_e++
+ if res_firstbit_e >= 1024
+ if res_n
+ -> std.flt64frombits(0xfff0000000000000)
+ else
+ -> std.flt64frombits(0x7ff0000000000000)
+ ;;
+ ;;
+ ;;
+
+ -> std.flt64assem(res_n, res_firstbit_e, res_s)
+}
+
+/* >> and <<, but without wrapping when the shift is >= 64 */
+const shr = {u : uint64, s : int64
+ if (s : uint64) >= 64
+ -> 0
+ else
+ -> u >> (s : uint64)
+ ;;
+}
+
+const shl = {u : uint64, s : int64
+ if (s : uint64) >= 64
+ -> 0
+ else
+ -> u << (s : uint64)
+ ;;
+}
+
+/*
+ Add (a << s) to [ h ][ l ], where if s < 0 then a corresponding
+ right-shift is used. This is aligned such that if s == 0, then
+ the result is [ h ][ l + a ]
+ */
+const add_shifted = {h : uint64, l : uint64, a : uint64, s : int64
+ if s >= 64
+ -> (h + shl(a, s - 64), l)
+ elif s >= 0
+ var new_h = h + shr(a, 64 - s)
+ var sa = shl(a, s)
+ var new_l = l + sa
+ if new_l < l
+ new_h++
+ ;;
+ -> (new_h, new_l)
+ else
+ var new_h = h
+ var sa = shr(a, -s)
+ var new_l = l + sa
+ if new_l < l
+ new_h++
+ ;;
+ -> (new_h, new_l)
+ ;;
+}
+
+/* As above, but subtract (a << s) */
+const sub_shifted = {h : uint64, l : uint64, a : uint64, s : int64
+ if s >= 64
+ -> (h - shl(a, s - 64), l)
+ elif s >= 0
+ var new_h = h - shr(a, 64 - s)
+ var sa = shl(a, s)
+ var new_l = l - sa
+ if sa > l
+ new_h--
+ ;;
+ -> (new_h, new_l)
+ else
+ var new_h = h
+ var sa = shr(a, -s)
+ var new_l = l - sa
+ if sa > l
+ new_h--
+ ;;
+ -> (new_h, new_l)
+ ;;
+}
+
+const compare_hl_z = {h : uint64, l : uint64, hl_firstbit_e : int64, hl_lastbit_e : int64, z : uint64, z_firstbit_e : int64, z_lastbit_e : int64
+ if hl_firstbit_e > z_firstbit_e
+ -> `std.Before
+ elif hl_firstbit_e < z_firstbit_e
+ -> `std.After
+ ;;
+
+ var h_k : int64 = (hl_firstbit_e - hl_lastbit_e - 64)
+ var z_k : int64 = (z_firstbit_e - z_lastbit_e)
+ while h_k >= 0 && z_k >= 0
+ var h1 = h & shl(1, h_k) != 0
+ var z1 = z & shl(1, z_k) != 0
+ if h1 && !z1
+ -> `std.Before
+ elif !h1 && z1
+ -> `std.After
+ ;;
+ h_k--
+ z_k--
+ ;;
+
+ if z_k < 0
+ if (h & shr((-1 : uint64), 64 - h_k) != 0) || (l != 0)
+ -> `std.Before
+ else
+ -> `std.Equal
+ ;;
+ ;;
+
+ var l_k : int64 = 63
+ while l_k >= 0 && z_k >= 0
+ var l1 = l & shl(1, l_k) != 0
+ var z1 = z & shl(1, z_k) != 0
+ if l1 && !z1
+ -> `std.Before
+ elif !l1 && z1
+ -> `std.After
+ ;;
+ l_k--
+ z_k--
+ ;;
+
+ if (z_k < 0) && (l & shr((-1 : uint64), 64 - l_k) != 0)
+ -> `std.Before
+ elif (l_k < 0) && (z & shr((-1 : uint64), 64 - z_k) != 0)
+ -> `std.After
+ ;;
+
+ -> `std.Equal
+}
+
+/* Find the first 1 bit in a bitstring */
+const find_first1_64 = {b : uint64, start : int64
+ for var j = start; j >= 0; --j
+ var m = shl(1, j)
+ if b & m != 0
+ -> j
+ ;;
+ ;;
+
+ -> -1
+}
+
+const find_first1_64_hl = {h, l, start
+ var first1_h = find_first1_64(h, start - 64)
+ if first1_h >= 0
+ -> first1_h + 64
+ ;;
+
+ -> find_first1_64(l, 63)
+}
+
+/*
+ For [ h ][ l ], where bitpos_last is the position of the last
+ bit that was included in the 53-bit wide result (l's last bit
+ has position 0), decide whether rounding up/away is needed. This
+ is true if
+
+ - following bitpos_last is a 1, then a non-zero sequence, or
+
+ - following bitpos_last is a 1, then a zero sequence, and the
+ round would be to even
+ */
+const need_round_away = {h : uint64, l : uint64, bitpos_last : int64
+ var first_omitted_is_1 = false
+ var nonzero_beyond = false
+ if bitpos_last > 64
+ first_omitted_is_1 = h & shl(1, bitpos_last - 1 - 64) != 0
+ nonzero_beyond = h & shr((-1 : uint64), 2 + 64 - (bitpos_last - 64)) != 0
+ nonzero_beyond = nonzero_beyond || (l != 0)
+ else
+ first_omitted_is_1 = l & shl(1, bitpos_last - 1) != 0
+ nonzero_beyond = l & shr((-1 : uint64), 2 + 64 - bitpos_last) != 0
+ ;;
+
+ if !first_omitted_is_1
+ -> false
+ ;;
+
+ if nonzero_beyond
+ -> true
+ ;;
+
+ var hl_is_odd = false
+
+ if bitpos_last >= 64
+ hl_is_odd = h & shl(1, bitpos_last - 64) != 0
+ else
+ hl_is_odd = l & shl(1, bitpos_last) != 0
+ ;;
+
+ -> hl_is_odd
+}
--- a/lib/math/fpmath.myr
+++ b/lib/math/fpmath.myr
@@ -3,12 +3,15 @@
pkg math =
trait fpmath @f =
+ /* fpmath-fma-impl */
+ fma : (x : @f, y : @f, z : @f -> @f)
+
/* fpmath-trunc-impl */
trunc : (f : @f -> @f)
ceil : (f : @f -> @f)
floor : (f : @f -> @f)
- /* summation */
+ /* fpmath-sum-impl */
kahan_sum : (a : @f[:] -> @f)
priest_sum : (a : @f[:] -> @f)
;;
@@ -36,6 +39,8 @@
;;
impl fpmath flt32 =
+ fma = {x, y, z; -> fma32(x, y, z)}
+
trunc = {f; -> trunc32(f)}
floor = {f; -> floor32(f)}
ceil = {f; -> ceil32(f)}
@@ -45,6 +50,8 @@
;;
impl fpmath flt64 =
+ fma = {x, y, z; -> fma64(x, y, z)}
+
trunc = {f; -> trunc64(f)}
floor = {f; -> floor64(f)}
ceil = {f; -> ceil64(f)}
@@ -52,6 +59,9 @@
kahan_sum = {l; -> kahan_sum64(l) }
priest_sum = {l; -> priest_sum64(l) }
;;
+
+extern const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32)
+extern const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64)
extern const trunc32 : (f : flt32 -> flt32)
extern const floor32 : (f : flt32 -> flt32)
--- /dev/null
+++ b/lib/math/test/fpmath-fma-impl.myr
@@ -1,0 +1,46 @@
+use std
+use math
+use testr
+
+const main = {
+ testr.run([
+ [.name="fma-01", .fn = fma01],
+ [.name="fma-02", .fn = fma02],
+ ][:])
+}
+
+const fma01 = {c
+ /* Put the flt32 tests here */
+}
+
+const fma02 = {c
+ var inputs : (uint64, uint64, uint64, uint64)[:] = [
+ /*
+ These are mostly obtained by running fpmath-consensus
+ with seed 1234. Each covers a different corner case.
+ */
+ (0x0000000000000000, 0x0000000000000000, 0x0100000000000000, 0x0100000000000000),
+ (0x0000000000000000, 0x0000000000000000, 0x0200000000000000, 0x0200000000000000),
+ (0x00000000000009a4, 0x6900000000000002, 0x6834802690008002, 0x6834802690008002),
+ (0x49113e8d334802ab, 0x5c35d8c190aea62e, 0x6c1a8cc212dcb6e2, 0x6c1a8cc212dcb6e2),
+ (0x2b3e04e8f6266d83, 0xae84e20f62f99bda, 0xc9115a1ccea6ce27, 0xc9115a1ccea6ce27),
+ (0xa03ea9e9b09d932c, 0xded7bc19edcbf0c7, 0xbbc4c1f83b3f8f2e, 0x3f26be5f0c7b48e3),
+ (0xa5ec2141c1e6f339, 0xa2d80fc217f57b61, 0x00b3484b473ef1b8, 0x08d526cb86ee748d),
+ (0xccc6600ee88bb67c, 0xc692eeec9b51cf0f, 0xbf5f1ae3486401b0, 0x536a7a30857129db),
+ (0x5f9b9e449db17602, 0xbef22ae5b6a2b1c5, 0x6133e925e6bf8a12, 0x6133e925e6bf823b),
+ (0x7f851249841b6278, 0x3773388e53a375f4, 0x761c27fc2ffa57be, 0x7709506b0e99dc30),
+ (0x7c7cb20f3ca8af93, 0x800fd7f5cfd5baae, 0x14e4c09c9bb1e17e, 0xbc9c6a3fd0e58682),
+ (0xb5e8db2107f4463f, 0x614af740c0d7eb3b, 0xd7e3d25c4daa81e0, 0xd7e3d798d3ccdffb),
+ (0xae62c8be4cb45168, 0x90cc5236f3516c90, 0x0007f8b14f684558, 0x0007f9364eb1a815),
+ (0x5809f53e32a7e1ba, 0xcc647611ccaa5bf4, 0xdfbdb5c345ce7a56, 0xe480990da5526103),
+ ][:]
+
+ for (x, y, z, r) : inputs
+ var xf : flt64 = std.flt64frombits(x)
+ var yf : flt64 = std.flt64frombits(y)
+ var zf : flt64 = std.flt64frombits(z)
+ var rf = math.fma(xf, yf, zf)
+ testr.check(c, rf == std.flt64frombits(r),
+ "0x{b=16,w=16,p=0} * 0x{b=16,w=16,p=0} + 0x{b=16,w=16,p=0} should be 0x{b=16,w=16,p=0}, was 0x{b=16,w=16,p=0}", x, y, z, r, std.flt64bits(rf))
+ ;;
+}