shithub: stashfs

ref: d6730b9aeffeb72392c63b75af8694627fda457f
dir: /stashfs.c/

View raw version
#include <u.h>
#include <libc.h>
#include <auth.h>
#include <fcall.h>
#include <thread.h>
#include <9p.h>
#include <libsec.h>
#include <authsrv.h>

/* nonce[24] | poly1305-tag[16] | E(next-offset[8] | data[4096]) */
enum {
	Overhead = 24+16+8,
	Datasize = 4096,
	Blocksize = Overhead+Datasize,
};

typedef struct FileHdr FileHdr;
typedef struct FileKey FileKey;
typedef struct FileAux FileAux;

typedef struct KeyEntry KeyEntry;

struct FileHdr
{
	char	V[21];	/* version signature */
	uchar	N;	/* log2(computation cost) */
	uchar	R;	/* log2(scrypt block size) */
	uchar	P;	/* parallelization factor */
	uchar	S[32];	/* password salt */
	uchar	T[32];	/* random file nonce */
};

struct FileKey
{
	DigestState	ds;
	Salsastate	cs;
};

struct FileAux
{
	int		fd;
	int		mode;
	FileHdr		hdr;
	FileKey		key;
};

struct KeyEntry
{
	KeyEntry	*next;

	uchar		K[32];	/* K = scrypt(pass, S, N, R, P) */
	uchar		S[32];
	uchar		N;
	uchar		R;
	uchar		P;
};

int	verbose;
uchar	buf[Blocksize];
FileHdr defhdr = { "SCRYPTXSALSAPOLY1305\n", 16, 3, 1 };
char*	pass;
KeyEntry *keylist;

char*
getpass(int confirm)
{
Again:
	if(pass != nil){
		memset(pass, 0, strlen(pass));
		free(pass);
	}
	pass = readcons("Password", nil, 1);
	if(pass == nil || pass[0] == 0)
		sysfatal("no password");
	if(confirm){
		char *pass2;
		int n;

		pass2 = readcons("Confirm", nil, 1);
		if(pass2 == nil || pass2[0] == 0)
			sysfatal("no password");
		n = strcmp(pass2, pass);
		memset(pass2, 0, strlen(pass2));
		free(pass2);

		if(n != 0){
			fprint(2, "mismatch\n");
			goto Again;
		}
	}
	return pass;
}

KeyEntry*
getkey(uchar N, uchar R, uchar P, uchar S[32])
{
	int pfd[2], pid, n;
	KeyEntry *ke;
	Waitmsg *w;
	char *err;

	/* see if we did the key derivation already */
	for(ke = keylist; ke != nil; ke = ke->next)
		if(ke->N == N && ke->R == R && ke->P == P
		&& memcmp(ke->S, S, sizeof(ke->S)) == 0)
			return ke;

	if(verbose)
		fprint(2, "%s: will require %lldMB of memory\n",
			argv0, (128LL * (1LL<<R) * (1LL<<N)) >> 20);

	if(pass == nil)
		getpass(0);

	ke = emalloc9p(sizeof(*ke));
	ke->N = N;
	ke->R = R;
	ke->P = P;
	memmove(ke->S, S, sizeof(ke->S));
	if(pipe(pfd) < 0)
		return nil;
	pid = fork();
	if(pid < 0)
		return nil;
	if(pid == 0){
		close(pfd[0]);
		alarm(60*1000);	/* timeout after a minute */
		if((err = scrypt((uchar*)pass, strlen(pass),
			ke->S, sizeof(ke->S),
			1<<ke->N, 1<<ke->R, ke->P,
			ke->K, sizeof(ke->K))) != nil)
			exits(err);
		write(pfd[1], ke->K, sizeof(ke->K));
		exits(nil);
	}
	close(pfd[1]);
	n = readn(pfd[0], ke->K, sizeof(ke->K));
	close(pfd[0]);
	while((w = wait()) != nil){
		if(w->pid == pid){
			if(verbose)
				fprint(2, "%s: spent %g seconds crunching key...\n",
					argv0, (double)w->time[2] / 1000.0);
			if(w->msg[0]){
				werrstr("%s", w->msg);
				free(w);
				n = -1;
				break;
			}
		}
		free(w);
	}
	if(n != sizeof(ke->K)){
		free(ke);
		return nil;
	}

	ke->next = keylist;
	keylist = ke;

	return ke;
}

