shithub: MicroHs

Download patch

ref: d4f9619762e9c1d1f33d06d8cd4489c324411f49
parent: 8432c18eaebfda38826a6e648e7380d803c67b8e
author: Lennart Augustsson <lennart@augustsson.net>
date: Fri Aug 30 16:57:38 EDT 2024

Implement rudimentary bytestrings.

--- /dev/null
+++ b/lib/Data/ByteString.hs
@@ -1,0 +1,58 @@
+module Data.ByteString(
+  ByteString,
+  append, append3,
+  pack, unpack,
+  ) where
+import Prelude hiding ((++))
+import Data.Word(Word8)
+
+data ByteString  -- primitive type
+
+primBSappend  :: ByteString -> ByteString -> ByteString
+primBSappend  = primitive "bs++"
+primBSappend3 :: ByteString -> ByteString -> ByteString -> ByteString
+primBSappend3 = primitive "bs+++"
+primBSEQ      :: ByteString -> ByteString -> Bool
+primBSEQ      = primitive "bs=="
+primBSNE      :: ByteString -> ByteString -> Bool
+primBSNE      = primitive "bs/="
+primBSLT      :: ByteString -> ByteString -> Bool
+primBSLT      = primitive "bs<"
+primBSLE      :: ByteString -> ByteString -> Bool
+primBSLE      = primitive "bs<="
+primBSGT      :: ByteString -> ByteString -> Bool
+primBSGT      = primitive "bs>"
+primBSGE      :: ByteString -> ByteString -> Bool
+primBSGE      = primitive "bs>="
+primBScmp     :: ByteString -> ByteString -> Ordering
+primBScmp     = primitive "bscmp"
+primBSpack    :: [Word8] -> ByteString
+primBSpack    = primitive "bspack"
+primBSunpack  :: ByteString -> [Word8]
+primBSunpack  = primitive "bsunpack"
+
+instance Eq ByteString where
+  (==) = primBSEQ
+  (/=) = primBSNE
+
+instance Ord ByteString where
+  compare = primBScmp
+  (<)     = primBSLT
+  (<=)    = primBSLE
+  (>)     = primBSGT
+  (>=)    = primBSGE
+
+instance Show ByteString where
+  showsPrec _ bs = showString "pack" . showsPrec 0 (unpack bs)
+
+append :: ByteString -> ByteString -> ByteString
+append = primBSappend
+
+append3 :: ByteString -> ByteString -> ByteString -> ByteString
+append3 = primBSappend3
+
+pack :: [Word8] -> ByteString
+pack = primBSpack
+
+unpack :: ByteString -> [Word8]
+unpack = primBSunpack
--- a/src/runtime/eval.c
+++ b/src/runtime/eval.c
@@ -177,6 +177,7 @@
                 T_TOPTR, T_TOINT, T_TODBL, T_TOFUNPTR,
                 T_BININT2, T_BININT1, T_UNINT1,
                 T_BINDBL2, T_BINDBL1, T_UNDBL1,
+                T_BINBS2, T_BINBS1,
 #if WANT_FLOAT
                 T_FADD, T_FSUB, T_FMUL, T_FDIV, T_FNEG, T_ITOF,
                 T_FEQ, T_FNE, T_FLT, T_FLE, T_FGT, T_FGE, T_FSHOW, T_FREAD,
@@ -192,6 +193,8 @@
                 T_IO_CCALL, T_IO_GC, T_DYNSYM,
                 T_NEWCASTRINGLEN, T_PEEKCASTRING, T_PEEKCASTRINGLEN,
                 T_FROMUTF8,
+                T_BSAPPEND, T_BSAPPEND3, T_BSEQ, T_BSNE, T_BSLT, T_BSLE, T_BSGT, T_BSGE,
+                T_BSPACK, T_BSUNPACK,
                 T_BSTR,
                 T_LAST_TAG,
 };
@@ -602,6 +605,7 @@
 NODEPTR combShowExn, combU, combK2;
 NODEPTR combBININT1, combBININT2, combUNINT1;
 NODEPTR combBINDBL1, combBINDBL2, combUNDBL1;
+NODEPTR combBINBS1, combBINBS2;
 NODEPTR comb_stdin, comb_stdout, comb_stderr;
 
 /* One node of each kind for primitives, these are never GCd. */
@@ -668,6 +672,18 @@
   { "fshow", T_FSHOW},
   { "fread", T_FREAD},
 #endif  /* WANT_FLOAT */
