shithub: femtolisp

ref: c3f004633be7e8a15fe8e1e158dde24c2af2f765
dir: /3rd/mp/test/ld.c/

View raw version
#include "platform.h"
#include "mp.h"
#include "dat.h"
#include "fns.h"

static int
ldget(ldint *a, uint32_t n)
{
	if(n < 0) return 0;
	if(n >= a->n) return a->b[a->n - 1]&1;
	return a->b[n]&1;
}

static void
ldbits(ldint *a, int n)
{
	a->b = realloc(a->b, n);
	a->n = n;
}

static ldint *
ldnorm(ldint *a)
{
	int i;

	if(a->n > 0){
		for(i = a->n - 2; i >= 0; i--)
			if(a->b[i] != a->b[a->n-1])
				break;
		ldbits(a, i + 2);
	}else{
		ldbits(a, 1);
		a->b[0] = 0;
	}
	return a;
}

static void
ldneg(ldint *a)
{
	int c, s, z;
	uint32_t i;
	
	c = 1;
	s = a->b[a->n - 1];
	z = 1;
	for(i = 0; i < a->n; i++){
		if(a->b[i]) z = 0;
		c += 1 ^ a->b[i] & 1;
		a->b[i] = c & 1;
		c >>= 1;
	}
	if(!z && s == a->b[a->n - 1]){
		ldbits(a, a->n + 1);
		a->b[a->n - 1] = !s;
	}
}

static int
max(int a, int b)
{
	return a>b? a : b;
}

ldint *
ldnew(int n)
{
	ldint *a;
	
	a = malloc(sizeof(ldint));
	if(n <= 0) n = 1;
	a->b = malloc(n);
	a->n = n;
	return a;
}

void
ldfree(ldint *a)
{
	if(a == nil) return;
	free(a->b);
	free(a);
}

mpint *
ldtomp(ldint *a, mpint *b)
{
	int s, c;
	uint32_t i;

	if(b == nil)
		b = mpnew(0);
	mpbits(b, a->n);
	s = a->b[a->n - 1] & 1;
	b->sign = 1 - 2 * s;
	c = s;
	memset(b->p, 0, (a->n + Dbits - 1) / Dbits * Dbytes);
	for(i = 0; i < a->n; i++){
		c += s ^ a->b[i] & 1;
		b->p[i / Dbits] |= (mpdigit)(c & 1) << (i & Dbits - 1);
		c >>= 1;
	}
	b->top = (a->n + Dbits - 1) / Dbits;
	mpnorm(b);
	return b;
}

static ldint *
itold(int n, ldint *a)
{
	uint32_t i;

	if(a == nil)
		a = ldnew(sizeof(n)*8);
	else
		ldbits(a, sizeof(n)*8);
	for(i = 0; i < sizeof(n)*8; i++)
		a->b[i] = n >> i & 1;
	ldnorm(a);
	return a;
}

static ldint *
pow2told(int n, ldint *a)
{
	int k;
	
	k = abs(n);
	if(a == nil)
		a = ldnew(k+2);
	else
		ldbits(a, k+2);
	memset(a->b, 0, k+2);
	a->b[k] = 1;
	if(n < 0) ldneg(a);
	ldnorm(a);
	return a;
}

static char str[16][8192];
static int istr = -1;

char *
LFMT(ldint *a)
{
	char *b, *p;
	int s, c;
	uint32_t i, d;

	istr = (istr+1) % nelem(str);
	b = str[istr] + 3;
	d = (a->n + 3) / 4;
	c = s = a->b[a->n - 1];
	for(i = 0; i < a->n; i++){
		c += s^ldget(a, i);
		b[d - 1 - (i >> 2)] |= (c & 1) << (i & 3);
		c >>= 1;
	}
	for(i = 0; i < d; i++)
		b[i] = "0123456789ABCDEF"[(int)b[i]];
	p = b;
	while(*p == '0' && p[1] != 0) p++;
	*--p = 'x';
	*--p = '0';
	if(a->b[a->n - 1])
		*--p = '-';
	return p;
}

char *
MFMT(mpint *m)
{
	char *b;
	istr = (istr+1) % nelem(str);
	b = str[istr];
	return mptoa(m, 16, b, sizeof(str[istr]));
}

int
ldcmp(ldint *a, ldint *b)
{
	int x, y;
	int i, r;
	
	r = max(a->n, b->n);
	if(a->b[a->n-1] != b->b[b->n-1])
		return b->b[b->n - 1] - a->b[a->n - 1];
	for(i = r - 1; --i >= 0; ){
		x = ldget(a, i);
		y = ldget(b, i);
		if(x != y)
			return x - y;
	}
	return 0;
}

