shithub: mc

ref: 3c881e8b133876fc8cff92106616bb93166bf113
dir: /lib/math/fpmath-fma-impl.myr/

View raw version
use std

pkg math =
	pkglocal const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32)
	pkglocal const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64)
;;

const exp_mask32 : uint32 = 0xff << 23
const exp_mask64 : uint64 = 0x7ff << 52

pkglocal const fma32 = {x : flt32, y : flt32, z : flt32
	var xn, yn
	(xn, _, _) = std.flt32explode(x)
	(yn, _, _) = std.flt32explode(y)
	var xd : flt64 = flt64fromflt32(x)
	var yd : flt64 = flt64fromflt32(y)
	var zd : flt64 = flt64fromflt32(z)
	var prod : flt64 = xd * yd
	var pn, pe, ps
	(pn, pe, ps) = std.flt64explode(prod)
	if pe == -1023
		pe = -1022
	;;
	if pn != (xn != yn)
		/* In case of NaNs, sign might not have been preserved */
		pn = (xn != yn)
		prod = std.flt64assem(pn, pe, ps)
	;;

	var r : flt64 = prod + zd
	var rn, re, rs
	(rn, re, rs) = std.flt64explode(r)

	/*
	   At this point, r is probably the correct answer. The
	   only issue is that if truncating r to a flt32 causes
	   rounding-to-even, and it was obtained by rounding in the
	   first place, direction, then we'd be over-rounding. The
	   only way this could possibly be a problem is if the
	   product step rounded to the halfway point (a 100...0,
	   with the 1 just outside truncation range).
         */
	if re == 1024 || rs & 0x1fffffff != 0x10000000
		-> flt32fromflt64(r)
	;;

	/* We can check if rounding was performed by undoing */
	if flt32fromflt64(r - prod) == z
		-> flt32fromflt64(r)
	;;

	/*
	   At this point, there's definitely about to be a rounding
	   error. To figure out what to do, compute prod + z with
	   round-to-zero. If we get r again, then it's okay to round
	   r upwards, because it hasn't been rounded away from zero
	   yet and we allow ourselves one such rounding.
	 */
	var zn, ze, zs
	(zn, ze, zs) = std.flt64explode(zd)
	if ze == -1023
		ze = -1022
	;;
	var rtn, rte, rts
	if pe >= ze && pn == zn
		(rtn, rte, rts) = (pn, pe, ps)
		rts += shr(zs, pe - ze)
	elif pe > ze || (pe == ze && ps > zs)
		(rtn, rte, rts) = (pn, pe, ps)
		rts -= shr(zs, pe - ze)
		if shr((-1 : uint64), 64 - std.min(64, (pe - ze))) & zs != 0
			rts--
		;;
	elif pe < ze && pn == zn
		(rtn, rte, rts) = (zn, ze, zs)
		rts += shr(ps, ze - pe)
	else
		(rtn, rte, rts) = (zn, ze, zs)
		rts -= shr(ps, ze - pe)
		if shr((-1 : uint64), 64 - std.min(64, (ze - pe))) & ps != 0
			rts--
		;;
	;;

	if rts & (1 << 53) != 0
		rts >>= 1
		rte++
	;;

	if rn == rtn && rs == rts && re == rte
		rts++
		if rts & (1 << 53) != 0
			rts >>= 1
			rte++
		;;
	;;
	-> flt32fromflt64(std.flt64assem(rtn, rte, rts))
}

const flt64fromflt32 = {f : flt32
	var n, e, s
	(n, e, s) = std.flt32explode(f)
	var xs : uint64 = (s : uint64)
	var xe : int64 = (e : int64)

	if e == 128
		-> std.flt64assem(n, 1024, xs)
	elif e == -127
		/*
		  All subnormals in single precision (except 0.0s)
		  can be upgraded to double precision, since the
		  exponent range is so much wider.
		 */
		var first1 = find_first1_64(xs, 23)
		if first1 < 0
			-> std.flt64assem(n, -1023, 0)
		;;
		xs = xs << (52 - (first1 : uint64))
		xe = -126 - (23 - first1)
		-> std.flt64assem(n, xe, xs)
	;;

	-> std.flt64assem(n, xe, xs << (52 - 23))
}

