shithub: sl

ref: a70379d7e4b822f532fb0a8ccdd1624a90b64a68
dir: /src/operators.c/

View raw version
#include "sl.h"
#include "operators.h"

mpint *
conv_to_mp(void *data, sl_numtype tag)
{
	switch(tag){
	case T_S8:  return itomp(*(s8int*)data, nil);
	case T_U8:  return uitomp(*(u8int*)data, nil);
	case T_S16: return itomp(*(s16int*)data, nil);
	case T_U16: return uitomp(*(u16int*)data, nil);
	case T_S32: return itomp(*(s32int*)data, nil);
	case T_U32: return uitomp(*(u32int*)data, nil);
	case T_S64: return vtomp(*(s64int*)data, nil);
	case T_U64: return uvtomp(*(u64int*)data, nil);
	case T_MP:  return mpcopy(*(mpint**)data);
	case T_FLOAT:  return dtomp(*(float*)data, nil);
	case T_DOUBLE: return dtomp(*(double*)data, nil);
	}
	return mpzero;
}

sl_purefn
double
conv_to_double(void *data, sl_numtype tag)
{
	double d;
	switch(tag){
	case T_S8:  return *(s8int*)data;
	case T_U8:  return *(u8int*)data;
	case T_S16: return *(s16int*)data;
	case T_U16: return *(u16int*)data;
	case T_S32: return *(s32int*)data;
	case T_U32: return *(u32int*)data;
	case T_S64:
		d = *(s64int*)data;
		if(d > 0 && *(s64int*)data < 0)  // can happen!
			d = -d;
		return d;
	case T_U64: return *(u64int*)data;
	case T_MP:  return mptod(*(mpint**)data);
	case T_FLOAT:  return *(float*)data;
	case T_DOUBLE: return *(double*)data;
	}
	return 0;
}

// FIXME sign with mpint
#define CONV_TO_INTTYPE(name, ctype) \
sl_purefn \
ctype \
conv_to_##name(void *data, sl_numtype tag) \
{ \
	switch(tag){ \
	case T_S8:  return (ctype)*(s8int*)data; \
	case T_U8:  return (ctype)*(u8int*)data; \
	case T_S16: return (ctype)*(s16int*)data; \
	case T_U16: return (ctype)*(u16int*)data; \
	case T_S32: return (ctype)*(s32int*)data; \
	case T_U32: return (ctype)*(u32int*)data; \
	case T_S64: return (ctype)*(s64int*)data; \
	case T_U64: return (ctype)*(u64int*)data; \
	case T_MP:  return (ctype)mptov(*(mpint**)data); \
	case T_FLOAT:  return (ctype)*(float*)data; \
	case T_DOUBLE: return (ctype)*(double*)data; \
	} \
	return 0; \
}

CONV_TO_INTTYPE(s64, s64int)
CONV_TO_INTTYPE(s32, s32int)
CONV_TO_INTTYPE(u32, u32int)

// this is needed to work around an UB casting negative
// floats and doubles to uint64. you need to cast to int64
// first.
sl_purefn
u64int
conv_to_u64(void *data, sl_numtype tag)
{
	s64int s;
	switch(tag){
	case T_S8:  return *(s8int*)data; break;
	case T_U8:  return *(u8int*)data; break;
	case T_S16: return *(s16int*)data; break;
	case T_U16: return *(u16int*)data; break;
	case T_S32: return *(s32int*)data; break;
	case T_U32: return *(u32int*)data; break;
	case T_S64: return *(s64int*)data; break;
	case T_U64: return *(u64int*)data; break;
	case T_MP:  return mptouv(*(mpint**)data); break;
	case T_FLOAT:
		if(*(float*)data >= 0)
			return *(float*)data;
		s = *(float*)data;
		return s;
	case T_DOUBLE:
		if(*(double*)data >= 0)
			return *(double*)data;
		s = *(double*)data;
		return s;
	}
	return 0;
}

sl_purefn
bool
cmp_same_lt(void *a, void *b, sl_numtype tag)
{
	switch(tag){
	case T_S8:  return *(s8int*)a < *(s8int*)b;
	case T_U8:  return *(u8int*)a < *(u8int*)b;
	case T_S16: return *(s16int*)a < *(s16int*)b;
	case T_U16: return *(u16int*)a < *(u16int*)b;
	case T_S32: return *(s32int*)a < *(s32int*)b;
	case T_U32: return *(u32int*)a < *(u32int*)b;
	case T_S64: return *(s64int*)a < *(s64int*)b;
	case T_U64: return *(u64int*)a < *(u64int*)b;
	case T_MP:  return mpcmp(*(mpint**)a, *(mpint**)b) < 0;
	case T_FLOAT:  return *(float*)a < *(float*)b;
	case T_DOUBLE: return *(double*)a < *(double*)b;
	}
	return false;
}