+  { "bs++", T_BSAPPEND},
+  { "bs+++", T_BSAPPEND3},
+  { "bs==", T_BSEQ, T_BSEQ},
+  { "bs/=", T_BSNE, T_BSNE},
+  { "bs<", T_BSLT},
+  { "bs<=", T_BSLE},
+  { "bs>", T_BSGT},
+  { "bs>=", T_BSGE},
+  { "bscmp", T_COMPARE},
+  { "bspack", T_BSPACK},
+  { "bsunpack", T_BSUNPACK},
+
   { "ord", T_I },
   { "chr", T_I },
   { "==", T_EQ, T_EQ },
@@ -769,6 +785,8 @@
     case T_BINDBL1: combBINDBL1 = n; break;
     case T_BINDBL2: combBINDBL2 = n; break;
     case T_UNDBL1: combUNDBL1 = n; break;
+    case T_BINBS1: combBINBS1 = n; break;
+    case T_BINBS2: combBINBS2 = n; break;
 #if WANT_STDIO
     case T_IO_STDIN:  comb_stdin  = n; SETTAG(n, T_PTR); PTR(n) = add_utf8(add_FILE(stdin));  break;
     case T_IO_STDOUT: comb_stdout = n; SETTAG(n, T_PTR); PTR(n) = add_utf8(add_FILE(stdout)); break;
@@ -801,6 +819,8 @@
     case T_BINDBL1: combBINDBL1 = n; break;
     case T_BINDBL2: combBINDBL2 = n; break;
     case T_UNDBL1: combUNDBL1 = n; break;
+    case T_BINBS1: combBINBS1 = n; break;
+    case T_BINBS2: combBINBS2 = n; break;
 #if WANT_STDIO
     case T_IO_STDIN:  comb_stdin  = n; SETTAG(n, T_PTR); PTR(n) = add_utf8(add_FILE(stdin));  break;
     case T_IO_STDOUT: comb_stdout = n; SETTAG(n, T_PTR); PTR(n) = add_utf8(add_FILE(stdout)); break;
@@ -2052,6 +2072,16 @@
   case T_FSHOW: putsb("fshow", f); break;
   case T_FREAD: putsb("fread", f); break;
 #endif
+  case T_BSAPPEND: putsb("bs++", f); break;
+  case T_BSAPPEND3: putsb("bs+++", f); break;
+  case T_BSEQ: putsb("bs==", f); break;
+  case T_BSNE: putsb("bs/=", f); break;
+  case T_BSLT: putsb("bs<", f); break;
+  case T_BSLE: putsb("bs<=", f); break;
+  case T_BSGT: putsb("bs>", f); break;
+  case T_BSGE: putsb("bs>=", f); break;
+  case T_BSPACK: putsb("bspack", f); break;
+  case T_BSUNPACK: putsb("bsunpack", f); break;
   case T_EQ: putsb("==", f); break;
   case T_NE: putsb("/=", f); break;
   case T_LT: putsb("<", f); break;
@@ -2330,6 +2360,22 @@
   return n;
 }
 
+NODEPTR
+bsunpack(struct bytestring bs)
+{
+  NODEPTR n, *np, nc;
+  size_t i;
+
+  n = mkNil();
+  np = &n;
+  for(i = 0; i < bs.size; i++) {
+    nc = mkInt(((uint8_t *)bs.string)[i]);
+    *np = mkCons(nc, *np);
+    np = &ARG(*np);
+  }
+  return n;
+}
+
 NODEPTR evali(NODEPTR n);
 
 /* Follow indirections */
@@ -2465,6 +2511,88 @@
   return name;
 }
 
