shithub: MicroHs

Download patch

ref: 8681ea7f41031e95eba880baf934d743cb558758
parent: d870a2fe07edafdc829a76af55792551a8fae394
parent: 24e80ae9a4bfea0c43a594af91819e8052fa6423
author: Lennart Augustsson <lennart@augustsson.net>
date: Sun Jan 19 03:46:14 EST 2025

Merge pull request #94 from konsumlamm/bitcount

Add popcount, clz, ctz primitives

--- a/ghc/PrimTable.hs
+++ b/ghc/PrimTable.hs
@@ -62,6 +62,9 @@
   , arithwi "shl" shiftL
   , arithwi "shr" shiftR
   , arith "ashr" shiftR
+  , arithu "popcount" popCount
+  , arithu "clz" countLeadingZeros
+  , arithu "ctz" countTrailingZeros
   , cmp "==" (==)
   , cmp "/=" (/=)
   , cmp "<"  (<)
--- a/lib/Data/Bits.hs
+++ b/lib/Data/Bits.hs
@@ -127,8 +127,10 @@
   bitSize _ = _wordSize
   bit = bitDefault
   testBit = testBitDefault
-  popCount = popCountDefault
+  popCount = primIntPopcount
   zeroBits = 0
 
 instance FiniteBits Int where
   finiteBitSize _ = _wordSize
+  countLeadingZeros = primIntClz
+  countTrailingZeros = primIntCtz
--- a/lib/Data/Int/Instances.hs
+++ b/lib/Data/Int/Instances.hs
@@ -108,11 +108,13 @@
   bitSize _ = 8
   bit = bitDefault
   testBit = testBitDefault
-  popCount = popCountDefault
+  popCount (I8 x) = primIntPopcount (x .&. 0xff)
   zeroBits = 0
 
 instance FiniteBits Int8 where
   finiteBitSize _ = 8
+  countLeadingZeros (I8 x) = primIntClz (x .&. 0xff) - (_wordSize - 8)
+  countTrailingZeros (I8 x) = if x == 0 then 8 else primIntCtz x
 
 --------------------------------------------------------------------------------
 ----    Int16
@@ -202,11 +204,13 @@
   bitSize _ = 16
   bit = bitDefault
   testBit = testBitDefault
-  popCount = popCountDefault
+  popCount (I16 x) = primIntPopcount (x .&. 0xffff)
   zeroBits = 0
 
 instance FiniteBits Int16 where
   finiteBitSize _ = 16
+  countLeadingZeros (I16 x) = primIntClz (x .&. 0xffff) - (_wordSize - 16)
+  countTrailingZeros (I16 x) = if x == 0 then 16 else primIntCtz x
 
 --------------------------------------------------------------------------------
 ----    Int32
@@ -296,11 +300,13 @@
   bitSize _ = 32
   bit = bitDefault
   testBit = testBitDefault
-  popCount = popCountDefault
+  popCount (I32 x) = primIntPopcount (x .&. 0xffffffff)
   zeroBits = 0
 
 instance FiniteBits Int32 where
   finiteBitSize _ = 32
+  countLeadingZeros (I32 x) = primIntClz (x .&. 0xffffffff) - (_wordSize - 32)
+  countTrailingZeros (I32 x) = if x == 0 then 32 else primIntCtz x
 
 --------------------------------------------------------------------------------
 ----    Int64
@@ -389,8 +395,10 @@
   bitSize _ = 64
   bit = bitDefault
   testBit = testBitDefault
-  popCount = popCountDefault
+  popCount (I64 x) = primIntPopcount x
   zeroBits = 0
 
 instance FiniteBits Int64 where
   finiteBitSize _ = 64
+  countLeadingZeros = primIntClz . unI64
+  countTrailingZeros = primIntCtz . unI64
--- a/lib/Data/Word.hs
+++ b/lib/Data/Word.hs
@@ -97,11 +97,13 @@
   bitSize _ = _wordSize
   bit = bitDefault
   testBit = testBitDefault
-  popCount = popCountDefault
+  popCount = primWordPopcount
   zeroBits = 0
 
 instance FiniteBits Word where
   finiteBitSize _ = _wordSize
+  countLeadingZeros = primWordClz
+  countTrailingZeros = primWordCtz
 
 --------------------------------------------------------------------------------
 ----    Word8
