shithub: mc

Download patch

ref: f43021ba4fcf39c9ebd718bb3e594f063680e584
parent: 599feae2daa1392f20e1dc807c58bcd74c20b45e
author: S. Gilles <sgilles@math.umd.edu>
date: Tue Mar 13 06:54:41 EDT 2018

Implement Kahan and Priest summation

Sparse testing for now

--- a/lib/math/bld.sub
+++ b/lib/math/bld.sub
@@ -4,5 +4,8 @@
 	# trunc
 	fpmath-trunc-impl.myr
 
+	# summation
+	fpmath-sum-impl.myr
+
 	lib ../std:std
 ;;
--- /dev/null
+++ b/lib/math/fpmath-sum-impl.myr
@@ -1,0 +1,104 @@
+use std
+
+pkg math =
+	pkglocal const kahan_sum32 : (l : flt32[:] -> flt32)
+	pkglocal const priest_sum32 : (l : flt32[:] -> flt32)
+
+	pkglocal const kahan_sum64: (l : flt64[:] -> flt64)
+	pkglocal const priest_sum64 : (l : flt64[:] -> flt64)
+;;
+
+type doomed_flt32_arr = flt32[:]
+type doomed_flt64_arr = flt64[:]
+
+impl disposable doomed_flt32_arr =
+	__dispose__ = {a : doomed_flt32_arr; std.slfree((a : flt32[:])) }
+;;
+
+impl disposable doomed_flt64_arr =
+	__dispose__ = {a : doomed_flt64_arr; std.slfree((a : flt64[:])) }
+;;
+
+/*
+   Kahan's compensated summation. Fast and reasonably accurate,
+   although cancellation can cause relative error blowup. For
+   something slower, but more accurate, use something like Priest's
+   doubly compensated sums.
+ */
+pkglocal const kahan_sum32 = {l; -> kahan_sum_gen(l, (0.0 : flt32))}
+pkglocal const kahan_sum64 = {l; -> kahan_sum_gen(l, (0.0 : flt64))}
+
+generic kahan_sum_gen = {l : @f[:], zero : @f :: numeric,floating @f
+	if l.len == 0
+		-> zero
+	;;
+
+	var s = zero
+	var c = zero
+	var y = zero
+	var t = zero
+
+	for x : l
+		y = x - c
+		t = s + y
+		c = (t - s) - y
+		s = t
+	;;
+
+	-> s
+}
+
+/*
+   Priest's doubly compensated summation. Extremely accurate, but
+   relatively slow. For situations in which cancellation is not
+   expected, something like Kahan's compensated summation may be
+   more useful.
+ */
+pkglocal const priest_sum32 = {l : flt32[:]
+	var l2 = std.sldup(l)
+	std.sort(l2, mag_cmp32)
+	auto (l2 : doomed_flt32_arr)
+	-> priest_sum_gen(l2, (0.0 : flt32))
+}
+
+const mag_cmp32 = {f : flt32, g : flt32
+	var u = std.flt32bits(f) & ~(1 << 31)
+	var v = std.flt32bits(g) & ~(1 << 31)
+	-> std.numcmp(v, u)
+}
+
+pkglocal const priest_sum64 = {l : flt64[:]
+	var l2 = std.sldup(l)
+	std.sort(l, mag_cmp64)
+	auto (l2 : doomed_flt64_arr)
+	-> priest_sum_gen(l2, (0.0 : flt64))
+}
+
+const mag_cmp64 = {f : flt64, g : flt64
+	var u = std.flt64bits(f) & ~(1 << 63)
+	var v = std.flt64bits(g) & ~(1 << 63)
+	-> std.numcmp(v, u)
+}
+
+generic priest_sum_gen = {l : @f[:], zero : @f :: numeric,floating @f
+	/* l should be sorted in descending order */
+	if l.len == 0
+		-> zero
+	;;
+
+	var s = zero
+	var c = zero
+
+	for x : l
+		var y = c + x
+		var u = x - (y - c)
+		var t = (y + s)
+		var v = (y - (t - s))
+		var z = u + v
+		s = t + z
+		c = z - (s - t)
+	;;
+
+	-> s
+}
+
--- a/lib/math/fpmath.myr
+++ b/lib/math/fpmath.myr
@@ -8,8 +8,9 @@
 		ceil  : (f : @f -> @f)
 		floor : (f : @f -> @f)
 