+struct bytestring
+evalbstr(NODEPTR n)
+{
+  size_t sz = 100;
+  uint8_t *buf = MALLOC(sz);
+  size_t offs;
+  uvalue_t c;
+  NODEPTR x;
+  struct bytestring bs;
+
+  if (!buf)
+    memerr();
+  for (offs = 0;;) {
+    if (offs >= sz) {
+      sz *= 2;
+      buf = REALLOC(buf, sz);
+      if (!buf)
+        memerr();
+    }
+    n = evali(n);
+    if (GETTAG(n) == T_K)            /* Nil */
+      break;
+    else if (GETTAG(n) == T_AP && GETTAG(x = indir(&FUN(n))) == T_AP && GETTAG(indir(&FUN(x))) == T_O) { /* Cons */
+      PUSH(n);                  /* protect from GC */
+      c = (uvalue_t)evalint(ARG(x));
+      n = POPTOP();
+      buf[offs++] = (char)c;
+      n = ARG(n);
+    } else {
+      ERR("evalbstr not Nil/Cons");
+    }
+  }
+  bs.size = offs;
+  bs.string = buf;
+  return bs;
+}
+
+struct bytestring
+bsappend(struct bytestring p, struct bytestring q)
+{
+  struct bytestring r;
+  r.size = p.size + q.size;
+  r.string = MALLOC(r.size);
+  if (!r.string)
+    memerr();
+  memcpy(r.string, p.string, p.size);
+  memcpy((uint8_t *)r.string + p.size, q.string, q.size);
+  return r;
+}
+
+/* 
+ * Compare bytestrings.
+ * We can't use memcmp() directly for two reasons:
+ *  - the two strings can have different lengths
+ *  - the return value is only guaranteed to be ==0 or !=0
+ */
+int
+bscompare(struct bytestring bsp, struct bytestring bsq)
+{
+  uint8_t *p = bsp.string;
+  uint8_t *q = bsq.string;
+  size_t len = bsp.size < bsq.size ? bsp.size : bsq.size;
+  while (len--) {
+    int r = (int)*p++ - (int)*q++;
+    if (r) {
+      /* Unequal bytes found. */
+      if (r < 0)
+        return -1;
+      if (r > 0)
+        return 1;
+      return 0;
+    }
+  }
+  /* Got to the end of the shorter string. */
+  /* The shorter string is considered smaller. */
+  if (bsp.size < bsq.size)
+    return -1;
+  if (bsp.size > bsq.size)
+    return 1;
+  return 0;
+}
+
 /* Compares anything, but really only works well on strings.
  * if p < q  return -1
  * if p > q  return 1
@@ -2490,6 +2618,7 @@
   NODEPTR p, q;
   NODEPTR *ap, *aq;
   enum node_tag ptag, qtag;
+  int r;
 
   /* Since FUN(cmp) can be shared, allocate a copy for it. */
   GCCHECK(1);
@@ -2561,6 +2690,19 @@
       if ((intptr_t)ff > (intptr_t)fg)
         CRET(1);
       break;
+    case T_FORPTR:
+      f = FORPTR(p)->payload.string;
+      g = FORPTR(q)->payload.string;
+      if (f < g)
+        CRET(-1);
+      if (f > g)
+        CRET(1);
+      break;
+    case T_BSTR:
+      r = bscompare(BSTR(p), BSTR(q));
+      if (r)
+        CRET(r);
+      break;
     case T_ARR:
       if (ARR(p) < ARR(q))
         CRET(-1);
@@ -2631,6 +2773,7 @@
 #endif
   enum node_tag tag;
   struct ioarray *arr;
+  struct bytestring xbs, ybs, rbs;
 
 #if MAXSTACKDEPTH
   counter_t old_cur_c_stack = cur_c_stack;
@@ -2821,6 +2964,18 @@
     GOIND(dblToString(xd));
 #endif  /* WANT_FLOAT */
 
+  case T_BSAPPEND:
+  case T_BSEQ:
+  case T_BSNE:
+  case T_BSLT:
+  case T_BSLE:
+  case T_BSGT:
+  case T_BSGE:
+    CHECK(2);
+    n = ARG(TOP(1));
+    PUSH(combBINBS2);
+    goto top;
+
   /* Retag a word sized value, keeping the value bits */
 #define CONV(t) do { CHECK(1); x = evali(ARG(TOP(0))); n = POPTOP(); SETTAG(n, t); SETVALUE(n, GETVALUE(x)); RET; } while(0)
   case T_TODBL: CONV(T_DBL);
@@ -2858,6 +3013,26 @@
     //printf("T_FROMUTF8 x = %p fp=%p payload.string=%p\n", x, x->uarg.uuforptr, x->uarg.uuforptr->payload.string);
     GOIND(mkStringU(BSTR(x)));
 