sl_purefn
bool
cmp_same_eq(void *a, void *b, sl_numtype tag)
{
	switch(tag){
	case T_S8:  return *(s8int*)a == *(s8int*)b;
	case T_U8:  return *(u8int*)a == *(u8int*)b;
	case T_S16: return *(s16int*)a == *(s16int*)b;
	case T_U16: return *(u16int*)a == *(u16int*)b;
	case T_S32: return *(s32int*)a == *(s32int*)b;
	case T_U32: return *(u32int*)a == *(u32int*)b;
	case T_S64: return *(s64int*)a == *(s64int*)b;
	case T_U64: return *(u64int*)a == *(u64int*)b;
	case T_MP:  return mpcmp(*(mpint**)a, *(mpint**)b) == 0;
	case T_FLOAT:  return *(float*)a == *(float*)b && !isnan(*(float*)a);
	case T_DOUBLE: return *(double*)a == *(double*)b && !isnan(*(double*)b);
	}
	return false;
}

/* FIXME one is allocated for all compare ops */
static mpint *cmpmpint;

bool
cmp_lt(void *a, sl_numtype atag, void *b, sl_numtype btag)
{
	if(atag == btag)
		return cmp_same_lt(a, b, atag);

	double da = conv_to_double(a, atag);
	double db = conv_to_double(b, btag);

	if(isnan(da) || isnan(db))
		return false;

	// casting to double will only get the wrong answer for big int64s
	// that differ in low bits
	if(da < db)
		return true;
	if(db < da)
		return false;

	if(cmpmpint == nil && (atag == T_MP || btag == T_MP))
		cmpmpint = mpnew(0);

	if(atag == T_U64){
		if(btag == T_S64)
			return *(s64int*)b >= 0 && *(u64int*)a < (u64int)*(s64int*)b;
		if(btag == T_DOUBLE)
			return db >= 0 ? *(u64int*)a < (u64int)*(double*)b : 0;
		if(btag == T_MP)
			return mpcmp(uvtomp(*(u64int*)a, cmpmpint), *(mpint**)b) < 0;
	}
	if(atag == T_S64){
		if(btag == T_U64)
			return *(s64int*)a >= 0 && (u64int)*(s64int*)a < *(u64int*)b;
		if(btag == T_DOUBLE)
			return db == db ? *(s64int*)a < (s64int)*(double*)b : 0;
		if(btag == T_MP)
			return mpcmp(vtomp(*(s64int*)a, cmpmpint), *(mpint**)b) < 0;
	}
	if(btag == T_U64){
		if(atag == T_DOUBLE)
			return da >= 0 ? *(u64int*)b > (u64int)*(double*)a : 0;
		if(atag == T_MP)
			return mpcmp(*(mpint**)a, uvtomp(*(u64int*)b, cmpmpint)) < 0;
	}
	if(btag == T_S64){
		if(atag == T_DOUBLE)
			return da == da ? *(s64int*)b > (s64int)*(double*)a : 0;
		if(atag == T_MP)
			return mpcmp(*(mpint**)a, vtomp(*(s64int*)b, cmpmpint)) < 0;
	}
	return false;
}

bool
cmp_eq(void *a, sl_numtype atag, void *b, sl_numtype btag, bool equalnans)
{
	union {
		double d;
		s64int i64;
	}u, v;

	if(atag == btag && (!equalnans || atag < T_FLOAT))
		return cmp_same_eq(a, b, atag);

	double da = conv_to_double(a, atag);
	double db = conv_to_double(b, btag);

	if((int)atag >= T_FLOAT && (int)btag >= T_FLOAT){
		if(equalnans){
			u.d = da; v.d = db;
			return u.i64 == v.i64;
		}
		return da == db;
	}

	if(da != db)
		return false;

	if(cmpmpint == nil && (atag == T_MP || btag == T_MP))
		cmpmpint = mpnew(0);

	if(atag == T_U64){
		// this is safe because if a had been bigger than INT64_MAX,
		// we would already have concluded that it's bigger than b.
		if(btag == T_S64)
			return *(s64int*)b >= 0 && *(u64int*)a == *(u64int*)b;
		if(btag == T_DOUBLE)
			return *(double*)b >= 0 && *(u64int*)a == (u64int)*(double*)b;
		if(btag == T_MP)
			return mpcmp(uvtomp(*(u64int*)a, cmpmpint), *(mpint**)b) == 0;
	}
	if(atag == T_S64){
		if(btag == T_U64)
			return *(s64int*)a >= 0 && *(u64int*)a == *(u64int*)b;
		if(btag == T_DOUBLE)
			return *(s64int*)a == (s64int)*(double*)b;
		if(btag == T_MP)
			return mpcmp(vtomp(*(s64int*)a, cmpmpint), *(mpint**)b) == 0;
	}
	if(btag == T_U64){
		if(atag == T_S64)
			return *(s64int*)a >= 0 && *(u64int*)b == *(u64int*)a;
		if(atag == T_DOUBLE)
			return *(double*)a >= 0 && *(u64int*)b == (u64int)*(double*)a;
		if(atag == T_MP)
			return mpcmp(*(mpint**)a, uvtomp(*(u64int*)b, cmpmpint)) == 0;
	}
	if(btag == T_S64){
		if(atag == T_U64)
			return *(s64int*)b >= 0 && *(u64int*)b == *(u64int*)a;
		if(atag == T_DOUBLE)
			return *(s64int*)b == (s64int)*(double*)a;
		if(atag == T_MP)
			return mpcmp(*(mpint**)a, vtomp(*(s64int*)b, cmpmpint)) == 0;
	}
	return true;
}