ref: ca26282e9a7d2f8e8f1396172115b1efc4b75f62
parent: ac0c51a3279123d2a911e918cb87e8151a529c1a
author: Lennart Augustsson <lennart.augustsson@epicgames.com>
date: Sun Dec 17 13:53:43 EST 2023
Implement ST monad. Just STRef so far.
--- a/TODO
+++ b/TODO
@@ -23,7 +23,6 @@
- don't require kind signatures in forall
* Try Oleg's abstraction algorithm
- Seems to be slower
-* Use IORef for STRef
* Redo type synonym expansion
- Only non-injective synonyms necessitate expansion(?)
- Do expansion during unification
--- /dev/null
+++ b/lib/Control/Monad/ST.hs
@@ -1,0 +1,20 @@
+module Control.Monad.ST(
+ ST,
+ runST,
+ ) where
+import Prelude
+import Primitives(primPerformIO)
+import Control.Monad.ST_Type
+
+runST :: forall a . (forall s . ST s a) -> a
+runST (ST ioa) = primPerformIO ioa
+
+instance forall s . Functor (ST s) where
+ fmap f (ST x) = ST (fmap f x)
+
+instance forall s . Applicative (ST s) where
+ pure x = ST (pure x)
+ ST x <*> ST y = ST (x <*> y)
+
+instance forall s . Monad (ST s) where
+ ST x >>= f = ST (x >>= (unST . f))
--- /dev/null
+++ b/lib/Control/Monad/ST_Type.hs
@@ -1,0 +1,10 @@
+-- This module should not be imported!
+module Control.Monad.ST_Type(
+ ST(..), unST,
+ ) where
+import Primitives(IO)
+
+-- The ST monad is implemented with the IO monad.
+newtype ST s a = ST (IO a)
+unST :: forall s a . ST s a -> IO a
+unST (ST io) = io
--- a/lib/Data/IORef.hs
+++ b/lib/Data/IORef.hs
@@ -16,3 +16,6 @@
writeIORef :: forall a . IORef a -> a -> IO ()
writeIORef (R p) a = primArrWrite p 0 a
+
+modifyIORef :: forall a . IORef a -> (a -> a) -> IO ()
+modifyIORef (R p) f = primArrRead p 0 `primBind` \ a -> primArrWrite p 0 (f a)
--- /dev/null
+++ b/lib/Data/STRef.hs
@@ -1,0 +1,21 @@
+module Data.STRef(
+ STRef,
+ newSTRef, readSTRef, writeSTRef, modifySTRef,
+ ) where
+import Prelude
+import Control.Monad.ST_Type
+import Data.IORef
+
+newtype STRef s a = R (IORef a)
+
+newSTRef :: forall s a . a -> ST s (STRef s a)
+newSTRef a = ST (R <$> newIORef a)
+
+readSTRef :: forall s a . STRef s a -> ST s a
+readSTRef (R p) = ST (readIORef p)
+
+writeSTRef :: forall s a . STRef s a -> a -> ST s ()
+writeSTRef (R p) a = ST (writeIORef p a)
+
+modifySTRef :: forall s a . STRef s a -> (a -> a) -> ST s ()
+modifySTRef (R p) f = ST (modifyIORef p f)
--- a/tests/Makefile
+++ b/tests/Makefile
@@ -48,6 +48,7 @@
$(TMHS) TypeLits && $(EVAL) > TypeLits.out && diff TypeLits.ref TypeLits.out
$(TMHS) View && $(EVAL) > View.out && diff View.ref View.out
$(TMHS) IOArray && $(EVAL) > IOArray.out && diff IOArray.ref IOArray.out
+ $(TMHS) ST && $(EVAL) > ST.out && diff ST.ref ST.out
errtest:
sh errtester.sh < errmsg.test
--- /dev/null
+++ b/tests/ST.hs
@@ -1,0 +1,27 @@
+module ST(main) where
+import Prelude
+import Control.Monad.ST
+import Data.STRef
+import Debug.Trace
+
+facST :: forall s . Int -> ST s Int
+facST n = do
+ ri <- newSTRef 1
+ rr <- newSTRef 1
+ let loop = do
+ i <- readSTRef ri
+ if i > n then
+ return ()
+ else do
+ writeSTRef ri (i + 1)
+ modifySTRef rr (i *)
+ loop
+ loop
+ readSTRef rr
+
+fac :: Int -> Int
+fac n = runST (facST n)
+
+main :: IO ()
+main = do
+ print (fac 10)
--- /dev/null
+++ b/tests/ST.ref
@@ -1,0 +1,1 @@
+3628800
--
⑨