shithub: MicroHs

Download patch

ref: 0117f1320868055b945765fde57bf2462f38e9ff
parent: 9277a221708b1546518f8ee467e38c0ef7239d77
author: Lennart Augustsson <lennart@augustsson.net>
date: Tue Apr 9 06:51:47 EDT 2024

Add Data.Array

--- /dev/null
+++ b/lib/Data/Array.hs
@@ -1,0 +1,124 @@
+module Data.Array (
+    module Data.Ix,
+    Array,
+    array,
+    listArray,
+    accumArray,
+    (!),
+    bounds,
+    indices,
+    elems,
+    assocs,
+    (//),
+    accum,
+    ixmap,
+  ) where
+import Primitives(primPerformIO, primArrCopy, primArrEQ)
+import Data.Ix
+import Data.IOArray
+import Text.Show
+
+data Array i a
+   = Array (i,i)       -- bounds
+           !Int        -- = (rangeSize (l,u))
+           (IOArray a) -- elements
+
+instance Ix a => Functor (Array a) where
+  fmap f a@(Array b _ _) = array b [(i, f (a ! i)) | i <- range b]
+
+instance (Ix a, Eq b)  => Eq (Array a b) where
+  (==) (Array b1 _ a1) (Array b2 _ a2) = b1 == b2 && primArrEQ a1 a2
+
+instance (Ix a, Ord b) => Ord  (Array a b) where
+  compare = undefined
+
+instance (Ix a, Show a, Show b) => Show (Array a b) where
+  showsPrec p a =
+    showParen (p > appPrec) $
+    showString "array " .
+    showsPrec appPrec1 (bounds a) .
+    showChar ' ' .
+    showsPrec appPrec1 (assocs a)
+
+--instance (Ix a, Read a, Read b) => Read (Array a b) where
+--  readsPrec = undefined
+
+array :: (Ix a) => (a,a) -> [(a,b)] -> Array a b
+array b ies =
+  let n = safeRangeSize b
+  in  unsafeArray' b n [(safeIndex b n i, e) | (i, e) <- ies]
+
+listArray  :: (Ix a) => (a,a) -> [b] -> Array a b
+listArray b es =
+  let n = safeRangeSize b
+  in  if length es > n then error "listArray: list too long" else unsafeArray' b n (zip [0..] es)  
+
+accumArray :: (Ix a) => (b -> c -> b) -> b -> (a,a) -> [(a,c)] -> Array a b
+accumArray f z b = accum f (array b [(i, z) | i <- range b])
+
+(!) :: (Ix a) => Array a b -> a -> b
+(!) (Array b n a) i = primPerformIO $ readIOArray a (safeIndex b n i)
+
+bounds :: (Ix a) => Array a b -> (a,a)
+bounds (Array b _ _) = b
+
+indices :: (Ix a) => Array a b -> [a]
+indices (Array b _ _) = range b
+
+elems :: (Ix a) => Array a b -> [b]
+elems (Array _ _ a) = primPerformIO $ elemsIOArray a
+
+assocs :: (Ix a) => Array a b -> [(a,b)]
+assocs a = zip (indices a) (elems a)
+
+(//) :: (Ix a) => Array a b -> [(a,b)] -> Array a b
+(//) (Array b n oa) ies = primPerformIO $ do
+  a <- primArrCopy oa
+  let adj (i, e) = writeIOArray a (safeIndex b n i) e
+  mapM_ adj ies
+  return $ Array b n a
+
+accum :: (Ix a) => (b -> c -> b) -> Array a b -> [(a,c)] -> Array a b
+accum f arr@(Array b n _) ies = unsafeAccum f arr [(safeIndex b n i, e) | (i, e) <- ies]
+
+ixmap :: (Ix a, Ix b) => (a,a) -> (a -> b) -> Array b c -> Array a c
+ixmap b f a = array b [(i, a ! f i) | i <- range b]
+
+-------
+
+unsafeAccum :: (e -> a -> e) -> Array i e -> [(Int, a)] -> Array i e
+unsafeAccum f (Array b n oa) ies = primPerformIO $ do
+  a <- primArrCopy oa
+  let adj (i, e) = do
+        x <- readIOArray a i
+        let x' = f x e
+        seq x' (writeIOArray a i x')
+  mapM_ adj ies
+  return $ Array b n a
+
+unsafeArray' :: (i,i) -> Int -> [(Int, e)] -> Array i e
+unsafeArray' b n ies = primPerformIO $ do
+  a <- newIOArray n arrEleBottom
+  mapM_ (\ (i, e) -> writeIOArray a i e) ies
+  return $ Array b n a
+
+arrEleBottom :: a
+arrEleBottom = error "(Array.!): undefined array element"
+
+safeIndex :: Ix i => (i, i) -> Int -> i -> Int
+safeIndex (l,u) n i | 0 <= i' && i' < n = i'
+                    | otherwise         = badSafeIndex i' n
+  where i' = index (l,u) i
+
+badSafeIndex :: Int -> Int -> a
+badSafeIndex i n = error $ "Error in array index; " ++ show i ++ " not in range [0.." ++ show n ++ ")"
+
+safeRangeSize :: Ix i => (i, i) -> Int
+safeRangeSize b =
+  let r = rangeSize b
+  in  if r < 0 then error "Negative range size" else r
+
+elemsIOArray :: forall a . IOArray a -> IO [a]
+elemsIOArray a = do
+  s <- sizeIOArray a
+  mapM (readIOArray a) [0::Int .. s - 1]
--- a/lib/Primitives.hs
+++ b/lib/Primitives.hs
@@ -254,6 +254,9 @@
 primArrAlloc :: forall a . Int -> a -> IO (IOArray a)
 primArrAlloc = primitive "A.alloc"
 
+primArrCopy :: forall a . IOArray a -> IO (IOArray a)
+primArrCopy = primitive "A.copy"
+
 primArrSize :: forall a . IOArray a -> IO Int
 primArrSize = primitive "A.size"
 
@@ -263,6 +266,7 @@
 primArrWrite :: forall a . IOArray a -> Int -> a -> IO ()
 primArrWrite = primitive "A.write"
 
+-- Not referentially transparent
 primArrEQ :: forall a . IOArray a -> IOArray a -> Bool
 primArrEQ = primitive "A.=="
 
--- a/src/runtime/eval.c
+++ b/src/runtime/eval.c
@@ -177,7 +177,7 @@
                 T_FADD, T_FSUB, T_FMUL, T_FDIV, T_FNEG, T_ITOF,
                 T_FEQ, T_FNE, T_FLT, T_FLE, T_FGT, T_FGE, T_FSHOW, T_FREAD,
 #endif
-                T_ARR_ALLOC, T_ARR_SIZE, T_ARR_READ, T_ARR_WRITE, T_ARR_EQ,
+                T_ARR_ALLOC, T_ARR_COPY, T_ARR_SIZE, T_ARR_READ, T_ARR_WRITE, T_ARR_EQ,
                 T_RAISE, T_SEQ, T_EQUAL, T_COMPARE, T_RNF,
                 T_TICK,
                 T_IO_BIND, T_IO_THEN, T_IO_RETURN,
@@ -208,7 +208,7 @@
   "FADD", "FSUB", "FMUL", "FDIV", "FNEG", "ITOF",
   "FEQ", "FNE", "FLT", "FLE", "FGT", "FGE", "FSHOW", "FREAD",
 #endif
-  "ARR_ALLOC", "ARR_SIZE", "ARR_READ", "ARR_WRITE", "ARR_EQ",
+  "ARR_ALLOC", "ARR_COPY", "ARR_SIZE", "ARR_READ", "ARR_WRITE", "ARR_EQ",
   "RAISE", "SEQ", "EQUAL", "COMPARE", "RNF",
   "TICK",
   "IO_BIND", "IO_THEN", "IO_RETURN",
@@ -403,6 +403,24 @@
   return arr;
 }
 
+struct ioarray*
+arr_copy(struct ioarray *oarr)
+{
+  size_t sz = oarr->size;
+  struct ioarray *arr = MALLOC(sizeof(struct ioarray) + (sz-1) * sizeof(NODEPTR));
+
+  if (!arr)
+    memerr();
+  arr->next = array_root;
+  array_root = arr;
+  arr->marked = 0;
+  arr->permanent = 0;
+  arr->size = sz;
+  memcpy(arr->array, oarr->array, sz * sizeof(NODEPTR));
+  num_arr_alloc++;
+  return arr;
+}
+
 /*****************************************************************************/
 
 #if WANT_TICK
@@ -681,6 +699,7 @@
   { "raise", T_RAISE },
   { "catch", T_CATCH },
   { "A.alloc", T_ARR_ALLOC },
+  { "A.copy", T_ARR_COPY },
   { "A.size", T_ARR_SIZE },
   { "A.read", T_ARR_READ },
   { "A.write", T_ARR_WRITE },
@@ -2024,6 +2043,7 @@
   case T_RAISE: putsb("raise", f); break;
   case T_CATCH: putsb("catch", f); break;
   case T_ARR_ALLOC: putsb("A.alloc", f); break;
+  case T_ARR_COPY: putsb("A.copy", f); break;
   case T_ARR_SIZE: putsb("A.size", f); break;
   case T_ARR_READ: putsb("A.read", f); break;
   case T_ARR_WRITE: putsb("A.write", f); break;
@@ -2864,6 +2884,7 @@
   case T_PEEKCASTRING:
   case T_PEEKCASTRINGLEN:
   case T_ARR_ALLOC:
+  case T_ARR_COPY:
   case T_ARR_SIZE:
   case T_ARR_READ:
   case T_ARR_WRITE:
@@ -3304,6 +3325,19 @@
       size = evalint(ARG(TOP(1)));
       elem = ARG(TOP(2));
       arr = arr_alloc(size, elem);
+      n = alloc_node(T_ARR);
+      ARR(n) = arr;
+      RETIO(n);
+      }
+    case T_ARR_COPY:
+      {
+      struct ioarray *arr;
+      CHECKIO(1);
+      GCCHECK(1);
+      n = evali(ARG(TOP(1)));
+      if (GETTAG(n) != T_ARR)
+        ERR("T_ARR_COPY tag");
+      arr = arr_copy(ARR(n));
       n = alloc_node(T_ARR);
       ARR(n) = arr;
       RETIO(n);
--- /dev/null
+++ b/tests/Array.hs
@@ -1,0 +1,22 @@
+module Array where
+import Data.Array
+
+main :: IO ()
+main = do
+  let a = array (1::Int,3) [(1,'A'),(2,'b'),(3,'3')]
+  let b = array (1,2) [(1,'Q'),(2,'q')]
+  print a
+  print (a == a)
+  print (a == b)
+  print $ listArray (0,4) [1..5::Int]
+  print $ accumArray (+) 0 (0,1) [(1,10),(0,20),(1,5)]
+  print $ a ! 1
+  print $ a ! 3
+  print $ bounds a
+  print $ indices a
+  print $ elems a
+  print $ assocs a
+  print $ a // [(1,'w')]
+  print $ fmap fromEnum a
+  print $ ixmap (0,2) succ a
+  
\ No newline at end of file
--- /dev/null
+++ b/tests/Array.ref
@@ -1,0 +1,14 @@
+array (1,3) [(1,'A'),(2,'b'),(3,'3')]
+True
+False
+array (0,4) [(0,1),(1,2),(2,3),(3,4),(4,5)]
+array (0,1) [(0,20),(1,15)]
+'A'
+'3'
+(1,3)
+[1,2,3]
+"Ab3"
+[(1,'A'),(2,'b'),(3,'3')]
+array (1,3) [(1,'w'),(2,'b'),(3,'3')]
+array (1,3) [(1,65),(2,98),(3,51)]
+array (0,2) [(0,'A'),(1,'b'),(2,'3')]
--- a/tests/Makefile
+++ b/tests/Makefile
@@ -63,6 +63,7 @@
 	$(TMHS) Enum       && $(EVAL) > Enum.out       && diff Enum.ref Enum.out
 	$(TMHS) RecMdl     && $(EVAL) > RecMdl.out     && diff RecMdl.ref RecMdl.out
 	$(TMHS) ForeignPtr && $(EVAL) > ForeignPtr.out && diff ForeignPtr.ref ForeignPtr.out
+	$(TMHS) Array      && $(EVAL) > Array.out      && diff Array.ref Array.out
 
 errtest:
 	sh errtester.sh < errmsg.test
--