shithub: sl

ref: 1820b6e70a22e92f69dab14a1d2368b98b96d501
dir: /src/equal.c/

View raw version
#include "sl.h"
#include "operators.h"
#include "cvalues.h"
#include "equal.h"
#include "hashing.h"

#define BOUNDED_COMPARE_BOUND 128
#define BOUNDED_HASH_BOUND 16384

#if defined(BITS64)
#define MIX(a, b) inthash((sl_v)(a) ^ (sl_v)(b));
#define doublehash(a) inthash(a)
#else
#define MIX(a, b) int64to32hash((u64int)(a)<<32 | (u64int)(b));
#define doublehash(a) int64to32hash(a)
#endif

// comparable tag
#define cmptag(v) (isfixnum(v) ? TAG_NUM : tag(v))

static sl_v
eq_class(sl_htable *table, sl_v key)
{
	sl_v c = (sl_v)ptrhash_get(table, (void*)key);
	if(c == (sl_v)HT_NOTFOUND)
		return sl_nil;
	if(c == key)
		return c;
	return eq_class(table, c);
}

static void
eq_union(sl_htable *table, sl_v a, sl_v b, sl_v c, sl_v cb)
{
	sl_v ca = c == sl_nil ? a : c;
	if(cb != sl_nil)
		ptrhash_put(table, (void*)cb, (void*)ca);
	ptrhash_put(table, (void*)a, (void*)ca);
	ptrhash_put(table, (void*)b, (void*)ca);
}

static sl_v bounded_compare(sl_v a, sl_v b, int bound, bool eq);
static sl_v cyc_compare(sl_v a, sl_v b, sl_htable *table, bool eq);

static sl_v
bounded_vector_compare(sl_v a, sl_v b, int bound, bool eq)
{
	usize la = vector_size(a);
	usize lb = vector_size(b);
	usize m, i;
	if(eq && la != lb)
		return fixnum(1);
	m = la < lb ? la : lb;
	for(i = 0; i < m; i++){
		sl_v d = bounded_compare(vector_elt(a, i), vector_elt(b, i), bound-1, eq);
		if(d == sl_nil || numval(d) != 0)
			return d;
	}
	if(la < lb)
		return fixnum(-1);
	if(la > lb)
		return fixnum(1);
	return fixnum(0);
}

// strange comparisons are resolved arbitrarily but consistently.
// ordering: number < cprim < function < vector < cvalue < symbol < cons
static sl_v
bounded_compare(sl_v a, sl_v b, int bound, bool eq)
{
	sl_v d;
	csl_v *cv;

compare_top:
	if(a == b)
		return fixnum(0);
	if(bound <= 0)
		return sl_nil;
	int taga = tag(a);
	int tagb = cmptag(b);
	int c;
	switch(taga){
	case TAG_NUM :
	case TAG_NUM1:
		if(isfixnum(b))
			return (sl_fx)a < (sl_fx)b ? fixnum(-1) : fixnum(1);
		if(iscprim(b)){
			if(cp_class(ptr(b)) == sl_runetype)
				return fixnum(1);
			return fixnum(numeric_compare(a, b, eq, true, false));
		}
		if(iscvalue(b)){
			cv = ptr(b);
			if(valid_numtype(cv_class(cv)->numtype))
				return fixnum(numeric_compare(a, b, eq, true, false));
		}
		return fixnum(-1);
	case TAG_SYM:
		if(eq || tagb < TAG_SYM)
			return fixnum(1);
		if(tagb > TAG_SYM)
			return fixnum(-1);
		return fixnum(strcmp(symbol_name(a), symbol_name(b)));
	case TAG_VECTOR:
		if(isvector(b))
			return bounded_vector_compare(a, b, bound, eq);
		break;
	case TAG_CPRIM:
		if(cp_class(ptr(a)) == sl_runetype){
			if(!iscprim(b) || cp_class(ptr(b)) != sl_runetype)
				return fixnum(-1);
		}else if(iscprim(b) && cp_class(ptr(b)) == sl_runetype)
			return fixnum(1);
		c = numeric_compare(a, b, eq, true, false);
		if(c != 2)
			return fixnum(c);
		break;
	case TAG_CVALUE:
		cv = ptr(a);
		if(valid_numtype(cv_class(cv)->numtype)){
			if((c = numeric_compare(a, b, eq, true, false)) != 2)
				return fixnum(c);
		}
		if(iscvalue(b)){
			if(cv_isPOD(ptr(a)) && cv_isPOD(ptr(b)))
				return cvalue_compare(a, b);
			return fixnum(1);
		}
		break;
	case TAG_FUNCTION:
		if(tagb == TAG_FUNCTION){
			if(uintval(a) > N_BUILTINS && uintval(b) > N_BUILTINS){
				sl_fn *fa = ptr(a);
				sl_fn *fb = ptr(b);
				d = bounded_compare(fa->bcode, fb->bcode, bound-1, eq);
				if(d == sl_nil || numval(d) != 0)
					return d;
				d = bounded_compare(fa->vals, fb->vals, bound-1, eq);
				if(d == sl_nil || numval(d) != 0)
					return d;
				d = bounded_compare(fa->env, fb->env, bound-1, eq);
				if(d == sl_nil || numval(d) != 0)
					return d;
				return fixnum(0);
			}
			return uintval(a) < uintval(b) ? fixnum(-1) : fixnum(1);
		}
		break;
	case TAG_CONS:
		if(tagb < TAG_CONS)
			return fixnum(1);
		d = bounded_compare(car_(a), car_(b), bound-1, eq);
		if(d == sl_nil || numval(d) != 0)
			return d;
		a = cdr_(a); b = cdr_(b);
		bound--;
		goto compare_top;
	}
	return taga < tagb ? fixnum(-1) : fixnum(1);
}

