shithub: femtolisp

ref: ec2a902acc1c05ed0a95c26249bbda4032c668e7
dir: /3rd/mp/mplogic.c/

View raw version
#include "platform.h"

/*
	mplogic calculates b1|b2 subject to the
	following flag bits (fl)

	bit 0: subtract 1 from b1
	bit 1: invert b1
	bit 2: subtract 1 from b2
	bit 3: invert b2
	bit 4: add 1 to output
	bit 5: invert output

	it inverts appropriate bits automatically
	depending on the signs of the inputs
*/

static void
mplogic(mpint *b1, mpint *b2, mpint *sum, uint32_t fl)
{
	mpint *t;
	mpdigit *dp1, *dp2, *dpo, d1, d2, d;
	uint32_t c1, c2, co, i;

	assert(((b1->flags | b2->flags | sum->flags) & MPtimesafe) == 0);
	if(b1->sign < 0) fl ^= 0x03;
	if(b2->sign < 0) fl ^= 0x0c;
	sum->sign = (int)(((fl|fl>>2)^fl>>4)<<30)>>31|1;
	if(sum->sign < 0) fl ^= 0x30;
	if(b2->top > b1->top){
		t = b1;
		b1 = b2;
		b2 = t;
		fl = fl >> 2 & 0x03 | fl << 2 & 0x0c | fl & 0x30;
	}
	mpbits(sum, b1->top*Dbits+1);
	dp1 = b1->p;
	dp2 = b2->p;
	dpo = sum->p;
	c1 = fl & 1;
	c2 = fl >> 2 & 1;
	co = fl >> 4 & 1;
	for(i = 0; i < b1->top; i++){
		d1 = dp1[i] - c1;
		if(i < b2->top)
			d2 = dp2[i] - c2;
		else
			d2 = 0;
		if(d1 != (mpdigit)-1) c1 = 0;
		if(d2 != (mpdigit)-1) c2 = 0;
		if((fl & 2) != 0) d1 ^= -(mpdigit)1;
		if((fl & 8) != 0) d2 ^= -(mpdigit)1;
		d = d1 | d2;
		if((fl & 32) != 0) d ^= -(mpdigit)1;
		d += co;
		if(d != 0) co = 0;
		dpo[i] = d;
	}
	sum->top = i;
	if(co)
		dpo[sum->top++] = co;
	mpnorm(sum);
}

void
mpor(mpint *b1, mpint *b2, mpint *sum)
{
	mplogic(b1, b2, sum, 0);
}

void
mpand(mpint *b1, mpint *b2, mpint *sum)
{
	mplogic(b1, b2, sum, 0x2a);
}

void
mpbic(mpint *b1, mpint *b2, mpint *sum)
{
	mplogic(b1, b2, sum, 0x22);
}

void
mpnot(mpint *b, mpint *r)
{
	mpadd(b, mpone, r);
	if(r->top != 0)
		r->sign ^= -2;
}

void
mpxor(mpint *b1, mpint *b2, mpint *sum)
{
	mpint *t;
	mpdigit *dp1, *dp2, *dpo, d1, d2, d;
	uint32_t c1, c2, co, i;
	int fl;

	assert(((b1->flags | b2->flags | sum->flags) & MPtimesafe) == 0);
	if(b2->top > b1->top){
		t = b1;
		b1 = b2;
		b2 = t;
	}
	fl = (b1->sign & 10) ^ (b2->sign & 12);
	sum->sign = (int)(fl << 28) >> 31 | 1;
	mpbits(sum, b1->top*Dbits+1);
	dp1 = b1->p;
	dp2 = b2->p;
	dpo = sum->p;
	c1 = fl >> 1 & 1;
	c2 = fl >> 2 & 1;
	co = fl >> 3 & 1;
	for(i = 0; i < b1->top; i++){
		d1 = dp1[i] - c1;
		if(i < b2->top)
			d2 = dp2[i] - c2;
		else
			d2 = 0;
		if(d1 != (mpdigit)-1) c1 = 0;
		if(d2 != (mpdigit)-1) c2 = 0;
		d = d1 ^ d2;
		d += co;
		if(d != 0) co = 0;
		dpo[i] = d;
	}
	sum->top = i;
	if(co)
		dpo[sum->top++] = co;
	mpnorm(sum);
}

void
mpasr(mpint *b, int n, mpint *r)
{
	if(b->sign > 0 || n <= 0){
		mpright(b, n, r);
		return;
	}
	mpadd(b, mpone, r);
	mpright(r, n, r);
	mpsub(r, mpone, r);
}