+  case T_BSUNPACK:
+    if (doing_rnf) RET;
+    CHECK(1);
+    x = evali(ARG(TOP(0)));
+    if (GETTAG(x) != T_BSTR) ERR("BSUNPACK");
+    POP(1);
+    n = TOP(-1);
+    GCCHECK(strNodes(BSTR(x).size));
+    GOIND(bsunpack(BSTR(x)));
+
+  case T_BSPACK:
+    {
+      struct bytestring bs = evalbstr(ARG(TOP(0)));
+      POP(1);
+      n = TOP(-1);
+      SETTAG(n, T_BSTR);
+      FORPTR(n) = mkForPtr(bs);
+      RET;
+    }
+
   case T_RAISE:
     if (doing_rnf) RET;
     if (cur_handler) {
@@ -3132,6 +3307,50 @@
       SETDBL(n, rd);
       goto ret;
 #endif  /* WANT_FLOAT */
+
+    case T_BINBS2:
+      n = ARG(TOP(1));
+      TOP(0) = combBINBS1;
+      goto top;
+
+    case T_BINBS1:
+      /* First argument */
+#if SANITY
+      if (GETTAG(n) != T_BSTR)
+        ERR("BINBS 0");
+#endif  /* SANITY */
+      xbs = BSTR(n);
+      /* Second argument */
+      y = ARG(TOP(2));
+      while (GETTAG(y) == T_IND)
+        y = INDIR(y);
+#if SANITY
+      if (GETTAG(y) != T_BSTR)
+        ERR("BINBS 1");
+#endif  /* SANITY */
+      ybs = BSTR(y);
+      p = FUN(TOP(1));
+      POP(3);
+      n = TOP(-1);
+    binbs:
+      switch (GETTAG(p)) {
+      case T_IND:    p = INDIR(p); goto binbs;
+
+      case T_BSAPPEND: rbs = bsappend(xbs, ybs); break;
+      case T_BSEQ:   GOIND(bscompare(xbs, ybs) == 0 ? combTrue : combFalse);
+      case T_BSNE:   GOIND(bscompare(xbs, ybs) != 0 ? combTrue : combFalse);
+      case T_BSLT:   GOIND(bscompare(xbs, ybs) <  0 ? combTrue : combFalse);
+      case T_BSLE:   GOIND(bscompare(xbs, ybs) <= 0 ? combTrue : combFalse);
+      case T_BSGT:   GOIND(bscompare(xbs, ybs) >  0 ? combTrue : combFalse);
+      case T_BSGE:   GOIND(bscompare(xbs, ybs) >= 0 ? combTrue : combFalse);
+
+      default:
+        //fprintf(stderr, "tag=%d\n", GETTAG(FUN(TOP(0))));
+        ERR("BINBS");
+      }
+      SETTAG((n), T_BSTR);
+      FORPTR(n) = mkForPtr(rbs);
+      goto ret;
 
     default:
       stack_ptr = stk;
--- /dev/null
+++ b/tests/Bytestring.hs
@@ -1,0 +1,22 @@
+module Bytestring where
+import Data.Word
+import Data.ByteString
+
+bs1 :: ByteString
+bs1 = pack [1,2,3]
+
+bs2 :: ByteString
+bs2 = pack [1,2,4]
+
+bs3 :: ByteString
+bs3 = pack [1,2]
+
+main :: IO ()
+main = do
+  print (unpack bs1)
+  print bs1
+  print $ bs1 `append` bs2
+  print [ op x y | op <- [(==), (/=), (<), (<=), (>), (>=)]
+                 , x <- [bs1, bs2, bs3]
+                 , y <- [bs1, bs2, bs3]
+        ]
--- /dev/null
+++ b/tests/Bytestring.ref
@@ -1,0 +1,4 @@
+[1,2,3]
+pack[1,2,3]
+pack[1,2,3,1,2,4]
+[True,False,False,False,True,False,False,False,True,False,True,True,True,False,True,True,True,False,False,True,False,False,False,False,True,True,False,True,True,False,False,True,False,True,True,True,False,False,True,True,False,True,False,False,False,True,False,True,True,True,True,False,False,True]
--- a/tests/Makefile
+++ b/tests/Makefile
@@ -68,6 +68,7 @@
 	$(TMHS) Eq1        && $(EVAL) > Eq1.out        && diff Eq1.ref Eq1.out
 	$(TMHS) Irref      && $(EVAL) > Irref.out      && diff Irref.ref Irref.out
 	$(TMHS) DfltSig    && $(EVAL) > DfltSig.out    && diff DfltSig.ref DfltSig.out
+	$(TMHS) Bytestring && $(EVAL) > Bytestring.out && diff Bytestring.ref Bytestring.out
 
 errtest:
 	sh errtester.sh $(MHS) < errmsg.test
--