@@ -131,7 +133,7 @@
   (*)  = bin8 primWordMul
   abs x = x
   signum x = if x == 0 then 0 else 1
-  fromInteger i = w8 (primIntToWord (_integerToInt i))
+  fromInteger i = w8 (_integerToWord i)
 
 instance Integral Word8 where
   quot = bin8 primWordQuot
@@ -193,11 +195,13 @@
   bitSize _ = 8
   bit = bitDefault
   testBit = testBitDefault
-  popCount = popCountDefault
+  popCount = primWordPopcount . unW8
   zeroBits = 0
 
 instance FiniteBits Word8 where
   finiteBitSize _ = 8
+  countLeadingZeros (W8 x) = primWordClz x - (_wordSize - 8)
+  countTrailingZeros (W8 x) = if x == 0 then 8 else primWordCtz x
 
 --------------------------------------------------------------------------------
 ----    Word16
@@ -227,7 +231,7 @@
   (*)  = bin16 primWordMul
   abs x = x
   signum x = if x == 0 then 0 else 1
-  fromInteger i = w16 (primIntToWord (_integerToInt i))
+  fromInteger i = w16 (_integerToWord i)
 
 instance Integral Word16 where
   quot = bin16 primWordQuot
@@ -288,11 +292,13 @@
   bitSize _ = 16
   bit = bitDefault
   testBit = testBitDefault
-  popCount = popCountDefault
+  popCount = primWordPopcount . unW16
   zeroBits = 0
 
 instance FiniteBits Word16 where
   finiteBitSize _ = 16
+  countLeadingZeros (W16 x) = primWordClz x - (_wordSize - 16)
+  countTrailingZeros (W16 x) = if x == 0 then 16 else primWordCtz x
 
 --------------------------------------------------------------------------------
 ----    Word32
@@ -322,7 +328,7 @@
   (*)  = bin32 primWordMul
   abs x = x
   signum x = if x == 0 then 0 else 1
-  fromInteger i = w32 (primIntToWord (_integerToInt i))
+  fromInteger i = w32 (_integerToWord i)
 
 instance Integral Word32 where
   quot = bin32 primWordQuot
@@ -384,11 +390,13 @@
   bitSize _ = 32
   bit = bitDefault
   testBit = testBitDefault
-  popCount = popCountDefault
+  popCount = primWordPopcount . unW32
   zeroBits = 0
 
 instance FiniteBits Word32 where
   finiteBitSize _ = 32
+  countLeadingZeros (W32 x) = primWordClz x - (_wordSize - 32)
+  countTrailingZeros (W32 x) = if x == 0 then 32 else primWordCtz x
 
 --------------------------------------------------------------------------------
 ----    Word64
@@ -418,7 +426,7 @@
   (*)  = bin64 primWordMul
   abs x = x
   signum x = if x == 0 then 0 else 1
-  fromInteger i = w64 (primIntToWord (_integerToInt i))
+  fromInteger i = w64 (_integerToWord i)
 
 instance Integral Word64 where
   quot = bin64 primWordQuot
@@ -480,8 +488,10 @@
   bitSize _ = 64
   bit = bitDefault
   testBit = testBitDefault
-  popCount = popCountDefault
+  popCount = primWordPopcount . unW64
   zeroBits = 0
 
 instance FiniteBits Word64 where
   finiteBitSize _ = 64
+  countLeadingZeros = primWordClz . unW64
+  countTrailingZeros = primWordCtz . unW64
--- a/lib/Primitives.hs
+++ b/lib/Primitives.hs
@@ -119,6 +119,12 @@
 primWordAshr  = primitive "ashr"
 primWordInv :: Word -> Word
 primWordInv  = primitive "inv"
+primWordPopcount :: Word -> Int
+primWordPopcount = primitive "popcount"
+primWordClz :: Word -> Int
+primWordClz = primitive "clz"
+primWordCtz :: Word -> Int
+primWordCtz = primitive "ctz"
 primWordToFloatWRaw :: Word -> FloatW
 primWordToFloatWRaw = primitive "toDbl"
 primWordFromFloatWRaw :: FloatW -> Word
@@ -136,6 +142,12 @@
 primIntShr  = primitive "ashr"
 primIntInv :: Int -> Int
 primIntInv  = primitive "inv"
