shithub: MicroHs

Download patch

ref: 7d146bd75d98bea509acfad35ee5515423892e47
parent: 2ec34c1ab7655acce61166bd267290b590e20a1a
parent: 0022e67752ae28fa0b99302cdf7f6adc08b5cda7
author: Lennart Augustsson <lennart@augustsson.net>
date: Sun Jan 12 07:02:43 EST 2025

Merge pull request #89 from konsumlamm/shift

Fix UB in shifts & add missing `Bits` instances

--- a/lib/Data/Bits.hs
+++ b/lib/Data/Bits.hs
@@ -93,13 +93,24 @@
 
       w = finiteBitSize x
 
+_overflowError :: a
+_overflowError = error "arithmetic overflow"
+
 instance Bits Int where
   (.&.) = primIntAnd
   (.|.) = primIntOr
   xor   = primIntXor
   complement = primIntInv
-  shiftL = primIntShl
-  shiftR = primIntShr
+  x `shiftL` i
+    | i < 0 = _overflowError
+    | i >= _wordSize = 0
+    | otherwise = x `primIntShl` i
+  unsafeShiftL = primIntShl
+  x `shiftR` i
+    | i < 0 = _overflowError
+    | i >= _wordSize = 0
+    | otherwise = x `primIntShr` i
+  unsafeShiftR = primIntShr
   bitSizeMaybe _ = Just _wordSize
   bitSize _ = _wordSize
   bit n = primIntShl 1 n
--- a/lib/Data/Int/Instances.hs
+++ b/lib/Data/Int/Instances.hs
@@ -12,6 +12,7 @@
 import Data.Integer_Type
 import Data.Integral
 import Data.List
+import Data.Maybe_Type
 import Data.Num
 import Data.Ord
 import Data.Ratio_Type
@@ -88,20 +89,28 @@
   (>)  = cmp8 primIntGT
   (>=) = cmp8 primIntGE
 
-{-
 instance Bits Int8 where
   (.&.) = bin8 primIntAnd
   (.|.) = bin8 primIntOr
   xor   = bin8 primIntXor
   complement = una8 primIntInv
-  shiftL = bini8 primIntShl
-  shiftR = bini8 primIntShr
+  x `shiftL` i
+    | i < 0 = _overflowError
+    | i >= 8 = 0
+    | True = x `unsafeShiftL` i
+  unsafeShiftL = bini8 primIntShl
+  x `shiftR` i
+    | i < 0 = _overflowError
+    | i >= 8 = 0
+    | True = x `unsafeShiftR` i
+  unsafeShiftR = bini8 primIntShr
   bitSizeMaybe _ = Just 8
   bitSize _ = 8
   bit n = i8 (primIntShl 1 n)
   zeroBits = 0
--}
 
+instance FiniteBits Int8 where
+  finiteBitSize _ = 8
 
 --------------------------------------------------------------------------------
 ----    Int16
@@ -172,20 +181,28 @@
   (>)  = cmp16 primIntGT
   (>=) = cmp16 primIntGE
 
-{-
 instance Bits Int16 where
   (.&.) = bin16 primIntAnd
   (.|.) = bin16 primIntOr
   xor   = bin16 primIntXor
   complement = una16 primIntInv
-  shiftL = bini16 primIntShl
-  shiftR = bini16 primIntShr
+  x `shiftL` i
+    | i < 0 = _overflowError
+    | i >= 16 = 0
+    | True = x `unsafeShiftL` i
+  unsafeShiftL = bini16 primIntShl
+  x `shiftR` i
+    | i < 0 = _overflowError
+    | i >= 16 = 0
+    | True = x `unsafeShiftR` i
+  unsafeShiftR = bini16 primIntShr
   bitSizeMaybe _ = Just 16
   bitSize _ = 16
   bit n = i16 (primIntShl 1 n)
   zeroBits = 0
--}
 
+instance FiniteBits Int16 where
+  finiteBitSize _ = 16
 
 --------------------------------------------------------------------------------
 ----    Int32
@@ -256,21 +273,30 @@
   (>)  = cmp32 primIntGT
   (>=) = cmp32 primIntGE
 
-{-
 instance Bits Int32 where
   (.&.) = bin32 primIntAnd
   (.|.) = bin32 primIntOr
   xor   = bin32 primIntXor
   complement = una32 primIntInv
-  shiftL = bini32 primIntShl
-  shiftR = bini32 primIntShr
+  x `shiftL` i
+    | i < 0 = _overflowError
+    | i >= 32 = 0
+    | True = x `unsafeShiftL` i
+  unsafeShiftL = bini32 primIntShl
+  x `shiftR` i
+    | i < 0 = _overflowError
+    | i >= 32 = 0
+    | True = x `unsafeShiftR` i
+  unsafeShiftR = bini32 primIntShr
   bitSizeMaybe _ = Just 32
   bitSize _ = 32
   bit n = i32 (primIntShl 1 n)
   zeroBits = 0
--}
 
