shithub: mc

ref: 34f2230c4a505f3b94bc33ed07f0839fe66a0e93
dir: /lib/math/pown-impl.myr/

View raw version
use std

use "fpmath"
use "log-impl"
use "log-overkill"
use "sum-impl"
use "util"

/*
   This is an implementation of pown: computing x^n where n is an
   integer. We sort of follow [PEB04], but without their high-radix
   log_2. Instead, we use log-overkill, which should be good enough.
 */
pkg math =
	pkglocal const pown32 : (x : flt32, n : int32 -> flt32)
	pkglocal const pown64 : (x : flt64, n : int64 -> flt64)

	pkglocal const rootn32 : (x : flt32, q : uint32 -> flt32)
	pkglocal const rootn64 : (x : flt64, q : uint64 -> flt64)
;;

type fltdesc(@f, @u, @i) = struct
	explode : (f : @f -> (bool, @i, @u))
	assem : (n : bool, e : @i, s : @u -> @f)
	tobits : (f : @f -> @u)
	frombits : (u : @u -> @f)
	C : (@u, @u)[:]
	one_over_ln2_hi : @u
	one_over_ln2_lo : @u
	nan : @u
	inf : @u
	neginf : @u
	magcmp : (f : @f, g : @f -> std.order)
	two_by_two : (x : @f, y : @f -> (@f, @f))
	log_overkill : (x : @f -> (@f, @f))
	emin : @i
	emax : @i
	imax : @i
	imin : @i
;;

const desc32 : fltdesc(flt32, uint32, int32) =  [
	.explode = std.flt32explode,
	.assem = std.flt32assem,
	.tobits = std.flt32bits,
	.frombits = std.flt32frombits,
	.C = accurate_logs32[0:130], /* See log-impl.myr */
	.one_over_ln2_hi = 0x3fb8aa3b, /* 1/ln(2), top part */
	.one_over_ln2_lo = 0x32a57060, /* 1/ln(2), bottom part */
	.nan = 0x7fc00000,
	.inf = 0x7f800000,
	.neginf = 0xff800000,
	.magcmp = mag_cmp32,
	.two_by_two = two_by_two32,
	.log_overkill = logoverkill32,
	.emin = -126,
	.emax = 127,
	.imax = 2147483647, /* For detecting overflow in final exponent */
	.imin = -2147483648,
]

const desc64 : fltdesc(flt64, uint64, int64) =  [
	.explode = std.flt64explode,
	.assem = std.flt64assem,
	.tobits = std.flt64bits,
	.frombits = std.flt64frombits,
	.C = accurate_logs64[0:130], /* See log-impl.myr */
	.one_over_ln2_hi = 0x3ff71547652b82fe,
	.one_over_ln2_lo = 0x3c7777d0ffda0d24,
	.nan = 0x7ff8000000000000,
	.inf = 0x7ff0000000000000,
	.neginf = 0xfff0000000000000,
	.magcmp = mag_cmp64,
	.two_by_two = two_by_two64,
	.log_overkill = logoverkill64,
	.emin = -1022,
	.emax = 1023,
	.imax = 9223372036854775807,
	.imin = -9223372036854775808,
]

const pown32 = {x : flt32, n : int32
	-> powngen(x, n, desc32)
}

const pown64 = {x : flt64, n : int64
	-> powngen(x, n, desc64)
}