+primIntPopcount :: Int -> Int
+primIntPopcount = primitive "popcount"
+primIntClz :: Int -> Int
+primIntClz = primitive "clz"
+primIntCtz :: Int -> Int
+primIntCtz = primitive "ctz"
 
 primWordEQ  :: Word -> Word -> Bool
 primWordEQ  = primitive "=="
--- a/src/MicroHs/Translate.hs
+++ b/src/MicroHs/Translate.hs
@@ -79,6 +79,9 @@
   ("shr", primitive "shr"),
   ("ashr", primitive "ashr"),
   ("subtract", primitive "subtract"),
+  ("popcount", primitive "popcount"),
+  ("clz", primitive "clz"),
+  ("ctz", primitive "ctz"),
   ("==", primitive "=="),
   ("/=", primitive "/="),
   ("<", primitive "<"),
--- a/src/runtime/config-mingw-64.h
+++ b/src/runtime/config-mingw-64.h
@@ -43,6 +43,22 @@
  */
 #define FFS __builtin_ffsll
 
+#define POPCOUNT __builtin_popcountll
+
+#include <inttypes.h>
+
+static inline uint64_t clz(uint64_t x) {
+  if (x == 0) return 64;
+  return __builtin_clzll(x);
+}
+#define CLZ clz
+
+static inline uint64_t ctz(uint64_t x) {
+  if (x == 0) return 64;
+  return __builtin_ctzll(x);
+}
+#define CTZ ctz
+
 /*
  * This is the character used for comma-separation in printf.
  * Defaults to "'".
--- a/src/runtime/config-stm32f4.h
+++ b/src/runtime/config-stm32f4.h
@@ -142,6 +142,8 @@
 }
 #define FFS ffs
 
+#define CLZ __CLZ
+
 #define FFI_EXTRA \
   { "set_led",   (funptr_t)set_led,   FFI_IIV }, \
   { "busy_wait", (funptr_t)busy_wait, FFI_IV },
--- a/src/runtime/config-windows-64.h
+++ b/src/runtime/config-windows-64.h
@@ -62,6 +62,32 @@
 }
 #define FFS ffs
 
+#if defined(_M_X64)
+#define POPCOUNT __popcnt64
+#elif defined(_M_IX86)
+#define POPCOUNT __popcnt
+#endif
+
+static inline uint64_t clz(uint64_t x) {
+  unsigned long count;
+  if (_BitScanReverse64(&count, x)) {
+    return 63 - (uint64_t)count;
+  } else {
+    return 64;
+  }
+}
+#define CLZ clz
+
+static inline uint64_t ctz(uint64_t x) {
+  unsigned long count;
+  if (_BitScanForward64(&count, x)) {
+    return (uint64_t)count;
+  } else {
+    return 64;
+  }
+}
+#define CTZ ctz
+
 /*
  * This is the character used for comma-separation in printf.
  * Defaults to "'".
--- a/src/runtime/eval.c
+++ b/src/runtime/eval.c
@@ -127,6 +127,60 @@
 }
 #endif  /* !defined(FFS) */
 
+#if !defined(POPCOUNT)
+uvalue_t POPCOUNT(uvalue_t x) {
+#if defined(__GNUC__)
+  return __builtin_popcountl(x);
+#elif defined(__clang__) && __has_builtin(__builtin_popcountl)
+  return __builtin_popcountl(x);
+#else
+  uvalue_t count = 0;
+  while (x) {
+    x = x & (x - 1); // clear lowest 1 bit
+    count += 1;
+  }
+  return count;
+#endif
+}
+#endif
+
+#if !defined(CLZ)
+uvalue_t CLZ(uvalue_t x) {
+#if defined(__GNUC__)
+  if (x == 0) return WORD_SIZE;
+  return __builtin_clzl(x);
+#elif defined(__clang__) && __has_builtin(__builtin_clzl)
+  if (x == 0) return WORD_SIZE;
+  return __builtin_clzl(x);
+#else
+  value_t count = WORD_SIZE;
+  while (x) {
+    x = x >> 1;
+    count -= 1;
+  }
+  return count;
+#endif
+}
+#endif
+
+#if !defined(CTZ)
+uvalue_t CTZ(uvalue_t x) {
+  if (x == 0) return WORD_SIZE;
+#if defined(__GNUC__)
+  return __builtin_ctzl(x);
+#elif defined(__clang__) && __has_builtin(__builtin_ctzl)
+  return __builtin_ctzl(x);
+#else
+  uvalue_t count = 0;
+  while ((x & 1) == 0) {
+    x = x >> 1;
+    count += 1;
+  }
+  return count;
+#endif
+}
+#endif
+
 #if !defined(WANT_ARGS)
 #define WANT_ARGS 1
 #endif
