shithub: femtolisp

ref: 8837293e1085d71e10ef9216c457458450419801
dir: /operators.c/

View raw version
#include "llt.h"

mpint *conv_to_mpint(void *data, numerictype_t tag)
{
    mpint *i = mpzero;
    switch (tag) {
    case T_INT8:   i = itomp(*(int8_t*)data, nil); break;
    case T_UINT8:  i = uitomp(*(uint8_t*)data, nil); break;
    case T_INT16:  i = itomp(*(int16_t*)data, nil); break;
    case T_UINT16: i = uitomp(*(uint16_t*)data, nil); break;
    case T_INT32:  i = itomp(*(int32_t*)data, nil); break;
    case T_UINT32: i = uitomp(*(uint32_t*)data, nil); break;
    case T_INT64:  i = vtomp(*(int64_t*)data, nil); break;
    case T_UINT64: i = uvtomp(*(int64_t*)data, nil); break;
    case T_MPINT:  i = mpcopy(*(mpint**)data); break;
    case T_FLOAT:  i = dtomp(*(float*)data, nil); break;
    case T_DOUBLE: i = dtomp(*(double*)data, nil); break;
    }
    return i;
}

double conv_to_double(void *data, numerictype_t tag)
{
    double d=0;
    switch (tag) {
    case T_INT8:   d = (double)*(int8_t*)data; break;
    case T_UINT8:  d = (double)*(uint8_t*)data; break;
    case T_INT16:  d = (double)*(int16_t*)data; break;
    case T_UINT16: d = (double)*(uint16_t*)data; break;
    case T_INT32:  d = (double)*(int32_t*)data; break;
    case T_UINT32: d = (double)*(uint32_t*)data; break;
    case T_INT64:
        d = (double)*(int64_t*)data;
        if (d > 0 && *(int64_t*)data < 0)  // can happen!
            d = -d;
        break;
    case T_UINT64: d = (double)*(uint64_t*)data; break;
    case T_MPINT:  d = mptod(*(mpint**)data); break;
    case T_FLOAT:  d = (double)*(float*)data; break;
    case T_DOUBLE: return *(double*)data;
    }
    return d;
}

void conv_from_double(void *dest, double d, numerictype_t tag)
{
    switch (tag) {
    case T_INT8:   *(int8_t*)dest = d; break;
    case T_UINT8:  *(uint8_t*)dest = d; break;
    case T_INT16:  *(int16_t*)dest = d; break;
    case T_UINT16: *(uint16_t*)dest = d; break;
    case T_INT32:  *(int32_t*)dest = d; break;
    case T_UINT32: *(uint32_t*)dest = d; break;
    case T_INT64:
        *(int64_t*)dest = d;
        if (d > 0 && *(int64_t*)dest < 0)  // 0x8000000000000000 is a bitch
            *(int64_t*)dest = INT64_MAX;
        break;
    case T_UINT64: *(uint64_t*)dest = (int64_t)d; break;
    case T_MPINT:  *(mpint**)dest = dtomp(d, nil); break;
    case T_FLOAT:  *(float*)dest = d; break;
    case T_DOUBLE: *(double*)dest = d; break;
    }
}

// FIXME sign with mpint
#define CONV_TO_INTTYPE(name, ctype)                    \
ctype conv_to_##name(void *data, numerictype_t tag)     \
{                                                       \
    switch (tag) {                                      \
    case T_INT8:   return *(int8_t*)data;               \
    case T_UINT8:  return *(uint8_t*)data;              \
    case T_INT16:  return *(int16_t*)data;              \
    case T_UINT16: return *(uint16_t*)data;             \
    case T_INT32:  return *(int32_t*)data;              \
    case T_UINT32: return *(uint32_t*)data;             \
    case T_INT64:  return *(int64_t*)data;              \
    case T_UINT64: return *(uint64_t*)data;             \
    case T_MPINT:  return mptov(*(mpint**)data);        \
    case T_FLOAT:  return *(float*)data;                \
    case T_DOUBLE: return *(double*)data;               \
    }                                                   \
    return 0;                                           \
}

CONV_TO_INTTYPE(int64, int64_t)
CONV_TO_INTTYPE(int32, int32_t)
CONV_TO_INTTYPE(uint32, uint32_t)

