shithub: mc

Download patch

ref: 77d584ee227d829e842efdc938dffa322665b22c
parent: f3265a99202d98f6f7e5969e34faae934ce754b9
author: S. Gilles <sgilles@math.umd.edu>
date: Tue Mar 20 05:15:36 EDT 2018

Implement fma64

--- a/lib/math/bld.sub
+++ b/lib/math/bld.sub
@@ -7,5 +7,8 @@
 	# summation
 	fpmath-sum-impl.myr
 
+	# fused-multiply-add
+	fpmath-fma-impl.myr
+
 	lib ../std:std
 ;;
--- /dev/null
+++ b/lib/math/fpmath-fma-impl.myr
@@ -1,0 +1,400 @@
+use std
+
+pkg math =
+	pkglocal const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32)
+	pkglocal const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64)
+;;
+
+const exp_mask32 : uint32 = 0xff << 23
+const exp_mask64 : uint64 = 0x7ff << 52
+
+pkglocal const fma32 = {x : flt32, y : flt32, z : flt32
+	-> 0.0
+}
+
+pkglocal const fma64 = {x : flt64, y : flt64, z : flt64
+	var xn : bool, yn : bool, zn : bool
+	var xe : int64, ye : int64, ze : int64
+	var xs : uint64, ys : uint64, zs : uint64
+
+	var xb : uint64 = std.flt64bits(x)
+	var yb : uint64 = std.flt64bits(y)
+	var zb : uint64 = std.flt64bits(z)
+
+	/* check for both NaNs and infinities */
+	if xb & exp_mask64 == exp_mask64 || \
+	   yb & exp_mask64 == exp_mask64
+		-> x * y + z
+	elif z == 0.0 || z == -0.0 || x * y == 0.0 || x * y == -0.0
+		-> x * y + z
+	elif zb & exp_mask64 == exp_mask64
+		-> z
+	;;
+
+	(xn, xe, xs) = std.flt64explode(x)
+	(yn, ye, ys) = std.flt64explode(y)
+	(zn, ze, zs) = std.flt64explode(z)
+	if xe == -1023
+		xe = -1022
+	;;
+	if ye == -1023
+		ye = -1022
+	;;
+	if ze == -1023
+		ze = -1022
+	;;
+
+        /* Keep product in high/low uint64s */
+	var xs_h : uint64 = xs >> 32
+	var ys_h : uint64 = ys >> 32
+	var xs_l : uint64 = xs & 0xffffffff
+	var ys_l : uint64 = ys & 0xffffffff
+
+	var t_l : uint64 = xs_l * ys_l
+	var t_m : uint64 = xs_l * ys_h + xs_h * ys_l
+	var t_h : uint64 = xs_h * ys_h
+
+	var prod_l : uint64 = t_l + (t_m << 32)
+	var prod_h : uint64 = t_h + (t_m >> 32)
+	if t_l > prod_l
+		prod_h++
+	;;
+
+	var prod_n = xn != yn
+	var prod_lastbit_e = (xe - 52) + (ye - 52)
+	var prod_first1 = find_first1_64_hl(prod_h, prod_l, 105)
+	var prod_firstbit_e = prod_lastbit_e + prod_first1
+
+	var z_firstbit_e = ze
+	var z_lastbit_e = ze - 52
+	var z_first1 = 52
+
+	/* subnormals could throw firstbit_e calculations out of whack */
+	if (zb & exp_mask64 == 0)
+		z_first1 = find_first1_64(zs, z_first1)
+		z_firstbit_e = z_lastbit_e + z_first1
+	;;
+
+	var res_n
+	var res_h = 0
+	var res_l = 0
+	var res_first1
+	var res_lastbit_e
+	var res_firstbit_e
+
+	if prod_n == zn
+		res_n = prod_n
+
+		/*
+		   Align prod and z so that the top bit of the
+		   result is either 53 or 54, then add.
+		 */
+		if prod_firstbit_e >= z_firstbit_e
+			/*
+			    [ prod_h ][ prod_l ]
+			         [ z...
+			 */
+			res_lastbit_e = prod_lastbit_e
+			(res_h, res_l) = (prod_h, prod_l)
+			(res_h, res_l) = add_shifted(res_h, res_l, zs, z_lastbit_e - prod_lastbit_e)
+		else
+			/*
+			        [ prod_h ][ prod_l ]
+			    [ z...
+			 */
+			res_lastbit_e = z_lastbit_e - 64
+			res_h = zs
+			res_l = 0
+			if prod_lastbit_e >= res_lastbit_e + 64
+				/* In this situation, prod must be extremely subnormal */
+				res_h += shl(prod_l, prod_lastbit_e - res_lastbit_e - 64)
+			elif prod_lastbit_e >= res_lastbit_e
+				res_h += shl(prod_h, prod_lastbit_e - res_lastbit_e)
+				res_h += shr(prod_l, res_lastbit_e + 64 - prod_lastbit_e)
+				res_l += shl(prod_l, prod_lastbit_e - res_lastbit_e)
+			elif prod_lastbit_e + 64 >= res_lastbit_e
+				res_h += shr(prod_h, res_lastbit_e - prod_lastbit_e)
+				var l1 = shl(prod_h, prod_lastbit_e + 64 - res_lastbit_e)
+				var l2 = shr(prod_l, res_lastbit_e - prod_lastbit_e)
+				res_l = l1 + l2
+				if res_l < l1
+					res_h++
+				;;
+			elif prod_lastbit_e + 128 >= res_lastbit_e
+				res_l += shr(prod_h, res_lastbit_e - prod_lastbit_e - 64)
+			;;
+		;;
+	else
+		match compare_hl_z(prod_h, prod_l, prod_firstbit_e, prod_lastbit_e, zs, z_firstbit_e, z_lastbit_e)
+		| `std.Equal: -> 0.0
+		| `std.Before:
+			/* prod > z */
+			res_n = prod_n
+			res_lastbit_e = prod_lastbit_e
+			(res_h, res_l) = sub_shifted(prod_h, prod_l, zs, z_lastbit_e - prod_lastbit_e)
+		| `std.After:
+			/* z > prod */
+			res_n = zn
+			res_lastbit_e = z_lastbit_e - 64
+			(res_h, res_l) = sub_shifted(zs, 0, prod_h, prod_lastbit_e + 64 - (z_lastbit_e - 64))
+			(res_h, res_l) = sub_shifted(res_h, res_l, prod_l, prod_lastbit_e - (z_lastbit_e - 64))
+		;;
+	;;
+
+	res_first1 = 64 + find_first1_64(res_h, 55)
+	if res_first1 == 63
+		res_first1 = find_first1_64(res_l, 63)
+	;;
+	res_firstbit_e = res_first1 + res_lastbit_e
+
+	/*
+	   Finally, res_h and res_l are the high and low bits of
+	   the result. They now need to be assembled into a flt64.
+	   Subnormals and infinities could be a problem.
+	 */
+	var res_s = 0
+	if res_firstbit_e <= -1023
+		/* Subnormal case */
+		if res_lastbit_e + 128 < -1022
+			res_s = shr(res_h, 12 - 1022 - (res_lastbit_e + 128))
+			res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64))
+		elif res_lastbit_e + 64 < -1022
+			res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022))
+			res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64))
+		else
+			res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022))
+			res_s |= shl(res_l, -12 + (res_lastbit_e + 64) - (-1022))
+		;;
+
+		if need_round_away(res_h, res_l, res_first1 + (-1074 - res_firstbit_e))
+			res_s++
+		;;
+
+		/* No need for exponents, they are all zero */
+		var res = res_s
+		if res_n
+			res |= (1 << 63)
+		;;
+		-> std.flt64frombits(res)
+	;;
+
+	if res_firstbit_e >= 1024
+		/* Infinity case */
+		if res_n
+			-> std.flt64frombits(0xfff0000000000000)
+		else
+			-> std.flt64frombits(0x7ff0000000000000)
+		;;
+	;;
+
+	if res_first1 - 52 >= 64
+		res_s = shr(res_h, (res_first1 : int64) - 64 - 52)
+		if need_round_away(res_h, res_l, res_first1 - 52)
+			res_s++
+		;;
+	elif res_first1 - 52 >= 0
+		res_s = shl(res_h, 64 - (res_first1 - 52))
+		res_s |= shr(res_l, res_first1 - 52)
+		if need_round_away(res_h, res_l, res_first1 - 52)
+			res_s++
+		;;
+	else
+		res_s = shl(res_h, res_first1 - 52)
+	;;
+
+	/* The res_s++s might have messed everything up */
+	if res_s & (1 << 53) != 0
+		res_s >= 1
+		res_firstbit_e++
+		if res_firstbit_e >= 1024
+			if res_n
+				-> std.flt64frombits(0xfff0000000000000)
+			else
+				-> std.flt64frombits(0x7ff0000000000000)
+			;;
+		;;
+	;;
+
+	-> std.flt64assem(res_n, res_firstbit_e, res_s)
+}
+
+/* >> and <<, but without wrapping when the shift is >= 64 */
+const shr = {u : uint64, s : int64
+	if (s : uint64) >= 64
+		-> 0
+	else
+		-> u >> (s : uint64)
+	;;
+}
+
+const shl = {u : uint64, s : int64
+	if (s : uint64) >= 64
+		-> 0
+	else
+		-> u << (s : uint64)
+	;;
+}
+
+/*
+   Add (a << s) to [ h ][ l ], where if s < 0 then a corresponding
+   right-shift is used. This is aligned such that if s == 0, then
+   the result is [ h ][ l + a ]
+ */
+const add_shifted = {h : uint64, l : uint64, a : uint64, s : int64
+	if s >= 64
+		-> (h + shl(a, s - 64), l)
+	elif s >= 0
+		var new_h = h + shr(a, 64 - s)
+		var sa = shl(a, s)
+		var new_l = l + sa
+		if new_l < l
+			new_h++
+		;;
+		-> (new_h, new_l)
+	else
+		var new_h = h
+		var sa = shr(a, -s)
+		var new_l = l + sa
+		if new_l < l
+			new_h++
+		;;
+		-> (new_h, new_l)
+	;;
+}
+
+/* As above, but subtract (a << s) */
+const sub_shifted = {h : uint64, l : uint64, a : uint64, s : int64
+	if s >= 64
+		-> (h - shl(a, s - 64), l)
+	elif s >= 0
+		var new_h = h - shr(a, 64 - s)
+		var sa = shl(a, s)
+		var new_l = l - sa
+		if sa > l
+			new_h--
+		;;
+		-> (new_h, new_l)
+	else
+		var new_h = h
+		var sa = shr(a, -s)
+		var new_l = l - sa
+		if sa > l
+			new_h--
+		;;
+		-> (new_h, new_l)
+	;;
+}
+
+const compare_hl_z = {h : uint64, l : uint64, hl_firstbit_e : int64, hl_lastbit_e : int64, z : uint64, z_firstbit_e : int64, z_lastbit_e : int64
+	if hl_firstbit_e > z_firstbit_e
+		-> `std.Before
+	elif hl_firstbit_e < z_firstbit_e
+		-> `std.After
+	;;
+
+	var h_k : int64 = (hl_firstbit_e - hl_lastbit_e - 64)
+	var z_k : int64 = (z_firstbit_e - z_lastbit_e)
+	while h_k >= 0 && z_k >= 0
+		var h1 = h & shl(1, h_k) != 0
+		var z1 = z & shl(1, z_k) != 0
+		if h1 && !z1
+			-> `std.Before
+		elif !h1 && z1
+			-> `std.After
+		;;
+		h_k--
+		z_k--
+	;;
+
+	if z_k < 0
+		if (h & shr((-1 : uint64), 64 - h_k) != 0) || (l != 0)
+			-> `std.Before
+		else
+			-> `std.Equal
+		;;
+	;;
+
+	var l_k : int64 = 63
+	while l_k >= 0 && z_k >= 0
+		var l1 = l & shl(1, l_k) != 0
+		var z1 = z & shl(1, z_k) != 0
+		if l1 && !z1
+			-> `std.Before
+		elif !l1 && z1
+			-> `std.After
+		;;
+		l_k--
+		z_k--
+	;;
+
+	if (z_k < 0) && (l & shr((-1 : uint64), 64 - l_k) != 0)
+		-> `std.Before
+	elif (l_k < 0) && (z & shr((-1 : uint64), 64 - z_k) != 0)
+		-> `std.After
+	;;
+
+	-> `std.Equal
+}
+
+/* Find the first 1 bit in a bitstring */
+const find_first1_64 = {b : uint64, start : int64
+	for var j = start; j >= 0; --j
+		var m = shl(1, j)
+		if b & m != 0
+			-> j
+		;;
+	;;
+
+	-> -1
+}
+
+const find_first1_64_hl = {h, l, start
+	var first1_h = find_first1_64(h, start - 64)
+	if first1_h >= 0
+		-> first1_h + 64
+	;;
+
+	-> find_first1_64(l, 63)
+}
+
+/*
+   For [ h ][ l ], where bitpos_last is the position of the last
+   bit that was included in the 53-bit wide result (l's last bit
+   has position 0), decide whether rounding up/away is needed. This
+   is true if
+
+    - following bitpos_last is a 1, then a non-zero sequence, or
+
+    - following bitpos_last is a 1, then a zero sequence, and the
+      round would be to even
+ */
+const need_round_away = {h : uint64, l : uint64, bitpos_last : int64
+	var first_omitted_is_1 = false
+	var nonzero_beyond = false
+	if bitpos_last > 64
+		first_omitted_is_1 = h & shl(1, bitpos_last - 1 - 64) != 0
+		nonzero_beyond = h & shr((-1 : uint64), 2 + 64 - (bitpos_last - 64)) != 0
+		nonzero_beyond = nonzero_beyond || (l != 0)
+	else
+		first_omitted_is_1 = l & shl(1, bitpos_last - 1) != 0
+		nonzero_beyond = l & shr((-1 : uint64), 2 + 64 - bitpos_last) != 0
+	;;
+
+	if !first_omitted_is_1
+		-> false
+	;;
+
+	if nonzero_beyond
+		-> true
+	;;
+
+	var hl_is_odd = false
+
+	if bitpos_last >= 64
+		hl_is_odd = h & shl(1, bitpos_last - 64) != 0
+	else
+		hl_is_odd = l & shl(1, bitpos_last) != 0
+	;;
+
+	-> hl_is_odd
+}
--- a/lib/math/fpmath.myr
+++ b/lib/math/fpmath.myr
@@ -3,12 +3,15 @@
 pkg math =
 	trait fpmath @f =
 