generic powngen = {x : @f, n : @i, d : fltdesc(@f, @u, @i) :: numeric,floating,std.equatable @f, numeric,integral @u, numeric,integral @i
	var xb
	xb = d.tobits(x)

	var xn : bool, xe : @i, xs : @u
	(xn, xe, xs) = d.explode(x)

	var nf : @f = (n : @f)

	/*
	   Special cases. Note we do not follow IEEE exceptions.
	 */
	if n == 0
		/*
		   Anything^0 is 1. We're taking the view that x is a tiny range of reals,
		   so a dense subset of them are 1, even if x is 0.0.
		 */
		-> 1.0
	elif std.isnan(x)
		/* Propagate NaN (why doesn't this come first? Ask IEEE.) */
		-> d.frombits(d.nan)
	elif (x == 0.0 || x == -0.0)
		if n < 0 && (n % 2 == -1) && xn
			/* (+/- 0)^n = +/- oo */
			-> d.frombits(d.neginf)
		elif n < 0
			-> d.frombits(d.inf)
		elif n % 2 == 1
			/* (+/- 0)^n = +/- 0 (n odd) */
			-> d.assem(xn, d.emin - 1, 0)
		else
			-> 0.0
		;;
	elif n == 1
		/* Anything^1 is itself */
		-> x
	;;

	/* (-f)^n = (-1)^n * (f)^n. Figure this out now, then pretend f >= 0.0 */
	var ult_sgn = 1.0
	if xn && (n % 2 == 1 || n % 2 == -1)
		ult_sgn = -1.0
	;;

	/*
	   Compute (with x = xs * 2^e)

	     x^n  = 2^(n*log2(xs)) * 2^(n*e)

	          = 2^(I + F) * 2^(n*e)
	          = 2^(F) * 2^(I+n*e)

	   Since n and e, and I are all integers, we can get the last part from
	   scale2. The hard part is computing I and F, and then computing 2^F.
	 */
	var ln_xs_hi, ln_xs_lo
	(ln_xs_hi, ln_xs_lo) = d.log_overkill(d.assem(false, 0, xs))

	/* Now x^n = 2^(n * [ ln_xs / ln(2) ]) * 2^(n + e) */

	var ls1 : @f[8]
	(ls1[0], ls1[1]) = d.two_by_two(ln_xs_hi, d.frombits(d.one_over_ln2_hi))
	(ls1[2], ls1[3]) = d.two_by_two(ln_xs_hi, d.frombits(d.one_over_ln2_lo))
	(ls1[4], ls1[5]) = d.two_by_two(ln_xs_lo, d.frombits(d.one_over_ln2_hi))
	(ls1[6], ls1[7]) = d.two_by_two(ln_xs_lo, d.frombits(d.one_over_ln2_lo))

	/*
	   Now log2(xs) = Sum(ls1), so

	     x^n = 2^(n * Sum(ls1)) * 2^(n * e)
	 */
	var E1, E2
	(E1, E2) = double_compensated_sum(ls1[0:8])
	var ls2 : @f[5]
	var ls2s : @f[5]
	var I = 0
	(ls2[0], ls2[1]) = d.two_by_two(E1, nf)
	(ls2[2], ls2[3]) = d.two_by_two(E2, nf)
	ls2[4] = 0.0

	/* Now x^n = 2^(Sum(ls2)) * 2^(n + e) */

	for var j = 0; j < 5; ++j
		var i = rn(ls2[j])
		I += i
		ls2[j] -= (i : @f)
	;;

	var F1, F2
	std.slcp(ls2s[0:5], ls2[0:5])
	std.sort(ls2s[0:5], d.magcmp)
	(F1, F2) = double_compensated_sum(ls2s[0:5])

	if (F1 < 0.0 || F1 > 1.0)
		var i = rn(F1)
		I += i
		ls2[4] -= (i : @f)
		std.slcp(ls2s[0:5], ls2[0:5])
		std.sort(ls2s[0:5], d.magcmp)
		(F1, F2) = double_compensated_sum(ls2s[0:5])
	;;

	/* Now, x^n = 2^(F1 + F2) * 2^(I + n*e). */
	var ls3 : @f[6]
	var log2_hi, log2_lo
	(log2_hi, log2_lo) = d.C[128]
	(ls3[0], ls3[1]) = d.two_by_two(F1, d.frombits(log2_hi))
	(ls3[2], ls3[3]) = d.two_by_two(F1, d.frombits(log2_lo))
	(ls3[4], ls3[5]) = d.two_by_two(F2, d.frombits(log2_hi))
	var G1, G2
	(G1, G2) = double_compensated_sum(ls3[0:6])

	var base = exp(G1) + G2
	var pow_xen = xe * n
	var pow = pow_xen + I
	if pow_xen / n != xe || (I > 0 && d.imax - I < pow_xen) || (I < 0 && d.imin - I > pow_xen)
		/*
		   The exponent overflowed. There's no way this is representable. We need
		   to at least recover the correct sign. If the overflow was from the
		   multiplication, then the sign we want is the sign that pow_xen should
		   have been. If the overflow was from the addition, then we still want
		   the sign that pow_xen should have had.
		 */
		if (xe > 0) == (n > 0)
			pow = 2 * d.emax
		else
			pow = 2 * d.emin
		;;
	;;

	-> ult_sgn * scale2(base, pow)
}

