shithub: mc

Download patch

ref: 959b473779673322fbef51385a1cf730e73c7aac
parent: 1d7654102819c1db26c418cfdc9fb3032294157e
author: S. Gilles <sgilles@math.umd.edu>
date: Sun Mar 25 18:18:49 EDT 2018

Implement sqrt.

--- a/lib/math/bld.sub
+++ b/lib/math/bld.sub
@@ -1,6 +1,13 @@
 lib math =
 	fpmath.myr
 
+	# fused-multiply-add
+	fma-impl+posixy-x64-fma.s
+	fma-impl.myr
+
+	# sqrt
+	sqrt-impl.myr
+
 	# trunc, floor, ceil
 	trunc-impl+posixy-x64-sse4.s
 	trunc-impl.myr
@@ -7,10 +14,6 @@
 
 	# summation
 	sum-impl.myr
-
-	# fused-multiply-add
-	fma-impl+posixy-x64-fma.s
-	fma-impl.myr
 
 	lib ../std:std
 ;;
--- a/lib/math/fpmath.myr
+++ b/lib/math/fpmath.myr
@@ -6,6 +6,9 @@
 		/* fma-impl */
 		fma : (x : @f, y : @f, z : @f -> @f)
 
+		/* sqrt-impl */
+		sqrt : (f : @f -> @f)
+
 		/* trunc-impl */
 		trunc : (f : @f -> @f)
 		ceil  : (f : @f -> @f)
@@ -41,6 +44,8 @@
 impl fpmath flt32 =
 	fma = {x, y, z; -> fma32(x, y, z)}
 
+	sqrt = {f; -> sqrt32(f)}
+
 	trunc = {f; -> trunc32(f)}
 	floor = {f; -> floor32(f)}
 	ceil  = {f; -> ceil32(f)}
@@ -52,6 +57,8 @@
 impl fpmath flt64 =
 	fma = {x, y, z; -> fma64(x, y, z)}
 
+	sqrt = {f; -> sqrt64(f)}
+
 	trunc = {f; -> trunc64(f)}
 	floor = {f; -> floor64(f)}
 	ceil  = {f; -> ceil64(f)}
@@ -63,16 +70,20 @@
 extern const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32)
 extern const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64)
 
-extern const trunc32 : (f : flt32 -> flt32)
-extern const floor32 : (f : flt32 -> flt32)
-extern const ceil32  : (f : flt32 -> flt32)
+extern const sqrt32 : (x : flt32 -> flt32)
+extern const sqrt64 : (x : flt64 -> flt64)
 
+extern const trunc32 : (f : flt32 -> flt32)
 extern const trunc64 : (f : flt64 -> flt64)
+
+extern const floor32 : (f : flt32 -> flt32)
 extern const floor64 : (f : flt64 -> flt64)
+
+extern const ceil32  : (f : flt32 -> flt32)
 extern const ceil64  : (f : flt64 -> flt64)
 
 extern const kahan_sum32 : (l : flt32[:] -> flt32)
-extern const priest_sum32 : (l : flt32[:] -> flt32)
-
 extern const kahan_sum64 : (l : flt64[:] -> flt64)
+
+extern const priest_sum32 : (l : flt32[:] -> flt32)
 extern const priest_sum64 : (l : flt64[:] -> flt64)
