shithub: MicroHs

Download patch

ref: 20d03699a8ffc93e8c8196fe68f8ff08b58eca0d
parent: 29147d26b73966a63d40841d0b4682bbe04e76b8
author: Lennart Augustsson <lennart@augustsson.net>
date: Thu Apr 4 06:07:13 EDT 2024

Implement ForeignPtr

--- /dev/null
+++ b/lib/Foreign/ForeignPtr.hs
@@ -1,0 +1,86 @@
+module Foreign.ForeignPtr ( 
+  ForeignPtr,
+  FinalizerPtr,
+  newForeignPtr,
+  newForeignPtr_,
+  addForeignPtrFinalizer,
+  withForeignPtr,
+  -- finalizeForeignPtr,
+  touchForeignPtr,
+  castForeignPtr,
+  plusForeignPtr,
+  mallocForeignPtr,
+  mallocForeignPtrBytes,
+  mallocForeignPtrArray,
+  mallocForeignPtrArray0,
+  ) where
+import Primitives
+import Foreign.Ptr
+import Foreign.Storable
+import Foreign.Marshal.Alloc
+import Foreign.Marshal.Array
+
+instance Eq (ForeignPtr a) where
+    p == q  =  unsafeForeignPtrToPtr p == unsafeForeignPtrToPtr q
+
+{-
+instance Ord (ForeignPtr a) where
+    compare p q  =  compare (unsafeForeignPtrToPtr p) (unsafeForeignPtrToPtr q)
+-}
+
+instance Show (ForeignPtr a) where
+    showsPrec p f = showsPrec p (unsafeForeignPtrToPtr f)
+
+unsafeForeignPtrToPtr :: ForeignPtr a -> Ptr a
+unsafeForeignPtrToPtr = primitive "fp2p"
+
+type FinalizerPtr a = FunPtr (Ptr a -> IO ())
+
+foreign import ccall "&free" c_freefun :: FinalizerPtr a
+
+mallocForeignPtr :: Storable a => IO (ForeignPtr a)
+mallocForeignPtr = do
+  ptr <- malloc
+  newForeignPtr c_freefun ptr
+
+mallocForeignPtrBytes :: Int -> IO (ForeignPtr a)
+mallocForeignPtrBytes size = do
+  ptr <- mallocBytes size
+  newForeignPtr c_freefun ptr
+
+mallocForeignPtrArray :: Storable a => Int -> IO (ForeignPtr a)
+mallocForeignPtrArray size = do
+  ptr <- mallocArray size
+  newForeignPtr c_freefun ptr
+
+mallocForeignPtrArray0 :: Storable a => Int -> IO (ForeignPtr a)
+mallocForeignPtrArray0 size = do
+  ptr <- mallocArray0 size
+  newForeignPtr c_freefun ptr
+
+addForeignPtrFinalizer :: FinalizerPtr a -> ForeignPtr a -> IO ()
+addForeignPtrFinalizer = primitive "fpfin"
+
+newForeignPtr :: FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
+newForeignPtr f p = do
+  fp <- newForeignPtr_ p
+  addForeignPtrFinalizer f fp
+  return fp
+
+newForeignPtr_ :: Ptr a -> IO (ForeignPtr a)
+newForeignPtr_ = primitive "fpnew"
+
+withForeignPtr :: ForeignPtr a -> (Ptr a -> IO b) -> IO b
+withForeignPtr fp io = do
+  b <- io (unsafeForeignPtrToPtr fp)
+  touchForeignPtr fp
+  return b
+
+touchForeignPtr :: ForeignPtr a -> IO ()
+touchForeignPtr fp = seq fp (return ())
+
+castForeignPtr :: ForeignPtr a -> ForeignPtr b
+castForeignPtr = primUnsafeCoerce
+
+plusForeignPtr :: ForeignPtr a -> Int -> ForeignPtr b
+plusForeignPtr = primitive "fp+"
--- a/lib/Primitives.hs
+++ b/lib/Primitives.hs
@@ -31,6 +31,7 @@
 data IO a
 data Word
 data Ptr a