---------------------------------------------------------------------------------
+instance FiniteBits Int32 where
+  finiteBitSize _ = 32
+
+--------------------------------------------------------------------------------
 ----    Int64
 
 -- Do sign extension by shifting.
@@ -338,19 +364,25 @@
   (>)  = cmp64 primIntGT
   (>=) = cmp64 primIntGE
 
-{-
 instance Bits Int64 where
   (.&.) = bin64 primIntAnd
   (.|.) = bin64 primIntOr
   xor   = bin64 primIntXor
   complement = una64 primIntInv
-  shiftL = bini64 primIntShl
-  shiftR = bini64 primIntShr
+  x `shiftL` i
+    | i < 0 = _overflowError
+    | i >= 64 = 0
+    | True = x `unsafeShiftL` i
+  unsafeShiftL = bini64 primIntShl
+  x `shiftR` i
+    | i < 0 = _overflowError
+    | i >= 64 = 0
+    | True = x `unsafeShiftR` i
+  unsafeShiftR = bini64 primIntShr
   bitSizeMaybe _ = Just 64
   bitSize _ = 64
   bit n = i64 (primIntShl 1 n)
   zeroBits = 0
--}
 
-
-
+instance FiniteBits Int64 where
+  finiteBitSize _ = 64
--- a/lib/Data/Word.hs
+++ b/lib/Data/Word.hs
@@ -83,14 +83,21 @@
   (.|.) = primWordOr
   xor   = primWordXor
   complement = primWordInv
-  shiftL = primWordShl
-  shiftR = primWordShr
+  x `shiftL` i
+    | i < 0 = _overflowError
+    | i >= _wordSize = 0
+    | True = x `primWordShl` i
+  unsafeShiftL = primWordShl
+  x `shiftR` i
+    | i < 0 = _overflowError
+    | i >= _wordSize = 0
+    | True = x `primWordShr` i
+  unsafeShiftR = primWordShr
   bitSizeMaybe _ = Just _wordSize
   bitSize _ = _wordSize
   bit n = primWordShl 1 n
   zeroBits = 0
 
-
 instance FiniteBits Word where
   finiteBitSize _ = _wordSize
 
@@ -170,8 +177,16 @@
   (.|.) = bin8 primWordOr
   xor   = bin8 primWordXor
   complement = una8 primWordInv
-  shiftL = bini8 primWordShl
-  shiftR = bini8 primWordShr
+  x `shiftL` i
+    | i < 0 = _overflowError
+    | i >= 8 = 0
+    | True = x `unsafeShiftL` i
+  unsafeShiftL = bini8 primWordShl
+  x `shiftR` i
+    | i < 0 = _overflowError
+    | i >= 8 = 0
+    | True = x `unsafeShiftR` i
+  unsafeShiftR = bini8 primWordShr
   bitSizeMaybe _ = Just 8
   bitSize _ = 8
   bit n = w8 (primWordShl 1 n)
@@ -255,8 +270,16 @@
   (.|.) = bin16 primWordOr
   xor   = bin16 primWordXor
   complement = una16 primWordInv
-  shiftL = bini16 primWordShl
-  shiftR = bini16 primWordShr
+  x `shiftL` i
+    | i < 0 = _overflowError
+    | i >= 16 = 0
+    | True = x `unsafeShiftL` i
+  unsafeShiftL = bini16 primWordShl
+  x `shiftR` i
+    | i < 0 = _overflowError
+    | i >= 16 = 0
+    | True = x `unsafeShiftR` i
+  unsafeShiftR = bini16 primWordShr
   bitSizeMaybe _ = Just 16
   bitSize _ = 16
   bit n = w16 (primWordShl 1 n)
@@ -341,8 +364,16 @@
   (.|.) = bin32 primWordOr
   xor   = bin32 primWordXor
   complement = una32 primWordInv
-  shiftL = bini32 primWordShl
-  shiftR = bini32 primWordShr
+  x `shiftL` i
+    | i < 0 = _overflowError
+    | i >= 32 = 0
+    | True = x `unsafeShiftL` i
+  unsafeShiftL = bini32 primWordShl
+  x `shiftR` i
+    | i < 0 = _overflowError
+    | i >= 32 = 0
+    | True = x `unsafeShiftR` i
+  unsafeShiftR = bini32 primWordShr
   bitSizeMaybe _ = Just 32
   bitSize _ = 32
   bit n = w32 (primWordShl 1 n)
@@ -427,8 +458,16 @@
   (.|.) = bin64 primWordOr
   xor   = bin64 primWordXor
   complement = una64 primWordInv
-  shiftL = bini64 primWordShl
-  shiftR = bini64 primWordShr
+  x `shiftL` i
+    | i < 0 = _overflowError
+    | i >= 64 = 0
+    | True = x `unsafeShiftL` i
+  unsafeShiftL = bini64 primWordShl
+  x `shiftR` i
+    | i < 0 = _overflowError
+    | i >= 64 = 0
+    | True = x `unsafeShiftR` i
+  unsafeShiftR = bini64 primWordShr
   bitSizeMaybe _ = Just 64
   bitSize _ = 64
   bit n = w64 (primWordShl 1 n)