shithub: MicroHs

Download patch

ref: 0c1958c69aacd03ce364b048176026bad28a8334
parent: cfd47722e3147d6ac70239ac5209947a12cd21a5
parent: e28034d6111403de48d7141c563b2e84eca7c20f
author: Lennart Augustsson <lennart@augustsson.net>
date: Wed Jan 15 22:23:20 EST 2025

Merge pull request #90 from konsumlamm/BitsInteger

Implement `Bits Integer` instance

--- a/lib/Data/Bits.hs
+++ b/lib/Data/Bits.hs
@@ -93,6 +93,18 @@
 
       w = finiteBitSize x
 
+bitDefault :: (Bits a, Num a) => Int -> a
+bitDefault i = 1 `shiftL` i
+
+testBitDefault :: (Bits a, Num a) => a -> Int -> Bool
+testBitDefault x i = (x .&. bit i) /= 0
+
+popCountDefault :: (Bits a, Num a) => a -> Int
+popCountDefault = go 0
+  where
+    go c 0 = c
+    go c w = go (c + 1) (w .&. (w - 1)) -- clear the least significant bit
+
 _overflowError :: a
 _overflowError = error "arithmetic overflow"
 
@@ -113,7 +125,9 @@
   unsafeShiftR = primIntShr
   bitSizeMaybe _ = Just _wordSize
   bitSize _ = _wordSize
-  bit n = primIntShl 1 n
+  bit = bitDefault
+  testBit = testBitDefault
+  popCount = popCountDefault
   zeroBits = 0
 
 instance FiniteBits Int where
--- a/lib/Data/Int/Instances.hs
+++ b/lib/Data/Int/Instances.hs
@@ -106,7 +106,9 @@
   unsafeShiftR = bini8 primIntShr
   bitSizeMaybe _ = Just 8
   bitSize _ = 8
-  bit n = i8 (primIntShl 1 n)
+  bit = bitDefault
+  testBit = testBitDefault
+  popCount = popCountDefault
   zeroBits = 0
 
 instance FiniteBits Int8 where
@@ -198,7 +200,9 @@
   unsafeShiftR = bini16 primIntShr
   bitSizeMaybe _ = Just 16
   bitSize _ = 16
-  bit n = i16 (primIntShl 1 n)
+  bit = bitDefault
+  testBit = testBitDefault
+  popCount = popCountDefault
   zeroBits = 0
 
 instance FiniteBits Int16 where
@@ -290,7 +294,9 @@
   unsafeShiftR = bini32 primIntShr
   bitSizeMaybe _ = Just 32
   bitSize _ = 32
-  bit n = i32 (primIntShl 1 n)
+  bit = bitDefault
+  testBit = testBitDefault
+  popCount = popCountDefault
   zeroBits = 0
 
 instance FiniteBits Int32 where
@@ -381,7 +387,9 @@
   unsafeShiftR = bini64 primIntShr
   bitSizeMaybe _ = Just 64
   bitSize _ = 64
-  bit n = i64 (primIntShl 1 n)
+  bit = bitDefault
+  testBit = testBitDefault
+  popCount = popCountDefault
   zeroBits = 0
 
 instance FiniteBits Int64 where
--- a/lib/Data/Integer.hs
+++ b/lib/Data/Integer.hs
@@ -14,6 +14,7 @@
 import Prelude()              -- do not import Prelude
 import Primitives
 import Control.Error
+import Data.Bits
 import Data.Bool
 import Data.Char
 import Data.Enum
@@ -99,8 +100,49 @@
   enumFromTo = numericEnumFromTo
   enumFromThenTo = numericEnumFromThenTo
 