static sl_v
cyc_vector_compare(sl_v a, sl_v b, sl_htable *table, bool eq)
{
	usize la = vector_size(a);
	usize lb = vector_size(b);
	usize m, i;
	sl_v d, xa, xb, ca, cb;

	// first try to prove them different with no recursion
	if(eq && la != lb)
		return fixnum(1);
	m = la < lb ? la : lb;
	for(i = 0; i < m; i++){
		xa = vector_elt(a, i);
		xb = vector_elt(b, i);
		if(leafp(xa) || leafp(xb)){
			d = bounded_compare(xa, xb, 1, eq);
			if(d != sl_nil && numval(d) != 0)
				return d;
		}else if(tag(xa) < tag(xb))
			return fixnum(-1);
		else if(tag(xa) > tag(xb))
			return fixnum(1);
	}

	ca = eq_class(table, a);
	cb = eq_class(table, b);
	if(ca != sl_nil && ca == cb)
		return fixnum(0);

	eq_union(table, a, b, ca, cb);

	for(i = 0; i < m; i++){
		xa = vector_elt(a, i);
		xb = vector_elt(b, i);
		if(!leafp(xa) || tag(xa) == TAG_FUNCTION){
			d = cyc_compare(xa, xb, table, eq);
			if(numval(d) != 0)
				return d;
		}
	}

	if(la < lb)
		return fixnum(-1);
	if(la > lb)
		return fixnum(1);
	return fixnum(0);
}

static sl_v
cyc_compare(sl_v a, sl_v b, sl_htable *table, bool eq)
{
	sl_v d, ca, cb;
cyc_compare_top:
	if(a == b)
		return fixnum(0);
	if(iscons(a)){
		if(iscons(b)){
			sl_v aa = car_(a);
			sl_v da = cdr_(a);
			sl_v ab = car_(b);
			sl_v db = cdr_(b);
			int tagaa = tag(aa);
			int tagda = tag(da);
			int tagab = tag(ab);
			int tagdb = tag(db);
			if(leafp(aa) || leafp(ab)){
				d = bounded_compare(aa, ab, 1, eq);
				if(d != sl_nil && numval(d) != 0)
					return d;
			}
			if(tagaa < tagab)
				return fixnum(-1);
			if(tagaa > tagab)
				return fixnum(1);
			if(leafp(da) || leafp(db)){
				d = bounded_compare(da, db, 1, eq);
				if(d != sl_nil && numval(d) != 0)
					return d;
			}
			if(tagda < tagdb)
				return fixnum(-1);
			if(tagda > tagdb)
				return fixnum(1);

			ca = eq_class(table, a);
			cb = eq_class(table, b);
			if(ca != sl_nil && ca == cb)
				return fixnum(0);

			eq_union(table, a, b, ca, cb);
			d = cyc_compare(aa, ab, table, eq);
			if(numval(d) != 0)
				return d;
			a = da;
			b = db;
			goto cyc_compare_top;
		}else{
			return fixnum(1);
		}
	}
	if(isvector(a) && isvector(b))
		return cyc_vector_compare(a, b, table, eq);
	if(isfunction(a) && isfunction(b)){
		sl_fn *fa = ptr(a);
		sl_fn *fb = ptr(b);
		d = bounded_compare(fa->bcode, fb->bcode, 1, eq);
		if(numval(d) != 0)
			return d;

		ca = eq_class(table, a);
		cb = eq_class(table, b);
		if(ca != sl_nil && ca == cb)
			return fixnum(0);

		eq_union(table, a, b, ca, cb);
		d = cyc_compare(fa->vals, fb->vals, table, eq);
		if(numval(d) != 0)
			return d;
		a = fa->env;
		b = fb->env;
		goto cyc_compare_top;
	}
	return bounded_compare(a, b, 1, eq);
}