+		/* fpmath-fma-impl */
+		fma : (x : @f, y : @f, z : @f -> @f)
+
 		/* fpmath-trunc-impl */
 		trunc : (f : @f -> @f)
 		ceil  : (f : @f -> @f)
 		floor : (f : @f -> @f)
 
-		/* summation */
+		/* fpmath-sum-impl */
 		kahan_sum : (a : @f[:] -> @f)
 		priest_sum : (a : @f[:] -> @f)
 	;;
@@ -36,6 +39,8 @@
 ;;
 
 impl fpmath flt32 =
+	fma = {x, y, z; -> fma32(x, y, z)}
+
 	trunc = {f; -> trunc32(f)}
 	floor = {f; -> floor32(f)}
 	ceil  = {f; -> ceil32(f)}
@@ -45,6 +50,8 @@
 ;;
 
 impl fpmath flt64 =
+	fma = {x, y, z; -> fma64(x, y, z)}
+
 	trunc = {f; -> trunc64(f)}
 	floor = {f; -> floor64(f)}
 	ceil  = {f; -> ceil64(f)}
@@ -52,6 +59,9 @@
 	kahan_sum = {l; -> kahan_sum64(l) }
 	priest_sum = {l; -> priest_sum64(l) }
 ;;
+
+extern const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32)
+extern const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64)
 
 extern const trunc32 : (f : flt32 -> flt32)
 extern const floor32 : (f : flt32 -> flt32)