const flt32fromflt64 = {f : flt64
	var n : bool, e : int64, s : uint64
	(n, e, s) = std.flt64explode(f)
	var ts : uint32
	var te : int32 = (e : int32)

	if e >= 128
		if e == 1023 && s != 0
			/* NaN */
			-> std.flt32assem(n, 128, 1)
		else
			/* infinity */
			-> std.flt32assem(n, 128, 0)
		;;
	;;

	if e >= -126
		/* normal */
		ts = ((s >> (52 - 23)) : uint32)
		if need_round_away(0, s, 52 - 23)
			ts++
			if ts & (1 << 24) != 0
				ts >>= 1
				te++
			;;
		;;
		-> std.flt32assem(n, te, ts)
	;;

	/* subnormal already, will have to go to 0 */
	if e == -1023
		-> std.flt32assem(n, -127, 0)
	;;

	/* subnormal (at least, it will be) */
	te = -127
	var shift : int64 = (52 - 23) + (-126 - e)
	var ts1 = shr(s, shift)
	ts = (ts1 : uint32)
	if need_round_away(0, s, shift)
		ts++
	;;
	-> std.flt32assem(n, te, ts)
}

pkglocal const fma64 = {x : flt64, y : flt64, z : flt64
	var xn : bool, yn : bool, zn : bool
	var xe : int64, ye : int64, ze : int64
	var xs : uint64, ys : uint64, zs : uint64

	var xb : uint64 = std.flt64bits(x)
	var yb : uint64 = std.flt64bits(y)
	var zb : uint64 = std.flt64bits(z)

	/* check for both NaNs and infinities */
	if xb & exp_mask64 == exp_mask64 || \
	   yb & exp_mask64 == exp_mask64
		-> x * y + z
	elif z == 0.0 || z == -0.0 || x * y == 0.0 || x * y == -0.0
		-> x * y + z
	elif zb & exp_mask64 == exp_mask64
		-> z
	;;

	(xn, xe, xs) = std.flt64explode(x)
	(yn, ye, ys) = std.flt64explode(y)
	(zn, ze, zs) = std.flt64explode(z)
	if xe == -1023
		xe = -1022
	;;
	if ye == -1023
		ye = -1022
	;;
	if ze == -1023
		ze = -1022
	;;

        /* Keep product in high/low uint64s */
	var xs_h : uint64 = xs >> 32
	var ys_h : uint64 = ys >> 32
	var xs_l : uint64 = xs & 0xffffffff
	var ys_l : uint64 = ys & 0xffffffff

	var t_l : uint64 = xs_l * ys_l
	var t_m : uint64 = xs_l * ys_h + xs_h * ys_l
	var t_h : uint64 = xs_h * ys_h

	var prod_l : uint64 = t_l + (t_m << 32)
	var prod_h : uint64 = t_h + (t_m >> 32)
	if t_l > prod_l
		prod_h++
	;;

	var prod_n = xn != yn
	var prod_lastbit_e = (xe - 52) + (ye - 52)
	var prod_first1 = find_first1_64_hl(prod_h, prod_l, 105)
	var prod_firstbit_e = prod_lastbit_e + prod_first1

	var z_firstbit_e = ze
	var z_lastbit_e = ze - 52
	var z_first1 = 52

	/* subnormals could throw firstbit_e calculations out of whack */
	if (zb & exp_mask64 == 0)
		z_first1 = find_first1_64(zs, z_first1)
		z_firstbit_e = z_lastbit_e + z_first1
	;;

	var res_n
	var res_h = 0
	var res_l = 0
	var res_first1
	var res_lastbit_e
	var res_firstbit_e

	if prod_n == zn
		res_n = prod_n

		/*
		   Align prod and z so that the top bit of the
		   result is either 53 or 54, then add.
		 */
		if prod_firstbit_e >= z_firstbit_e
			/*
			    [ prod_h ][ prod_l ]
			         [ z...
			 */
			res_lastbit_e = prod_lastbit_e
			(res_h, res_l) = (prod_h, prod_l)
			(res_h, res_l) = add_shifted(res_h, res_l, zs, z_lastbit_e - prod_lastbit_e)
		else
			/*
			        [ prod_h ][ prod_l ]
			    [ z...
			 */
			res_lastbit_e = z_lastbit_e - 64
			res_h = zs
			res_l = 0
			if prod_lastbit_e >= res_lastbit_e + 64
				/* In this situation, prod must be extremely subnormal */
				res_h += shl(prod_l, prod_lastbit_e - res_lastbit_e - 64)
			elif prod_lastbit_e >= res_lastbit_e
				res_h += shl(prod_h, prod_lastbit_e - res_lastbit_e)
				res_h += shr(prod_l, res_lastbit_e + 64 - prod_lastbit_e)
				res_l += shl(prod_l, prod_lastbit_e - res_lastbit_e)
			elif prod_lastbit_e + 64 >= res_lastbit_e
				res_h += shr(prod_h, res_lastbit_e - prod_lastbit_e)
				var l1 = shl(prod_h, prod_lastbit_e + 64 - res_lastbit_e)
				var l2 = shr(prod_l, res_lastbit_e - prod_lastbit_e)
				res_l = l1 + l2
				if res_l < l1
					res_h++
				;;
			elif prod_lastbit_e + 128 >= res_lastbit_e
				res_l += shr(prod_h, res_lastbit_e - prod_lastbit_e - 64)
			;;
		;;
	else
		match compare_hl_z(prod_h, prod_l, prod_firstbit_e, prod_lastbit_e, zs, z_firstbit_e, z_lastbit_e)
		| `std.Equal: -> 0.0
		| `std.Before:
			/* prod > z */
			res_n = prod_n
			res_lastbit_e = prod_lastbit_e
			(res_h, res_l) = sub_shifted(prod_h, prod_l, zs, z_lastbit_e - prod_lastbit_e)
		| `std.After:
			/* z > prod */
			res_n = zn
			res_lastbit_e = z_lastbit_e - 64
			(res_h, res_l) = sub_shifted(zs, 0, prod_h, prod_lastbit_e + 64 - (z_lastbit_e - 64))
			(res_h, res_l) = sub_shifted(res_h, res_l, prod_l, prod_lastbit_e - (z_lastbit_e - 64))
		;;
	;;

	res_first1 = 64 + find_first1_64(res_h, 55)
	if res_first1 == 63
		res_first1 = find_first1_64(res_l, 63)
	;;
	res_firstbit_e = res_first1 + res_lastbit_e

	/*
	   Finally, res_h and res_l are the high and low bits of
	   the result. They now need to be assembled into a flt64.
	   Subnormals and infinities could be a problem.
	 */
	var res_s = 0
	if res_firstbit_e <= -1023
		/* Subnormal case */
		if res_lastbit_e + 128 < 12 - 1022
			res_s = shr(res_h, 12 - 1022 - (res_lastbit_e + 128))
			res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64))
		elif res_lastbit_e + 64 < 12 - 1022
			res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022))
			res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64))
		else
			res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022))
			res_s |= shl(res_l, -12 + (res_lastbit_e + 64) - (-1022))
		;;

		if need_round_away(res_h, res_l, res_first1 + (-1074 - res_firstbit_e))
			res_s++
		;;

		/* No need for exponents, they are all zero */
		var res = res_s
		if res_n
			res |= (1 << 63)
		;;
		-> std.flt64frombits(res)
	;;

	if res_firstbit_e >= 1024
		/* Infinity case */
		if res_n
			-> std.flt64frombits(0xfff0000000000000)
		else
			-> std.flt64frombits(0x7ff0000000000000)
		;;
	;;

	if res_first1 - 52 >= 64
		res_s = shr(res_h, (res_first1 : int64) - 64 - 52)
		if need_round_away(res_h, res_l, res_first1 - 52)
			res_s++
		;;
	elif res_first1 - 52 >= 0
		res_s = shl(res_h, 64 - (res_first1 - 52))
		res_s |= shr(res_l, res_first1 - 52)
		if need_round_away(res_h, res_l, res_first1 - 52)
			res_s++
		;;
	else
		res_s = shl(res_h, res_first1 - 52)
	;;

	/* The res_s++s might have messed everything up */
	if res_s & (1 << 53) != 0
		res_s >= 1
		res_firstbit_e++
		if res_firstbit_e >= 1024
			if res_n
				-> std.flt64frombits(0xfff0000000000000)
			else
				-> std.flt64frombits(0x7ff0000000000000)
			;;
		;;
	;;

	-> std.flt64assem(res_n, res_firstbit_e, res_s)
}

/* >> and <<, but without wrapping when the shift is >= 64 */
const shr : (u : uint64, s : int64 -> uint64) = {u : uint64, s : int64
	if (s : uint64) >= 64
		-> 0
	else
		-> u >> (s : uint64)
	;;
}

const shl : (u : uint64, s : int64 -> uint64) = {u : uint64, s : int64
	if (s : uint64) >= 64
		-> 0
	else
		-> u << (s : uint64)
	;;
}

/*
   Add (a << s) to [ h ][ l ], where if s < 0 then a corresponding
   right-shift is used. This is aligned such that if s == 0, then
   the result is [ h ][ l + a ]
 */
const add_shifted = {h : uint64, l : uint64, a : uint64, s : int64
	if s >= 64
		-> (h + shl(a, s - 64), l)
	elif s >= 0
		var new_h = h + shr(a, 64 - s)
		var sa = shl(a, s)
		var new_l = l + sa
		if new_l < l
			new_h++
		;;
		-> (new_h, new_l)
	else
		var new_h = h
		var sa = shr(a, -s)
		var new_l = l + sa
		if new_l < l
			new_h++
		;;
		-> (new_h, new_l)
	;;
}

/* As above, but subtract (a << s) */
const sub_shifted = {h : uint64, l : uint64, a : uint64, s : int64
	if s >= 64
		-> (h - shl(a, s - 64), l)
	elif s >= 0
		var new_h = h - shr(a, 64 - s)
		var sa = shl(a, s)
		var new_l = l - sa
		if sa > l
			new_h--
		;;
		-> (new_h, new_l)
	else
		var new_h = h
		var sa = shr(a, -s)
		var new_l = l - sa
		if sa > l
			new_h--
		;;
		-> (new_h, new_l)
	;;
}

const compare_hl_z = {h : uint64, l : uint64, hl_firstbit_e : int64, hl_lastbit_e : int64, z : uint64, z_firstbit_e : int64, z_lastbit_e : int64
	if hl_firstbit_e > z_firstbit_e
		-> `std.Before
	elif hl_firstbit_e < z_firstbit_e
		-> `std.After
	;;

	var h_k : int64 = (hl_firstbit_e - hl_lastbit_e - 64)
	var z_k : int64 = (z_firstbit_e - z_lastbit_e)
	while h_k >= 0 && z_k >= 0
		var h1 = h & shl(1, h_k) != 0
		var z1 = z & shl(1, z_k) != 0
		if h1 && !z1
			-> `std.Before
		elif !h1 && z1
			-> `std.After
		;;
		h_k--
		z_k--
	;;

	if z_k < 0
		if (h & shr((-1 : uint64), 64 - h_k) != 0) || (l != 0)
			-> `std.Before
		else
			-> `std.Equal
		;;
	;;

	var l_k : int64 = 63
	while l_k >= 0 && z_k >= 0
		var l1 = l & shl(1, l_k) != 0
		var z1 = z & shl(1, z_k) != 0
		if l1 && !z1
			-> `std.Before
		elif !l1 && z1
			-> `std.After
		;;
		l_k--
		z_k--
	;;

	if (z_k < 0) && (l & shr((-1 : uint64), 64 - l_k) != 0)
		-> `std.Before
	elif (l_k < 0) && (z & shr((-1 : uint64), 64 - z_k) != 0)
		-> `std.After
	;;

	-> `std.Equal
}