@@ -183,6 +237,7 @@
                 T_K2, T_K3, T_K4, T_CCB,
                 T_ADD, T_SUB, T_MUL, T_QUOT, T_REM, T_SUBR, T_UQUOT, T_UREM, T_NEG,
                 T_AND, T_OR, T_XOR, T_INV, T_SHL, T_SHR, T_ASHR,
+                T_POPCOUNT, T_CLZ, T_CTZ,
                 T_EQ, T_NE, T_LT, T_LE, T_GT, T_GE, T_ULT, T_ULE, T_UGT, T_UGE, T_ICMP, T_UCMP,
                 T_FPADD, T_FP2P, T_FPNEW, T_FPFIN, // T_FPSTR,
                 T_TOPTR, T_TOINT, T_TODBL, T_TOFUNPTR,
@@ -217,6 +272,7 @@
   "K2", "K3", "K4", "CCB",
   "ADD", "SUB", "MUL", "QUOT", "REM", "SUBR", "UQUOT", "UREM", "NEG",
   "AND", "OR", "XOR", "INV", "SHL", "SHR", "ASHR",
+  "POPCOUNT", "CLZ", "CTZ",
   "EQ", "NE", "LT", "LE", "GT", "GE", "ULT", "ULE", "UGT", "UGE",
   "FPADD", "FP2P", "FPNEW", "FPFIN",
   "TOPTR", "TOINT", "TODBL", "TOFUNPTR",
@@ -676,6 +732,9 @@
   { "shl", T_SHL },
   { "shr", T_SHR },
   { "ashr", T_ASHR },
+  { "popcount", T_POPCOUNT },
+  { "clz", T_CLZ },
+  { "ctz", T_CTZ },
 #if WANT_FLOAT
   { "f+" , T_FADD, T_FADD},
   { "f-" , T_FSUB, T_FSUB},
@@ -2155,6 +2214,9 @@
   case T_SHL: putsb("shl", f); break;
   case T_SHR: putsb("shr", f); break;
   case T_ASHR: putsb("ashr", f); break;
+  case T_POPCOUNT: putsb("popcount", f); break;
+  case T_CLZ: putsb("clz", f); break;
+  case T_CTZ: putsb("ctz", f); break;
 #if WANT_FLOAT
   case T_FADD: putsb("f+", f); break;
   case T_FSUB: putsb("f-", f); break;
@@ -3198,6 +3260,9 @@
     goto top;
   case T_NEG:
   case T_INV:
+  case T_POPCOUNT:
+  case T_CLZ:
+  case T_CTZ:
     CHECK(1);
     n = ARG(TOP(0));
     PUSH(combUNINT1);
