shithub: MicroHs

Download patch

ref: ca26282e9a7d2f8e8f1396172115b1efc4b75f62
parent: ac0c51a3279123d2a911e918cb87e8151a529c1a
author: Lennart Augustsson <lennart.augustsson@epicgames.com>
date: Sun Dec 17 13:53:43 EST 2023

Implement ST monad.  Just STRef so far.

--- a/TODO
+++ b/TODO
@@ -23,7 +23,6 @@
   - don't require kind signatures in forall
 * Try Oleg's abstraction algorithm
   - Seems to be slower
-* Use IORef for STRef
 * Redo type synonym expansion
   - Only non-injective synonyms necessitate expansion(?)
   - Do expansion during unification
--- /dev/null
+++ b/lib/Control/Monad/ST.hs
@@ -1,0 +1,20 @@
+module Control.Monad.ST(
+  ST,
+  runST,
+  ) where
+import Prelude
+import Primitives(primPerformIO)
+import Control.Monad.ST_Type
+
+runST :: forall a . (forall s . ST s a) -> a
+runST (ST ioa) = primPerformIO ioa
+
+instance forall s . Functor (ST s) where
+  fmap f (ST x) = ST (fmap f x)
+
+instance forall s . Applicative (ST s) where
+  pure x = ST (pure x)
+  ST x <*> ST y = ST (x <*> y)
+
+instance forall s . Monad (ST s) where
+  ST x >>= f = ST (x >>= (unST . f))
--- /dev/null
+++ b/lib/Control/Monad/ST_Type.hs
@@ -1,0 +1,10 @@
+-- This module should not be imported!
+module Control.Monad.ST_Type(
+  ST(..), unST,
+  ) where
+import Primitives(IO)
+
+-- The ST monad is implemented with the IO monad.
+newtype ST s a = ST (IO a)
+unST :: forall s a . ST s a -> IO a
+unST (ST io) = io
--- a/lib/Data/IORef.hs
+++ b/lib/Data/IORef.hs
@@ -16,3 +16,6 @@
 
 writeIORef :: forall a . IORef a -> a -> IO ()
 writeIORef (R p) a = primArrWrite p 0 a
+
+modifyIORef :: forall a . IORef a -> (a -> a) -> IO ()
+modifyIORef (R p) f = primArrRead p 0 `primBind` \ a -> primArrWrite p 0 (f a)
--- /dev/null
+++ b/lib/Data/STRef.hs
@@ -1,0 +1,21 @@
+module Data.STRef(
+  STRef,
+  newSTRef, readSTRef, writeSTRef, modifySTRef,
+  ) where
+import Prelude
+import Control.Monad.ST_Type
+import Data.IORef
+
+newtype STRef s a = R (IORef a)
+
+newSTRef :: forall s a . a -> ST s (STRef s a)
+newSTRef a = ST (R <$> newIORef a)
+
+readSTRef :: forall s a . STRef s a -> ST s a
+readSTRef (R p) = ST (readIORef p)
+
+writeSTRef :: forall s a . STRef s a -> a -> ST s ()
+writeSTRef (R p) a = ST (writeIORef p a)
+
+modifySTRef :: forall s a . STRef s a -> (a -> a) -> ST s ()
+modifySTRef (R p) f = ST (modifyIORef p f)
--- a/tests/Makefile
+++ b/tests/Makefile
@@ -48,6 +48,7 @@
 	$(TMHS) TypeLits   && $(EVAL) > TypeLits.out   && diff TypeLits.ref TypeLits.out
 	$(TMHS) View       && $(EVAL) > View.out       && diff View.ref View.out
 	$(TMHS) IOArray    && $(EVAL) > IOArray.out    && diff IOArray.ref IOArray.out
+	$(TMHS) ST         && $(EVAL) > ST.out         && diff ST.ref ST.out
 
 errtest:
 	sh errtester.sh < errmsg.test
--- /dev/null
+++ b/tests/ST.hs
@@ -1,0 +1,27 @@
+module ST(main) where
+import Prelude
+import Control.Monad.ST
+import Data.STRef
+import Debug.Trace
+
+facST :: forall s . Int -> ST s Int
+facST n = do
+  ri <- newSTRef 1
+  rr <- newSTRef 1
+  let loop = do
+        i <- readSTRef ri
+        if i > n then
+          return ()
+         else do
+          writeSTRef ri (i + 1)
+          modifySTRef rr (i *)
+          loop
+  loop
+  readSTRef rr
+  
+fac :: Int -> Int
+fac n = runST (facST n)
+
+main :: IO ()
+main = do
+  print (fac 10)
--- /dev/null
+++ b/tests/ST.ref
@@ -1,0 +1,1 @@
+3628800
--