-		/* compute (s, t) with s = round-nearest(a+b), s + t = a + b */
-//		fast2sum : (a : @f, b : @f -> (@f, @f))
+		/* summation */
+		kahan_sum : (a : @f[:] -> @f)
+		priest_sum : (a : @f[:] -> @f)
 	;;
 
 	impl std.equatable flt32
@@ -38,6 +39,9 @@
 	trunc = {f; -> trunc32(f)}
 	floor = {f; -> floor32(f)}
 	ceil  = {f; -> ceil32(f)}
+
+	kahan_sum = {l; -> kahan_sum32(l) }
+	priest_sum = {l; -> priest_sum32(l) }
 ;;
 
 impl fpmath flt64 =
@@ -44,11 +48,21 @@
 	trunc = {f; -> trunc64(f)}
 	floor = {f; -> floor64(f)}
 	ceil  = {f; -> ceil64(f)}
+
+	kahan_sum = {l; -> kahan_sum64(l) }
+	priest_sum = {l; -> priest_sum64(l) }
 ;;
 
 extern const trunc32 : (f : flt32 -> flt32)
 extern const floor32 : (f : flt32 -> flt32)
 extern const ceil32  : (f : flt32 -> flt32)
+
 extern const trunc64 : (f : flt64 -> flt64)
 extern const floor64 : (f : flt64 -> flt64)
 extern const ceil64  : (f : flt64 -> flt64)
+
+extern const kahan_sum32 : (l : flt32[:] -> flt32)
+extern const priest_sum32 : (l : flt32[:] -> flt32)
+
+extern const kahan_sum64 : (l : flt64[:] -> flt64)
+extern const priest_sum64 : (l : flt64[:] -> flt64)
--- a/lib/math/test/fpmath-sum-impl.myr
+++ b/lib/math/test/fpmath-sum-impl.myr
@@ -4,152 +4,33 @@
 
 const main = {
 	testr.run([
-		[.name = "trunc-01",    .fn = trunc01],
-		[.name = "trunc-02",    .fn = trunc02],
-		[.name = "floor-01",    .fn = floor01],
-		[.name = "floor-02",    .fn = floor02],
-		[.name = "ceil-01",     .fn = ceil01],
-		[.name = "ceil-02",     .fn = ceil02],
-		[.name = "fast2sum-01", .fn = fast2sum01],
+		[.name = "sums-kahan-01", .fn = sums_kahan_01],
+		[.name = "sums-priest-01", .fn = sums_priest_01],
 	][:])
 }
 
