shithub: mc

Download patch

ref: 8c252ab901a3375d608cb34619cafe9ac342bb36
parent: 9aabbe574e648e0a462b211c62da5eb5875343b7
author: S. Gilles <sgilles@math.umd.edu>
date: Wed Apr 18 04:31:57 EDT 2018

Add round-to-nearest-integer.

As distinct from trunc/floor/ceil, these obey the current rounding
mode (round to nearest) and return actual integers.

--- a/lib/math/bld.sub
+++ b/lib/math/bld.sub
@@ -1,6 +1,10 @@
 lib math =
 	fpmath.myr
 
+	# rounding (to actual integers)
+	round-impl+posixy-x64-sse4.s
+	round-impl.myr
+
 	# fused-multiply-add
 	fma-impl+posixy-x64-fma.s
 	fma-impl.myr
@@ -9,12 +13,12 @@
 	sqrt-impl+posixy-x64-sse2.s
 	sqrt-impl.myr
 
+	# summation
+	sum-impl.myr
+
 	# trunc, floor, ceil
 	trunc-impl+posixy-x64-sse4.s
 	trunc-impl.myr
-
-	# summation
-	sum-impl.myr
 
 	# util
 	util.myr
--- a/lib/math/fpmath.myr
+++ b/lib/math/fpmath.myr
@@ -9,18 +9,25 @@
 		/* sqrt-impl */
 		sqrt : (f : @f -> @f)
 
+		/* sum-impl */
+		kahan_sum : (a : @f[:] -> @f)
+		priest_sum : (a : @f[:] -> @f)
+
 		/* trunc-impl */
 		trunc : (f : @f -> @f)
 		ceil  : (f : @f -> @f)
 		floor : (f : @f -> @f)
+	;;
 
-		/* sum-impl */
-		kahan_sum : (a : @f[:] -> @f)
-		priest_sum : (a : @f[:] -> @f)
+	trait roundable @f -> @i =
+		/* round-impl */
+		rn : (f : @f -> @i)
 	;;
 
 	impl std.equatable flt32
 	impl std.equatable flt64
+	impl roundable flt64 -> int64
+	impl roundable flt32 -> int32
 	impl fpmath flt32
 	impl fpmath flt64
 ;;
@@ -41,17 +48,26 @@
 	eq = {a : flt64, b : flt64; -> std.flt64bits(a) == std.flt64bits(b)}
 ;;
 
+impl roundable flt32 -> int32 =
+	rn = {f : flt32; -> rn32(f) }
+;;
+
+impl roundable flt64 -> int64 =
+	rn = {f : flt64; -> rn64(f) }
+;;
+
 impl fpmath flt32 =
 	fma = {x, y, z; -> fma32(x, y, z)}
 
 	sqrt = {f; -> sqrt32(f)}
 
+	kahan_sum = {l; -> kahan_sum32(l) }
+	priest_sum = {l; -> priest_sum32(l) }
+
 	trunc = {f; -> trunc32(f)}
 	floor = {f; -> floor32(f)}
 	ceil  = {f; -> ceil32(f)}
 
-	kahan_sum = {l; -> kahan_sum32(l) }
-	priest_sum = {l; -> priest_sum32(l) }
 ;;
 
 impl fpmath flt64 =
@@ -59,14 +75,17 @@
 
 	sqrt = {f; -> sqrt64(f)}
 
+	kahan_sum = {l; -> kahan_sum64(l) }
+	priest_sum = {l; -> priest_sum64(l) }
+
 	trunc = {f; -> trunc64(f)}
 	floor = {f; -> floor64(f)}
 	ceil  = {f; -> ceil64(f)}
-
-	kahan_sum = {l; -> kahan_sum64(l) }
-	priest_sum = {l; -> priest_sum64(l) }
 ;;
 
+extern const rn32 : (f : flt32 -> int32)
+extern const rn64 : (f : flt64 -> int64)
+
 extern const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32)
 extern const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64)
 
@@ -73,6 +92,12 @@
 extern const sqrt32 : (x : flt32 -> flt32)
 extern const sqrt64 : (x : flt64 -> flt64)
 
+extern const kahan_sum32 : (l : flt32[:] -> flt32)
+extern const kahan_sum64 : (l : flt64[:] -> flt64)
+
+extern const priest_sum32 : (l : flt32[:] -> flt32)
+extern const priest_sum64 : (l : flt64[:] -> flt64)
+
 extern const trunc32 : (f : flt32 -> flt32)
 extern const trunc64 : (f : flt64 -> flt64)
 
@@ -81,9 +106,3 @@
 
 extern const ceil32  : (f : flt32 -> flt32)
 extern const ceil64  : (f : flt64 -> flt64)
