shithub: mc

ref: 6d707df71c0fe11da7dd53bf36eb3d3fadb8ec9b
dir: /lib/math/sqrt-impl.myr/

View raw version
use std

use "fpmath"

/* See [Mul+10], sections 5.4 and 8.7 */
pkg math =
	pkglocal const sqrt32 : (x : flt32 -> flt32)
	pkglocal const sqrt64 : (x : flt64 -> flt64)
;;

extern const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32)
extern const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64)

type fltdesc(@f, @u, @i) = struct
	explode : (f : @f -> (bool, @i, @u))
	assem : (n : bool, e : @i, s : @u -> @f)
	fma : (x : @f, y : @f, z : @f -> @f)
	tobits : (f : @f -> @u)
	frombits : (u : @u -> @f)
	nan : @u
	emin : @i
	emax : @i
	normmask : @u
	sgnmask : @u
	ab : (@u, @u)[:]
	iterlim : int
;;

/*
   The starting point of the N-R iteration of 1/sqrt, after significand
   has been normalized to [1, 4).

   See [KM06] for the construction and notation. Case p = -2. The
   dividers (left values) are chosen roughly so that maximal error
   of N-R, after 3 iterations, starting with the right value, is
   less than an ulp (of the result). If g falls in [a_i, a_{i+1}),
   N-R should start with b_{i+1}.

   In the flt64 case, we need only one more iteration.
 */
const ab32 : (uint32, uint32)[7] = [
	(0x3f800000, 0x3f800000), /* Nothing should ever get normalized to < 1.0 */
	(0x3fa66666, 0x3f6f30ae), /* [1.0,  1.3 ) -> 0.9343365431 */
	(0x3fd9999a, 0x3f5173ca), /* [1.3,  1.7 ) -> 0.8181730509 */
	(0x40100000, 0x3f3691d3), /* [1.7,  2.25) -> 0.713162601  */
	(0x40333333, 0x3f215342), /* [2.25, 2.8 ) -> 0.6301766634 */
	(0x4059999a, 0x3f118e0e), /* [2.8,  3.4 ) -> 0.5685738325 */
	(0x40800000, 0x3f053049), /* [3.4,  4.0 ) -> 0.520268023  */
]

const ab64 : (uint64, uint64)[8] = [
	(0x3ff0000000000000, 0x3ff0000000000000), /* < 1.0 */
	(0x3ff3333333333333, 0x3fee892ce1608cbc), /* [1.0,  1.2)  -> 0.954245033445111356940060431953 */
	(0x3ff6666666666666, 0x3fec1513a2184094), /* [1.2,  1.4)  -> 0.877572838393478438234751592972 */
	(0x3ffc000000000000, 0x3fe9878eb3e9ba20), /* [1.4,  1.75) -> 0.797797538178034670863780775107 */
	(0x400199999999999a, 0x3fe6ccb14eeb238d), /* [1.75, 2.2)  -> 0.712486890924184046447464879748 */
	(0x400599999999999a, 0x3fe47717c17cd34f), /* [2.2,  2.7)  -> 0.639537694840876969060161627567 */
	(0x400b333333333333, 0x3fe258df212a8e9a), /* [2.7,  3.4)  -> 0.573348583963212421465982515656 */
	(0x4010000000000000, 0x3fe0a5989f2dc59a), /* [3.4,  4.0)  -> 0.520214377304159869552790951275 */
]

const desc32 : fltdesc(flt32, uint32, int32) =  [
	.explode = std.flt32explode,
	.assem = std.flt32assem,
	.tobits = std.flt32bits,
	.frombits = std.flt32frombits,
	.nan = 0x7fc00000,
	.emin = -127,
	.emax = 128,
	.normmask = 1 << 23,
	.sgnmask = 1 << 31,
	.ab = ab32[0:7],
	.iterlim = 3,
]

const desc64 : fltdesc(flt64, uint64, int64) =  [
	.explode = std.flt64explode,
	.assem = std.flt64assem,
	.tobits = std.flt64bits,
	.frombits = std.flt64frombits,
	.nan = 0x7ff8000000000000,
	.emin = -1023,
	.emax = 1024,
	.normmask = 1 << 52,
	.sgnmask = 1 << 63,
	.ab = ab64[0:8],
	.iterlim = 4,
]

const sqrt32 = {x : flt32
	-> sqrtgen(x, desc32)
}

const sqrt64 = {x : flt64
	-> sqrtgen(x, desc64)
}

generic sqrtgen = {x : @f, d : fltdesc(@f, @u, @i) :: numeric,floating,std.equatable @f, numeric,integral @u, numeric,integral @i
	var n : bool, e : @i, s : @u, e2 : @i
	(n, e, s) = d.explode(x)

	/* Special cases: +/- 0.0, negative, NaN, and +inf */
	if e == d.emin && s == 0
		-> x
	elif n || std.isnan(x)
		/* Make sure to return a quiet NaN */
		-> d.frombits(d.nan)
	elif e == d.emax
		-> x
	;;

	/*
	   Remove a factor of 2^(even) to normalize significand.
	 */
	if e == d.emin
		e = d.emin + 1
		while s & d.normmask == 0
			s <<= 1
			e--
		;;
	;;
	if e % 2 != 0
		e2 = e - 1
		e = 1
	else
		e2 = e
		e = 0
	;;

	var a : @f = d.assem(false, e, s)
	var au : @u = d.tobits(a)

        /*
           We shall perform iterated Newton-Raphson in order to
           compute 1/sqrt(g), then multiply by g to obtain sqrt(g).
           This is faster than calculating sqrt(g) directly because
           it avoids division. (The multiplication by g is built
           into Markstein's r, g, n variables.)
         */
	var xn : @f = d.frombits(0)
	for (ai, beta) : d.ab
		if au <= ai
			xn = d.frombits(beta)
			break
		;;
	;;

	/* split up "x_{n+1} = x_n (3 - ax_n^2)/2" */
	var epsn = fma(-1.0 * a, xn * xn, 1.0)
	var rn = 0.5 * epsn
	var gn = a * xn
	var hn = 0.5 * xn
	for var j = 0; j < d.iterlim; ++j
		rn = fma(-1.0 * gn, hn, 0.5)
		gn = fma(gn, rn, gn)
		hn = fma(hn, rn, hn)
	;;

	/*
	   gn is almost what we want, except that we might want to
	   adjust by an ulp in one direction or the other. This is
	   the Tuckerman test.

	   Exhaustive testing has shown that we need only 3 adjustments
	   in the flt32 case (and it should be 4 in the flt64 case).
	 */
	(_, e, s) = d.explode(gn)
	e += (e2 / 2)
	var r : @f = d.assem(false, e, s)

	for var j = 0; j < d.iterlim; ++j
		var r_plus_ulp : @f = d.frombits(d.tobits(r) + 1)
		var r_minus_ulp : @f = d.frombits(d.tobits(r) - 1)

		var delta_1 = fma(r, r_minus_ulp, -1.0 * x)
		if d.tobits(delta_1) & d.sgnmask == 0
			r = r_minus_ulp
		else
			var delta_2 = fma(r, r_plus_ulp, -1.0 * x)
			if d.tobits(delta_2) & d.sgnmask != 0
				r = r_plus_ulp
			else
				-> r
			;;
		;;
	;;

	-> r
}