static sl_htable equal_eq_hashtable;

void
comparehash_init(void)
{
	htable_new(&equal_eq_hashtable, 512);
}

// 'eq' means unordered comparison is sufficient
sl_v
sl_compare(sl_v a, sl_v b, bool eq)
{
	sl_v guess = bounded_compare(a, b, BOUNDED_COMPARE_BOUND, eq);
	if(guess == sl_nil){
		guess = cyc_compare(a, b, &equal_eq_hashtable, eq);
		htable_reset(&equal_eq_hashtable, 512);
	}
	return guess;
}

/*
  optimizations:
  - use hash updates instead of calling lookup then insert. i.e. get the
	bp once and use it twice.
  * preallocate hash table and call reset() instead of new/free
  * less redundant tag checking, 3-bit tags
*/

// *oob: output argument, means we hit the limit specified by 'bound'
static uintptr
bounded_hash(sl_v a, int bound, bool *oob)
{
	union {
		double d;
		s64int i64;
	}u;
	sl_numtype nt;
	usize i, len;
	csl_v *cv;
	sl_cprim *cp;
	void *data;
	uintptr h = 0;
	int tg = tag(a);
	bool oob2;

	*oob = false;
	switch(tg){
	case TAG_NUM :
	case TAG_NUM1:
		u.d = (double)numval(a);
		return doublehash(u.i64);
	case TAG_FUNCTION:
		if(uintval(a) > N_BUILTINS)
			return bounded_hash(((sl_fn*)ptr(a))->bcode, bound, oob);
		return inthash(a);
	case TAG_SYM:
		return ((sl_sym*)ptr(a))->hash;
	case TAG_CPRIM:
		cp = ptr(a);
		data = cp_data(cp);
		if(cp_class(cp) == sl_runetype)
			return inthash(*(Rune*)data);
		nt = cp_numtype(cp);
		u.d = conv_to_double(data, nt);
		return doublehash(u.i64);
	case TAG_CVALUE:
		cv = ptr(a);
		data = cv_data(cv);
		if(cv->type == sl_mptype){
			len = mptobe(*(mpint**)data, nil, 0, (u8int**)&data);
			h = memhash(data, len);
			MEM_FREE(data);
		}else{
			h = memhash(data, cv_len(cv));
		}
		return h;

	case TAG_VECTOR:
		if(bound <= 0){
			*oob = true;
			return 1;
		}
		len = vector_size(a);
		for(i = 0; i < len; i++){
			h = MIX(h, bounded_hash(vector_elt(a, i), bound/2, &oob2)^1);
			if(oob2)
				bound /= 2;
			*oob = *oob || oob2;
		}
		return h;

	case TAG_CONS:
		do{
			if(bound <= 0){
				*oob = true;
				return h;
			}
			h = MIX(h, bounded_hash(car_(a), bound/2, &oob2));
			// bounds balancing: try to share the bounds efficiently
			// so we can hash better when a list is cdr-deep (a common case)
			if(oob2)
				bound /= 2;
			else
				bound--;
			// recursive OOB propagation. otherwise this case is slow:
			// (hash '#2=((#0=(#1=(#1#) . #0#)) . #2#))
			*oob = *oob || oob2;
			a = cdr_(a);
		}while(iscons(a));
		h = MIX(h, bounded_hash(a, bound-1, &oob2)^2);
		*oob = *oob || oob2;
		return h;
	}
	return 0;
}

int
equal_lispvalue(sl_v a, sl_v b)
{
	if(eq_comparable(a, b))
		return a == b;
	return numval(sl_compare(a, b, true)) == 0;
}

uintptr
hash_lispvalue(sl_v a)
{
	bool oob = false;
	return bounded_hash(a, BOUNDED_HASH_BOUND, &oob);
}

BUILTIN("hash", hash)
{
	argcount(nargs, 1);
	return fixnum(hash_lispvalue(args[0]));
}