ref: 8681ea7f41031e95eba880baf934d743cb558758
parent: d870a2fe07edafdc829a76af55792551a8fae394
parent: 24e80ae9a4bfea0c43a594af91819e8052fa6423
author: Lennart Augustsson <lennart@augustsson.net>
date: Sun Jan 19 03:46:14 EST 2025
Merge pull request #94 from konsumlamm/bitcount Add popcount, clz, ctz primitives
--- a/ghc/PrimTable.hs
+++ b/ghc/PrimTable.hs
@@ -62,6 +62,9 @@
, arithwi "shl" shiftL
, arithwi "shr" shiftR
, arith "ashr" shiftR
+ , arithu "popcount" popCount
+ , arithu "clz" countLeadingZeros
+ , arithu "ctz" countTrailingZeros
, cmp "==" (==)
, cmp "/=" (/=)
, cmp "<" (<)
--- a/lib/Data/Bits.hs
+++ b/lib/Data/Bits.hs
@@ -127,8 +127,10 @@
bitSize _ = _wordSize
bit = bitDefault
testBit = testBitDefault
- popCount = popCountDefault
+ popCount = primIntPopcount
zeroBits = 0
instance FiniteBits Int where
finiteBitSize _ = _wordSize
+ countLeadingZeros = primIntClz
+ countTrailingZeros = primIntCtz
--- a/lib/Data/Int/Instances.hs
+++ b/lib/Data/Int/Instances.hs
@@ -108,11 +108,13 @@
bitSize _ = 8
bit = bitDefault
testBit = testBitDefault
- popCount = popCountDefault
+ popCount (I8 x) = primIntPopcount (x .&. 0xff)
zeroBits = 0
instance FiniteBits Int8 where
finiteBitSize _ = 8
+ countLeadingZeros (I8 x) = primIntClz (x .&. 0xff) - (_wordSize - 8)
+ countTrailingZeros (I8 x) = if x == 0 then 8 else primIntCtz x
--------------------------------------------------------------------------------
---- Int16
@@ -202,11 +204,13 @@
bitSize _ = 16
bit = bitDefault
testBit = testBitDefault
- popCount = popCountDefault
+ popCount (I16 x) = primIntPopcount (x .&. 0xffff)
zeroBits = 0
instance FiniteBits Int16 where
finiteBitSize _ = 16
+ countLeadingZeros (I16 x) = primIntClz (x .&. 0xffff) - (_wordSize - 16)
+ countTrailingZeros (I16 x) = if x == 0 then 16 else primIntCtz x
--------------------------------------------------------------------------------
---- Int32
@@ -296,11 +300,13 @@
bitSize _ = 32
bit = bitDefault
testBit = testBitDefault
- popCount = popCountDefault
+ popCount (I32 x) = primIntPopcount (x .&. 0xffffffff)
zeroBits = 0
instance FiniteBits Int32 where
finiteBitSize _ = 32
+ countLeadingZeros (I32 x) = primIntClz (x .&. 0xffffffff) - (_wordSize - 32)
+ countTrailingZeros (I32 x) = if x == 0 then 32 else primIntCtz x
--------------------------------------------------------------------------------
---- Int64
@@ -389,8 +395,10 @@
bitSize _ = 64
bit = bitDefault
testBit = testBitDefault
- popCount = popCountDefault
+ popCount (I64 x) = primIntPopcount x
zeroBits = 0
instance FiniteBits Int64 where
finiteBitSize _ = 64
+ countLeadingZeros = primIntClz . unI64
+ countTrailingZeros = primIntCtz . unI64
--- a/lib/Data/Word.hs
+++ b/lib/Data/Word.hs
@@ -97,11 +97,13 @@
bitSize _ = _wordSize
bit = bitDefault
testBit = testBitDefault
- popCount = popCountDefault
+ popCount = primWordPopcount
zeroBits = 0
instance FiniteBits Word where
finiteBitSize _ = _wordSize
+ countLeadingZeros = primWordClz
+ countTrailingZeros = primWordCtz
--------------------------------------------------------------------------------
---- Word8
@@ -131,7 +133,7 @@
(*) = bin8 primWordMul
abs x = x
signum x = if x == 0 then 0 else 1
- fromInteger i = w8 (primIntToWord (_integerToInt i))
+ fromInteger i = w8 (_integerToWord i)
instance Integral Word8 where
quot = bin8 primWordQuot
@@ -193,11 +195,13 @@
bitSize _ = 8
bit = bitDefault
testBit = testBitDefault
- popCount = popCountDefault
+ popCount = primWordPopcount . unW8
zeroBits = 0
instance FiniteBits Word8 where
finiteBitSize _ = 8
+ countLeadingZeros (W8 x) = primWordClz x - (_wordSize - 8)
+ countTrailingZeros (W8 x) = if x == 0 then 8 else primWordCtz x
--------------------------------------------------------------------------------
---- Word16
@@ -227,7 +231,7 @@
(*) = bin16 primWordMul
abs x = x
signum x = if x == 0 then 0 else 1
- fromInteger i = w16 (primIntToWord (_integerToInt i))
+ fromInteger i = w16 (_integerToWord i)
instance Integral Word16 where
quot = bin16 primWordQuot
@@ -288,11 +292,13 @@
bitSize _ = 16
bit = bitDefault
testBit = testBitDefault
- popCount = popCountDefault
+ popCount = primWordPopcount . unW16
zeroBits = 0
instance FiniteBits Word16 where
finiteBitSize _ = 16
+ countLeadingZeros (W16 x) = primWordClz x - (_wordSize - 16)
+ countTrailingZeros (W16 x) = if x == 0 then 16 else primWordCtz x
--------------------------------------------------------------------------------
---- Word32
@@ -322,7 +328,7 @@
(*) = bin32 primWordMul
abs x = x
signum x = if x == 0 then 0 else 1
- fromInteger i = w32 (primIntToWord (_integerToInt i))
+ fromInteger i = w32 (_integerToWord i)
instance Integral Word32 where
quot = bin32 primWordQuot
@@ -384,11 +390,13 @@
bitSize _ = 32
bit = bitDefault
testBit = testBitDefault
- popCount = popCountDefault
+ popCount = primWordPopcount . unW32
zeroBits = 0
instance FiniteBits Word32 where
finiteBitSize _ = 32
+ countLeadingZeros (W32 x) = primWordClz x - (_wordSize - 32)
+ countTrailingZeros (W32 x) = if x == 0 then 32 else primWordCtz x
--------------------------------------------------------------------------------
---- Word64
@@ -418,7 +426,7 @@
(*) = bin64 primWordMul
abs x = x
signum x = if x == 0 then 0 else 1
- fromInteger i = w64 (primIntToWord (_integerToInt i))
+ fromInteger i = w64 (_integerToWord i)
instance Integral Word64 where
quot = bin64 primWordQuot
@@ -480,8 +488,10 @@
bitSize _ = 64
bit = bitDefault
testBit = testBitDefault
- popCount = popCountDefault
+ popCount = primWordPopcount . unW64
zeroBits = 0
instance FiniteBits Word64 where
finiteBitSize _ = 64
+ countLeadingZeros = primWordClz . unW64
+ countTrailingZeros = primWordCtz . unW64
--- a/lib/Primitives.hs
+++ b/lib/Primitives.hs
@@ -119,6 +119,12 @@
primWordAshr = primitive "ashr"
primWordInv :: Word -> Word
primWordInv = primitive "inv"
+primWordPopcount :: Word -> Int
+primWordPopcount = primitive "popcount"
+primWordClz :: Word -> Int
+primWordClz = primitive "clz"
+primWordCtz :: Word -> Int
+primWordCtz = primitive "ctz"
primWordToFloatWRaw :: Word -> FloatW
primWordToFloatWRaw = primitive "toDbl"
primWordFromFloatWRaw :: FloatW -> Word
@@ -136,6 +142,12 @@
primIntShr = primitive "ashr"
primIntInv :: Int -> Int
primIntInv = primitive "inv"
+primIntPopcount :: Int -> Int
+primIntPopcount = primitive "popcount"
+primIntClz :: Int -> Int
+primIntClz = primitive "clz"
+primIntCtz :: Int -> Int
+primIntCtz = primitive "ctz"
primWordEQ :: Word -> Word -> Bool
primWordEQ = primitive "=="
--- a/src/MicroHs/Translate.hs
+++ b/src/MicroHs/Translate.hs
@@ -79,6 +79,9 @@
("shr", primitive "shr"),
("ashr", primitive "ashr"),
("subtract", primitive "subtract"),
+ ("popcount", primitive "popcount"),
+ ("clz", primitive "clz"),
+ ("ctz", primitive "ctz"),
("==", primitive "=="),
("/=", primitive "/="),
("<", primitive "<"),
--- a/src/runtime/config-mingw-64.h
+++ b/src/runtime/config-mingw-64.h
@@ -43,6 +43,22 @@
*/
#define FFS __builtin_ffsll
+#define POPCOUNT __builtin_popcountll
+
+#include <inttypes.h>
+
+static inline uint64_t clz(uint64_t x) {
+ if (x == 0) return 64;
+ return __builtin_clzll(x);
+}
+#define CLZ clz
+
+static inline uint64_t ctz(uint64_t x) {
+ if (x == 0) return 64;
+ return __builtin_ctzll(x);
+}
+#define CTZ ctz
+
/*
* This is the character used for comma-separation in printf.
* Defaults to "'".
--- a/src/runtime/config-stm32f4.h
+++ b/src/runtime/config-stm32f4.h
@@ -142,6 +142,8 @@
}
#define FFS ffs
+#define CLZ __CLZ
+
#define FFI_EXTRA \
{ "set_led", (funptr_t)set_led, FFI_IIV }, \
{ "busy_wait", (funptr_t)busy_wait, FFI_IV },
--- a/src/runtime/config-windows-64.h
+++ b/src/runtime/config-windows-64.h
@@ -62,6 +62,32 @@
}
#define FFS ffs
+#if defined(_M_X64)
+#define POPCOUNT __popcnt64
+#elif defined(_M_IX86)
+#define POPCOUNT __popcnt
+#endif
+
+static inline uint64_t clz(uint64_t x) {
+ unsigned long count;
+ if (_BitScanReverse64(&count, x)) {
+ return 63 - (uint64_t)count;
+ } else {
+ return 64;
+ }
+}
+#define CLZ clz
+
+static inline uint64_t ctz(uint64_t x) {
+ unsigned long count;
+ if (_BitScanForward64(&count, x)) {
+ return (uint64_t)count;
+ } else {
+ return 64;
+ }
+}
+#define CTZ ctz
+
/*
* This is the character used for comma-separation in printf.
* Defaults to "'".
--- a/src/runtime/eval.c
+++ b/src/runtime/eval.c
@@ -127,6 +127,60 @@
}
#endif /* !defined(FFS) */
+#if !defined(POPCOUNT)
+uvalue_t POPCOUNT(uvalue_t x) {
+#if defined(__GNUC__)
+ return __builtin_popcountl(x);
+#elif defined(__clang__) && __has_builtin(__builtin_popcountl)
+ return __builtin_popcountl(x);
+#else
+ uvalue_t count = 0;
+ while (x) {
+ x = x & (x - 1); // clear lowest 1 bit
+ count += 1;
+ }
+ return count;
+#endif
+}
+#endif
+
+#if !defined(CLZ)
+uvalue_t CLZ(uvalue_t x) {
+#if defined(__GNUC__)
+ if (x == 0) return WORD_SIZE;
+ return __builtin_clzl(x);
+#elif defined(__clang__) && __has_builtin(__builtin_clzl)
+ if (x == 0) return WORD_SIZE;
+ return __builtin_clzl(x);
+#else
+ value_t count = WORD_SIZE;
+ while (x) {
+ x = x >> 1;
+ count -= 1;
+ }
+ return count;
+#endif
+}
+#endif
+
+#if !defined(CTZ)
+uvalue_t CTZ(uvalue_t x) {
+ if (x == 0) return WORD_SIZE;
+#if defined(__GNUC__)
+ return __builtin_ctzl(x);
+#elif defined(__clang__) && __has_builtin(__builtin_ctzl)
+ return __builtin_ctzl(x);
+#else
+ uvalue_t count = 0;
+ while ((x & 1) == 0) {
+ x = x >> 1;
+ count += 1;
+ }
+ return count;
+#endif
+}
+#endif
+
#if !defined(WANT_ARGS)
#define WANT_ARGS 1
#endif
@@ -183,6 +237,7 @@
T_K2, T_K3, T_K4, T_CCB,
T_ADD, T_SUB, T_MUL, T_QUOT, T_REM, T_SUBR, T_UQUOT, T_UREM, T_NEG,
T_AND, T_OR, T_XOR, T_INV, T_SHL, T_SHR, T_ASHR,
+ T_POPCOUNT, T_CLZ, T_CTZ,
T_EQ, T_NE, T_LT, T_LE, T_GT, T_GE, T_ULT, T_ULE, T_UGT, T_UGE, T_ICMP, T_UCMP,
T_FPADD, T_FP2P, T_FPNEW, T_FPFIN, // T_FPSTR,
T_TOPTR, T_TOINT, T_TODBL, T_TOFUNPTR,
@@ -217,6 +272,7 @@
"K2", "K3", "K4", "CCB",
"ADD", "SUB", "MUL", "QUOT", "REM", "SUBR", "UQUOT", "UREM", "NEG",
"AND", "OR", "XOR", "INV", "SHL", "SHR", "ASHR",
+ "POPCOUNT", "CLZ", "CTZ",
"EQ", "NE", "LT", "LE", "GT", "GE", "ULT", "ULE", "UGT", "UGE",
"FPADD", "FP2P", "FPNEW", "FPFIN",
"TOPTR", "TOINT", "TODBL", "TOFUNPTR",
@@ -676,6 +732,9 @@
{ "shl", T_SHL },
{ "shr", T_SHR },
{ "ashr", T_ASHR },
+ { "popcount", T_POPCOUNT },
+ { "clz", T_CLZ },
+ { "ctz", T_CTZ },
#if WANT_FLOAT
{ "f+" , T_FADD, T_FADD},
{ "f-" , T_FSUB, T_FSUB},
@@ -2155,6 +2214,9 @@
case T_SHL: putsb("shl", f); break;
case T_SHR: putsb("shr", f); break;
case T_ASHR: putsb("ashr", f); break;
+ case T_POPCOUNT: putsb("popcount", f); break;
+ case T_CLZ: putsb("clz", f); break;
+ case T_CTZ: putsb("ctz", f); break;
#if WANT_FLOAT
case T_FADD: putsb("f+", f); break;
case T_FSUB: putsb("f-", f); break;
@@ -3198,6 +3260,9 @@
goto top;
case T_NEG:
case T_INV:
+ case T_POPCOUNT:
+ case T_CLZ:
+ case T_CTZ:
CHECK(1);
n = ARG(TOP(0));
PUSH(combUNINT1);
@@ -3568,9 +3633,12 @@
n = TOP(-1);
unint:
switch (GETTAG(p)) {
- case T_IND: p = INDIR(p); goto unint;
- case T_NEG: ru = -xu; break;
- case T_INV: ru = ~xu; break;
+ case T_IND: p = INDIR(p); goto unint;
+ case T_NEG: ru = -xu; break;
+ case T_INV: ru = ~xu; break;
+ case T_POPCOUNT: ru = POPCOUNT(xu); break;
+ case T_CLZ: ru = CLZ(xu); break;
+ case T_CTZ: ru = CTZ(xu); break;
default:
//fprintf(stderr, "tag=%d\n", GETTAG(FUN(TOP(0))));
ERR("UNINT");
--- /dev/null
+++ b/tests/BitCount.hs
@@ -1,0 +1,149 @@
+module BitCount where
+
+import Data.Bits
+import Data.Int
+import Data.Word
+
+main :: IO ()
+main = do
+ -- popcount
+ print $ popCount (0 :: Word8)
+ print $ popCount (42 :: Word8)
+ print $ popCount (64 :: Word8)
+ print $ popCount (maxBound :: Word8)
+ print $ popCount (0 :: Word16)
+ print $ popCount (42 :: Word16)
+ print $ popCount (64 :: Word16)
+ print $ popCount (maxBound :: Word16)
+ print $ popCount (0 :: Word32)
+ print $ popCount (42 :: Word32)
+ print $ popCount (64 :: Word32)
+ print $ popCount (maxBound :: Word32)
+ print $ popCount (0 :: Word)
+ print $ popCount (42 :: Word)
+ print $ popCount (64 :: Word)
+ print $ popCount (maxBound :: Word) == _wordSize
+ print $ popCount (0 :: Int8)
+ print $ popCount (42 :: Int8)
+ print $ popCount (64 :: Int8)
+ print $ popCount (-1 :: Int8)
+ print $ popCount (-42 :: Int8)
+ print $ popCount (minBound :: Int8)
+ print $ popCount (maxBound :: Int8)
+ print $ popCount (0 :: Int16)
+ print $ popCount (42 :: Int16)
+ print $ popCount (64 :: Int16)
+ print $ popCount (-1 :: Int16)
+ print $ popCount (-42 :: Int16)
+ print $ popCount (minBound :: Int16)
+ print $ popCount (maxBound :: Int16)
+ print $ popCount (0 :: Int32)
+ print $ popCount (42 :: Int32)
+ print $ popCount (64 :: Int32)
+ print $ popCount (-1 :: Int32)
+ print $ popCount (-42 :: Int32)
+ print $ popCount (minBound :: Int32)
+ print $ popCount (maxBound :: Int32)
+ print $ popCount (0 :: Int)
+ print $ popCount (42 :: Int)
+ print $ popCount (64 :: Int)
+ print $ popCount (-1 :: Int) == _wordSize
+ print $ popCount (-42 :: Int) == _wordSize - 3
+ print $ popCount (minBound :: Int)
+ print $ popCount (maxBound :: Int) == _wordSize - 1
+
+ putStrLn ""
+
+ -- clz
+ print $ countLeadingZeros (0 :: Word8)
+ print $ countLeadingZeros (42 :: Word8)
+ print $ countLeadingZeros (64 :: Word8)
+ print $ countLeadingZeros (maxBound :: Word8)
+ print $ countLeadingZeros (0 :: Word16)
+ print $ countLeadingZeros (42 :: Word16)
+ print $ countLeadingZeros (64 :: Word16)
+ print $ countLeadingZeros (maxBound :: Word16)
+ print $ countLeadingZeros (0 :: Word32)
+ print $ countLeadingZeros (42 :: Word32)
+ print $ countLeadingZeros (64 :: Word32)
+ print $ countLeadingZeros (maxBound :: Word32)
+ print $ countLeadingZeros (0 :: Word) == _wordSize
+ print $ countLeadingZeros (42 :: Word) == _wordSize - 6
+ print $ countLeadingZeros (64 :: Word) == _wordSize - 7
+ print $ countLeadingZeros (maxBound :: Word)
+ print $ countLeadingZeros (0 :: Int8)
+ print $ countLeadingZeros (42 :: Int8)
+ print $ countLeadingZeros (64 :: Int8)
+ print $ countLeadingZeros (-1 :: Int8)
+ print $ countLeadingZeros (-42 :: Int8)
+ print $ countLeadingZeros (minBound :: Int8)
+ print $ countLeadingZeros (maxBound :: Int8)
+ print $ countLeadingZeros (0 :: Int16)
+ print $ countLeadingZeros (42 :: Int16)
+ print $ countLeadingZeros (64 :: Int16)
+ print $ countLeadingZeros (-1 :: Int16)
+ print $ countLeadingZeros (-42 :: Int16)
+ print $ countLeadingZeros (minBound :: Int16)
+ print $ countLeadingZeros (maxBound :: Int16)
+ print $ countLeadingZeros (0 :: Int32)
+ print $ countLeadingZeros (42 :: Int32)
+ print $ countLeadingZeros (64 :: Int32)
+ print $ countLeadingZeros (-1 :: Int32)
+ print $ countLeadingZeros (-42 :: Int32)
+ print $ countLeadingZeros (minBound :: Int32)
+ print $ countLeadingZeros (maxBound :: Int32)
+ print $ countLeadingZeros (0 :: Int) == _wordSize
+ print $ countLeadingZeros (42 :: Int) == _wordSize - 6
+ print $ countLeadingZeros (64 :: Int) == _wordSize - 7
+ print $ countLeadingZeros (-1 :: Int)
+ print $ countLeadingZeros (-42 :: Int)
+ print $ countLeadingZeros (minBound :: Int)
+ print $ countLeadingZeros (maxBound :: Int)
+
+ putStrLn ""
+
+ -- ctz
+ print $ countTrailingZeros (0 :: Word8)
+ print $ countTrailingZeros (42 :: Word8)
+ print $ countTrailingZeros (64 :: Word8)
+ print $ countTrailingZeros (maxBound :: Word8)
+ print $ countTrailingZeros (0 :: Word16)
+ print $ countTrailingZeros (42 :: Word16)
+ print $ countTrailingZeros (64 :: Word16)
+ print $ countTrailingZeros (maxBound :: Word16)
+ print $ countTrailingZeros (0 :: Word32)
+ print $ countTrailingZeros (42 :: Word32)
+ print $ countTrailingZeros (64 :: Word32)
+ print $ countTrailingZeros (maxBound :: Word32)
+ print $ countTrailingZeros (0 :: Word) == _wordSize
+ print $ countTrailingZeros (42 :: Word)
+ print $ countTrailingZeros (64 :: Word)
+ print $ countTrailingZeros (maxBound :: Word)
+ print $ countTrailingZeros (0 :: Int8)
+ print $ countTrailingZeros (42 :: Int8)
+ print $ countTrailingZeros (64 :: Int8)
+ print $ countTrailingZeros (-1 :: Int8)
+ print $ countTrailingZeros (-42 :: Int8)
+ print $ countTrailingZeros (minBound :: Int8)
+ print $ countTrailingZeros (maxBound :: Int8)
+ print $ countTrailingZeros (0 :: Int16)
+ print $ countTrailingZeros (42 :: Int16)
+ print $ countTrailingZeros (64 :: Int16)
+ print $ countTrailingZeros (-1 :: Int16)
+ print $ countTrailingZeros (-42 :: Int16)
+ print $ countTrailingZeros (minBound :: Int16)
+ print $ countTrailingZeros (maxBound :: Int16)
+ print $ countTrailingZeros (0 :: Int32)
+ print $ countTrailingZeros (42 :: Int32)
+ print $ countTrailingZeros (64 :: Int32)
+ print $ countTrailingZeros (-1 :: Int32)
+ print $ countTrailingZeros (-42 :: Int32)
+ print $ countTrailingZeros (minBound :: Int32)
+ print $ countTrailingZeros (maxBound :: Int32)
+ print $ countTrailingZeros (0 :: Int) == _wordSize
+ print $ countTrailingZeros (42 :: Int)
+ print $ countTrailingZeros (64 :: Int)
+ print $ countTrailingZeros (-1 :: Int)
+ print $ countTrailingZeros (-42 :: Int)
+ print $ countTrailingZeros (minBound :: Int) == _wordSize - 1
+ print $ countTrailingZeros (maxBound :: Int)
--- /dev/null
+++ b/tests/BitCount.ref
@@ -1,0 +1,134 @@
+0
+3
+1
+8
+0
+3
+1
+16
+0
+3
+1
+32
+0
+3
+1
+True
+0
+3
+1
+8
+5
+1
+7
+0
+3
+1
+16
+13
+1
+15
+0
+3
+1
+32
+29
+1
+31
+0
+3
+1
+True
+True
+1
+True
+
+8
+2
+1
+0
+16
+10
+9
+0
+32
+26
+25
+0
+True
+True
+True
+0
+8
+2
+1
+0
+0
+0
+1
+16
+10
+9
+0
+0
+0
+1
+32
+26
+25
+0
+0
+0
+1
+True
+True
+True
+0
+0
+0
+1
+
+8
+1
+6
+0
+16
+1
+6
+0
+32
+1
+6
+0
+True
+1
+6
+0
+8
+1
+6
+0
+1
+7
+0
+16
+1
+6
+0
+1
+15
+0
+32
+1
+6
+0
+1
+31
+0
+True
+1
+6
+0
+1
+True
+0