-------------------------------------------------
-
+instance Bits Integer where
+  (.&.) = andI
+  (.|.) = orI
+  xor = xorI
+  complement x = negOneI - x -- -x = complement x + 1 => complement x = -1 - x
+  I sign ds `unsafeShiftL` i
+    | null ds = zeroI
+    | otherwise =
+      let (q, r) = quotRem i shiftD
+      in I sign (replicate q 0 ++ shiftLD ds r)
+  x `shiftL` i
+    | i < 0 = _overflowError
+    | otherwise = x `unsafeShiftL` i
+  I sign ds `unsafeShiftR` i
+    | null ds = zeroI
+    | otherwise =
+      let
+        (q, r) = quotRem i shiftD
+        (rs, ds') = splitAt q ds
+        (ds'', shiftedOut1s) = shiftRD ds' r
+      in case sign of
+        Minus | shiftedOut1s || any (/= 0) rs -> I sign (add1 ds'')
+        _ -> I sign ds''
+  x `shiftR` i
+    | i < 0 = _overflowError
+    | otherwise = x `unsafeShiftR` i
+  x `shift` i
+    | i < 0 = x `unsafeShiftR` (-i)
+    | i > 0 = x `unsafeShiftL` i
+    | otherwise = x
+  rotate = shift
+  bit i = oneI `shiftL` i
+  testBit = testBitI
+  zeroBits = zeroI
+  bitSizeMaybe _ = Nothing
+  popCount (I sign ds) =
+    let count = sum (map popCount ds)
+    in case sign of
+      Plus -> count
+      Minus -> -count
+
+------------------------------------------------
+
 isZero :: Integer -> Bool
 isZero (I _ ds) = null ds
 
@@ -242,7 +284,7 @@
     qr ci []     res = (res, [ci])
     qr ci (x:xs) res = qr r xs (q:res)
       where
-        cx = ci * maxD + x
+        cx = ci `unsafeShiftL` shiftD + x
         q = quot cx y
         r = rem cx y
 
@@ -328,6 +370,91 @@
 _intListToInteger :: [Int] -> Integer
 _intListToInteger ads@(x : ds) = if x == -1 then - f ds else f ads
   where f = foldr (\ d a -> a * integerListBase + toInteger d) 0
+
+---------------------------------
+
+andI :: Integer -> Integer -> Integer
+andI (I Plus  xs) (I Plus  ys) = bI Plus  (andDigits xs ys)
+andI (I Plus  xs) (I Minus ys) = bI Plus  (andNotDigits (sub1 ys) xs)
+andI (I Minus xs) (I Plus  ys) = bI Plus  (andNotDigits (sub1 xs) ys)
+andI (I Minus xs) (I Minus ys) = bI Minus (orDigits (sub1 xs) (sub1 ys))
+
+orI :: Integer -> Integer -> Integer
+orI (I Plus  xs) (I Plus  ys) = bI Plus  (orDigits xs ys)
+orI (I Plus  xs) (I Minus ys) = bI Minus (andNotDigits xs (sub1 ys))
+orI (I Minus xs) (I Plus  ys) = bI Minus (andNotDigits ys (sub1 xs))
+orI (I Minus xs) (I Minus ys) = bI Minus (andDigits (sub1 xs) (sub1 ys))
+
+xorI :: Integer -> Integer -> Integer
+xorI (I Plus  xs) (I Plus  ys) = bI Plus  (xorDigits xs ys)
+xorI (I Plus  xs) (I Minus ys) = bI Minus (xorDigits xs (sub1 ys))
+xorI (I Minus xs) (I Plus  ys) = bI Minus (xorDigits (sub1 xs) ys)
+xorI (I Minus xs) (I Minus ys) = bI Plus  (xorDigits (sub1 xs) (sub1 ys))
+
+bI :: Sign -> [Digit] -> Integer
+bI Plus  ds = sI Plus  ds
+bI Minus ds = sI Minus (add1 ds)
+
+add1 :: [Digit] -> [Digit]
+add1 ds = add ds [1]
+
+sub1 :: [Digit] -> [Digit]
+sub1 ds = sub ds [1]
+
+andDigits :: [Digit] -> [Digit] -> [Digit]
+andDigits (x : xs) (y : ys) = (x .&. y) : andDigits xs ys
+andDigits _        _        = []
+
+andNotDigits :: [Digit] -> [Digit] -> [Digit]
+andNotDigits []       []       = []
+andNotDigits []       ys       = ys
+andNotDigits xs       []       = []
+andNotDigits (x : xs) (y : ys) = (complement x .&. y) : andNotDigits xs ys
+
+orDigits :: [Digit] -> [Digit] -> [Digit]
+orDigits []       []       = []
+orDigits []       ys       = ys
+orDigits xs       []       = xs
+orDigits (x : xs) (y : ys) = (x .|. y) : orDigits xs ys
+
+xorDigits :: [Digit] -> [Digit] -> [Digit]
+xorDigits []       []       = []
+xorDigits []       ys       = ys
+xorDigits xs       []       = xs
+xorDigits (x : xs) (y : ys) = (x `xor` y) : xorDigits xs ys
+
+shiftLD :: [Digit] -> Int -> [Digit]
+shiftLD ds 0 = ds
+shiftLD ds i = go 0 ds
+  where
+    go ci [] = if ci == 0 then [] else [ci]
+    go ci (d : ds) =
+      let
+        x = (d `unsafeShiftL` i) .|. ci
+        co = quotMaxD x
+        s = remMaxD x
+      in s : go co ds
+
+shiftRD :: [Digit] -> Int -> ([Digit], Bool)
+shiftRD ds 0 = (ds, False)
+shiftRD ds i =
+  let (rs, ds') = splitAt 1 (shiftLD ds (shiftD - i))
+  in (ds', any (/= 0) rs)
+
+testBitI :: Integer -> Int -> Bool
+testBitI (I Plus  ds) i =
+  case ds !? q of
+    Just d -> testBit d r
+    Nothing -> False
+  where (q, r) = quotRem i shiftD
+testBitI (I Minus ds) i =
+  -- not (testBitI (complement (I Minus ds)) i)
+  case ds !? q of
+    Just d ->
+      let d' = if all (== 0) (take q ds) then d - 1 else d
+      in not (testBit d' r)
+    Nothing -> True
+  where (q, r) = quotRem i shiftD
 
 ---------------------------------
 {-
--- a/lib/Data/Integer_Type.hs
+++ b/lib/Data/Integer_Type.hs
@@ -14,19 +14,22 @@
 type Digit = Word
 
 maxD :: Digit
-maxD =
+maxD = 1 `primWordShl` shiftD
+
+shiftD :: Int
+shiftD =
   if _wordSize `primIntEQ` 64 then
-    (4294967296 :: Word) -- 2^32, this is used so multiplication of two digits doesn't overflow a 64 bit Word
+    (32 :: Int) -- this is used so multiplication of two digits doesn't overflow a 64 bit Word
   else if _wordSize `primIntEQ` 32 then
-    (65536 :: Word)      -- 2^16, this is used so multiplication of two digits doesn't overflow a 32 bit Word
+    (16 :: Int) -- this is used so multiplication of two digits doesn't overflow a 32 bit Word
   else
     error "Integer: unsupported word size"
 
 quotMaxD :: Digit -> Digit
-quotMaxD d = d `primWordQuot` maxD
+quotMaxD d = d `primWordShr` shiftD
 
 remMaxD :: Digit -> Digit
-remMaxD d = d `primWordRem` maxD
+remMaxD d = d `primWordAnd` (maxD `primWordSub` 1)
 
 -- Sadly, we also need a bunch of functions.
 
@@ -38,8 +41,8 @@
   where
     f sign i =
       let
-        high = i `primWordQuot` maxD
-        low = i `primWordRem` maxD
+        high = quotMaxD i
+        low = remMaxD i
       in if high `primWordEQ` 0 then I sign [low] else I sign [low, high]
 
 _integerToInt :: Integer -> Int
@@ -51,8 +54,8 @@
   | high `primWordEQ` 0 = I Plus [low]
   | True                = I Plus [low, high]
   where
-    high = i `primWordQuot` maxD
-    low = i `primWordRem` maxD
+    high = quotMaxD i
+    low = remMaxD i
 
 _integerToWord :: Integer -> Word
 _integerToWord (I sign ds) =
@@ -64,7 +67,7 @@
       case ds of
         []          -> 0 :: Word
         [d1]        -> d1
-        d1 : d2 : _ -> d1 `primWordAdd` (maxD `primWordMul` d2)
+        d1 : d2 : _ -> d1 `primWordAdd` (d2 `primWordShl` shiftD)
 
 _integerToFloatW :: Integer -> FloatW
 _integerToFloatW (I sign ds) = s `primFloatWMul` loop ds
--- a/lib/Data/Word.hs
+++ b/lib/Data/Word.hs
@@ -95,7 +95,9 @@
   unsafeShiftR = primWordShr
   bitSizeMaybe _ = Just _wordSize
   bitSize _ = _wordSize
-  bit n = primWordShl 1 n
+  bit = bitDefault
+  testBit = testBitDefault
+  popCount = popCountDefault
   zeroBits = 0
 
 instance FiniteBits Word where
@@ -189,7 +191,9 @@
   unsafeShiftR = bini8 primWordShr
   bitSizeMaybe _ = Just 8
   bitSize _ = 8
-  bit n = w8 (primWordShl 1 n)
+  bit = bitDefault
+  testBit = testBitDefault
+  popCount = popCountDefault
   zeroBits = 0
 
 instance FiniteBits Word8 where
@@ -282,7 +286,9 @@
   unsafeShiftR = bini16 primWordShr
   bitSizeMaybe _ = Just 16
   bitSize _ = 16
-  bit n = w16 (primWordShl 1 n)
+  bit = bitDefault
+  testBit = testBitDefault
+  popCount = popCountDefault
   zeroBits = 0
 
 instance FiniteBits Word16 where
@@ -376,7 +382,9 @@
   unsafeShiftR = bini32 primWordShr
   bitSizeMaybe _ = Just 32
   bitSize _ = 32
-  bit n = w32 (primWordShl 1 n)
+  bit = bitDefault
+  testBit = testBitDefault
+  popCount = popCountDefault
   zeroBits = 0
 
 instance FiniteBits Word32 where
@@ -470,7 +478,9 @@
   unsafeShiftR = bini64 primWordShr
   bitSizeMaybe _ = Just 64
   bitSize _ = 64
-  bit n = w64 (primWordShl 1 n)
+  bit = bitDefault
+  testBit = testBitDefault
+  popCount = popCountDefault
   zeroBits = 0
 
 instance FiniteBits Word64 where