uvlong
cryptsetup(FileKey *k, uchar nonce[24], uvlong o)
{
	uchar otk[SalsaBsize];

	salsa_setiv(&k->cs, nonce);
	salsa_setblock(&k->cs, (o / Datasize) * ((SalsaBsize+Datasize)/SalsaBsize));

	memset(otk, 0, sizeof(otk));
	salsa_encrypt(otk, sizeof(otk), &k->cs);

	/* first 256 bits used as one time authenticator key */
	memset(&k->ds, 0, sizeof(k->ds));
	poly1305(nil, 0, otk, 32, nil, &k->ds);

	/* last 64 bits used to encrypt next-offset */
	return GBIT64(otk+SalsaBsize-8);
}

int
encryptbuf(FileKey *k, uchar *buf, ulong n, vlong o)
{
	uvlong e;

	genrandom(buf, 24);
	e = cryptsetup(k, buf, o);

	n -= Overhead;
	o += n;
	if(o < 0){
		werrstr("offset too large");
		return -1;
	}

	/* random fill block tail */
	if(n < Datasize)
		genrandom(buf+Overhead+n, Datasize-n);

	/* encrypt plaintext */
	salsa_encrypt(buf+Overhead, Datasize, &k->cs);

	e ^= o;	/* encrypt next-block offset */
	PBIT64(buf+40, e);

	/* authenticate ciphertext */ 
	poly1305(buf+40, 8+Datasize, nil, 0, buf+24, &k->ds);

	return n;
}

int
decryptbuf(FileKey *k, uchar *buf, ulong n, vlong o)
{
	uchar tag[16];
	uvlong e;

	e = cryptsetup(k, buf, o);

	/* authenticate ciphertext */ 
	poly1305(buf+40, 8+Datasize, nil, 0, tag, &k->ds);

	/* check the tag */
	if(tsmemcmp(tag, buf+24, 16) != 0){
		werrstr("bad block tag");
		return -1;
	}

	/* decrypt ciphertext */
	salsa_encrypt(buf+Overhead, Datasize, &k->cs);

	/* decrypt next-block offset */
	e ^= GBIT64(buf+40);

	/* sanity check offset */
	n -= Overhead;
	if(e < o || e > o+n)
		e = o;

	/* zero fill remainder */
	n = e - o;
	if(n < Datasize)
		memset(buf+Overhead+n, 0, Datasize - n);

	return n;
}

vlong
off2block(uvlong off, int *rem)
{
	if(rem != nil)
		*rem = off % Datasize;
	off /= Datasize;
	off *= Blocksize;
	off += sizeof(FileHdr);
	return off;
}

int
cryptread(File *f, void *data, int n, vlong o)
{
	FileAux *a;
	int r, m;
	uchar *p;

	if(o >= f->length)
		return 0;
	if(f->length - o < n)
		n = f->length - o;

	a = f->aux;
	for(p = (uchar*)data; n > 0; p += m, n -= m, o += m){
		if(pread(a->fd, buf, Blocksize, off2block(o, &r)) != Blocksize)
			return -1;
		m = decryptbuf(&a->key, buf, Blocksize, o - r);
		if(m < 0)
			return -1;
		m -= r;
		if(m <= 0)
			break;
		if(n < m)
			m = n;
		memmove(p, buf+Overhead+r, m);
	}

	return p - (uchar*)data;
}

int
cryptwrite(File *f, void *data, int n, vlong o)
{
	FileAux *a;
	vlong boff;
	int r, m;
	uchar *p;

	for(p = (uchar*)data;;p += r, n -= m, o += m){
		boff = off2block(o, &r);
		m = Datasize - r;
		if(n <= m)
			break;
		r = cryptwrite(f, p, m, o);
		if(r < 0)
			return -1;
	}

	a = f->aux;
	if(n < Datasize && boff < off2block(f->length + Datasize-1, nil)){
		if(pread(a->fd, buf, Blocksize, boff) != Blocksize)
			return -1;
		if(decryptbuf(&a->key, buf, Blocksize, o - r) < 0)
			return -1;
	}

	if(p == nil)
		memset(buf+Overhead+r, 0, n);
	else {
		memmove(buf+Overhead+r, p, n);
		p += n;
	}

	if(f->length - o < m){
		m = f->length - o;
		if(m < n)
			m = n;
	}
	if(encryptbuf(&a->key, buf, Overhead+r+m, o - r) < 0)
		return -1;
	if(pwrite(a->fd, buf, Blocksize, boff) != Blocksize)
		return -1;

	o += n;
	if(o > f->length)
		f->length = o;

	return p - (uchar*)data;
}

int
cryptsize(File *f, vlong size)
{
	Dir d;

	if(f->length == size)
		return 0;

	if(size > f->length){
		while(size - f->length > 1<<30)
			if(cryptwrite(f, nil, 1<<30, f->length) < 0)
				return -1;
		if(cryptwrite(f, nil, size - f->length, f->length) < 0)
			return -1;
		return 0;
	}

	if(size < 0){
		werrstr("negative file size");
		return -1;
	}
	nulldir(&d);
	d.length = off2block(size + Datasize-1, nil);
	if(dirfwstat(((FileAux*)f->aux)->fd, &d) < 0)
		return -1;

	f->length = size;
	if(cryptwrite(f, nil, 0, size) < 0)	/* rewrite last block */
		return -1;

	return 0;
}