-const trunc01 = {c
-	var flt32s : (flt32, flt32)[:] = [
-		(0.0, 0.0),
-		(-0.0, -0.0),
-		(1.0, 1.0),
-		(1.1, 1.0),
-		(0.9, 0.0),
-		(10664524000000000000.0, 10664524000000000000.0),
-		(-3.5, -3.0),
-		(101.999, 101.0),
-		(std.flt32nan(), std.flt32nan()),
-	][:]
+const sums_kahan_01 = {c
+	var sums : (flt32[:], flt32)[:] = [
+		([1.0, 2.0, 3.0][:], 6.0),
 
-	for (f, g) : flt32s
-		testr.eq(c, math.trunc(f), g)
-	;;
-}
-
-const trunc02 = {c
-	var flt64s : (flt64, flt64)[:] = [
-		(0.0, 0.0),
-		(-0.0, -0.0),
-		(1.0, 1.0),
-		(1.1, 1.0),
-		(0.9, 0.0),
-		(13809453812721350000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000.0, 13809453812721350000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000.0),
-		(-3.5, -3.0),
-		(101.999, 101.0),
-		(std.flt64nan(), std.flt64nan()),
+		/* Naive summing gives 1.0, actual answer is 2.0 */
+		([33554432.0, 33554430.0, -16777215.0, -16777215.0, -16777215.0, -16777215.0][:], 3.0)
 	][:]
 
-	for (f, g) : flt64s
-		testr.eq(c, math.trunc(f), g)
+	for (a, r) : sums
+		var s = math.kahan_sum(a)
+		testr.eq(c, r, s)
 	;;
 }
 
-const floor01 = {c
-	var flt32s : (flt32, flt32)[:] = [
-		(0.0, 0.0),
-		(-0.0, -0.0),
-		(0.5, 0.0),
-		(1.1, 1.0),
-		(10664524000000000000.0, 10664524000000000000.0),
-		(-3.5, -4.0),
-		(-101.999, -102.0),
-		(std.flt32nan(), std.flt32nan()),
+const sums_priest_01 = {c
+	var sums : (flt32[:], flt32)[:] = [
+		([1.0, 2.0, 3.0][:], 6.0),
+		([33554432.0, 33554430.0, -16777215.0, -16777215.0, -16777215.0, -16777215.0][:], 2.0)
 	][:]
 
-	for (f, g) : flt32s
-		testr.eq(c, math.floor(f), g)
+	for (a, r) : sums
+		var s = math.priest_sum(a)
+		testr.eq(c, r, s)
 	;;
-}
-
-const floor02 = {c
-	var flt64s : (flt64, flt64)[:] = [
-		(0.0, 0.0),
-		(-0.0, -0.0),
-		(0.5, 0.0),
-		(1.1, 1.0),
-		(13809453812721350000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000.0, 13809453812721350000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000.0),
-		(-3.5, -4.0),
-		(-101.999, -102.0),
-		(std.flt64nan(), std.flt64nan()),
-	][:]
-
-	for (f, g) : flt64s
-		testr.eq(c, math.floor(f), g)
-	;;
-}
-
-const ceil01 = {c
-	var flt32s : (flt32, flt32)[:] = [
-		(0.0, 0.0),
-		(-0.0, -0.0),
-		(0.5, 1.0),
-		(-0.1, -0.0),
-		(1.1, 2.0),
-		(10664524000000000000.0, 10664524000000000000.0),
-		(-3.5, -3.0),
-		(-101.999, -101.0),
-		(std.flt32nan(), std.flt32nan()),
-	][:]
-
-	for (f, g) : flt32s
-		testr.eq(c, math.ceil(f), g)
-	;;
-}
-
-const ceil02 = {c
-	var flt64s : (flt64, flt64)[:] = [
-		(0.0, 0.0),
-		(-0.0, -0.0),
-		(0.5, 1.0),
-		(-0.1, -0.0),
-		(1.1, 2.0),
-		(13809453812721350000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000.0, 13809453812721350000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000.0),
-		(-3.5, -3.0),
-		(-101.999, -101.0),
-		(std.flt64nan(), std.flt64nan()),
-	][:]
-
-	for (f, g) : flt64s
-		testr.eq(c, math.ceil(f), g)
-	;;
-}
-
-const fast2sum01 = {c
-        var flt32s : (flt32, flt32, flt32, flt32)[:] = [
-                (1.0, 1.0, 2.0, 0.0),
-                (10664524000000000000.0, 1.11842, 10664524000000000000.0, 1.11842),
-                (1.11843, 10664524000000000000.0, 10664524000000000000.0, 1.11843),
-                (-21897.1324, 17323.22, -4573.912, 0.0),
-        ][:]
-
-        for (a, b, s1, t1) : flt32s
-                var s2, t2
-                (s2, t2) = math.fast2sum(a, b)
-                testr.eq(c, s1, s2)
-                testr.eq(c, t1, t2)
-        ;;
-
-        var flt64s : (flt64, flt64, flt64, flt64)[:] = [
-                (1.0, 1.0, 2.0, 0.0),
-                (-21897.1324, 17323.22, -4573.912399999997, 0.0),
-                (std.flt64frombits(0x78591b0672a81284), std.flt64frombits(0x6a8c3190e27a1884),
-                 std.flt64frombits(0x78591b0672a81284), std.flt64frombits(0x6a8c3190e27a1884)),
-                (std.flt64frombits(0x6a8c3190e27a1884), std.flt64frombits(0x78591b0672a81284),
-                 std.flt64frombits(0x78591b0672a81284), std.flt64frombits(0x6a8c3190e27a1884)),
-                (std.flt64frombits(0x78591b0672a81284), std.flt64frombits(0x7858273672ca19a0),
-                 std.flt64frombits(0x7868a11e72b91612), 0.0),
-        ][:]
-
-        for (a, b, s1, t1) : flt64s
-                var s2, t2
-                (s2, t2) = math.fast2sum(a, b)
-                testr.eq(c, s1, s2)
-                testr.eq(c, std.flt64bits(t1), std.flt64bits(t2))
-        ;;
 }