shithub: mc

Download patch

ref: 00affd843ab9b371c3a18bfa79b4d82b261911a1
parent: 2e89856cee5ff88161295f3ea0d200f72eb4d31f
author: Ori Bernstein <ori@eigenstate.org>
date: Mon Dec 21 16:39:28 EST 2015

Add custom iterators.

--- a/6/gen.c
+++ b/6/gen.c
@@ -96,7 +96,7 @@
 		d = n;
 	}
 	t = tybase(decltype(d));
-	if (d && d->decl.isconst && d->decl.isglobl)
+	if (d && d->decl.isconst && d->decl.isglobl && !d->decl.isgeneric)
 		return t->type == Tyfunc || t->type == Tycode;
 	return 0;
 }
--- a/6/simp.c
+++ b/6/simp.c
@@ -83,6 +83,8 @@
 
 static void append(Simp *s, Node *n)
 {
+	if (debugopt['S'])
+		dump(n, stdout);
 	lappend(&s->stmts, &s->nstmts, n);
 }
 
@@ -431,6 +433,26 @@
 	s->nloopexit--;
 }
 
+static void simploopmatch(Simp *s, Node *pat, Node *val, Node *ltrue, Node *lfalse)
+{
+	Node **cap, **out, *lload;
+	size_t i, ncap, nout;
+
+	/* pattern match */
+	lload = genlbl(pat->loc);
+	out = NULL;
+	nout = 0;
+	cap = NULL;
+	ncap = 0;
+	genonematch(pat, val, lload, lfalse, &out, &nout, &cap, &ncap);
+	for (i = 0; i < nout; i++)
+		simp(s, out[i]);
+	simp(s, lload);
+	for (i = 0; i < ncap; i++)
+		simp(s, cap[i]);
+	jmp(s, ltrue);
+}
+
 /* pat; seq; 
  *      body;;
  *
@@ -446,19 +468,18 @@
  *           cjmp (cond) :match :end
  *      :match
  *           ...match...
- *           cjmp (match) :body :step
+ *           cjmp (match) :load :step
+ *      :load
+ *           matchval = load
  *      :end
  */
 static void simpidxiter(Simp *s, Node *n)
 {
-	Node *lbody, *lload, *lstep, *lcond, *lmatch, *lend;
+	Node *lbody, *lstep, *lcond, *lmatch, *lend;
 	Node *idx, *len, *dcl, *seq, *val, *done;
-	Node **cap, **out;
-	size_t i, ncap, nout;
 	Node *zero;
 
 	lbody = genlbl(n->loc);
-	lload = genlbl(n->loc);
 	lstep = genlbl(n->loc);
 	lcond = genlbl(n->loc);
 	lmatch = genlbl(n->loc);
@@ -493,16 +514,7 @@
 	val = load(idxaddr(s, seq, idx));
 
 	/* pattern match */
-	out = NULL;
-	nout = 0;
-	cap = NULL;
-	ncap = 0;
-	genonematch(n->iterstmt.elt, val, lload, lstep, &out, &nout, &cap, &ncap);
-	for (i = 0; i < nout; i++)
-		simp(s, out[i]);
-	simp(s, lload);
-	for (i = 0; i < ncap; i++)
-		simp(s, cap[i]);
+	simploopmatch(s, n->iterstmt.elt, val, lbody, lstep);
 	jmp(s, lbody);
 	simp(s, lend);
 
@@ -510,6 +522,26 @@
 	s->nloopexit--;
 }
 
+static Node *itertraitfn(Srcloc loc, Trait *tr, char *fn, Type *ty)
+{
+	Node *proto, *dcl, *var;
+	char *name;
+	size_t i;
+
+	for (i = 0; i < tr->nfuncs; i++) {
+		name = declname(tr->funcs[i]);
+		if (!strcmp(fn, name)) {
+			proto = tr->funcs[i];
+			dcl = htget(proto->decl.impls, ty);
+			var = mkexpr(loc, Ovar, dcl->decl.name, NULL);
+			var->expr.type = dcl->decl.type;
+			var->expr.did = dcl->decl.did;
+			return var;
+		}
+	}
+	return NULL;
+}
+
 /* for pat in seq
  * 	body;;
  * =>
@@ -517,18 +549,73 @@
  * 	.elt = elt
  * 	:body
  * 		..body..
- * 		__iterfin__(&seq, &elt)
  * 	:step
+ * 		__iterfin__(&seq, &elt)
  * 		cond = __iternext__(&seq, &eltout)
  * 		cjmp (cond) :match :end
  * 	:match
  * 		...match...
- * 		cjmp (match) :body :step
+ * 		cjmp (match) :load :step
+ * 	:load
+ * 		...load matches...
  * 	:end
  */
 static void simptraititer(Simp *s, Node *n)
 {
-	die("unimplemented");
+	Node *lbody, *lclean, *lstep, *lmatch, *lend;
+	Node *done, *val, *iter, *valptr, *iterptr;
+	Node *func, *call, *asn;
+	Trait *tr;
+
+	val = temp(s, n->iterstmt.elt);
+	valptr = mkexpr(val->loc, Oaddr, val, NULL);
+	valptr->expr.type = mktyptr(n->loc, exprtype(val));
+	iter = temp(s, n->iterstmt.seq);
+	iterptr = mkexpr(val->loc, Oaddr, iter, NULL);
+	iterptr->expr.type = mktyptr(n->loc, exprtype(iter));
+	tr = traittab[Tciter];
+
+	/* create labels */
+	lbody = genlbl(n->loc);
+	lclean = genlbl(n->loc);
+	lstep = genlbl(n->loc);
+	lmatch = genlbl(n->loc);
+	lend = genlbl(n->loc);
+	lappend(&s->loopstep, &s->nloopstep, lstep);
+	lappend(&s->loopexit, &s->nloopexit, lend);
+
+	asn = assign(s, iter, n->iterstmt.seq);
+	append(s, asn);
+	jmp(s, lstep);
+	simp(s, lbody);
+	/* body */
+	simp(s, n->iterstmt.body);
+	simp(s, lclean);
+
+	/* call iterator cleanup */
+	func = itertraitfn(n->loc, tr, "__iterfin__", exprtype(iter));
+	call = mkexpr(n->loc, Ocall, func, iterptr, valptr, NULL);
+	call->expr.type = mktype(n->loc, Tyvoid);
+	append(s, call);
+
+	simp(s, lstep);
+	/* call iterator step */
+	func = itertraitfn(n->loc, tr, "__iternext__", exprtype(iter));
+	call = mkexpr(n->loc, Ocall, func, iterptr, valptr, NULL);
+	done = gentemp(n->loc, mktype(n->loc, Tybool), NULL);
+	call->expr.type = exprtype(done);
+	asn = assign(s, done, call);
+	append(s, asn);
+	cjmp(s, done, lmatch, lend);
+
+	/* pattern match */
+	simp(s, lmatch);
+	simploopmatch(s, n->iterstmt.elt, val, lbody, lclean);
+	jmp(s, lbody);
+	simp(s, lend);
+
+	s->nloopstep--;
+	s->nloopexit--;
 }
 
 static void simpiter(Simp *s, Node *n)
