shithub: mc

Download patch

ref: ffc55ee721c02d1caede8d7fbad9e8b687f01306
parent: 7ef2abad32fe3b273f16eeb28d63a63229dca3a6
author: Ori Bernstein <ori@eigenstate.org>
date: Thu Dec 31 20:05:50 EST 2015

Actually check array sizes when inferring.

--- a/mi/Makefile
+++ b/mi/Makefile
@@ -1,7 +1,6 @@
 LIB=libmi.a
 OBJ=cfg.o \
     dfcheck.o \
-    fold.o \
     match.o \
     reaching.o \
 
--- a/mi/fold.c
+++ /dev/null
@@ -1,242 +1,0 @@
-#include <stdlib.h>
-#include <stdio.h>
-#include <inttypes.h>
-#include <stdarg.h>
-#include <ctype.h>
-#include <string.h>
-#include <assert.h>
-#include <sys/types.h>
-#include <sys/stat.h>
-#include <fcntl.h>
-#include <unistd.h>
-
-#include "parse.h"
-#include "mi.h"
-
-static int getintlit(Node *n, vlong *v)
-{
-	Node *l;
-
-	if (exprop(n) != Olit)
-		return 0;
-	l = n->expr.args[0];
-	if (l->lit.littype != Lint)
-		return 0;
-	*v = l->lit.intval;
-	return 1;
-}
-
-static int isintval(Node *n, vlong val)
-{
-	vlong v;
-
-	if (!getintlit(n, &v))
-		return 0;
-	return v == val;
-}
-
-static Node *val(Srcloc loc, vlong val, Type *t)
-{
-	Node *l, *n;
-
-	l = mkint(loc, val);
-	n = mkexpr(loc, Olit, l, NULL);
-	l->lit.type = t;
-	n->expr.type = t;
-	return n;
-}
-
-static int issmallconst(Node *dcl)
-{
-	Type *t;
-
-	if (!dcl->decl.isconst)
-		return 0;
-	if (!dcl->decl.init)
-		return 0;
-	t = tybase(exprtype(dcl->decl.init));
-	if (t->type <= Tyflt64)
-		return 1;
-	return 0;
-}
-
-static Node *foldcast(Node *n)
-{
-	Type *to, *from;
-	Node *sub;
-
-	sub = n->expr.args[0];
-	to = exprtype(n);
-	from = exprtype(sub);
-
-	switch (tybase(to)->type) {
-	case Tybool:
-	case Tyint8: case Tyint16: case Tyint32: case Tyint64:
-	case Tyuint8: case Tyuint16: case Tyuint32: case Tyuint64:
-	case Tyint: case Tyuint: case Tychar: case Tybyte:
-	case Typtr:
-		switch (tybase(from)->type) {
-		case Tybool:
-		case Tyint8: case Tyint16: case Tyint32: case Tyint64:
-		case Tyuint8: case Tyuint16: case Tyuint32: case Tyuint64:
-		case Tyint: case Tyuint: case Tychar: case Tybyte:
-		case Typtr:
-			if (exprop(sub) == Olit || tybase(from)->type == tybase(to)->type) {
-				sub->expr.type = to;
-				return sub;
-			} else {
-				return n;
-			}
-		default:
-			return n;
-		}
-	default:
-		return n;
-	}
-	return n;
-}
-
-int idxcmp(const void *pa, const void *pb)
-{
-	Node *a, *b;
-	vlong av, bv;
-
-	a = *(Node **)pa;
-	b = *(Node **)pb;
-
-	assert(getintlit(a->expr.idx, &av));
-	assert(getintlit(b->expr.idx, &bv));
-
-	/* don't trust overflow with int64 */
-	if (av < bv)
-		return -1;
-	else if (av == bv)
-		return 0;
-	else
-		return 1;
-}
-
-Node *fold(Node *n, int foldvar)
-{
-	Node **args, *r;
-	Type *t;
-	vlong a, b;
-	size_t i;
-
-	if (!n)
-		return NULL;
-	if (n->type != Nexpr)
-		return n;
-
-	r = NULL;
-	args = n->expr.args;
-	if (n->expr.idx)
-		n->expr.idx = fold(n->expr.idx, foldvar);
-	for (i = 0; i < n->expr.nargs; i++)
-		args[i] = fold(args[i], foldvar);
-	switch (exprop(n)) {
-	case Ovar:
-		if (foldvar && issmallconst(decls[n->expr.did]))
-			r = fold(decls[n->expr.did]->decl.init, foldvar);
-		break;
-	case Oadd:
-		/* x + 0 = 0 */
-		if (isintval(args[0], 0))
-			r = args[1];
-		if (isintval(args[1], 0))
-			r = args[0];
-		if (getintlit(args[0], &a) && getintlit(args[1], &b))
-			r = val(n->loc, a + b, exprtype(n));
-		break;
-	case Osub:
-		/* x - 0 = 0 */
-		if (isintval(args[1], 0))
-			r = args[0];
-		if (getintlit(args[0], &a) && getintlit(args[1], &b))
-			r = val(n->loc, a - b, exprtype(n));
-		break;
-	case Omul:
-		/* 1 * x = x */
-		if (isintval(args[0], 1))
-			r = args[1];
-		if (isintval(args[1], 1))
-			r = args[0];
-		/* 0 * x = 0 */
-		if (isintval(args[0], 0))
-			r = args[0];
-		if (isintval(args[1], 0))
-			r = args[1];
-		if (getintlit(args[0], &a) && getintlit(args[1], &b))
-			r = val(n->loc, a * b, exprtype(n));
-		break;
-	case Odiv:
-		/* x/0 = error */
-		if (isintval(args[1], 0))
-			fatal(args[1], "division by zero");
-		/* x/1 = x */
-		if (isintval(args[1], 1))
-			r = args[0];
-		/* 0/x = 0 */
-		if (isintval(args[1], 0))
-			r = args[1];
-		if (getintlit(args[0], &a) && getintlit(args[1], &b))
-			r = val(n->loc, a / b, exprtype(n));
-		break;
-	case Omod:
-		/* x%1 = x */
-		if (isintval(args[1], 0))
-			r = args[0];
-		if (getintlit(args[0], &a) && getintlit(args[1], &b))
-			r = val(n->loc, a % b, exprtype(n));
-		break;
-	case Oneg:
-		if (getintlit(args[0], &a))
-			r = val(n->loc, -a, exprtype(n));
-		break;
-	case Obsl:
-		if (getintlit(args[0], &a) && getintlit(args[1], &b))
-			r = val(n->loc, a << b, exprtype(n));
-		break;
-	case Obsr:
-		if (getintlit(args[0], &a) && getintlit(args[1], &b))
-			r = val(n->loc, a >> b, exprtype(n));
-		break;
-	case Obor:
-		if (getintlit(args[0], &a) && getintlit(args[1], &b))
-			r = val(n->loc, a | b, exprtype(n));
-		break;
-	case Oband:
-		if (getintlit(args[0], &a) && getintlit(args[1], &b))
-			r = val(n->loc, a & b, exprtype(n));
-		break;
-	case Obxor:
-		if (getintlit(args[0], &a) && getintlit(args[1], &b))
-			r = val(n->loc, a ^ b, exprtype(n));
-		break;
-	case Omemb:
-		t = tybase(exprtype(args[0]));
-		/* we only fold lengths right now */
-		if (t->type == Tyarray && !strcmp(namestr(args[1]), "len")) {
-			r = t->asize;
-			r->expr.type = exprtype(n);
-		}
-		break;
-	case Oarr:
-		qsort(n->expr.args, n->expr.nargs, sizeof(Node*), idxcmp);
-		break;
-	case Ocast:
-		r = foldcast(n);
-		break;
-	default:
-		break;
-	}
-
-	if (r && n->expr.idx)
-		r->expr.idx = n->expr.idx;
-
-	if (r)
-		return r;
-	else
-		return n;
-}
-
--- a/mi/mi.h
+++ b/mi/mi.h
@@ -35,9 +35,6 @@
 	size_t *ndefs;
 };
 