-
-extern const kahan_sum32 : (l : flt32[:] -> flt32)
-extern const kahan_sum64 : (l : flt64[:] -> flt64)
-
-extern const priest_sum32 : (l : flt32[:] -> flt32)
-extern const priest_sum64 : (l : flt64[:] -> flt64)
--- /dev/null
+++ b/lib/math/round-impl+posixy-x64-sse4.s
@@ -1,0 +1,14 @@
+.globl math$rn32
+.globl math$_rn32
+math$rn32:
+math$_rn32:
+	cvtss2si	%xmm0, %eax
+	ret
+
+.globl math$rn64
+.globl math$_rn64
+math$rn64:
+math$_rn64:
+	cvtsd2si	%xmm0, %rax
+	ret
+
--- /dev/null
+++ b/lib/math/round-impl.myr
@@ -1,0 +1,70 @@
+use std
+
+use "util"
+
+pkg math =
+	const rn64 : (f : flt64 -> int64)
+	const rn32 : (f : flt32 -> int32)
+;;
+
+const rn64 = {f : flt64
+	var n : bool, e : int64, s : uint64
+
+	(n, e, s) = std.flt64explode(f)
+
+	if e >= 63
+		-> -9223372036854775808
+	elif e >= 52
+		var shifted : int64 = (( s << (e - 52 : uint64)) : int64)
+		if n
+			-> -1 * shifted
+		else
+			-> shifted
+		;;
+	elif e < -1
+		-> 0
+	;;
+
+	var integral_s = (s >> (52 - e : uint64) : int64)
+
+	if need_round_away(0, s, 52 - e)
+		integral_s++
+	;;
+
+	if n
+		-> -integral_s
+	else
+		-> integral_s
+	;;
+}
+
+const rn32 = {f : flt32
+	var n : bool, e : int32, s : uint32
+
+	(n, e, s) = std.flt32explode(f)
+
+	if e >= 31
+		-> -2147483648
+	elif e >= 23
+		var shifted : int32 = (( s << (e - 23 : uint32)) : int32)
+		if n
+			->  -1 * shifted
+		else
+			->  shifted
+		;;
+	elif e < -1
+		-> 0
+	;;
+
+	var integral_s = (s >> (23 - e : uint32) : int32)
+
+	if need_round_away(0, (s : uint64), (23 - e : int64))
+		integral_s++
+	;;
+
+	if n
+		-> -integral_s
+	else
+		-> integral_s
+	;;
+}
--- /dev/null
+++ b/lib/math/test/round-impl.myr
@@ -1,0 +1,82 @@
+use std
+use math
+use testr
+
+const main = {
+	testr.run([
+		[.name = "round-01",    .fn = round01],
+		[.name = "round-02",    .fn = round02],
+	][:])
+}
+
+const round01 = {c
+	var inputs : (flt32, int32)[:] = [
+		(123.4, 123),
+		(0.0, 0),
+		(-0.0, 0),
+		(1.0, 1),
+		(1.1, 1),
+		(0.9, 1),
+		(15.3, 15),
+		(15.5, 16),
+		(15.7, 16),
+		(16.3, 16),
+		(16.5, 16),
+		(16.7, 17),
+		(-102.1, -102),
+		(-102.5, -102),
+		(-102.7, -103),
+		(-103.1, -103),
+		(-103.5, -104),
+		(-103.7, -104),
+		(2147483641.5, -2147483648),
+		(2147483646.5, -2147483648),
+		(2147483647.5, -2147483648),
+		(2147483649.0, -2147483648),
+		(-2147483641.5, -2147483648),
+		(-2147483646.5, -2147483648),
+		(-2147483647.5, -2147483648),
+		(-2147483649.0, -2147483648),
+	][:]
+
+	for (f, g) : inputs
+		testr.eq(c, math.rn(f), g)
+	;;
+}
+
+const round02 = {c
+	var inputs : (flt64, int64)[:] = [
+		(123.4, 123),
+		(0.0, 0),
+		(-0.0, 0),
+		(1.0, 1),
+		(1.1, 1),
+		(0.9, 1),
+		(15.3, 15),
+		(15.5, 16),
+		(15.7, 16),
+		(16.3, 16),
+		(16.5, 16),
+		(16.7, 17),
+		(-102.1, -102),
+		(-102.5, -102),
+		(-102.7, -103),
+		(-103.1, -103),
+		(-103.5, -104),
+		(-103.7, -104),
+		(2147483641.5, 2147483642),
+		(2147483646.5, 2147483646),
+		(2147483647.5, 2147483648),
+		(2147483649.0, 2147483649),
+		(-2147483641.5, -2147483642),
+		(-2147483646.5, -2147483646),
+		(-2147483647.5, -2147483648),
+		(-2147483649.0, -2147483649),
+		(9223372036854775806.1, -9223372036854775808),
+		(-9223372036854775806.1, -9223372036854775808),
+	][:]
+
+	for (f, g) : inputs
+		testr.eq(c, math.rn(f), g)
+	;;
+}