@@ -3568,9 +3633,12 @@
       n = TOP(-1);
     unint:
       switch (GETTAG(p)) {
-      case T_IND:   p = INDIR(p); goto unint;
-      case T_NEG:   ru = -xu; break;
-      case T_INV:   ru = ~xu; break;
+      case T_IND:      p = INDIR(p); goto unint;
+      case T_NEG:      ru = -xu; break;
+      case T_INV:      ru = ~xu; break;
+      case T_POPCOUNT: ru = POPCOUNT(xu); break;
+      case T_CLZ:      ru = CLZ(xu); break;
+      case T_CTZ:      ru = CTZ(xu); break;
       default:
         //fprintf(stderr, "tag=%d\n", GETTAG(FUN(TOP(0))));
         ERR("UNINT");
--- /dev/null
+++ b/tests/BitCount.hs
@@ -1,0 +1,149 @@
+module BitCount where
+
+import Data.Bits
+import Data.Int
+import Data.Word
+
+main :: IO ()
+main = do
+  -- popcount
+  print $ popCount (0 :: Word8)
+  print $ popCount (42 :: Word8)
+  print $ popCount (64 :: Word8)
+  print $ popCount (maxBound :: Word8)
+  print $ popCount (0 :: Word16)
+  print $ popCount (42 :: Word16)
+  print $ popCount (64 :: Word16)
+  print $ popCount (maxBound :: Word16)
+  print $ popCount (0 :: Word32)
+  print $ popCount (42 :: Word32)
+  print $ popCount (64 :: Word32)
+  print $ popCount (maxBound :: Word32)
+  print $ popCount (0 :: Word)
+  print $ popCount (42 :: Word)
+  print $ popCount (64 :: Word)
+  print $ popCount (maxBound :: Word) == _wordSize
+  print $ popCount (0 :: Int8)
+  print $ popCount (42 :: Int8)
+  print $ popCount (64 :: Int8)
+  print $ popCount (-1 :: Int8)
+  print $ popCount (-42 :: Int8)
+  print $ popCount (minBound :: Int8)
+  print $ popCount (maxBound :: Int8)
+  print $ popCount (0 :: Int16)
+  print $ popCount (42 :: Int16)
+  print $ popCount (64 :: Int16)
+  print $ popCount (-1 :: Int16)
+  print $ popCount (-42 :: Int16)
+  print $ popCount (minBound :: Int16)
+  print $ popCount (maxBound :: Int16)
+  print $ popCount (0 :: Int32)
+  print $ popCount (42 :: Int32)
+  print $ popCount (64 :: Int32)
+  print $ popCount (-1 :: Int32)
+  print $ popCount (-42 :: Int32)
+  print $ popCount (minBound :: Int32)
+  print $ popCount (maxBound :: Int32)
+  print $ popCount (0 :: Int)
+  print $ popCount (42 :: Int)
+  print $ popCount (64 :: Int)
+  print $ popCount (-1 :: Int) == _wordSize
+  print $ popCount (-42 :: Int) == _wordSize - 3
+  print $ popCount (minBound :: Int)
+  print $ popCount (maxBound :: Int) == _wordSize - 1
+
+  putStrLn ""
+
+  -- clz
+  print $ countLeadingZeros (0 :: Word8)
+  print $ countLeadingZeros (42 :: Word8)
+  print $ countLeadingZeros (64 :: Word8)
+  print $ countLeadingZeros (maxBound :: Word8)
+  print $ countLeadingZeros (0 :: Word16)
+  print $ countLeadingZeros (42 :: Word16)
+  print $ countLeadingZeros (64 :: Word16)
+  print $ countLeadingZeros (maxBound :: Word16)
+  print $ countLeadingZeros (0 :: Word32)
+  print $ countLeadingZeros (42 :: Word32)
+  print $ countLeadingZeros (64 :: Word32)
+  print $ countLeadingZeros (maxBound :: Word32)
+  print $ countLeadingZeros (0 :: Word) == _wordSize
+  print $ countLeadingZeros (42 :: Word) == _wordSize - 6
+  print $ countLeadingZeros (64 :: Word) == _wordSize - 7
+  print $ countLeadingZeros (maxBound :: Word)
+  print $ countLeadingZeros (0 :: Int8)
+  print $ countLeadingZeros (42 :: Int8)
+  print $ countLeadingZeros (64 :: Int8)
+  print $ countLeadingZeros (-1 :: Int8)
+  print $ countLeadingZeros (-42 :: Int8)
+  print $ countLeadingZeros (minBound :: Int8)
+  print $ countLeadingZeros (maxBound :: Int8)
+  print $ countLeadingZeros (0 :: Int16)
+  print $ countLeadingZeros (42 :: Int16)
+  print $ countLeadingZeros (64 :: Int16)
+  print $ countLeadingZeros (-1 :: Int16)
+  print $ countLeadingZeros (-42 :: Int16)
+  print $ countLeadingZeros (minBound :: Int16)
+  print $ countLeadingZeros (maxBound :: Int16)
+  print $ countLeadingZeros (0 :: Int32)
+  print $ countLeadingZeros (42 :: Int32)
+  print $ countLeadingZeros (64 :: Int32)
+  print $ countLeadingZeros (-1 :: Int32)
+  print $ countLeadingZeros (-42 :: Int32)
+  print $ countLeadingZeros (minBound :: Int32)
+  print $ countLeadingZeros (maxBound :: Int32)
+  print $ countLeadingZeros (0 :: Int) == _wordSize
+  print $ countLeadingZeros (42 :: Int) == _wordSize - 6
+  print $ countLeadingZeros (64 :: Int) == _wordSize - 7
+  print $ countLeadingZeros (-1 :: Int)
+  print $ countLeadingZeros (-42 :: Int)
+  print $ countLeadingZeros (minBound :: Int)
+  print $ countLeadingZeros (maxBound :: Int)
+
+  putStrLn ""
+
+  -- ctz
+  print $ countTrailingZeros (0 :: Word8)
+  print $ countTrailingZeros (42 :: Word8)
+  print $ countTrailingZeros (64 :: Word8)
+  print $ countTrailingZeros (maxBound :: Word8)
+  print $ countTrailingZeros (0 :: Word16)
+  print $ countTrailingZeros (42 :: Word16)
+  print $ countTrailingZeros (64 :: Word16)
+  print $ countTrailingZeros (maxBound :: Word16)
+  print $ countTrailingZeros (0 :: Word32)
+  print $ countTrailingZeros (42 :: Word32)
+  print $ countTrailingZeros (64 :: Word32)
+  print $ countTrailingZeros (maxBound :: Word32)
+  print $ countTrailingZeros (0 :: Word) == _wordSize
+  print $ countTrailingZeros (42 :: Word)
+  print $ countTrailingZeros (64 :: Word)
+  print $ countTrailingZeros (maxBound :: Word)
+  print $ countTrailingZeros (0 :: Int8)
+  print $ countTrailingZeros (42 :: Int8)
+  print $ countTrailingZeros (64 :: Int8)
+  print $ countTrailingZeros (-1 :: Int8)
+  print $ countTrailingZeros (-42 :: Int8)
+  print $ countTrailingZeros (minBound :: Int8)
+  print $ countTrailingZeros (maxBound :: Int8)
+  print $ countTrailingZeros (0 :: Int16)
+  print $ countTrailingZeros (42 :: Int16)
+  print $ countTrailingZeros (64 :: Int16)
+  print $ countTrailingZeros (-1 :: Int16)
+  print $ countTrailingZeros (-42 :: Int16)
+  print $ countTrailingZeros (minBound :: Int16)
+  print $ countTrailingZeros (maxBound :: Int16)
+  print $ countTrailingZeros (0 :: Int32)
+  print $ countTrailingZeros (42 :: Int32)
+  print $ countTrailingZeros (64 :: Int32)
+  print $ countTrailingZeros (-1 :: Int32)
+  print $ countTrailingZeros (-42 :: Int32)
+  print $ countTrailingZeros (minBound :: Int32)
+  print $ countTrailingZeros (maxBound :: Int32)
+  print $ countTrailingZeros (0 :: Int) == _wordSize
+  print $ countTrailingZeros (42 :: Int)
+  print $ countTrailingZeros (64 :: Int)
+  print $ countTrailingZeros (-1 :: Int)
+  print $ countTrailingZeros (-42 :: Int)
+  print $ countTrailingZeros (minBound :: Int) == _wordSize - 1
+  print $ countTrailingZeros (maxBound :: Int)
--- /dev/null
+++ b/tests/BitCount.ref
@@ -1,0 +1,134 @@
+0
+3
+1
+8
+0
+3
+1
+16
+0
+3
+1
+32
+0
+3
+1
+True
+0
+3
+1
+8
+5
+1
+7
+0
+3
+1
+16
+13
+1
+15
+0
+3
+1
+32
+29
+1
+31
+0
+3
+1
+True
+True
+1
+True
+
+8
+2
+1
+0
+16
+10
+9
+0
+32
+26
+25
+0
+True
+True
+True
+0
+8
+2
+1
+0
+0
+0
+1
+16
+10
+9
+0
+0
+0
+1
+32
+26
+25
+0
+0
+0
+1
+True
+True
+True
+0
+0
+0
+1
+
+8
+1
+6
+0
+16
+1
+6
+0
+32
+1
+6
+0
+True
+1
+6
+0
+8
+1
+6
+0
+1
+7
+0
+16
+1
+6
+0
+1
+15
+0
+32
+1
+6
+0
+1
+31
+0
+True
+1
+6
+0
+1
+True
+0