-/* expression folding */
-Node *fold(Node *n, int foldvar);
-
 /* dataflow analysis */
 Reaching *reaching(Cfg *cfg);
 Node *assignee(Node *n);
--- a/parse/Makefile
+++ b/parse/Makefile
@@ -1,6 +1,7 @@
 LIB=libparse.a
 OBJ=bitset.o \
     dump.o \
+    fold.o \
     gram.o \
     htab.o \
     infer.o \
--- /dev/null
+++ b/parse/fold.c
@@ -1,0 +1,241 @@
+#include <stdlib.h>
+#include <stdio.h>
+#include <inttypes.h>
+#include <stdarg.h>
+#include <ctype.h>
+#include <string.h>
+#include <assert.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <fcntl.h>
+#include <unistd.h>
+
+#include "parse.h"
+
+static int getintlit(Node *n, vlong *v)
+{
+	Node *l;
+
+	if (exprop(n) != Olit)
+		return 0;
+	l = n->expr.args[0];
+	if (l->lit.littype != Lint)
+		return 0;
+	*v = l->lit.intval;
+	return 1;
+}
+
+static int isintval(Node *n, vlong val)
+{
+	vlong v;
+
+	if (!getintlit(n, &v))
+		return 0;
+	return v == val;
+}
+
+static Node *val(Srcloc loc, vlong val, Type *t)
+{
+	Node *l, *n;
+
+	l = mkint(loc, val);
+	n = mkexpr(loc, Olit, l, NULL);
+	l->lit.type = t;
+	n->expr.type = t;
+	return n;
+}
+
+static int issmallconst(Node *dcl)
+{
+	Type *t;
+
+	if (!dcl->decl.isconst)
+		return 0;
+	if (!dcl->decl.init)
+		return 0;
+	t = tybase(exprtype(dcl->decl.init));
+	if (t->type <= Tyflt64)
+		return 1;
+	return 0;
+}
+
+static Node *foldcast(Node *n)
+{
+	Type *to, *from;
+	Node *sub;
+
+	sub = n->expr.args[0];
+	to = exprtype(n);
+	from = exprtype(sub);
+
+	switch (tybase(to)->type) {
+	case Tybool:
+	case Tyint8: case Tyint16: case Tyint32: case Tyint64:
+	case Tyuint8: case Tyuint16: case Tyuint32: case Tyuint64:
+	case Tyint: case Tyuint: case Tychar: case Tybyte:
+	case Typtr:
+		switch (tybase(from)->type) {
+		case Tybool:
+		case Tyint8: case Tyint16: case Tyint32: case Tyint64:
+		case Tyuint8: case Tyuint16: case Tyuint32: case Tyuint64:
+		case Tyint: case Tyuint: case Tychar: case Tybyte:
+		case Typtr:
+			if (exprop(sub) == Olit || tybase(from)->type == tybase(to)->type) {
+				sub->expr.type = to;
+				return sub;
+			} else {
+				return n;
+			}
+		default:
+			return n;
+		}
+	default:
+		return n;
+	}
+	return n;
+}
+
+int idxcmp(const void *pa, const void *pb)
+{
+	Node *a, *b;
+	vlong av, bv;
+
+	a = *(Node **)pa;
+	b = *(Node **)pb;
+
+	assert(getintlit(a->expr.idx, &av));
+	assert(getintlit(b->expr.idx, &bv));
+
+	/* don't trust overflow with int64 */
+	if (av < bv)
+		return -1;
+	else if (av == bv)
+		return 0;
+	else
+		return 1;
+}
+
+Node *fold(Node *n, int foldvar)
+{
+	Node **args, *r;
+	Type *t;
+	vlong a, b;
+	size_t i;
+
+	if (!n)
+		return NULL;
+	if (n->type != Nexpr)
+		return n;
+
+	r = NULL;
+	args = n->expr.args;
+	if (n->expr.idx)
+		n->expr.idx = fold(n->expr.idx, foldvar);
+	for (i = 0; i < n->expr.nargs; i++)
+		args[i] = fold(args[i], foldvar);
+	switch (exprop(n)) {
+	case Ovar:
+		if (foldvar && issmallconst(decls[n->expr.did]))
+			r = fold(decls[n->expr.did]->decl.init, foldvar);
+		break;
+	case Oadd:
+		/* x + 0 = 0 */
+		if (isintval(args[0], 0))
+			r = args[1];
+		if (isintval(args[1], 0))
+			r = args[0];
+		if (getintlit(args[0], &a) && getintlit(args[1], &b))
+			r = val(n->loc, a + b, exprtype(n));
+		break;
+	case Osub:
+		/* x - 0 = 0 */
+		if (isintval(args[1], 0))
+			r = args[0];
+		if (getintlit(args[0], &a) && getintlit(args[1], &b))
+			r = val(n->loc, a - b, exprtype(n));
+		break;
+	case Omul:
+		/* 1 * x = x */
+		if (isintval(args[0], 1))
+			r = args[1];
+		if (isintval(args[1], 1))
+			r = args[0];
+		/* 0 * x = 0 */
+		if (isintval(args[0], 0))
+			r = args[0];
+		if (isintval(args[1], 0))
+			r = args[1];
+		if (getintlit(args[0], &a) && getintlit(args[1], &b))
+			r = val(n->loc, a * b, exprtype(n));
+		break;
+	case Odiv:
+		/* x/0 = error */
+		if (isintval(args[1], 0))
+			fatal(args[1], "division by zero");
+		/* x/1 = x */
+		if (isintval(args[1], 1))
+			r = args[0];
+		/* 0/x = 0 */
+		if (isintval(args[1], 0))
+			r = args[1];
+		if (getintlit(args[0], &a) && getintlit(args[1], &b))
+			r = val(n->loc, a / b, exprtype(n));
+		break;
+	case Omod:
+		/* x%1 = x */
+		if (isintval(args[1], 0))
+			r = args[0];
+		if (getintlit(args[0], &a) && getintlit(args[1], &b))
+			r = val(n->loc, a % b, exprtype(n));
+		break;
+	case Oneg:
+		if (getintlit(args[0], &a))
+			r = val(n->loc, -a, exprtype(n));
+		break;
+	case Obsl:
+		if (getintlit(args[0], &a) && getintlit(args[1], &b))
+			r = val(n->loc, a << b, exprtype(n));
+		break;
+	case Obsr:
+		if (getintlit(args[0], &a) && getintlit(args[1], &b))
+			r = val(n->loc, a >> b, exprtype(n));
+		break;
+	case Obor:
+		if (getintlit(args[0], &a) && getintlit(args[1], &b))
+			r = val(n->loc, a | b, exprtype(n));
+		break;
+	case Oband:
+		if (getintlit(args[0], &a) && getintlit(args[1], &b))
+			r = val(n->loc, a & b, exprtype(n));
+		break;
+	case Obxor:
+		if (getintlit(args[0], &a) && getintlit(args[1], &b))
+			r = val(n->loc, a ^ b, exprtype(n));
+		break;
+	case Omemb:
+		t = tybase(exprtype(args[0]));
+		/* we only fold lengths right now */
+		if (t->type == Tyarray && !strcmp(namestr(args[1]), "len")) {
+			r = t->asize;
+			r->expr.type = exprtype(n);
+		}
+		break;
+	case Oarr:
+		qsort(n->expr.args, n->expr.nargs, sizeof(Node*), idxcmp);
+		break;
+	case Ocast:
+		r = foldcast(n);
+		break;
+	default:
+		break;
+	}
+
+	if (r && n->expr.idx)
+		r->expr.idx = n->expr.idx;
+
+	if (r)
+		return r;
+	else
+		return n;
+}
+
--- a/parse/infer.c
+++ b/parse/infer.c
@@ -901,6 +901,29 @@
 	return t;
 }
 