int
ldmagcmp(ldint *a, ldint *b)
{
	int s1, s2, r;
	
	s1 = a->b[a->n - 1];
	s2 = b->b[b->n - 1];
	if(s1) ldneg(a);
	if(s2) ldneg(b);
	r = ldcmp(a, b);
	if(s1) ldneg(a);
	if(s2) ldneg(b);
	return r;
}

int
ldmpeq(ldint *a, mpint *b)
{
	uint32_t i, c;

	if(b->sign > 0){
		for(i = 0; i < b->top * Dbits; i++)
			if(ldget(a, i) != (b->p[i / Dbits] >> (i & Dbits - 1) & 1))
				return 0;
		for(; i < a->n; i++)
			if(a->b[i] != 0)
				return 0;
		return 1;
	}else{
		c = 1;
		for(i = 0; i < b->top * Dbits; i++){
			c += !ldget(a, i);
			if((c & 1) != (b->p[i / Dbits] >> (i & Dbits - 1) & 1))
				return 0;
			c >>= 1;
		}
		for(; i < a->n; i++)
			if(a->b[i] != 1)
				return 0;
		return 1;
	}
}

void
mptarget(mpint *r)
{
	int n;

	n = rand() & 15;
	mpbits(r, n * Dbits);
	r->top = n;
	prng((void *) r->p, n * Dbytes);
	r->sign = 1 - 2 * (rand() & 1);
}

void
ldadd(ldint *a, ldint *b, ldint *q)
{
	int r, i, c;
	
	r = max(a->n, b->n) + 1;
	ldbits(q, r);
	c = 0;
	for(i = 0; i < r; i++){
		c += ldget(a, i) + ldget(b, i);
		q->b[i] = c & 1;
		c >>= 1;
	}
	ldnorm(q);
}

void
ldmagadd(ldint *a, ldint *b, ldint *q)
{
	int i, r, s1, s2, c1, c2, co;
	
	r = max(a->n, b->n) + 2;
	ldbits(q, r);
	co = 0;
	s1 = c1 = a->b[a->n - 1] & 1;
	s2 = c2 = b->b[b->n - 1] & 1;
	for(i = 0; i < r; i++){
		c1 += s1 ^ ldget(a, i) & 1;
		c2 += s2 ^ ldget(b, i) & 1;
		co += (c1 & 1) + (c2 & 1);
		q->b[i] = co & 1;
		co >>= 1;
		c1 >>= 1;
		c2 >>= 1;
	}
	ldnorm(q);
}

void
ldmagsub(ldint *a, ldint *b, ldint *q)
{
	int i, r, s1, s2, c1, c2, co;
	
	r = max(a->n, b->n) + 2;
	ldbits(q, r);
	co = 0;
	s1 = c1 = a->b[a->n - 1] & 1;
	s2 = c2 = 1 ^ b->b[b->n - 1] & 1;
	for(i = 0; i < r; i++){
		c1 += s1 ^ ldget(a, i) & 1;
		c2 += s2 ^ ldget(b, i) & 1;
		co += (c1 & 1) + (c2 & 1);
		q->b[i] = co & 1;
		co >>= 1;
		c1 >>= 1;
		c2 >>= 1;
	}
	ldnorm(q);
}

void
ldsub(ldint *a, ldint *b, ldint *q)
{
	int r, i, c;
	
	r = max(a->n, b->n) + 1;
	ldbits(q, r);
	c = 1;
	for(i = 0; i < r; i++){
		c += ldget(a, i) + (1^ldget(b, i));
		q->b[i] = c & 1;
		c >>= 1;
	}
	ldnorm(q);
}

static void
lddiv(ldint *a, ldint *b, ldint *q, ldint *r)
{
	int n, i, j, c, s;
	
	n = max(a->n, b->n) + 1;
	ldbits(q, n);
	ldbits(r, n);
	memset(r->b, 0, n);
	c = s = a->b[a->n-1];
	for(i = 0; i < n; i++){
		c += s ^ ldget(a, i);
		q->b[i] = c & 1;
		c >>= 1;
	}
	for(i = 0; i < n; i++){
		for(j = n-1; --j >= 0; )
			r->b[j + 1] = r->b[j];
		r->b[0] = q->b[n - 1];
		for(j = n-1; --j >= 0; )
			q->b[j + 1] = q->b[j];
		q->b[0] = !r->b[n - 1];
		c = s = r->b[n - 1] == b->b[b->n - 1];
		for(j = 0; j < n; j++){
			c += r->b[j] + (s ^ ldget(b, j));
			r->b[j] = c & 1;
			c >>= 1;
		}
	}
	for(j = n-1; --j >= 0; )
		q->b[j + 1] = q->b[j];
	q->b[0] = 1;
	if(r->b[r->n - 1]){
		c = 0;
		for(j = 0; j < n; j++){
			c += 1 + q->b[j];
			q->b[j] = c & 1;
			c >>= 1;
		}
		c = s = b->b[b->n - 1];
		for(j = 0; j < n; j++){
			c += r->b[j] + (s ^ ldget(b, j));
			r->b[j] = c & 1;
			c >>= 1;
		}
	}
	c = s = a->b[a->n-1] ^ b->b[b->n-1];
	for(j = 0; j < n; j++){
		c += s ^ q->b[j];
		q->b[j] = c & 1;
		c >>= 1;
	}
	c = s = a->b[a->n-1];
	for(j = 0; j < n; j++){
		c += s ^ r->b[j];
		r->b[j] = c & 1;
		c >>= 1;
	}
	ldnorm(q);
	ldnorm(r);
}