// this is needed to work around an UB casting negative
// floats and doubles to uint64. you need to cast to int64
// first.
uint64_t conv_to_uint64(void *data, numerictype_t tag)
{
    int64_t s;
    switch (tag) {
    case T_INT8:   return *(int8_t*)data; break;
    case T_UINT8:  return *(uint8_t*)data; break;
    case T_INT16:  return *(int16_t*)data; break;
    case T_UINT16: return *(uint16_t*)data; break;
    case T_INT32:  return *(int32_t*)data; break;
    case T_UINT32: return *(uint32_t*)data; break;
    case T_INT64:  return *(int64_t*)data; break;
    case T_UINT64: return *(uint64_t*)data; break;
    case T_MPINT:  return mptouv(*(mpint**)data); break;
    case T_FLOAT:
        if (*(float*)data >= 0)
            return *(float*)data;
        s = *(float*)data;
        return s;
    case T_DOUBLE:
        if (*(double*)data >= 0)
            return *(double*)data;
        s = *(double*)data;
        return s;
    }
    return 0;
}

int cmp_same_lt(void *a, void *b, numerictype_t tag)
{
    switch (tag) {
    case T_INT8:   return *(int8_t*)a < *(int8_t*)b;
    case T_UINT8:  return *(uint8_t*)a < *(uint8_t*)b;
    case T_INT16:  return *(int16_t*)a < *(int16_t*)b;
    case T_UINT16: return *(uint16_t*)a < *(uint16_t*)b;
    case T_INT32:  return *(int32_t*)a < *(int32_t*)b;
    case T_UINT32: return *(uint32_t*)a < *(uint32_t*)b;
    case T_INT64:  return *(int64_t*)a < *(int64_t*)b;
    case T_UINT64: return *(uint64_t*)a < *(uint64_t*)b;
    case T_MPINT:  return mpcmp(*(mpint**)a, *(mpint**)b) < 0;
    case T_FLOAT:  return *(float*)a < *(float*)b;
    case T_DOUBLE: return *(double*)a < *(double*)b;
    }
    return 0;
}

int cmp_same_eq(void *a, void *b, numerictype_t tag)
{
    switch (tag) {
    case T_INT8:   return *(int8_t*)a == *(int8_t*)b;
    case T_UINT8:  return *(uint8_t*)a == *(uint8_t*)b;
    case T_INT16:  return *(int16_t*)a == *(int16_t*)b;
    case T_UINT16: return *(uint16_t*)a == *(uint16_t*)b;
    case T_INT32:  return *(int32_t*)a == *(int32_t*)b;
    case T_UINT32: return *(uint32_t*)a == *(uint32_t*)b;
    case T_INT64:  return *(int64_t*)a == *(int64_t*)b;
    case T_UINT64: return *(uint64_t*)a == *(uint64_t*)b;
    case T_MPINT:  return mpcmp(*(mpint**)a, *(mpint**)b) == 0;
    case T_FLOAT:  return *(float*)a == *(float*)b;
    case T_DOUBLE: return *(double*)a == *(double*)b;
    }
    return 0;
}

/* FIXME one is allocated for all compare ops */
static mpint *cmpmpint;