--- /dev/null
+++ b/lib/math/test/fpmath-fma-impl.myr
@@ -1,0 +1,46 @@
+use std
+use math
+use testr
+
+const main = {
+	testr.run([
+		[.name="fma-01", .fn = fma01],
+		[.name="fma-02", .fn = fma02],
+	][:])
+}
+
+const fma01 = {c
+	/* Put the flt32 tests here */
+}
+
+const fma02 = {c
+	var inputs : (uint64, uint64, uint64, uint64)[:] = [
+		/*
+		   These are mostly obtained by running fpmath-consensus
+		   with seed 1234. Each covers a different corner case.
+		 */
+		(0x0000000000000000, 0x0000000000000000, 0x0100000000000000, 0x0100000000000000),
+		(0x0000000000000000, 0x0000000000000000, 0x0200000000000000, 0x0200000000000000),
+		(0x00000000000009a4, 0x6900000000000002, 0x6834802690008002, 0x6834802690008002),
+		(0x49113e8d334802ab, 0x5c35d8c190aea62e, 0x6c1a8cc212dcb6e2, 0x6c1a8cc212dcb6e2),
+		(0x2b3e04e8f6266d83, 0xae84e20f62f99bda, 0xc9115a1ccea6ce27, 0xc9115a1ccea6ce27),
+		(0xa03ea9e9b09d932c, 0xded7bc19edcbf0c7, 0xbbc4c1f83b3f8f2e, 0x3f26be5f0c7b48e3),
+		(0xa5ec2141c1e6f339, 0xa2d80fc217f57b61, 0x00b3484b473ef1b8, 0x08d526cb86ee748d),
+		(0xccc6600ee88bb67c, 0xc692eeec9b51cf0f, 0xbf5f1ae3486401b0, 0x536a7a30857129db),
+		(0x5f9b9e449db17602, 0xbef22ae5b6a2b1c5, 0x6133e925e6bf8a12, 0x6133e925e6bf823b),
+		(0x7f851249841b6278, 0x3773388e53a375f4, 0x761c27fc2ffa57be, 0x7709506b0e99dc30),
+		(0x7c7cb20f3ca8af93, 0x800fd7f5cfd5baae, 0x14e4c09c9bb1e17e, 0xbc9c6a3fd0e58682),
+		(0xb5e8db2107f4463f, 0x614af740c0d7eb3b, 0xd7e3d25c4daa81e0, 0xd7e3d798d3ccdffb),
+		(0xae62c8be4cb45168, 0x90cc5236f3516c90, 0x0007f8b14f684558, 0x0007f9364eb1a815),
+		(0x5809f53e32a7e1ba, 0xcc647611ccaa5bf4, 0xdfbdb5c345ce7a56, 0xe480990da5526103),
+	][:]
+
+	for (x, y, z, r) : inputs
+		var xf : flt64 = std.flt64frombits(x)
+		var yf : flt64 = std.flt64frombits(y)
+		var zf : flt64 = std.flt64frombits(z)
+		var rf = math.fma(xf, yf, zf)
+		testr.check(c, rf == std.flt64frombits(r),
+			"0x{b=16,w=16,p=0} * 0x{b=16,w=16,p=0} + 0x{b=16,w=16,p=0} should be 0x{b=16,w=16,p=0}, was 0x{b=16,w=16,p=0}", x, y, z, r, std.flt64bits(rf))
+	;;
+}