int
cryptfilekey(FileAux *a)
{
	KeyEntry *ke;
	FileHdr *h;
	uchar E[32];

	/* K = scrypt(pass, N, R, P, S); */
	h = &a->hdr;
	if((ke = getkey(h->N, h->R, h->P, h->S)) == nil)
		return -1;

	/* E = hkdf_sha256(T, V, K) */
	hkdf_x(	h->T, sizeof(h->T),
		(uchar*)h->V, sizeof(h->V),
		ke->K, sizeof(ke->K),
		E, sizeof(E),
		hmac_sha2_256, SHA2_256dlen);

	/* xsalsa with 192 bit nonces */
	setupSalsastate(&a->key.cs, E, sizeof(E), nil, 24, 20);

	memset(E, 0, sizeof(E));
	return 0;
}

void
cryptclose(File *f)
{
	FileAux *a = f->aux;
	if(a->fd >= 0){
		close(a->fd);
		a->fd = -1;
		memset(&a->key, 0, sizeof(a->key));
		memset(&a->hdr, 0, sizeof(a->hdr));
	}
	a->mode = 0;
}

int
cryptopen(File *f, int mode)
{
	FileAux *a = f->aux;
	vlong o;
	int n;

	mode &= 3;
	if(mode >= OWRITE)
		mode = ORDWR;
	if(a->fd >= 0){
		if(a->mode == mode)
			return 0;
		close(a->fd);
	}
	if((a->fd = open(f->name, mode)) < 0){
		werrstr("open: %r");
		goto Error;
	}
	a->mode = mode;
	o = seek(a->fd, 0, 2) - Blocksize;
	o -= sizeof(FileHdr);
	o -= (o / Blocksize) * Overhead;
	if(o < 0 || o % Datasize){
		werrstr("bad file size");
		goto Error;
	}

	/* read header */
	if(seek(a->fd, 0, 0) != 0
	|| readn(a->fd, &a->hdr, sizeof(a->hdr)) != sizeof(a->hdr)
	|| memcmp(a->hdr.V, defhdr.V, sizeof(a->hdr.V)) != 0){
		werrstr("bad file header");
		goto Error;
	}

	/* decrypt last block to determine plaintext size */
	if(cryptfilekey(a) < 0
	|| pread(a->fd, buf, Blocksize, off2block(o, nil)) != Blocksize
	|| (n = decryptbuf(&a->key, buf, Blocksize, o)) < 0)
		goto Error;

	f->length = o+n;
	return 0;
Error:
	cryptclose(f);
	return -1;
}

File*
cryptcreate(File *root, char *name, int mode)
{
	FileAux *a;
	File *f;
	Dir *d;
	int fd;

	if((mode & DMDIR) != 0){
		werrstr("can't create directory");
		return nil;
	}
	mode = (mode & ~0666) | ((mode & root->mode) & 0666);
	if((fd = create(name, OEXCL|ORDWR, mode)) < 0)
		return nil;

	if((d = dirfstat(fd)) == nil
	|| (f = createfile(root, name, d->uid, d->mode, nil)) == nil){
		close(fd);
		remove(name);
		return nil;
	}
	free(d);

	a = emalloc9p(sizeof(FileAux));
	a->fd = fd;
	a->mode = ORDWR;
	a->hdr = defhdr;
	f->aux = a;
	f->length = 0;
	genrandom(a->hdr.T, sizeof(a->hdr.T));
	if(cryptfilekey(a) < 0
	|| write(fd, &a->hdr, sizeof(a->hdr)) != sizeof(a->hdr)
	|| cryptwrite(f, nil, 0, 0) < 0){
		removefile(f);
		closefile(f);
		remove(name);
		return nil;
	}

	return f;
}