/*
   Rootn is barely different enough from pown to justify being split out
   into an entirely separate function.
 */
const rootn32 = {x : flt32, q : uint32
	-> rootngen(x, q, desc32)
}

const rootn64 = {x : flt64, q : uint64
	-> rootngen(x, q, desc64)
}

generic rootngen = {x : @f, q : @u, d : fltdesc(@f, @u, @i) :: numeric,floating,std.equatable @f, numeric,integral @u, numeric,integral @i
	var xb
	xb = d.tobits(x)

	var xn : bool, xe : @i, xs : @u
	(xn, xe, xs) = d.explode(x)

	var qf : @f = (q : @f)

	/*
	   Special cases. Note we do not follow IEEE exceptions.
	 */
	if q == 0
		/* "for any x (even a zero, quiet NaN, or infinity" */
		-> 1.0
	elif std.isnan(x)
		-> d.frombits(d.nan)
	elif (x == 0.0 || x == -0.0)
		if xn && q % 2 == 1
			/* (+/- 0)^1/q = +/- oo (q odd) */
			-> d.assem(xn, d.emax, 0)
		else
			-> d.frombits(d.inf)
		;;
	elif q == 1
		/* Anything^1/1 is itself */
		-> x
	;;

	/* As in pown */
	var ult_sgn = 1.0
	if xn && (q % 2 == 1)
		ult_sgn = -1.0
	;;

	/* Similar to pown. Let e/q = E + psi, with E an integer.

	   x^(1/q) = e^(log(xs)/q) * 2^(e/q)
	           = e^(log(xs)/q) * 2^(psi) * 2^E
	           = e^(log(xs)/q) * e^(log(2) * psi) * 2^E
	           = e^( log(xs)/q  +  log(2) * psi ) * 2^E

	   I've opted to do things just in terms of natural base here because we
	   don't have an integer part, I, that we can slide over in infinite
	   precision.
	 */

	/* Calculate 1/q in very high precision */
	var r1 = 1.0 / qf
	var r2 = -math.fma(r1, qf, -1.0) / qf
	var ln_xs_hi, ln_xs_lo
	(ln_xs_hi, ln_xs_lo) = d.log_overkill(d.assem(false, 0, xs))
	var ls1 : @f[12]
	(ls1[0], ls1[1]) = d.two_by_two(ln_xs_hi, r1)
	(ls1[2], ls1[3]) = d.two_by_two(ln_xs_hi, r2)
	(ls1[4], ls1[5]) = d.two_by_two(ln_xs_lo, r1)

	var E : @i
	if q > std.abs(xe)
		/* Don't cast q to @i unless we're sure it's in small range */
		E = 0
	else
		E = xe / (q : @i)
	;;
	var qpsi = xe - q * E
	var psi_hi = (qpsi : @f) / qf
	var psi_lo = -math.fma(psi_hi, qf, -(qpsi : @f)) / qf
	var log2_hi, log2_lo
	(log2_hi, log2_lo) = d.C[128]
	(ls1[ 6], ls1[ 7]) = d.two_by_two(psi_hi, d.frombits(log2_hi))
	(ls1[ 8], ls1[ 9]) = d.two_by_two(psi_hi, d.frombits(log2_lo))
	(ls1[10], ls1[11]) = d.two_by_two(psi_lo, d.frombits(log2_hi))

	var G1, G2
	(G1, G2) = double_compensated_sum(ls1[0:12])
	/* G1 + G2 approximates log(xs)/q + log(2)*psi */

	var base = exp(G1) + G2

	-> ult_sgn * scale2(base, E)
}