/* Find the first 1 bit in a bitstring */
const find_first1_64 : (b : uint64, start : int64 -> int64) = {b : uint64, start : int64
	for var j = start; j >= 0; --j
		var m = shl(1, j)
		if b & m != 0
			-> j
		;;
	;;

	-> -1
}

const find_first1_64_hl = {h, l, start
	var first1_h = find_first1_64(h, start - 64)
	if first1_h >= 0
		-> first1_h + 64
	;;

	-> find_first1_64(l, 63)
}

/*
   For [ h ][ l ], where bitpos_last is the position of the last
   bit that was included in the truncated result (l's last bit has
   position 0), decide whether rounding up/away is needed. This is
   true if

    - following bitpos_last is a 1, then a non-zero sequence, or

    - following bitpos_last is a 1, then a zero sequence, and the
      round would be to even
 */
const need_round_away = {h : uint64, l : uint64, bitpos_last : int64
	var first_omitted_is_1 = false
	var nonzero_beyond = false
	if bitpos_last > 64
		first_omitted_is_1 = h & shl(1, bitpos_last - 1 - 64) != 0
		nonzero_beyond = nonzero_beyond || h & shr((-1 : uint64), 2 + 64 - (bitpos_last - 64)) != 0
		nonzero_beyond = nonzero_beyond || (l != 0)
	else
		first_omitted_is_1 = l & shl(1, bitpos_last - 1) != 0
		nonzero_beyond = nonzero_beyond || l & shr((-1 : uint64), 1 + 64 - bitpos_last) != 0
	;;

	if !first_omitted_is_1
		-> false
	;;

	if nonzero_beyond
		-> true
	;;

	var hl_is_odd = false

	if bitpos_last >= 64
		hl_is_odd = h & shl(1, bitpos_last - 64) != 0
	else
		hl_is_odd = l & shl(1, bitpos_last) != 0
	;;

	-> hl_is_odd
}