ref: ffc55ee721c02d1caede8d7fbad9e8b687f01306
parent: 7ef2abad32fe3b273f16eeb28d63a63229dca3a6
author: Ori Bernstein <ori@eigenstate.org>
date: Thu Dec 31 20:05:50 EST 2015
Actually check array sizes when inferring.
--- a/mi/Makefile
+++ b/mi/Makefile
@@ -1,7 +1,6 @@
LIB=libmi.a
OBJ=cfg.o \
dfcheck.o \
- fold.o \
match.o \
reaching.o \
--- a/mi/fold.c
+++ /dev/null
@@ -1,242 +1,0 @@
-#include <stdlib.h>
-#include <stdio.h>
-#include <inttypes.h>
-#include <stdarg.h>
-#include <ctype.h>
-#include <string.h>
-#include <assert.h>
-#include <sys/types.h>
-#include <sys/stat.h>
-#include <fcntl.h>
-#include <unistd.h>
-
-#include "parse.h"
-#include "mi.h"
-
-static int getintlit(Node *n, vlong *v)
-{
- Node *l;
-
- if (exprop(n) != Olit)
- return 0;
- l = n->expr.args[0];
- if (l->lit.littype != Lint)
- return 0;
- *v = l->lit.intval;
- return 1;
-}
-
-static int isintval(Node *n, vlong val)
-{
- vlong v;
-
- if (!getintlit(n, &v))
- return 0;
- return v == val;
-}
-
-static Node *val(Srcloc loc, vlong val, Type *t)
-{
- Node *l, *n;
-
- l = mkint(loc, val);
- n = mkexpr(loc, Olit, l, NULL);
- l->lit.type = t;
- n->expr.type = t;
- return n;
-}
-
-static int issmallconst(Node *dcl)
-{
- Type *t;
-
- if (!dcl->decl.isconst)
- return 0;
- if (!dcl->decl.init)
- return 0;
- t = tybase(exprtype(dcl->decl.init));
- if (t->type <= Tyflt64)
- return 1;
- return 0;
-}
-
-static Node *foldcast(Node *n)
-{
- Type *to, *from;
- Node *sub;
-
- sub = n->expr.args[0];
- to = exprtype(n);
- from = exprtype(sub);
-
- switch (tybase(to)->type) {
- case Tybool:
- case Tyint8: case Tyint16: case Tyint32: case Tyint64:
- case Tyuint8: case Tyuint16: case Tyuint32: case Tyuint64:
- case Tyint: case Tyuint: case Tychar: case Tybyte:
- case Typtr:
- switch (tybase(from)->type) {
- case Tybool:
- case Tyint8: case Tyint16: case Tyint32: case Tyint64:
- case Tyuint8: case Tyuint16: case Tyuint32: case Tyuint64:
- case Tyint: case Tyuint: case Tychar: case Tybyte:
- case Typtr:
- if (exprop(sub) == Olit || tybase(from)->type == tybase(to)->type) {
- sub->expr.type = to;
- return sub;
- } else {
- return n;
- }
- default:
- return n;
- }
- default:
- return n;
- }
- return n;
-}
-
-int idxcmp(const void *pa, const void *pb)
-{
- Node *a, *b;
- vlong av, bv;
-
- a = *(Node **)pa;
- b = *(Node **)pb;
-
- assert(getintlit(a->expr.idx, &av));
- assert(getintlit(b->expr.idx, &bv));
-
- /* don't trust overflow with int64 */
- if (av < bv)
- return -1;
- else if (av == bv)
- return 0;
- else
- return 1;
-}
-
-Node *fold(Node *n, int foldvar)
-{
- Node **args, *r;
- Type *t;
- vlong a, b;
- size_t i;
-
- if (!n)
- return NULL;
- if (n->type != Nexpr)
- return n;
-
- r = NULL;
- args = n->expr.args;
- if (n->expr.idx)
- n->expr.idx = fold(n->expr.idx, foldvar);
- for (i = 0; i < n->expr.nargs; i++)
- args[i] = fold(args[i], foldvar);
- switch (exprop(n)) {
- case Ovar:
- if (foldvar && issmallconst(decls[n->expr.did]))
- r = fold(decls[n->expr.did]->decl.init, foldvar);
- break;
- case Oadd:
- /* x + 0 = 0 */
- if (isintval(args[0], 0))
- r = args[1];
- if (isintval(args[1], 0))
- r = args[0];
- if (getintlit(args[0], &a) && getintlit(args[1], &b))
- r = val(n->loc, a + b, exprtype(n));
- break;
- case Osub:
- /* x - 0 = 0 */
- if (isintval(args[1], 0))
- r = args[0];
- if (getintlit(args[0], &a) && getintlit(args[1], &b))
- r = val(n->loc, a - b, exprtype(n));
- break;
- case Omul:
- /* 1 * x = x */
- if (isintval(args[0], 1))
- r = args[1];
- if (isintval(args[1], 1))
- r = args[0];
- /* 0 * x = 0 */
- if (isintval(args[0], 0))
- r = args[0];
- if (isintval(args[1], 0))
- r = args[1];
- if (getintlit(args[0], &a) && getintlit(args[1], &b))
- r = val(n->loc, a * b, exprtype(n));
- break;
- case Odiv:
- /* x/0 = error */
- if (isintval(args[1], 0))
- fatal(args[1], "division by zero");
- /* x/1 = x */
- if (isintval(args[1], 1))
- r = args[0];
- /* 0/x = 0 */
- if (isintval(args[1], 0))
- r = args[1];
- if (getintlit(args[0], &a) && getintlit(args[1], &b))
- r = val(n->loc, a / b, exprtype(n));
- break;
- case Omod:
- /* x%1 = x */
- if (isintval(args[1], 0))
- r = args[0];
- if (getintlit(args[0], &a) && getintlit(args[1], &b))
- r = val(n->loc, a % b, exprtype(n));
- break;
- case Oneg:
- if (getintlit(args[0], &a))
- r = val(n->loc, -a, exprtype(n));
- break;
- case Obsl:
- if (getintlit(args[0], &a) && getintlit(args[1], &b))
- r = val(n->loc, a << b, exprtype(n));
- break;
- case Obsr:
- if (getintlit(args[0], &a) && getintlit(args[1], &b))
- r = val(n->loc, a >> b, exprtype(n));
- break;
- case Obor:
- if (getintlit(args[0], &a) && getintlit(args[1], &b))
- r = val(n->loc, a | b, exprtype(n));
- break;
- case Oband:
- if (getintlit(args[0], &a) && getintlit(args[1], &b))
- r = val(n->loc, a & b, exprtype(n));
- break;
- case Obxor:
- if (getintlit(args[0], &a) && getintlit(args[1], &b))
- r = val(n->loc, a ^ b, exprtype(n));
- break;
- case Omemb:
- t = tybase(exprtype(args[0]));
- /* we only fold lengths right now */
- if (t->type == Tyarray && !strcmp(namestr(args[1]), "len")) {
- r = t->asize;
- r->expr.type = exprtype(n);
- }
- break;
- case Oarr:
- qsort(n->expr.args, n->expr.nargs, sizeof(Node*), idxcmp);
- break;
- case Ocast:
- r = foldcast(n);
- break;
- default:
- break;
- }
-
- if (r && n->expr.idx)
- r->expr.idx = n->expr.idx;
-
- if (r)
- return r;
- else
- return n;
-}
-
--- a/mi/mi.h
+++ b/mi/mi.h
@@ -35,9 +35,6 @@
size_t *ndefs;
};
-/* expression folding */
-Node *fold(Node *n, int foldvar);
-
/* dataflow analysis */
Reaching *reaching(Cfg *cfg);
Node *assignee(Node *n);
--- a/parse/Makefile
+++ b/parse/Makefile
@@ -1,6 +1,7 @@
LIB=libparse.a
OBJ=bitset.o \
dump.o \
+ fold.o \
gram.o \
htab.o \
infer.o \
--- /dev/null
+++ b/parse/fold.c
@@ -1,0 +1,241 @@
+#include <stdlib.h>
+#include <stdio.h>
+#include <inttypes.h>
+#include <stdarg.h>
+#include <ctype.h>
+#include <string.h>
+#include <assert.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <fcntl.h>
+#include <unistd.h>
+
+#include "parse.h"
+
+static int getintlit(Node *n, vlong *v)
+{
+ Node *l;
+
+ if (exprop(n) != Olit)
+ return 0;
+ l = n->expr.args[0];
+ if (l->lit.littype != Lint)
+ return 0;
+ *v = l->lit.intval;
+ return 1;
+}
+
+static int isintval(Node *n, vlong val)
+{
+ vlong v;
+
+ if (!getintlit(n, &v))
+ return 0;
+ return v == val;
+}
+
+static Node *val(Srcloc loc, vlong val, Type *t)
+{
+ Node *l, *n;
+
+ l = mkint(loc, val);
+ n = mkexpr(loc, Olit, l, NULL);
+ l->lit.type = t;
+ n->expr.type = t;
+ return n;
+}
+
+static int issmallconst(Node *dcl)
+{
+ Type *t;
+
+ if (!dcl->decl.isconst)
+ return 0;
+ if (!dcl->decl.init)
+ return 0;
+ t = tybase(exprtype(dcl->decl.init));
+ if (t->type <= Tyflt64)
+ return 1;
+ return 0;
+}
+
+static Node *foldcast(Node *n)
+{
+ Type *to, *from;
+ Node *sub;
+
+ sub = n->expr.args[0];
+ to = exprtype(n);
+ from = exprtype(sub);
+
+ switch (tybase(to)->type) {
+ case Tybool:
+ case Tyint8: case Tyint16: case Tyint32: case Tyint64:
+ case Tyuint8: case Tyuint16: case Tyuint32: case Tyuint64:
+ case Tyint: case Tyuint: case Tychar: case Tybyte:
+ case Typtr:
+ switch (tybase(from)->type) {
+ case Tybool:
+ case Tyint8: case Tyint16: case Tyint32: case Tyint64:
+ case Tyuint8: case Tyuint16: case Tyuint32: case Tyuint64:
+ case Tyint: case Tyuint: case Tychar: case Tybyte:
+ case Typtr:
+ if (exprop(sub) == Olit || tybase(from)->type == tybase(to)->type) {
+ sub->expr.type = to;
+ return sub;
+ } else {
+ return n;
+ }
+ default:
+ return n;
+ }
+ default:
+ return n;
+ }
+ return n;
+}
+
+int idxcmp(const void *pa, const void *pb)
+{
+ Node *a, *b;
+ vlong av, bv;
+
+ a = *(Node **)pa;
+ b = *(Node **)pb;
+
+ assert(getintlit(a->expr.idx, &av));
+ assert(getintlit(b->expr.idx, &bv));
+
+ /* don't trust overflow with int64 */
+ if (av < bv)
+ return -1;
+ else if (av == bv)
+ return 0;
+ else
+ return 1;
+}
+
+Node *fold(Node *n, int foldvar)
+{
+ Node **args, *r;
+ Type *t;
+ vlong a, b;
+ size_t i;
+
+ if (!n)
+ return NULL;
+ if (n->type != Nexpr)
+ return n;
+
+ r = NULL;
+ args = n->expr.args;
+ if (n->expr.idx)
+ n->expr.idx = fold(n->expr.idx, foldvar);
+ for (i = 0; i < n->expr.nargs; i++)
+ args[i] = fold(args[i], foldvar);
+ switch (exprop(n)) {
+ case Ovar:
+ if (foldvar && issmallconst(decls[n->expr.did]))
+ r = fold(decls[n->expr.did]->decl.init, foldvar);
+ break;
+ case Oadd:
+ /* x + 0 = 0 */
+ if (isintval(args[0], 0))
+ r = args[1];
+ if (isintval(args[1], 0))
+ r = args[0];
+ if (getintlit(args[0], &a) && getintlit(args[1], &b))
+ r = val(n->loc, a + b, exprtype(n));
+ break;
+ case Osub:
+ /* x - 0 = 0 */
+ if (isintval(args[1], 0))
+ r = args[0];
+ if (getintlit(args[0], &a) && getintlit(args[1], &b))
+ r = val(n->loc, a - b, exprtype(n));
+ break;
+ case Omul:
+ /* 1 * x = x */
+ if (isintval(args[0], 1))
+ r = args[1];
+ if (isintval(args[1], 1))
+ r = args[0];
+ /* 0 * x = 0 */
+ if (isintval(args[0], 0))
+ r = args[0];
+ if (isintval(args[1], 0))
+ r = args[1];
+ if (getintlit(args[0], &a) && getintlit(args[1], &b))
+ r = val(n->loc, a * b, exprtype(n));
+ break;
+ case Odiv:
+ /* x/0 = error */
+ if (isintval(args[1], 0))
+ fatal(args[1], "division by zero");
+ /* x/1 = x */
+ if (isintval(args[1], 1))
+ r = args[0];
+ /* 0/x = 0 */
+ if (isintval(args[1], 0))
+ r = args[1];
+ if (getintlit(args[0], &a) && getintlit(args[1], &b))
+ r = val(n->loc, a / b, exprtype(n));
+ break;
+ case Omod:
+ /* x%1 = x */
+ if (isintval(args[1], 0))
+ r = args[0];
+ if (getintlit(args[0], &a) && getintlit(args[1], &b))
+ r = val(n->loc, a % b, exprtype(n));
+ break;
+ case Oneg:
+ if (getintlit(args[0], &a))
+ r = val(n->loc, -a, exprtype(n));
+ break;
+ case Obsl:
+ if (getintlit(args[0], &a) && getintlit(args[1], &b))
+ r = val(n->loc, a << b, exprtype(n));
+ break;
+ case Obsr:
+ if (getintlit(args[0], &a) && getintlit(args[1], &b))
+ r = val(n->loc, a >> b, exprtype(n));
+ break;
+ case Obor:
+ if (getintlit(args[0], &a) && getintlit(args[1], &b))
+ r = val(n->loc, a | b, exprtype(n));
+ break;
+ case Oband:
+ if (getintlit(args[0], &a) && getintlit(args[1], &b))
+ r = val(n->loc, a & b, exprtype(n));
+ break;
+ case Obxor:
+ if (getintlit(args[0], &a) && getintlit(args[1], &b))
+ r = val(n->loc, a ^ b, exprtype(n));
+ break;
+ case Omemb:
+ t = tybase(exprtype(args[0]));
+ /* we only fold lengths right now */
+ if (t->type == Tyarray && !strcmp(namestr(args[1]), "len")) {
+ r = t->asize;
+ r->expr.type = exprtype(n);
+ }
+ break;
+ case Oarr:
+ qsort(n->expr.args, n->expr.nargs, sizeof(Node*), idxcmp);
+ break;
+ case Ocast:
+ r = foldcast(n);
+ break;
+ default:
+ break;
+ }
+
+ if (r && n->expr.idx)
+ r->expr.idx = n->expr.idx;
+
+ if (r)
+ return r;
+ else
+ return n;
+}
+
--- a/parse/infer.c
+++ b/parse/infer.c
@@ -901,6 +901,29 @@
return t;
}
+static void checksize(Inferstate *st, Node *ctx, Type *a, Type *b)
+{
+ if (a->asize)
+ a->asize = fold(a->asize, 1);
+ if (b->asize)
+ b->asize = fold(b->asize, 1);
+ if (a->asize && exprop(a->asize) != Olit)
+ lfatal(ctx->loc, "%s: array size is not constant near %s",
+ tystr(a), ctxstr(st, ctx));
+ if (a->asize && exprop(b->asize) != Olit)
+ lfatal(ctx->loc, "%s: array size is not constant near %s",
+ tystr(b), ctxstr(st, ctx));
+ if (!a->asize)
+ a->asize = b->asize;
+ else if (!b->asize)
+ b->asize = a->asize;
+ else if (a->asize && b->asize)
+ if (!litvaleq(a->asize->expr.args[0], b->asize->expr.args[0]))
+ lfatal(ctx->loc, "array size of %s does not match %s near %s",
+ tystr(a), tystr(b), ctxstr(st, ctx));
+}
+
+
/* Unifies two types, or errors if the types are not unifiable. */
static Type *unify(Inferstate *st, Node *ctx, Type *u, Type *v)
{
@@ -947,6 +970,10 @@
if (a->type == Tyvar && b->type != Tyvar) {
if (occurs(a, b))
typeerror(st, a, b, ctx, "Infinite type\n");
+ }
+
+ if (a->type == Tyarray && b->type == Tyarray) {
+ checksize(st, ctx, a, b);
}
/* if the tyrank of a is 0 (ie, a raw tyvar), just unify.
--- a/parse/node.c
+++ b/parse/node.c
@@ -429,10 +429,13 @@
int liteq(Node *a, Node *b)
{
+ return litvaleq(a, b) && tyeq(a->lit.type, b->lit.type);
+}
+
+int litvaleq(Node *a, Node *b)
+{
assert(a->type == Nlit && b->type == Nlit);
if (a->lit.littype != b->lit.littype)
- return 0;
- if (!tyeq(a->lit.type, b->lit.type))
return 0;
switch (a->lit.littype) {
case Lvoid: return 1;
--- a/parse/parse.h
+++ b/parse/parse.h
@@ -454,6 +454,7 @@
void **htkeys(Htab *ht, size_t *nkeys);
/* useful key types */
int liteq(Node *a, Node *b);
+int litvaleq(Node *a, Node *b);
ulong strhash(void *key);
int streq(void *a, void *b);
ulong strlithash(void *key);
@@ -623,6 +624,9 @@
void writeuse(FILE *fd, Node *file);
void tagexports(Node *file, int hidelocal);
void addextlibs(Node *file, char **libs, size_t nlibs);
+
+/* expression folding */
+Node *fold(Node *n, int foldvar);
/* typechecking/inference */
void infer(Node *file);