void
fsrpc(Req *req)
{
	File *f = req->fid->file;
	int n = -1;

	if((f->qid.type & QTDIR) != 0 && req->ifcall.type != Tcreate)
		goto Done;
	switch(req->ifcall.type){
	case Tcreate:
		if((f = cryptcreate(f, req->ifcall.name, req->ifcall.perm)) == nil)
			break;
		req->fid->file = f;
		req->ofcall.qid = f->qid;
		goto Done;
	case Topen:
		if(cryptopen(f, req->ifcall.mode) < 0)
			break;
		if((req->ifcall.mode & OTRUNC) != 0 && cryptsize(f, 0) < 0)
			break;
		goto Done;
	case Tremove:
		if(remove(f->name) < 0)
			break;
		goto Done;
	case Twstat:
		if(req->d.length != ~0LL && cryptsize(f, req->d.length) < 0)
			break;
		if(req->d.name[0] != '\0' && strcmp(req->d.name, f->name) != 0){
			Dir nd;

			nulldir(&nd);
			nd.name = req->d.name;
			if(dirwstat(f->name, &nd) < 0)
				break;
			free(f->name);
			f->name = estrdup9p(nd.name);
		}
		goto Done;
	case Tread:
		n = cryptread(f, req->ofcall.data, req->ifcall.count, req->ifcall.offset);
		break;
	case Twrite:
		if(req->ifcall.offset > f->length && cryptsize(f, req->ifcall.offset) < 0)
			break;
		n = cryptwrite(f, req->ifcall.data, req->ifcall.count, req->ifcall.offset);
		break;
	}
	if(n < 0){
		responderror(req);
		return;
	}
	req->ofcall.count = n;
Done:
	respond(req, nil);
}

void
destroyfid(Fid *fid)
{
	File *f = fid->file;

	if(fid->omode == -1 || f == nil)
		return;
	if(f->ref <= 2)
		cryptclose(f);
	if((fid->omode & ORCLOSE) != 0 && f->parent != nil){
		if(remove(f->name) < 0)
			return;
		removefile(f);
	}
}

void
destroyfile(File *f)
{
	cryptclose(f);
	free(f->aux);
	f->aux = nil;
}

Srv fs = {
	.create = fsrpc,
	.open = fsrpc,
	.remove = fsrpc,
	.wstat = fsrpc,
	.read = fsrpc,
	.write = fsrpc,
	.destroyfid = destroyfid,
};

void
usage(void)
{
	fprint(2, "usage: %s [-Dv] [-N N] [-R R] [-P P] [-m mtpt] dir\n", argv0);
	exits("usage");
}

void
main(int argc, char *argv[])
{
	int fd, nd;
	char *mtpt, *s;
	Dir *d, *dd;

	mtpt = ".";
	fmtinstall('H', encodefmt);

	genrandom(defhdr.S, sizeof(defhdr.S));

	ARGBEGIN {
	case 'N':
		defhdr.N = atoi(EARGF(usage()));
		break;
	case 'R':
		defhdr.R = atoi(EARGF(usage()));
		break;
	case 'P':
		defhdr.P = atoi(EARGF(usage()));
		break;
	case 'm':
		mtpt = EARGF(usage());
		break;
	case 'v':
		verbose++;
		break;
	case 'D':
		chatty9p++;
		break;
	default:
		usage();
	} ARGEND;

	if(argc != 1)
		usage();

	if((d = dirstat(argv[0])) == nil)
		sysfatal("stat: %r");

	if((d->qid.type & QTDIR) != 0){
		if(chdir(argv[0]) < 0)
			sysfatal("chdir: %r");
		if((fd = open(".", OREAD)) < 0)
			sysfatal("open: %r");
		dd = d;
		d = nil;
		nd = dirreadall(fd, &d);
		close(fd);
	} else {
		if((s = strrchr(argv[0], '/')) != nil){
			*s = 0;
			if(argv[0] && chdir(argv[0]) < 0)
				sysfatal("chdir: %r");
		}
		if((dd = dirstat(".")) == nil)
			sysfatal("stat: %r");
		nd = 1;
	}
	fs.tree = alloctree(dd->uid, dd->gid, dd->mode, destroyfile);
	free(dd);

	for(; nd > 0; nd--, d++){
		FileAux *a;
		File *f;

		if((d->qid.type & QTDIR) != 0)
			continue;

		if((f = createfile(fs.tree->root, d->name, d->uid, d->mode, nil)) == nil)
			continue;

		a = emalloc9p(sizeof(FileAux));
		a->fd = -1;
		a->mode = 0;
		f->aux = a;
		if(cryptopen(f, 0) < 0){
			fprint(2, "%s: can't decrypt %s: %r\n", argv0, f->name);
			removefile(f);
			closefile(f);
			continue;
		}
		cryptclose(f);
		closefile(f);

		/* try to reuse the salt for the directory */
		memmove(defhdr.S, a->hdr.S, sizeof(a->hdr.S));
	}

	if(verbose)
		fprint(2, "%d files decrypted\n", fs.tree->root->nchild);

	if(pass == nil){
		getpass(1);
		getkey(defhdr.N, defhdr.R, defhdr.P, defhdr.S);
	}

	postmountsrv(&fs, nil, mtpt, MBEFORE|MCREATE);
	exits(nil);
}