ref: 3c881e8b133876fc8cff92106616bb93166bf113
dir: /lib/math/fpmath-fma-impl.myr/
use std pkg math = pkglocal const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32) pkglocal const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64) ;; const exp_mask32 : uint32 = 0xff << 23 const exp_mask64 : uint64 = 0x7ff << 52 pkglocal const fma32 = {x : flt32, y : flt32, z : flt32 var xn, yn (xn, _, _) = std.flt32explode(x) (yn, _, _) = std.flt32explode(y) var xd : flt64 = flt64fromflt32(x) var yd : flt64 = flt64fromflt32(y) var zd : flt64 = flt64fromflt32(z) var prod : flt64 = xd * yd var pn, pe, ps (pn, pe, ps) = std.flt64explode(prod) if pe == -1023 pe = -1022 ;; if pn != (xn != yn) /* In case of NaNs, sign might not have been preserved */ pn = (xn != yn) prod = std.flt64assem(pn, pe, ps) ;; var r : flt64 = prod + zd var rn, re, rs (rn, re, rs) = std.flt64explode(r) /* At this point, r is probably the correct answer. The only issue is that if truncating r to a flt32 causes rounding-to-even, and it was obtained by rounding in the first place, direction, then we'd be over-rounding. The only way this could possibly be a problem is if the product step rounded to the halfway point (a 100...0, with the 1 just outside truncation range). */ if re == 1024 || rs & 0x1fffffff != 0x10000000 -> flt32fromflt64(r) ;; /* We can check if rounding was performed by undoing */ if flt32fromflt64(r - prod) == z -> flt32fromflt64(r) ;; /* At this point, there's definitely about to be a rounding error. To figure out what to do, compute prod + z with round-to-zero. If we get r again, then it's okay to round r upwards, because it hasn't been rounded away from zero yet and we allow ourselves one such rounding. */ var zn, ze, zs (zn, ze, zs) = std.flt64explode(zd) if ze == -1023 ze = -1022 ;; var rtn, rte, rts if pe >= ze && pn == zn (rtn, rte, rts) = (pn, pe, ps) rts += shr(zs, pe - ze) elif pe > ze || (pe == ze && ps > zs) (rtn, rte, rts) = (pn, pe, ps) rts -= shr(zs, pe - ze) if shr((-1 : uint64), 64 - std.min(64, (pe - ze))) & zs != 0 rts-- ;; elif pe < ze && pn == zn (rtn, rte, rts) = (zn, ze, zs) rts += shr(ps, ze - pe) else (rtn, rte, rts) = (zn, ze, zs) rts -= shr(ps, ze - pe) if shr((-1 : uint64), 64 - std.min(64, (ze - pe))) & ps != 0 rts-- ;; ;; if rts & (1 << 53) != 0 rts >>= 1 rte++ ;; if rn == rtn && rs == rts && re == rte rts++ if rts & (1 << 53) != 0 rts >>= 1 rte++ ;; ;; -> flt32fromflt64(std.flt64assem(rtn, rte, rts)) } const flt64fromflt32 = {f : flt32 var n, e, s (n, e, s) = std.flt32explode(f) var xs : uint64 = (s : uint64) var xe : int64 = (e : int64) if e == 128 -> std.flt64assem(n, 1024, xs) elif e == -127 /* All subnormals in single precision (except 0.0s) can be upgraded to double precision, since the exponent range is so much wider. */ var first1 = find_first1_64(xs, 23) if first1 < 0 -> std.flt64assem(n, -1023, 0) ;; xs = xs << (52 - (first1 : uint64)) xe = -126 - (23 - first1) -> std.flt64assem(n, xe, xs) ;; -> std.flt64assem(n, xe, xs << (52 - 23)) } const flt32fromflt64 = {f : flt64 var n : bool, e : int64, s : uint64 (n, e, s) = std.flt64explode(f) var ts : uint32 var te : int32 = (e : int32) if e >= 128 if e == 1023 && s != 0 /* NaN */ -> std.flt32assem(n, 128, 1) else /* infinity */ -> std.flt32assem(n, 128, 0) ;; ;; if e >= -126 /* normal */ ts = ((s >> (52 - 23)) : uint32) if need_round_away(0, s, 52 - 23) ts++ if ts & (1 << 24) != 0 ts >>= 1 te++ ;; ;; -> std.flt32assem(n, te, ts) ;; /* subnormal already, will have to go to 0 */ if e == -1023 -> std.flt32assem(n, -127, 0) ;; /* subnormal (at least, it will be) */ te = -127 var shift : int64 = (52 - 23) + (-126 - e) var ts1 = shr(s, shift) ts = (ts1 : uint32) if need_round_away(0, s, shift) ts++ ;; -> std.flt32assem(n, te, ts) } pkglocal const fma64 = {x : flt64, y : flt64, z : flt64 var xn : bool, yn : bool, zn : bool var xe : int64, ye : int64, ze : int64 var xs : uint64, ys : uint64, zs : uint64 var xb : uint64 = std.flt64bits(x) var yb : uint64 = std.flt64bits(y) var zb : uint64 = std.flt64bits(z) /* check for both NaNs and infinities */ if xb & exp_mask64 == exp_mask64 || \ yb & exp_mask64 == exp_mask64 -> x * y + z elif z == 0.0 || z == -0.0 || x * y == 0.0 || x * y == -0.0 -> x * y + z elif zb & exp_mask64 == exp_mask64 -> z ;; (xn, xe, xs) = std.flt64explode(x) (yn, ye, ys) = std.flt64explode(y) (zn, ze, zs) = std.flt64explode(z) if xe == -1023 xe = -1022 ;; if ye == -1023 ye = -1022 ;; if ze == -1023 ze = -1022 ;; /* Keep product in high/low uint64s */ var xs_h : uint64 = xs >> 32 var ys_h : uint64 = ys >> 32 var xs_l : uint64 = xs & 0xffffffff var ys_l : uint64 = ys & 0xffffffff var t_l : uint64 = xs_l * ys_l var t_m : uint64 = xs_l * ys_h + xs_h * ys_l var t_h : uint64 = xs_h * ys_h var prod_l : uint64 = t_l + (t_m << 32) var prod_h : uint64 = t_h + (t_m >> 32) if t_l > prod_l prod_h++ ;; var prod_n = xn != yn var prod_lastbit_e = (xe - 52) + (ye - 52) var prod_first1 = find_first1_64_hl(prod_h, prod_l, 105) var prod_firstbit_e = prod_lastbit_e + prod_first1 var z_firstbit_e = ze var z_lastbit_e = ze - 52 var z_first1 = 52 /* subnormals could throw firstbit_e calculations out of whack */ if (zb & exp_mask64 == 0) z_first1 = find_first1_64(zs, z_first1) z_firstbit_e = z_lastbit_e + z_first1 ;; var res_n var res_h = 0 var res_l = 0 var res_first1 var res_lastbit_e var res_firstbit_e if prod_n == zn res_n = prod_n /* Align prod and z so that the top bit of the result is either 53 or 54, then add. */ if prod_firstbit_e >= z_firstbit_e /* [ prod_h ][ prod_l ] [ z... */ res_lastbit_e = prod_lastbit_e (res_h, res_l) = (prod_h, prod_l) (res_h, res_l) = add_shifted(res_h, res_l, zs, z_lastbit_e - prod_lastbit_e) else /* [ prod_h ][ prod_l ] [ z... */ res_lastbit_e = z_lastbit_e - 64 res_h = zs res_l = 0 if prod_lastbit_e >= res_lastbit_e + 64 /* In this situation, prod must be extremely subnormal */ res_h += shl(prod_l, prod_lastbit_e - res_lastbit_e - 64) elif prod_lastbit_e >= res_lastbit_e res_h += shl(prod_h, prod_lastbit_e - res_lastbit_e) res_h += shr(prod_l, res_lastbit_e + 64 - prod_lastbit_e) res_l += shl(prod_l, prod_lastbit_e - res_lastbit_e) elif prod_lastbit_e + 64 >= res_lastbit_e res_h += shr(prod_h, res_lastbit_e - prod_lastbit_e) var l1 = shl(prod_h, prod_lastbit_e + 64 - res_lastbit_e) var l2 = shr(prod_l, res_lastbit_e - prod_lastbit_e) res_l = l1 + l2 if res_l < l1 res_h++ ;; elif prod_lastbit_e + 128 >= res_lastbit_e res_l += shr(prod_h, res_lastbit_e - prod_lastbit_e - 64) ;; ;; else match compare_hl_z(prod_h, prod_l, prod_firstbit_e, prod_lastbit_e, zs, z_firstbit_e, z_lastbit_e) | `std.Equal: -> 0.0 | `std.Before: /* prod > z */ res_n = prod_n res_lastbit_e = prod_lastbit_e (res_h, res_l) = sub_shifted(prod_h, prod_l, zs, z_lastbit_e - prod_lastbit_e) | `std.After: /* z > prod */ res_n = zn res_lastbit_e = z_lastbit_e - 64 (res_h, res_l) = sub_shifted(zs, 0, prod_h, prod_lastbit_e + 64 - (z_lastbit_e - 64)) (res_h, res_l) = sub_shifted(res_h, res_l, prod_l, prod_lastbit_e - (z_lastbit_e - 64)) ;; ;; res_first1 = 64 + find_first1_64(res_h, 55) if res_first1 == 63 res_first1 = find_first1_64(res_l, 63) ;; res_firstbit_e = res_first1 + res_lastbit_e /* Finally, res_h and res_l are the high and low bits of the result. They now need to be assembled into a flt64. Subnormals and infinities could be a problem. */ var res_s = 0 if res_firstbit_e <= -1023 /* Subnormal case */ if res_lastbit_e + 128 < 12 - 1022 res_s = shr(res_h, 12 - 1022 - (res_lastbit_e + 128)) res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64)) elif res_lastbit_e + 64 < 12 - 1022 res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022)) res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64)) else res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022)) res_s |= shl(res_l, -12 + (res_lastbit_e + 64) - (-1022)) ;; if need_round_away(res_h, res_l, res_first1 + (-1074 - res_firstbit_e)) res_s++ ;; /* No need for exponents, they are all zero */ var res = res_s if res_n res |= (1 << 63) ;; -> std.flt64frombits(res) ;; if res_firstbit_e >= 1024 /* Infinity case */ if res_n -> std.flt64frombits(0xfff0000000000000) else -> std.flt64frombits(0x7ff0000000000000) ;; ;; if res_first1 - 52 >= 64 res_s = shr(res_h, (res_first1 : int64) - 64 - 52) if need_round_away(res_h, res_l, res_first1 - 52) res_s++ ;; elif res_first1 - 52 >= 0 res_s = shl(res_h, 64 - (res_first1 - 52)) res_s |= shr(res_l, res_first1 - 52) if need_round_away(res_h, res_l, res_first1 - 52) res_s++ ;; else res_s = shl(res_h, res_first1 - 52) ;; /* The res_s++s might have messed everything up */ if res_s & (1 << 53) != 0 res_s >= 1 res_firstbit_e++ if res_firstbit_e >= 1024 if res_n -> std.flt64frombits(0xfff0000000000000) else -> std.flt64frombits(0x7ff0000000000000) ;; ;; ;; -> std.flt64assem(res_n, res_firstbit_e, res_s) } /* >> and <<, but without wrapping when the shift is >= 64 */ const shr : (u : uint64, s : int64 -> uint64) = {u : uint64, s : int64 if (s : uint64) >= 64 -> 0 else -> u >> (s : uint64) ;; } const shl : (u : uint64, s : int64 -> uint64) = {u : uint64, s : int64 if (s : uint64) >= 64 -> 0 else -> u << (s : uint64) ;; } /* Add (a << s) to [ h ][ l ], where if s < 0 then a corresponding right-shift is used. This is aligned such that if s == 0, then the result is [ h ][ l + a ] */ const add_shifted = {h : uint64, l : uint64, a : uint64, s : int64 if s >= 64 -> (h + shl(a, s - 64), l) elif s >= 0 var new_h = h + shr(a, 64 - s) var sa = shl(a, s) var new_l = l + sa if new_l < l new_h++ ;; -> (new_h, new_l) else var new_h = h var sa = shr(a, -s) var new_l = l + sa if new_l < l new_h++ ;; -> (new_h, new_l) ;; } /* As above, but subtract (a << s) */ const sub_shifted = {h : uint64, l : uint64, a : uint64, s : int64 if s >= 64 -> (h - shl(a, s - 64), l) elif s >= 0 var new_h = h - shr(a, 64 - s) var sa = shl(a, s) var new_l = l - sa if sa > l new_h-- ;; -> (new_h, new_l) else var new_h = h var sa = shr(a, -s) var new_l = l - sa if sa > l new_h-- ;; -> (new_h, new_l) ;; } const compare_hl_z = {h : uint64, l : uint64, hl_firstbit_e : int64, hl_lastbit_e : int64, z : uint64, z_firstbit_e : int64, z_lastbit_e : int64 if hl_firstbit_e > z_firstbit_e -> `std.Before elif hl_firstbit_e < z_firstbit_e -> `std.After ;; var h_k : int64 = (hl_firstbit_e - hl_lastbit_e - 64) var z_k : int64 = (z_firstbit_e - z_lastbit_e) while h_k >= 0 && z_k >= 0 var h1 = h & shl(1, h_k) != 0 var z1 = z & shl(1, z_k) != 0 if h1 && !z1 -> `std.Before elif !h1 && z1 -> `std.After ;; h_k-- z_k-- ;; if z_k < 0 if (h & shr((-1 : uint64), 64 - h_k) != 0) || (l != 0) -> `std.Before else -> `std.Equal ;; ;; var l_k : int64 = 63 while l_k >= 0 && z_k >= 0 var l1 = l & shl(1, l_k) != 0 var z1 = z & shl(1, z_k) != 0 if l1 && !z1 -> `std.Before elif !l1 && z1 -> `std.After ;; l_k-- z_k-- ;; if (z_k < 0) && (l & shr((-1 : uint64), 64 - l_k) != 0) -> `std.Before elif (l_k < 0) && (z & shr((-1 : uint64), 64 - z_k) != 0) -> `std.After ;; -> `std.Equal } /* Find the first 1 bit in a bitstring */ const find_first1_64 : (b : uint64, start : int64 -> int64) = {b : uint64, start : int64 for var j = start; j >= 0; --j var m = shl(1, j) if b & m != 0 -> j ;; ;; -> -1 } const find_first1_64_hl = {h, l, start var first1_h = find_first1_64(h, start - 64) if first1_h >= 0 -> first1_h + 64 ;; -> find_first1_64(l, 63) } /* For [ h ][ l ], where bitpos_last is the position of the last bit that was included in the truncated result (l's last bit has position 0), decide whether rounding up/away is needed. This is true if - following bitpos_last is a 1, then a non-zero sequence, or - following bitpos_last is a 1, then a zero sequence, and the round would be to even */ const need_round_away = {h : uint64, l : uint64, bitpos_last : int64 var first_omitted_is_1 = false var nonzero_beyond = false if bitpos_last > 64 first_omitted_is_1 = h & shl(1, bitpos_last - 1 - 64) != 0 nonzero_beyond = nonzero_beyond || h & shr((-1 : uint64), 2 + 64 - (bitpos_last - 64)) != 0 nonzero_beyond = nonzero_beyond || (l != 0) else first_omitted_is_1 = l & shl(1, bitpos_last - 1) != 0 nonzero_beyond = nonzero_beyond || l & shr((-1 : uint64), 1 + 64 - bitpos_last) != 0 ;; if !first_omitted_is_1 -> false ;; if nonzero_beyond -> true ;; var hl_is_odd = false if bitpos_last >= 64 hl_is_odd = h & shl(1, bitpos_last - 64) != 0 else hl_is_odd = l & shl(1, bitpos_last) != 0 ;; -> hl_is_odd }