ref: 7d146bd75d98bea509acfad35ee5515423892e47
parent: 2ec34c1ab7655acce61166bd267290b590e20a1a
parent: 0022e67752ae28fa0b99302cdf7f6adc08b5cda7
author: Lennart Augustsson <lennart@augustsson.net>
date: Sun Jan 12 07:02:43 EST 2025
Merge pull request #89 from konsumlamm/shift Fix UB in shifts & add missing `Bits` instances
--- a/lib/Data/Bits.hs
+++ b/lib/Data/Bits.hs
@@ -93,13 +93,24 @@
w = finiteBitSize x
+_overflowError :: a
+_overflowError = error "arithmetic overflow"
+
instance Bits Int where
(.&.) = primIntAnd
(.|.) = primIntOr
xor = primIntXor
complement = primIntInv
- shiftL = primIntShl
- shiftR = primIntShr
+ x `shiftL` i
+ | i < 0 = _overflowError
+ | i >= _wordSize = 0
+ | otherwise = x `primIntShl` i
+ unsafeShiftL = primIntShl
+ x `shiftR` i
+ | i < 0 = _overflowError
+ | i >= _wordSize = 0
+ | otherwise = x `primIntShr` i
+ unsafeShiftR = primIntShr
bitSizeMaybe _ = Just _wordSize
bitSize _ = _wordSize
bit n = primIntShl 1 n
--- a/lib/Data/Int/Instances.hs
+++ b/lib/Data/Int/Instances.hs
@@ -12,6 +12,7 @@
import Data.Integer_Type
import Data.Integral
import Data.List
+import Data.Maybe_Type
import Data.Num
import Data.Ord
import Data.Ratio_Type
@@ -88,20 +89,28 @@
(>) = cmp8 primIntGT
(>=) = cmp8 primIntGE
-{-
instance Bits Int8 where
(.&.) = bin8 primIntAnd
(.|.) = bin8 primIntOr
xor = bin8 primIntXor
complement = una8 primIntInv
- shiftL = bini8 primIntShl
- shiftR = bini8 primIntShr
+ x `shiftL` i
+ | i < 0 = _overflowError
+ | i >= 8 = 0
+ | True = x `unsafeShiftL` i
+ unsafeShiftL = bini8 primIntShl
+ x `shiftR` i
+ | i < 0 = _overflowError
+ | i >= 8 = 0
+ | True = x `unsafeShiftR` i
+ unsafeShiftR = bini8 primIntShr
bitSizeMaybe _ = Just 8
bitSize _ = 8
bit n = i8 (primIntShl 1 n)
zeroBits = 0
--}
+instance FiniteBits Int8 where
+ finiteBitSize _ = 8
--------------------------------------------------------------------------------
---- Int16
@@ -172,20 +181,28 @@
(>) = cmp16 primIntGT
(>=) = cmp16 primIntGE
-{-
instance Bits Int16 where
(.&.) = bin16 primIntAnd
(.|.) = bin16 primIntOr
xor = bin16 primIntXor
complement = una16 primIntInv
- shiftL = bini16 primIntShl
- shiftR = bini16 primIntShr
+ x `shiftL` i
+ | i < 0 = _overflowError
+ | i >= 16 = 0
+ | True = x `unsafeShiftL` i
+ unsafeShiftL = bini16 primIntShl
+ x `shiftR` i
+ | i < 0 = _overflowError
+ | i >= 16 = 0
+ | True = x `unsafeShiftR` i
+ unsafeShiftR = bini16 primIntShr
bitSizeMaybe _ = Just 16
bitSize _ = 16
bit n = i16 (primIntShl 1 n)
zeroBits = 0
--}
+instance FiniteBits Int16 where
+ finiteBitSize _ = 16
--------------------------------------------------------------------------------
---- Int32
@@ -256,21 +273,30 @@
(>) = cmp32 primIntGT
(>=) = cmp32 primIntGE
-{-
instance Bits Int32 where
(.&.) = bin32 primIntAnd
(.|.) = bin32 primIntOr
xor = bin32 primIntXor
complement = una32 primIntInv
- shiftL = bini32 primIntShl
- shiftR = bini32 primIntShr
+ x `shiftL` i
+ | i < 0 = _overflowError
+ | i >= 32 = 0
+ | True = x `unsafeShiftL` i
+ unsafeShiftL = bini32 primIntShl
+ x `shiftR` i
+ | i < 0 = _overflowError
+ | i >= 32 = 0
+ | True = x `unsafeShiftR` i
+ unsafeShiftR = bini32 primIntShr
bitSizeMaybe _ = Just 32
bitSize _ = 32
bit n = i32 (primIntShl 1 n)
zeroBits = 0
--}
---------------------------------------------------------------------------------
+instance FiniteBits Int32 where
+ finiteBitSize _ = 32
+
+--------------------------------------------------------------------------------
---- Int64
-- Do sign extension by shifting.
@@ -338,19 +364,25 @@
(>) = cmp64 primIntGT
(>=) = cmp64 primIntGE
-{-
instance Bits Int64 where
(.&.) = bin64 primIntAnd
(.|.) = bin64 primIntOr
xor = bin64 primIntXor
complement = una64 primIntInv
- shiftL = bini64 primIntShl
- shiftR = bini64 primIntShr
+ x `shiftL` i
+ | i < 0 = _overflowError
+ | i >= 64 = 0
+ | True = x `unsafeShiftL` i
+ unsafeShiftL = bini64 primIntShl
+ x `shiftR` i
+ | i < 0 = _overflowError
+ | i >= 64 = 0
+ | True = x `unsafeShiftR` i
+ unsafeShiftR = bini64 primIntShr
bitSizeMaybe _ = Just 64
bitSize _ = 64
bit n = i64 (primIntShl 1 n)
zeroBits = 0
--}
-
-
+instance FiniteBits Int64 where
+ finiteBitSize _ = 64
--- a/lib/Data/Word.hs
+++ b/lib/Data/Word.hs
@@ -83,14 +83,21 @@
(.|.) = primWordOr
xor = primWordXor
complement = primWordInv
- shiftL = primWordShl
- shiftR = primWordShr
+ x `shiftL` i
+ | i < 0 = _overflowError
+ | i >= _wordSize = 0
+ | True = x `primWordShl` i
+ unsafeShiftL = primWordShl
+ x `shiftR` i
+ | i < 0 = _overflowError
+ | i >= _wordSize = 0
+ | True = x `primWordShr` i
+ unsafeShiftR = primWordShr
bitSizeMaybe _ = Just _wordSize
bitSize _ = _wordSize
bit n = primWordShl 1 n
zeroBits = 0
-
instance FiniteBits Word where
finiteBitSize _ = _wordSize
@@ -170,8 +177,16 @@
(.|.) = bin8 primWordOr
xor = bin8 primWordXor
complement = una8 primWordInv
- shiftL = bini8 primWordShl
- shiftR = bini8 primWordShr
+ x `shiftL` i
+ | i < 0 = _overflowError
+ | i >= 8 = 0
+ | True = x `unsafeShiftL` i
+ unsafeShiftL = bini8 primWordShl
+ x `shiftR` i
+ | i < 0 = _overflowError
+ | i >= 8 = 0
+ | True = x `unsafeShiftR` i
+ unsafeShiftR = bini8 primWordShr
bitSizeMaybe _ = Just 8
bitSize _ = 8
bit n = w8 (primWordShl 1 n)
@@ -255,8 +270,16 @@
(.|.) = bin16 primWordOr
xor = bin16 primWordXor
complement = una16 primWordInv
- shiftL = bini16 primWordShl
- shiftR = bini16 primWordShr
+ x `shiftL` i
+ | i < 0 = _overflowError
+ | i >= 16 = 0
+ | True = x `unsafeShiftL` i
+ unsafeShiftL = bini16 primWordShl
+ x `shiftR` i
+ | i < 0 = _overflowError
+ | i >= 16 = 0
+ | True = x `unsafeShiftR` i
+ unsafeShiftR = bini16 primWordShr
bitSizeMaybe _ = Just 16
bitSize _ = 16
bit n = w16 (primWordShl 1 n)
@@ -341,8 +364,16 @@
(.|.) = bin32 primWordOr
xor = bin32 primWordXor
complement = una32 primWordInv
- shiftL = bini32 primWordShl
- shiftR = bini32 primWordShr
+ x `shiftL` i
+ | i < 0 = _overflowError
+ | i >= 32 = 0
+ | True = x `unsafeShiftL` i
+ unsafeShiftL = bini32 primWordShl
+ x `shiftR` i
+ | i < 0 = _overflowError
+ | i >= 32 = 0
+ | True = x `unsafeShiftR` i
+ unsafeShiftR = bini32 primWordShr
bitSizeMaybe _ = Just 32
bitSize _ = 32
bit n = w32 (primWordShl 1 n)
@@ -427,8 +458,16 @@
(.|.) = bin64 primWordOr
xor = bin64 primWordXor
complement = una64 primWordInv
- shiftL = bini64 primWordShl
- shiftR = bini64 primWordShr
+ x `shiftL` i
+ | i < 0 = _overflowError
+ | i >= 64 = 0
+ | True = x `unsafeShiftL` i
+ unsafeShiftL = bini64 primWordShl
+ x `shiftR` i
+ | i < 0 = _overflowError
+ | i >= 64 = 0
+ | True = x `unsafeShiftR` i
+ unsafeShiftR = bini64 primWordShr
bitSizeMaybe _ = Just 64
bitSize _ = 64
bit n = w64 (primWordShl 1 n)