void
lddiv_(ldint *a, ldint *b, ldint *q, ldint *r)
{
	if(ldmpeq(b, mpzero)){
		memset(q->b, 0, q->n);
		memset(r->b, 0, r->n);
		return;
	}
	lddiv(a, b, q, r);
}

void
mpdiv_(mpint *a, mpint *b, mpint *q, mpint *r)
{
	if(mpcmp(b, mpzero) == 0){
		mpassign(mpzero, q);
		mpassign(mpzero, r);
		return;
	}
	mpdiv(a, b, q, r);
}

void
ldand(ldint *a, ldint *b, ldint *q)
{
	int r, i;
	
	r = max(a->n, b->n);
	ldbits(q, r);
	for(i = 0; i < r; i++)
		q->b[i] = ldget(a, i) & ldget(b, i);
	ldnorm(q);
}

void
ldbic(ldint *a, ldint *b, ldint *q)
{
	int r, i;
	
	r = max(a->n, b->n);
	ldbits(q, r);
	for(i = 0; i < r; i++)
		q->b[i] = ldget(a, i) & ~ldget(b, i);
	ldnorm(q);
}

void
ldor(ldint *a, ldint *b, ldint *q)
{
	int r, i;
	
	r = max(a->n, b->n);
	ldbits(q, r);
	for(i = 0; i < r; i++)
		q->b[i] = ldget(a, i) | ldget(b, i);
	ldnorm(q);
}

void
ldxor(ldint *a, ldint *b, ldint *q)
{
	int r, i;
	
	r = max(a->n, b->n);
	ldbits(q, r);
	for(i = 0; i < r; i++)
		q->b[i] = ldget(a, i) ^ ldget(b, i);
	ldnorm(q);
}

void
ldleft(ldint *a, int n, ldint *b)
{
	int c;
	uint32_t i;

	if(n < 0){
		if(a->n <= (uint32_t)-n){
			b->n = 0;
			ldnorm(b);
			return;
		}
		c = 0;
		if(a->b[a->n - 1])
			for(i = 0; i < (uint32_t)-n; i++)
				if(a->b[i]){
					c = 1;
					break;
				}
		ldbits(b, a->n + n);
		for(i = 0; i < a->n + n; i++){
			c += a->b[i - n] & 1;
			b->b[i] = c & 1;
			c >>= 1;
		}
	}else{
		ldbits(b, a->n + n);
		memmove(b->b + n, a->b, a->n);
		memset(b->b, 0, n);
	}
	ldnorm(b);
}

void
ldright(ldint *a, int n, ldint *b)
{
	ldleft(a, -n, b);
}

void
ldasr(ldint *a, int n, ldint *b)
{
	if(n < 0){
		ldleft(a, -n, b);
		return;
	}
	if(a->n <= (uint32_t)n){
		ldbits(b, 1);
		b->b[0] = a->b[a->n - 1];
		return;
	}
	ldbits(b, a->n - n);
	memmove(b->b, a->b + n, a->n - n);
	ldnorm(b);
}

void
ldnot(ldint *a, ldint *b)
{
	uint32_t i;
	
	ldbits(b, a->n);
	for(i = 0; i < a->n; i++)
		b->b[i] = a->b[i] ^ 1;
}

static uint32_t
xorshift(uint32_t *state)
{
	uint32_t x = *state;
	x ^= x << 13;
	x ^= x >> 17;
	x ^= x << 5;
	*state = x;
	return x;
}

void
testgen(int i, ldint *a)
{
	uint32_t j, state;
	uint32_t r = 0;

	if(i < 257)
		itold(i-128, a);
	else if(i < 514)
		pow2told(i-385, a);
	else{
		state = i;
		xorshift(&state);
		xorshift(&state);
		xorshift(&state);
		ldbits(a, Dbits * (1 + (xorshift(&state) & 15)));
		for(j = 0; j < a->n; j++){
			if((j & 31) == 0)
				r = xorshift(&state);
			a->b[j] = r & 1;
			r >>= 1;
		}
	}
}