--- a/parse/infer.c
+++ b/parse/infer.c
@@ -890,7 +890,7 @@
 {
 	Type *t;
 
-	if (a->nsub == 1)
+	if (a->type == Tyslice || a->type == Tyarray)
 		t = a->sub[0];
 	else
 		t = htget(st->seqbase, a);
@@ -1652,9 +1652,9 @@
 			fatal(n, "trait %s already specialized with %s on %s:%d",
 				namestr(t->name), tystr(n->impl.type),
 				fname(sym->loc), lnum(sym->loc));
-		htput(proto->decl.impls, n->impl.type, ty);
 		dcl->decl.name = name;
 		putdcl(file->file.globls, dcl);
+		htput(proto->decl.impls, n->impl.type, dcl);
 		if (debugopt['S'])
 			printf("specializing trait [%d]%s:%s => %s:%s\n", n->loc.line,
 					namestr(proto->decl.name), tystr(type(st, proto)), namestr(name),
@@ -1997,15 +1997,18 @@
 
 static void checkvar(Inferstate *st, Node *n)
 {
-	Node *dcl;
+	Node *proto, *dcl;
 	Type *ty;
 
-	dcl = decls[n->expr.did];
+	proto = decls[n->expr.did];
 	ty = NULL;
+	dcl = NULL;
 	if (n->expr.param)
-		ty = htget(dcl->decl.impls, tf(st, n->expr.param));
+		dcl = htget(proto->decl.impls, tf(st, n->expr.param));
+	if (dcl)
+		ty = dcl->decl.type;
 	if (!ty)
-		ty = tyfreshen(st, NULL, type(st, dcl));
+		ty = tyfreshen(st, NULL, type(st, proto));
 	unify(st, n, type(st, n), ty);
 }
 
--- a/parse/type.c
+++ b/parse/type.c
@@ -868,6 +868,8 @@
 	func->decl.trait = tr;
 	func->decl.impls = mkht(tyhash, tyeq); 
 	func->decl.isgeneric = 1;
+	func->decl.isconst = 1;
+	func->decl.isglobl = 1;
 
 	lappend(&tr->funcs, &tr->nfuncs, func);
 	putdcl(st, func);
@@ -885,6 +887,8 @@
 	func->decl.trait = tr;
 	func->decl.impls = mkht(tyhash, tyeq); 
 	func->decl.isgeneric = 1;
+	func->decl.isconst = 1;
+	func->decl.isglobl = 1;
 
 	lappend(&tr->funcs, &tr->nfuncs, func);
 	putdcl(st, func);
--- /dev/null
+++ b/test/custiter.myr
@@ -1,0 +1,33 @@
+use std
+
+type range = struct
+	lo	: int
+	hi	: int
+;;
+
+/* iterate up to 100 */
+impl iterable range -> int =
+        __iternext__ = {rng, output
+		if rng.lo > rng.hi
+			-> false
+		else
+			output# = rng.lo++
+                        -> true
+                ;;
+        }
+
+        __iterfin__ = {it, val
+        }
+;;
+
+const main = {
+	var r : range
+	var x : int
+
+	r = [.lo=6, .hi=11]
+	for v in r
+		x = v
+                std.put("{}", x)
+        ;;
+	std.put("\n")
+}
--- a/test/tests
+++ b/test/tests
@@ -66,6 +66,7 @@
 B loop		P	0123401236789
 B subrangefor	P       12
 B patiter	P	23512
+B custiter	P	67891011
 B condiftrue	E	7
 B condiffalse	E	9
 B condifrel	E	7