+static void checksize(Inferstate *st, Node *ctx, Type *a, Type *b)
+{
+	if (a->asize)
+		a->asize = fold(a->asize, 1);
+	if (b->asize)
+		b->asize = fold(b->asize, 1);
+	if (a->asize && exprop(a->asize) != Olit)
+		lfatal(ctx->loc, "%s: array size is not constant near %s",
+				tystr(a), ctxstr(st, ctx));
+	if (a->asize && exprop(b->asize) != Olit)
+		lfatal(ctx->loc, "%s: array size is not constant near %s",
+				tystr(b), ctxstr(st, ctx));
+	if (!a->asize)
+		a->asize = b->asize;
+	else if (!b->asize)
+		b->asize = a->asize;
+	else if (a->asize && b->asize)
+		if (!litvaleq(a->asize->expr.args[0], b->asize->expr.args[0]))
+			lfatal(ctx->loc, "array size of %s does not match %s near %s",
+				tystr(a), tystr(b), ctxstr(st, ctx));
+}
+
+
 /* Unifies two types, or errors if the types are not unifiable. */
 static Type *unify(Inferstate *st, Node *ctx, Type *u, Type *v)
 {
@@ -947,6 +970,10 @@
 	if (a->type == Tyvar && b->type != Tyvar) {
 		if (occurs(a, b))
 			typeerror(st, a, b, ctx, "Infinite type\n");
+	}
+
+	if (a->type == Tyarray && b->type == Tyarray) {
+		checksize(st, ctx, a, b);
 	}
 
 	/* if the tyrank of a is 0 (ie, a raw tyvar), just unify.
--- a/parse/node.c
+++ b/parse/node.c
@@ -429,10 +429,13 @@
 
 int liteq(Node *a, Node *b)
 {
+	return litvaleq(a, b) && tyeq(a->lit.type, b->lit.type);
+}
+
+int litvaleq(Node *a, Node *b)
+{
 	assert(a->type == Nlit && b->type == Nlit);
 	if (a->lit.littype != b->lit.littype)
-		return 0;
-	if (!tyeq(a->lit.type, b->lit.type))
 		return 0;
 	switch (a->lit.littype) {
 	case Lvoid:	return 1;
--- a/parse/parse.h
+++ b/parse/parse.h
@@ -454,6 +454,7 @@
 void **htkeys(Htab *ht, size_t *nkeys);
 /* useful key types */
 int liteq(Node *a, Node *b);
+int litvaleq(Node *a, Node *b);
 ulong strhash(void *key);
 int streq(void *a, void *b);
 ulong strlithash(void *key);
@@ -623,6 +624,9 @@
 void writeuse(FILE *fd, Node *file);
 void tagexports(Node *file, int hidelocal);
 void addextlibs(Node *file, char **libs, size_t nlibs);
+
+/* expression folding */
+Node *fold(Node *n, int foldvar);
 
 /* typechecking/inference */
 void infer(Node *file);