int cmp_lt(void *a, numerictype_t atag, void *b, numerictype_t btag)
{
    if (atag==btag)
        return cmp_same_lt(a, b, atag);

    double da = conv_to_double(a, atag);
    double db = conv_to_double(b, btag);

    // casting to double will only get the wrong answer for big int64s
    // that differ in low bits
    if (da < db && !isnan(da) && !isnan(db))
        return 1;
    if (db < da)
        return 0;

    if (cmpmpint == nil && (atag == T_MPINT || btag == T_MPINT))
        cmpmpint = mpnew(0);

    if (atag == T_UINT64) {
        if (btag == T_INT64) {
            if (*(int64_t*)b >= 0)
                return (*(uint64_t*)a < (uint64_t)*(int64_t*)b);
            return ((int64_t)*(uint64_t*)a < *(int64_t*)b);
        }
        else if (btag == T_DOUBLE) {
            if (db != db) return 0;
            return (*(uint64_t*)a < (uint64_t)*(double*)b);
        }
        else if (btag == T_MPINT) {
            return mpcmp(uvtomp(*(uint64_t*)a, cmpmpint), *(mpint**)b) < 0;
        }
    }
    else if (atag == T_INT64) {
        if (btag == T_UINT64) {
            if (*(int64_t*)a >= 0)
                return ((uint64_t)*(int64_t*)a < *(uint64_t*)b);
            return (*(int64_t*)a < (int64_t)*(uint64_t*)b);
        }
        else if (btag == T_DOUBLE) {
            if (db != db) return 0;
            return (*(int64_t*)a < (int64_t)*(double*)b);
        }
        else if (btag == T_MPINT) {
            return mpcmp(vtomp(*(int64_t*)a, cmpmpint), *(mpint**)b) < 0;
        }
    }
    if (btag == T_UINT64) {
        if (atag == T_DOUBLE) {
            if (da != da) return 0;
            return (*(uint64_t*)b > (uint64_t)*(double*)a);
        }
        else if (atag == T_MPINT) {
            return mpcmp(*(mpint**)a, uvtomp(*(uint64_t*)b, cmpmpint)) < 0;
        }
    }
    else if (btag == T_INT64) {
        if (atag == T_DOUBLE) {
            if (da != da) return 0;
            return (*(int64_t*)b > (int64_t)*(double*)a);
        }
        else if (atag == T_MPINT) {
            return mpcmp(*(mpint**)a, vtomp(*(int64_t*)b, cmpmpint)) < 0;
        }
    }
    return 0;
}

int cmp_eq(void *a, numerictype_t atag, void *b, numerictype_t btag,
           int equalnans)
{
    union { double d; int64_t i64; } u, v;
    if (atag==btag && (!equalnans || atag < T_FLOAT))
        return cmp_same_eq(a, b, atag);

    double da = conv_to_double(a, atag);
    double db = conv_to_double(b, btag);

    if ((int)atag >= T_FLOAT && (int)btag >= T_FLOAT) {
        if (equalnans) {
            u.d = da; v.d = db;
            return u.i64 == v.i64;
        }
        return (da == db);
    }

    if (da != db)
        return 0;

    if (cmpmpint == nil && (atag == T_MPINT || btag == T_MPINT))
        cmpmpint = mpnew(0);

    if (atag == T_UINT64) {
        // this is safe because if a had been bigger than INT64_MAX,
        // we would already have concluded that it's bigger than b.
        if (btag == T_INT64)
            return ((int64_t)*(uint64_t*)a == *(int64_t*)b);
        else if (btag == T_DOUBLE)
            return (*(uint64_t*)a == (uint64_t)(int64_t)*(double*)b);
        else if (btag == T_MPINT)
            return mpcmp(uvtomp(*(uint64_t*)a, cmpmpint), *(mpint**)b) == 0;
    }
    else if (atag == T_INT64) {
        if (btag == T_UINT64)
            return (*(int64_t*)a == (int64_t)*(uint64_t*)b);
        else if (btag == T_DOUBLE)
            return (*(int64_t*)a == (int64_t)*(double*)b);
        else if (btag == T_MPINT)
            return mpcmp(vtomp(*(int64_t*)a, cmpmpint), *(mpint**)b) == 0;
    }
    else if (btag == T_UINT64) {
        if (atag == T_INT64)
            return ((int64_t)*(uint64_t*)b == *(int64_t*)a);
        else if (atag == T_DOUBLE)
            return (*(uint64_t*)b == (uint64_t)(int64_t)*(double*)a);
        else if (atag == T_MPINT)
            return mpcmp(*(mpint**)a, uvtomp(*(uint64_t*)b, cmpmpint)) == 0;
    }
    else if (btag == T_INT64) {
        if (atag == T_UINT64)
            return (*(int64_t*)b == (int64_t)*(uint64_t*)a);
        else if (atag == T_DOUBLE)
            return (*(int64_t*)b == (int64_t)*(double*)a);
        else if (atag == T_MPINT)
            return mpcmp(*(mpint**)a, vtomp(*(int64_t*)b, cmpmpint)) == 0;
    }
    return 1;
}