+data ForeignPtr a
 data FunPtr a
 data IOArray a
 
--- a/src/runtime/eval.c
+++ b/src/runtime/eval.c
@@ -162,7 +162,7 @@
 #endif  /* WANT_STDIO */
 #endif  /* !define(ERR) */
 
-enum node_tag { T_FREE, T_IND, T_AP, T_INT, T_DBL, T_PTR, T_FUNPTR, T_BADDYN, T_ARR,
+enum node_tag { T_FREE, T_IND, T_AP, T_INT, T_DBL, T_PTR, T_FUNPTR, T_FORPTR, T_BADDYN, T_ARR,
                 T_S, T_K, T_I, T_B, T_C,
                 T_A, T_Y, T_SS, T_BB, T_CC, T_P, T_R, T_O, T_U, T_Z,
                 T_K2, T_K3, T_K4, T_CCB,
@@ -170,6 +170,7 @@
                 T_AND, T_OR, T_XOR, T_INV, T_SHL, T_SHR, T_ASHR,
                 T_EQ, T_NE, T_LT, T_LE, T_GT, T_GE, T_ULT, T_ULE, T_UGT, T_UGE,
                 T_PEQ, T_PNULL, T_PADD, T_PSUB,
+                T_FPADD, T_FP2P, T_FPNEW, T_FPFIN,
                 T_TOPTR, T_TOINT, T_TODBL,
                 T_BININT2, T_BININT1, T_UNINT1,
                 T_BINDBL2, T_BINDBL1, T_UNDBL1,
@@ -193,7 +194,7 @@
 };
 #if 0
 static const char* tag_names[] = {
-  "FREE", "IND", "AP", "INT", "DBL", "PTR", "BADDYN", "ARR",
+  "FREE", "IND", "AP", "INT", "DBL", "PTR", "FUNPTR", "FORPTR", "BADDYN", "ARR",
   "S", "K", "I", "B", "C",
   "A", "Y", "SS", "BB", "CC", "P", "R", "O", "U", "Z",
   "K2", "K3", "K4", "CCB",
@@ -201,6 +202,7 @@
   "AND", "OR", "XOR", "INV", "SHL", "SHR", "ASHR",
   "EQ", "NE", "LT", "LE", "GT", "GE", "ULT", "ULE", "UGT", "UGE",
   "PEQ", "PNULL", "PADD", "PSUB",
+  "FPADD", "FP2P", "FPNEW", "FPFIN",
   "TOPTR", "TOINT", "TODBL",
 #if WANT_FLOAT
   "FADD", "FSUB", "FMUL", "FDIV", "FNEG", "ITOF",
@@ -224,6 +226,7 @@
 
 struct ioarray;
 struct ustring;
+struct forptr;
 
 typedef struct node {
   union {
@@ -239,6 +242,7 @@
     void           *uuptr;
     HsFunPtr        uufunptr;
     struct ioarray *uuarray;
+    struct forptr  *uuforptr;
   } uarg;
 } node;
 typedef struct node* NODEPTR;
@@ -256,6 +260,7 @@
 #define CSTR(p) (p)->uarg.uucstring
 #define PTR(p) (p)->uarg.uuptr
 #define FUNPTR(p) (p)->uarg.uufunptr
+#define FORPTR(p) (p)->uarg.uuforptr
 #define ARR(p) (p)->uarg.uuarray
 #define INDIR(p) ARG(p)
 #define NODE_SIZE sizeof(node)
@@ -287,6 +292,34 @@
 };
 struct ioarray *array_root = 0;
 
+/*
+ * A Haskell ForeignPtr has a normal pointer, and a finalizer
+ * function that is to be called when there are no more references
+ * to the ForeignPtr.
+ * A complication is that using plusForeignPtr creates a new
+ * ForeignPtr that must share the same finalizer.
+ * There is one struct forptr for each ForeignPtr.  It has pointer
+ * to the actual data, and to a struct final which is shared between
+ * all ForeignPtrs that have been created with plusForeignPtr.
+ * During GC the used bit is set for any references to the forptr.
+ * The scan phase will traverse the struct final chain and run
+ * the finalizer, and free associated structs.
+ */
+struct final {
+  struct final  *next;      /* the next finalizer */
+  HsFunPtr       final;     /* function to call to release resource */
+  void          *arg;       /* argument to final when called */
+  struct forptr *back;      /* back pointer to the first forptr */
+  int            used;      /* mark bit for GC */
+};
+
+struct forptr {
+  struct forptr *next;      /* the next ForeignPtr that shares the same finilizer */
+  void          *payload;   /* the actual pointer */
+  struct final  *finalizer; /* the finalizer for this ForeignPtr */
+};
+struct final *final_root = 0;
+
 counter_t num_reductions = 0;
 counter_t num_alloc = 0;
 counter_t num_gc = 0;
@@ -622,6 +655,11 @@
   { "pcast", T_I },
   { "p+", T_PADD },
   { "p-", T_PSUB },
+  { "fpcast", T_I },
+  { "fp+", T_FPADD },
+  { "fp2p", T_FP2P },
+  { "fpnew", T_FPNEW },
+  { "fpfin", T_FPFIN },
   { "seq", T_SEQ },
   { "equal", T_EQUAL, T_EQUAL },
   { "sequal", T_EQUAL, T_EQUAL },
@@ -1861,6 +1899,8 @@
     break;
   case T_FUNPTR:
       ERR("Cannot serialize function pointers");
+  case T_FORPTR:
+      ERR("Cannot serialize foreign pointers");
     break;
   case T_STR:
     print_string(f, STR(n));
@@ -1932,6 +1972,10 @@
   case T_PNULL: putsb("pnull", f); break;
   case T_PADD: putsb("p+", f); break;
   case T_PSUB: putsb("p-", f); break;
+  case T_FPADD: putsb("fp+", f); break;
+  case T_FP2P: putsb("fp2p", f); break;
+  case T_FPNEW: putsb("fpnew", f); break;
+  case T_FPFIN: putsb("fpfin", f); break;
   case T_EQUAL: putsb("equal", f); break;
   case T_COMPARE: putsb("compare", f); break;
   case T_RNF: putsb("rnf", f); break;
@@ -2083,6 +2127,39 @@
   return n;
 }
 