--- /dev/null
+++ b/lib/math/references
@@ -1,0 +1,11 @@
+References
+
+[KM06]
+Peter Kornerup and Jean-Michel Muller. “Choosing starting values
+for certain Newton–Raphson iterations”. In: Theoretical Computer
+Science 351 (1 2006), pp. 101–110. doi:
+https://doi.org/10.1016/j.tcs.2005.09.056.
+
+[Mul+10]
+Jean-Michel Muller et al. Handbook of floating-point arithmetic.
+Boston: Birkhauser, 2010. isbn: 9780817647049.
--- /dev/null
+++ b/lib/math/sqrt-impl.myr
@@ -1,0 +1,188 @@
+use std
+
+use "fpmath"
+
+/* See [Mul+10], sections 5.4 and 8.7 */
+pkg math =
+	pkglocal const sqrt32 : (f : flt32 -> flt32)
+	pkglocal const sqrt64 : (f : flt64 -> flt64)
+;;
+
+extern const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32)
+extern const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64)
+
+type fltdesc(@f, @u, @i) = struct
+	explode : (f : @f -> (bool, @i, @u))
+	assem : (n : bool, e : @i, s : @u -> @f)
+	fma : (x : @f, y : @f, z : @f -> @f)
+	tobits : (f : @f -> @u)
+	frombits : (u : @u -> @f)
+	nan : @f
+	emin : @i
+	emax : @i
+	normmask : @u
+	sgnmask : @u
+	ab : (@u, @u)[:]
+	iterlim : int
+;;
+
+/*
+   The starting point of the N-R iteration of 1/sqrt, after significand
+   has been normalized to [1, 4).
+
+   See [KM06] for the construction and notation. Case p = -2. The
+   dividers (left values) are chosen roughly so that maximal error
+   of N-R, after 3 iterations, starting with the right value, is
+   less than an ulp (of the result). If g falls in [a_i, a_{i+1}),
+   N-R should start with b_{i+1}.
+
+   In the flt64 case, we need only one more iteration.
+ */
+const ab32 : (uint32, uint32)[:] = [
+	(0x3f800000, 0x3f800000), /* Nothing should ever get normalized to < 1.0 */
+	(0x3fa66666, 0x3f6f30ae), /* [1.0,  1.3 ) -> 0.9343365431 */
+	(0x3fd9999a, 0x3f5173ca), /* [1.3,  1.7 ) -> 0.8181730509 */
+	(0x40100000, 0x3f3691d3), /* [1.7,  2.25) -> 0.713162601  */
+	(0x40333333, 0x3f215342), /* [2.25, 2.8 ) -> 0.6301766634 */
+	(0x4059999a, 0x3f118e0e), /* [2.8,  3.4 ) -> 0.5685738325 */
+	(0x40800000, 0x3f053049), /* [3.4,  4.0 ) -> 0.520268023  */
+][:]
+
+const ab64 : (uint64, uint64)[:] = [
+	(0x3ff0000000000000, 0x3ff0000000000000), /* < 1.0 */
+	(0x3ff3333333333333, 0x3fee892ce1608cbc), /* [1.0,  1.2)  -> 0.954245033445111356940060431953 */
+	(0x3ff6666666666666, 0x3fec1513a2184094), /* [1.2,  1.4)  -> 0.877572838393478438234751592972 */
+	(0x3ffc000000000000, 0x3fe9878eb3e9ba20), /* [1.4,  1.75) -> 0.797797538178034670863780775107 */
+	(0x400199999999999a, 0x3fe6ccb14eeb238d), /* [1.75, 2.2)  -> 0.712486890924184046447464879748 */
+	(0x400599999999999a, 0x3fe47717c17cd34f), /* [2.2,  2.7)  -> 0.639537694840876969060161627567 */
+	(0x400b333333333333, 0x3fe258df212a8e9a), /* [2.7,  3.4)  -> 0.573348583963212421465982515656 */
+	(0x4010000000000000, 0x3fe0a5989f2dc59a), /* [3.4,  4.0)  -> 0.520214377304159869552790951275 */
+][:]
+
+const sqrt32 = {f : flt32
+	var d : fltdesc(flt32, uint32, int32) =  [
+		.explode = std.flt32explode,
+		.assem = std.flt32assem,
+		.fma = fma32,
+		.tobits = std.flt32bits,
+		.frombits = std.flt32frombits,
+		.nan = std.flt32nan(),
+		.emin = -127,
+		.emax = 128,
+		.normmask = 1 << 23,
+		.sgnmask = 1 << 31,
+		.ab = ab32,
+		.iterlim = 3,
+	]
+	-> sqrtgen(f, d)
+}
+
+const sqrt64 = {f : flt64
+	var d : fltdesc(flt64, uint64, int64) =  [
+		.explode = std.flt64explode,
+		.assem = std.flt64assem,
+		.fma = fma64,
+		.tobits = std.flt64bits,
+		.frombits = std.flt64frombits,
+		.nan = std.flt64nan(),
+		.emin = -1023,
+		.emax = 1024,
+		.normmask = 1 << 52,
+		.sgnmask = 1 << 63,
+		.ab = ab64,
+		.iterlim = 4,
+	]
+	-> sqrtgen(f, d)
+}
+
+generic sqrtgen = {f : @f, d : fltdesc(@f, @u, @i) :: numeric,floating,std.equatable,fpmath @f, numeric,integral @u, numeric,integral @i
+	var n, e, s, e2
+	(n, e, s) = d.explode(f)
+
+	/* Special cases: +/- 0.0, negative, NaN, and +inf */
+	if e == d.emin && s == 0
+		-> f
+	elif n || std.isnan(f)
+		-> d.nan
+	elif e == d.emax
+		-> f
+	;;
+
+	/*
+	   Remove a factor of 2^(even) to normalize significand.
+	 */
+	if e == d.emin
+		e = d.emin + 1
+		while s & d.normmask == 0
+			s <<= 1
+			e--
+		;;
+	;;
+	if e % 2 != 0
+		e2 = e - 1
+		e = 1
+	else
+		e2 = e
+		e = 0
+	;;
+
+	var a = d.assem(false, e, s)
+	var au = d.tobits(a)
+
+        /*
+           We shall perform iterated Newton-Raphson in order to
+           compute 1/sqrt(g), then multiply by g to obtain sqrt(g).
+           This is faster than calculating sqrt(g) directly because
+           it avoids division. (The multiplication by g is built
+           into Markstein's r, g, n variables.)
+         */
+	var xn = d.frombits(0)
+	for (ai, beta) : d.ab
+		if au <= ai
+			xn = d.frombits(beta)
+			break
+		;;
+	;;
+
+	/* split up "x_{n+1} = x_n (3 - ax_n^2)/2" */
+	var epsn = fma(-1.0 * a, xn * xn, 1.0)
+	var rn = 0.5 * epsn
+	var gn = a * xn
+	var hn = 0.5 * xn
+	for var j = 0; j < d.iterlim; ++j
+		rn = d.fma(-1.0 * gn, hn, 0.5)
+		gn = d.fma(gn, rn, gn)
+		hn = d.fma(hn, rn, hn)
+	;;
+
+	/*
+	   gn is almost what we want, except that we might want to
+	   adjust by an ulp in one direction or the other. This is
+	   the Tuckerman test.
+
+	   Exhaustive testing has shown that we need only 3 adjustments
+	   in the flt32 case (and it should be 4 in the flt64 case).
+	 */
+	(_, e, s) = d.explode(gn)
+	e += (e2 / 2)
+	var r = d.assem(false, e, s)
+
+	for var j = 0; j < d.iterlim; ++j
+		var r_plus_ulp = d.frombits(d.tobits(r) + 1)
+		var r_minus_ulp = d.frombits(d.tobits(r) - 1)
+
+		var delta_1 = d.fma(r, r_minus_ulp, -1.0 * f)
+		if d.tobits(delta_1) & d.sgnmask == 0
+			r = r_minus_ulp
+		else
+			var delta_2 = d.fma(r, r_plus_ulp, -1.0 * f)
+			if d.tobits(delta_2) & d.sgnmask != 0
+				r = r_plus_ulp
+			else
+				-> r
+			;;
+		;;
+	;;
+
+	-> r
+}
--- a/lib/math/sum-impl.myr
+++ b/lib/math/sum-impl.myr
@@ -1,5 +1,6 @@
 use std
 