+struct forptr*
+mkForPtr(void *p)
+{
+  struct final *fin = malloc(sizeof(struct final));
+  struct forptr *fp = malloc(sizeof(struct forptr));
+  if (!fin || !fp)
+    memerr();
+  fin->next = final_root;
+  final_root = fin;
+  fin->final = 0;
+  fin->arg = p;
+  fin->back = fp;
+  fin->used = 0;
+  fp->next = 0;
+  fp->payload = p;
+  fp->finalizer = fin;
+  return fp;
+}
+
+struct forptr*
+addForPtr(struct forptr *ofp, int s)
+{
+  struct forptr *fp = malloc(sizeof(struct forptr));
+  struct final *fin = ofp->finalizer;
+  if (!fp)
+    memerr();
+  fp->next = ofp;
+  fin->back = fp;
+  fp->payload = (char*)ofp->payload + s;
+  fp->finalizer = fin;
+  return fp;
+}
+
 static INLINE NODEPTR
 mkNil(void)
 {
@@ -2199,8 +2276,8 @@
   return PTR(n);
 }
 
-/* Evaluate to a T_PTR */
-void *
+/* Evaluate to a T_FUNPTR */
+HsFunPtr
 evalfunptr(NODEPTR n)
 {
   n = evali(n);
@@ -2212,6 +2289,19 @@
   return FUNPTR(n);
 }
 
+/* Evaluate to a T_FORPTR */
+struct forptr *
+evalforptr(NODEPTR n)
+{
+  n = evali(n);
+#if SANITY
+  if (GETTAG(n) != T_FORPTR) {
+    ERR1("evalforptr, bad tag %d", GETTAG(n));
+  }
+#endif
+  return FORPTR(n);
+}
+
 /* Evaluate a string, returns a newly allocated buffer. */
 /* XXX this is cheating, should use continuations */
 /* XXX the malloc()ed string is leaked if we yield in here. */
@@ -2428,6 +2518,7 @@
   NODEPTR x, y, z, w;
   value_t xi, yi, r;
   void *xp, *yp;
+  struct forptr *xfp;
 #if WANT_FLOAT
   flt_t xd, rd;
 #endif  /* WANT_FLOAT */
@@ -2472,6 +2563,7 @@
 #define SETDBL(n,d)    do { SETTAG((n), T_DBL); SETDBLVALUE((n), (d)); } while(0)
 #define SETPTR(n,r)    do { SETTAG((n), T_PTR); PTR(n) = (r); } while(0)
 #define SETFUNPTR(n,r) do { SETTAG((n), T_FUNPTR); FUNPTR(n) = (r); } while(0)
+#define SETFORPTR(n,r) do { SETTAG((n), T_FORPTR); FORPTR(n) = (r); } while(0)
 #define OPINT1(e)      do { CHECK(1); xi = evalint(ARG(TOP(0)));                            e; POP(1); n = TOP(-1); } while(0);
 #define OPPTR2(e)      do { CHECK(2); xp = evalptr(ARG(TOP(0))); yp = evalptr(ARG(TOP(1))); e; POP(2); n = TOP(-1); } while(0);
 #define CMPP(op)       do { OPPTR2(r = xp op yp); GOIND(r ? combTrue : combFalse); } while(0)
@@ -2633,6 +2725,9 @@
   case T_PADD: CHECK(2); xp = evalptr(ARG(TOP(0))); yi = evalint(ARG(TOP(1))); POP(2); n = TOP(-1); SETPTR(n, (char*)xp + yi); RET;
   case T_PSUB: CHECK(2); xp = evalptr(ARG(TOP(0))); yp = evalptr(ARG(TOP(1))); POP(2); n = TOP(-1); SETINT(n, (char*)xp - (char*)yp); RET;
 
+  case T_FPADD: CHECK(2); xp = evalptr(ARG(TOP(0))); yi = evalint(ARG(TOP(1))); POP(2); n = TOP(-1); SETFORPTR(n, addForPtr(xp, yi)); RET;
+  case T_FP2P:  CHECK(1); xfp = evalforptr(ARG(TOP(0))); POP(1); n = TOP(-1); SETPTR(n, xfp->payload); RET;
+
   case T_ARR_EQ:
     {
       CHECK(2);
@@ -2735,6 +2830,8 @@
   case T_ARR_SIZE:
   case T_ARR_READ:
   case T_ARR_WRITE:
+  case T_FPNEW:
+  case T_FPFIN:
     RET;
 
   case T_DYNSYM:
@@ -3165,6 +3262,7 @@
       NODEPTR elem;
       struct ioarray *arr;
       CHECKIO(2);
+      GCCHECK(1);
       size = evalint(ARG(TOP(1)));
       elem = ARG(TOP(2));
       arr = arr_alloc(size, elem);
@@ -3204,6 +3302,23 @@
       }
       ARR(n)->array[i] = ARG(TOP(3));
       RETIO(combUnit);
+      }
+
+    case T_FPNEW:
+      {
+        CHECKIO(1);
+        void *xp = evalptr(ARG(TOP(1)));
+        n = alloc_node(T_FORPTR);
+        SETFORPTR(n, mkForPtr(xp));
+        RETIO(n);
+      }
+    case T_FPFIN:
+      {
+        CHECKIO(1);
+        struct forptr *xfp = evalforptr(ARG(TOP(2)));
+        HsFunPtr yp = evalfunptr(ARG(TOP(1)));
+        xfp->finalizer->final = yp;
+        RETIO(combUnit);
       }
 
     default:
--