+/* For references, see [Mul+10] section 6.3 */
 pkg math =
 	pkglocal const kahan_sum32 : (l : flt32[:] -> flt32)
 	pkglocal const priest_sum32 : (l : flt32[:] -> flt32)
--- /dev/null
+++ b/lib/math/test/sqrt-impl.myr
@@ -1,0 +1,83 @@
+use std
+use math
+use testr
+
+const main = {
+	testr.run([
+		[.name="sqrt-01", .fn = sqrt01],
+		[.name="sqrt-02", .fn = sqrt02],
+	][:])
+}
+
+const sqrt01 = {c
+	var inputs : (uint32, uint32)[:] = [
+		(0x00000000, 0x00000000),
+		(0x80000000, 0x80000000),
+		(0x80000001, 0x7ff80000),
+		(0x8aaaaaaa, 0x7ff80000),
+		(0x3f800000, 0x3f800000),
+		(0x40800000, 0x40000000),
+		(0x41100000, 0x40400000),
+		(0x3e800000, 0x3f000000),
+		(0x3a3a0000, 0x3cda35fe),
+		(0x017a1000, 0x207d038b),
+		(0x00fc0500, 0x20339b45),
+		(0x160b0000, 0x2abca321),
+		(0x00000800, 0x1d000000),
+		(0x7f690a00, 0x5f743ff8),
+		(0x7f5c0e00, 0x5f6d590c),
+	][:]
+
+	for (x, y) : inputs
+		var xf : flt32 = std.flt32frombits(x)
+		var yf : flt32 = std.flt32frombits(y)
+		var rf = math.sqrt(xf)
+		testr.check(c, rf == yf,
+			"sqrt(0x{b=16,w=8,p=0}) should be 0x{b=16,w=8,p=0}, was 0x{b=16,w=8,p=0}",
+			x, y, std.flt32bits(rf))
+	;;
+}
+
+const sqrt02 = {c
+	var inputs : (uint64, uint64)[:] = [
+		(0x0000000000000000ul, 0x0000000000000000ul),
+		(0x8000000000000000ul, 0x8000000000000000ul),
+		(0x8000000000000001ul, 0x7ff8000000000000ul),
+		(0x8aaaaaaaaaaaaaaaul, 0x7ff8000000000000ul),
+		(0x0606e437817acd16ul, 0x22fb10b36e4ae795ul),
+		(0x3ff0000000000000ul, 0x3ff0000000000000ul),
+		(0x4010000000000000ul, 0x4000000000000000ul),
+		(0x3fd0000000000000ul, 0x3fe0000000000000ul),
+		(0x1bbffa831c8f220eul, 0x2dd69ead9d353d6cul),
+		(0x3f0e0f7339499f0bul, 0x3f7f03d8229c8b81ul),
+		(0x3ca510f548e0f3ecul, 0x3e49f6bcadd1e806ul),
+		(0x044ef24a3cca214bul, 0x221f780430319d58ul),
+		(0x7ab034357a1e0474ul, 0x5d501a0593fd8d49ul),
+		(0x216b2df113b38de7ul, 0x30ad7dcc6f26285aul),
+		(0x2e2de34118496c06ul, 0x370eed0301fdade1ul),
+		(0x155bf26b4fb0b2c8ul, 0x2aa5255cf9bd799cul),
+		(0x4b8004df0ac137aaul, 0x45b6a40fee232f2aul),
+		(0x1acaf23d7b0bf80cul, 0x2d5d5d56beda3392ul),
+		(0x3f97ea4c6399a8e6ul, 0x3fc38fb000d55805ul),
+		(0x78f36ea1656dec48ul, 0x5c71a1fce3f370e4ul),
+		(0x409636d438489edbul, 0x4042da4eeac985aaul),
+		(0x72dfd27869ffd768ul, 0x5966907fc9668f57ul),
+		(0x1f483c585e4f03dcul, 0x2f9bd93c3bd1f884ul),
+		(0x7ade25ea6bb6464eul, 0x5d65f681bdbcdf4eul),
+		(0x24ffe5593b0836dbul, 0x3276973038d3bbddul),
+		(0x03e92ac739ec355eul, 0x21ec60eea1d102e8ul),
+		(0x76b656a961a4f64eul, 0x5b52e7cc1d30f55bul),
+		(0x5bc2fac208381d11ul, 0x4dd8a4f5203ab3d2ul),
+		(0x000578e105ac27aaul, 0x1ff2b6d3204e206eul),
+		(0x00057e1016b7c1edul, 0x1ff2bfae3a8e21bbul),
+	][:]
+
+	for (x, y) : inputs
+		var xf : flt64 = std.flt64frombits(x)
+		var yf : flt64 = std.flt64frombits(y)
+		var rf = math.sqrt(xf)
+		testr.check(c, rf == yf,
+			"sqrt(0x{b=16,w=16,p=0}) should be 0x{b=16,w=16,p=0}, was 0x{b=16,w=16,p=0}",
+			x, y, std.flt64bits(rf))
+	;;
+}