shithub: ircs

ref: 5479f0fb4283ca4a5a58c9b7b50c5b86bde13aa1
dir: /main.c/

View raw version
#include <u.h>
#include <libc.h>
#include <auth.h>
#include <libsec.h>
#include <thread.h>
#include "dat.h"
#include "fns.h"

enum {
	Stacksize = 2 * 8192,
	Chanmsgs = 10,
};

enum {
	Init = 0,
	Ok,
	Fail,
	Quit,
	Connend,
	Connerr,
	Pipeend,
	Pipeerr,
	Reconn,
	Reconnok,
};

typedef struct Log Log;
typedef struct Ch Ch;
typedef struct User User;

struct Log {
	int	fd;
	int	mday;
};

struct Ch {
	char	name[64];
	Log	log;
	List	*users;
};

struct User {
	char	nick[32];
	List	*channels;
};

int mainstacksize = Stacksize;
char *logdir;

static char *service = "ircs";

static char *post;
static char *file;
static char *mynick;
static char *user;
static char *addr;

static int cmdfd = -1;
static int ircfd = -1;

static int state = Init;

static Channel *inic;
static Channel *ctlc;

static Ioproc *outio;
static Ioproc *logio;

static int debug;
static int timestamps = 1;
static int usetls;
static int passwd;
static int rawlog;
static int userdb = 1;

static Log mainlog;
static Trie *channels;
static Trie *users;

static char *
estrdup(char *s)
{
	s = strdup(s);
	if(s == nil)
		sysfatal("strdup: not enough mem");
	setmalloctag(s, getcallerpc(&s));
	return s;
}

static int
eiowrite(Ioproc *io, int fd, void *buf, long n)
{
	if(iowrite(io, fd, buf, n) != n){
		perror("iowrite");
		return -1;
	}
	return 0;
}

static long
_iosetname(va_list *arg)
{
	char *name;

	name = va_arg(*arg, char*);
	threadsetname(name);
	return 0;
}

static void
iosetname(Ioproc *io, char *name, ...)
{
	va_list arg;
	char buf[Bufsize];

	va_start(arg, name);
	vsnprint(buf, sizeof(buf), name, arg);
	va_end(arg);
	iocall(io, _iosetname, buf);
}

static Ch *
challoc(char *name)
{
	Ch *c;
	char buf[128];

	c = emalloc(sizeof(Ch));
	snprint(c->name, sizeof(c->name), "%s", name);
	if(logdir != nil){		
		snprint(buf, sizeof(buf), "%s/%s", logdir, name);
		c->log.fd = create(buf, OWRITE, 0600 | DMAPPEND);
		if(c->log.fd < 0)
			sysfatal("create: %r");
		c->log.mday = 0;
	} else
		c->log.fd = 0;
	c->users = nil;
	return c;
}

static Ch *
chget(char *name)
{
	Ch *c;
	Rune key[64];

	runesnprint(key, nelem(key), "%s", name);
	c = trieget(channels, key);
	if(c == nil){
		c = challoc(name);
		trieadd(channels, key, c);
	}
	return c;
}

static void
chfree(Ch *c)
{
	Rune key[64];

	runesnprint(key, nelem(key), "%s", c->name);
	triedel(channels, key);
	listfree(&c->users);
	if(c->log.fd > 0)
		close(c->log.fd);
	free(c);
}

static User *
useralloc(char *nick)
{
	User *u;

	u = emalloc(sizeof(User));
	snprint(u->nick, sizeof(u->nick), "%s", nick);
	u->channels = nil;
	return u;
}

static User *
userget(char *nick)
{
	User *u;
	Rune key[32];

	runesnprint(key, nelem(key), "%s", nick);
	u = trieget(users, key);
	if(u == nil){
		u = useralloc(nick);
		trieadd(users, key, u);
	}
	return u;
}

static void
userfree(User *u)
{
	Rune key[32];

	runesnprint(key, nelem(key), "%s", u->nick);
	triedel(users, key);
	listfree(&u->channels);
	free(u);
}

static void
usernick(User *u, char *newnick)
{
	Rune key[32];

	runesnprint(key, nelem(key), "%s", u->nick);
	triedel(users, key);
	snprint(u->nick, sizeof(u->nick), "%s", newnick);
	runesnprint(key, nelem(key), "%s", u->nick);
	trieadd(users, key, u);
}

static void
ircsend(char *msg)
{
	if(debug) fprint(2, "ircsend: %s\n", msg);
	if(*msg != 0){
		eiowrite(outio, ircfd, msg, strlen(msg));
		eiowrite(outio, ircfd, "\r\n", 2);
	}
}

static void
logsend(char *msg, long time, Log *log)
{
	char buf[Bufsize];
	int n;

	if(*msg != 0 && log->fd > 0){
		if(timestamps){
			if(rawlog)
				n = snprint(buf, sizeof(buf),
					"%ld ", time);
			else
				n = ircfmttime(buf, sizeof(buf),
					time, &log->mday);
		}else
			n = 0;
		n += snprint(buf+n, sizeof(buf)-n, "%s\n", msg);
		eiowrite(logio, log->fd, buf, n);
	}
}

static void
loginfo(char *msg)
{
	Triewalk w;
	Rune key[64];
	Ch *c;
	long t;

	t = time(0);
	logsend(msg, t, &mainlog);
	if(logdir != nil){
		triewalk(channels, &w, key, nelem(key));
		while((c = trienext(&w)) != nil)
			logsend(msg, t, &c->log);
	}
}

static void
joinmsg(Ircmsg *irc, char *msg, long time)
{
	Ch *c;
	Log *log;
	User *u;
	char *channel, *m, buf[Bufsize];
	
	channel = irc->par[0];

	if(rawlog)
		m = msg;
	else {
		ircfmtjoin(irc, buf, sizeof(buf));
		m = buf;
	}
	if(logdir != nil){
		c = chget(channel);
		log = &c->log;

		if(userdb && irc->pre != nil){
			u = userget(irc->nick);
			listadd(&c->users, u);
			listadd(&u->channels, c);
		}
	} else {
		log = &mainlog;

		/* for rejoin() */
		if(irc->pre != nil && strcmp(irc->nick, mynick) == 0)
			chget(channel);
	}
	logsend(m, time, log);
}

static void
freeusers(Ch *c)
{
	List *i;
	User *u;

	for(i = c->users; i != nil; i = i->next){
		u = i->val;
		listdel(&u->channels, c);
		if(u->channels == nil)
			userfree(u);
	}
	listfree(&c->users);
}

static void
rejoin(void)
{
	Triewalk w;
	Rune key[64];
	Ch *c;
	long t;
	char buf[Bufsize];

	t = time(0);
	triewalk(channels, &w, key, nelem(key));
	while((c = trienext(&w)) != nil){
		snprint(buf, sizeof(buf), "JOIN %s", c->name);
		logsend(buf, t, &mainlog);
		ircsend(buf);
	}
}

static void
partmsg(Ircmsg *irc, char *msg, long time)
{
	Ch *c;
	User *u;
	char *channel, *m, buf[Bufsize];

	channel = irc->par[0];

	if(rawlog)
		m = msg;
	else {
		ircfmtpart(irc, buf, sizeof(buf));
		m = buf;
	}
	if(logdir != nil){
		c = chget(channel);

		if(userdb && irc->pre != nil){
			u = userget(irc->nick);
			listdel(&c->users, u);
			listdel(&u->channels, c);
			if(u->channels == nil)
				userfree(u);
		}
		logsend(m, time, &c->log);

		if(irc->pre != nil && strcmp(irc->nick, mynick) == 0){
			freeusers(c);
			chfree(c);
		}
	} else {
		logsend(m, time, &mainlog);

		/* for rejoin() */
		if(irc->pre != nil && strcmp(irc->nick, mynick) == 0){
			c = chget(channel);
			freeusers(c);
			chfree(c);
		}
	}
}

static void
quitmsg(Ircmsg *irc, char *msg, long time)
{
	User *u;
	List *i;
	Ch *c;
	char *m, buf[Bufsize];

	if(rawlog)
		m = msg;
	else {
		ircfmtquit(irc, buf, sizeof(buf));
		m = buf;
	}
	if(logdir != nil && userdb && irc->pre != nil){
		u = userget(irc->nick);

		for(i = u->channels; i != nil; i = i->next){
			c = i->val;
			logsend(m, time, &c->log);
			listdel(&c->users, u);
		}
		if(u->channels == nil)
			logsend(m, time, &mainlog);

		userfree(u);
	} else
		logsend(m, time, &mainlog);
}

static void
nickmsg(Ircmsg *irc, char *msg, long time)
{
	User *u;
	List *i;
	Ch *c;
	char *newnick, *m, buf[Bufsize];

	newnick = irc->par[0];

	if(rawlog)
		m = msg;
	else {
		ircfmtnick(irc, buf, sizeof(buf));
		m = buf;
	}
	if(logdir != nil && userdb && irc->pre != nil){
		u = userget(irc->nick);

		for(i = u->channels; i != nil; i = i->next){
			c = i->val;
			logsend(m, time, &c->log);
		}
		if(u->channels == nil)
			logsend(m, time, &mainlog);

		usernick(u, newnick);
	} else
		logsend(m, time, &mainlog);

	if(irc->pre != nil && strcmp(irc->nick, mynick) == 0){
		free(mynick);
		mynick = estrdup(newnick);
	}
}

static void
namreply(Ircmsg *irc, char *msg, long time)
{
	Ch *c;
	Log *log;
	User *u;
	char *channel, *nick, *m, *a[128], buf[Bufsize];
	int i, n;

	if(rawlog)
		m = msg;
	else {
		ircfmtnumeric(irc, buf, sizeof(buf));
		m = buf;
	}
	if(logdir != nil){
		channel = irc->par[2];
		c = chget(channel);
		log = &c->log;

		if(userdb){
			n = getfields(irc->trail, a, nelem(a), 1, " ");
			for(i = 0; i < n; i++){
				nick = a[i];
				if(*nick == '@' || *nick == '+'){
					nick++;
					if(*nick == 0)
						continue;
				}
				u = userget(nick);
				listadd(&c->users, u);
				listadd(&u->channels, c);
			}
		}
	} else
		log = &mainlog;

	logsend(m, time, log);
}

static void
srvmsg(char *msg, long time)
{
	Ircmsg irc;
	Ircfmt fmt;
	Ch *c;
	Log *log;
	char *target, *m, buf[Bufsize];

	ircparse(&irc, msg);
	if(debug)
		ircprint(&irc);

	fmt = nil;
	target = nil;

	if(strcmp(irc.cmd, "PRIVMSG") == 0){
		fmt = ircfmtpriv;
		target = irc.par[0];
	} else if(strcmp(irc.cmd, "JOIN") == 0){
		joinmsg(&irc, msg, time);
		return;
	} else if(strcmp(irc.cmd, "QUIT") == 0){
		quitmsg(&irc, msg, time);
		return;
	} else if(strcmp(irc.cmd, "PART") == 0){
		partmsg(&irc, msg, time);
		return;
	} else if(strcmp(irc.cmd, "NICK") == 0){
		nickmsg(&irc, msg, time);
		return;
	} else if(strcmp(irc.cmd, "MODE") == 0 ||
		strcmp(irc.cmd, "KICK") == 0 ||
		strcmp(irc.cmd, "TOPIC") == 0){
		target = irc.par[0];
	} else if(strcmp(irc.cmd, "001") == 0){
		/* welcome */
		if(state == Reconnok)
			rejoin();
		sendul(ctlc, Ok);
	} else if(strcmp(irc.cmd, "433") == 0){
		/* nick in use */
		sendul(ctlc, Fail);
	} else if(strcmp(irc.cmd, "332") == 0 ||
		strcmp(irc.cmd, "333") == 0 ||
		strcmp(irc.cmd, "366") == 0){
		/* 332 = RPL_TOPIC */
		/* 333 = RPL_TOPICWHOTIME */
		/* 366 = RPL_ENDOFNAMES */
		target = irc.par[1];
	} else if(strcmp(irc.cmd, "353") == 0){
		/* 353 = RPL_NAMREPLY */
		namreply(&irc, msg, time);
		return;
	}

	if(irc.cmd[0] >= '0' && irc.cmd[0] <= '9')
		fmt = ircfmtnumeric;

	if(logdir != nil && ircischan(target)){
		c = chget(target);
		log = &c->log;
	} else
		log = &mainlog;

	if(rawlog || fmt == nil)
		m = msg;
	else {
		fmt(&irc, buf, sizeof(buf));
		m = buf;
	}		
	logsend(m, time, log);
}

static void
usrmsg(char *msg, long time)
{
	Ircmsg irc;
	Ircfmt fmt;
	Ch *c;
	Log *log;
	char *target, *m, buf[Bufsize];

	ircparse(&irc, msg);
	if(debug)
		ircprint(&irc);

	fmt = nil;
	target = nil;

	if(strcmp(irc.cmd, "PRIVMSG") == 0){
		fmt = ircfmtpriv;
		target = irc.par[0];
	} else if(strcmp(irc.cmd, "QUIT") == 0){
		state = Quit;
	}
	if(logdir != nil && ircischan(target)){
		c = chget(target);
		log = &c->log;
	} else
		log = &mainlog;

	if(rawlog || fmt == nil)
		m = msg;
	else {
		fmt(&irc, buf, sizeof(buf));
		m = buf;
	}
	logsend(m, time, log);
	ircsend(msg);
}

static void
touch(int fd, long time)
{
	Dir d;

	nulldir(&d);
	d.mtime = time;
	if(dirfwstat(fd, &d) < 0)
		perror("dirfwstat");
}

static void
ircin(void *v)
{
	Ioproc *io;
	char buf[Bufsize], *e, *p;
	long n, t;

	io = v;
	threadsetname("ircin");
	iosetname(io, "%s net reader", service);

	p = buf;
	e = buf + sizeof(buf);

	while((n = ioread(io, ircfd, p, e - p)) > 0){
		t = time(0);
		p += n;
		while((p > buf) && (e = memchr(buf, '\n', p - buf))){
			if((e > buf) && (e[-1] == '\r'))
				e[-1] = 0;
			*e++ = 0;
			if(strncmp(buf, "PING", 4) == 0){
				buf[1] = 'O';
				ircsend(buf);
				touch(mainlog.fd, t);
			} else
				srvmsg(buf, t);
			p -= e - buf;
			if(p > buf)
				memmove(buf, e, p - buf);
		}
		e = buf + sizeof(buf);
	}
	if(n < 0){
		perror("ircin: ioread");
		loginfo("ircs: connection error");
		sendul(ctlc, Connerr);
	}
	if(n == 0){
		loginfo("ircs: connection eof");
		sendul(ctlc, Connend);
	}
	closeioproc(io);
	threadexits(nil);
}

static void
cmdin(void *v)
{
	Ioproc *io;
	char buf[Bufsize];
	long n, t;

	io = v;
	threadsetname("cmdin");
	iosetname(io, "%s cmd reader", service);

	while((n = ioread(io, cmdfd, buf, sizeof(buf)-1)) > 0){
		t = time(0);
		buf[n] = 0;
		if(debug) fprint(2, "cmdin: %s", buf);
		if(strncmp(buf, "quit", 4) == 0){
			sendul(ctlc, Quit);
			break;
		}
		if(buf[n-1] == '\n')
			buf[n-1] = 0;
		if(n > 1 && buf[n-2] == '\r')
			buf[n-2] = 0;
		usrmsg(buf, t);
	}
	if(n < 0){
		perror("cmdin: ioread");
		loginfo("ircs: command pipe error");
		sendul(ctlc, Pipeerr);
	}
	if(n == 0){
		loginfo("ircs: command pipe eof");
		sendul(ctlc, Pipeend);
	}
	closeioproc(io);
	threadexits(nil);
}

static void
cleanup(void)
{
	if(access(post, AEXIST) == 0)
		if(remove(post) < 0)
			perror("remove");
}

static int
note(void *, char *msg)
{
	fprint(2, "%d: received note: %s\n", getpid(), msg);
	cleanup();
	threadexitsall("note");
	return 0;
}

static void
usage(void)
{
	fprint(2, "usage: %s [-deprTU] [-s srvname] [-f file] nick[!user] addr\n", argv0);
	threadexits("usage");
}

static int
connect(void)
{
	TLSconn tls;
	UserPasswd *u;
	char buf[Bufsize];
	long t;
	int fd;

	if(ircfd >= 0)
		close(ircfd);

	if(passwd){
		u = auth_getuserpasswd(auth_getkey,
			"proto=pass service=irc server=%q user=%q", addr, mynick);
		if(u == nil){
			perror("auth_getuserpasswd");
			return -1;
		}
	}else
		u = nil;

	ircfd = dial(netmkaddr(addr, "tcp", usetls ? "6697" : "6667"), nil, nil, nil);
	if(ircfd < 0){
		perror("dial");
		goto err;
	}
	if(usetls){
		memset(&tls, 0, sizeof(tls));
		fd = tlsClient(ircfd, &tls);
		free(tls.cert);
		if(fd < 0){
			perror("tlsClient");
			goto err;
		}
		ircfd = fd;
	}

	t = time(0);
	if(u != nil){
		snprint(buf, sizeof(buf), "PASS");
		logsend(buf, t, &mainlog);
		snprint(buf+4, sizeof(buf)-4, " %s", u->passwd);
		ircsend(buf);
		free(u);
	}
	snprint(buf, sizeof(buf), "USER %s 0 * :<nil>", user);
	logsend(buf, t, &mainlog);
	ircsend(buf);

	snprint(buf, sizeof(buf), "NICK %s", mynick);
	logsend(buf, t, &mainlog);
	ircsend(buf);

	return 0;
err:
	free(u);
	close(ircfd);
	return -1;
}

static void
freeallusers(void)
{
	Triewalk w;
	Rune key[64];
	Ch *c;

	triewalk(channels, &w, key, nelem(key));
	while((c = trienext(&w)) != nil)
		freeusers(c);
}

static void
reconnproc(void *)
{
	int c, i;
	long sec[] = { 5, 30, 60, 300, 900 };

	threadsetname("%s reconnect", service);
	c = 0;
	i = 0;
	while(connect() < 0){
		c++;
		threadsetname("%s reconnect attempts=%d",
			service, c);
		sleep(sec[i] * 1000);
		if(i < nelem(sec)-1)
			i++;
	}
	sendul(ctlc, Reconnok);
	threadexits(nil);
}

static void
mainproc(void *)
{
	ulong msg;

	threadnotify(note, 1);
	threadsetname("%s %s %s %s", mynick, addr, post,
		logdir != nil ? logdir : file);

	outio = ioproc();
	logio = ioproc();

	iosetname(outio, "%s net writer", service);
	iosetname(logio, "%s log writer", service);

	if(connect() < 0){
		sendul(inic, Fail);
		threadexits("no");
	}
	threadcreate(cmdin, ioproc(), Stacksize);
	threadcreate(ircin, ioproc(), Stacksize);

	for(;;){
		msg = recvul(ctlc);
		if(state == Init)
			sendul(inic, msg);
		switch(msg){
		case Ok:
			state = Ok;
			break;
		case Connend:
		case Connerr:
			if(state == Quit)
				goto done;
			if(state != Init){
				state = Reconn;
				if(userdb)
					freeallusers();
				loginfo("ircs: reconnecting");
				proccreate(reconnproc, nil, Stacksize);
			}
			break;
		case Reconnok:
			state = Reconnok;
			threadcreate(ircin, ioproc(), Stacksize);
			break;
		case Pipeend:
		case Pipeerr:
		case Fail:
		case Quit:
			if(state != Init)
				goto done;
			break;
		default:
			assert(0);
		}
	}
done:
	cleanup();
	threadexitsall(nil);
}

static void
postpipe(void)
{
	int fd, p[2];

	assert(service != nil && *service != 0);

	post = smprint("/srv/%s", service);
	if(post == nil)
		sysfatal("smprint: %r");

	fd = create(post, OWRITE, 0600);
	if(fd < 0)
		sysfatal("create: %r");

	if(pipe(p) < 0)
		sysfatal("pipe: %r");

	if(fprint(fd, "%d", p[1]) < 0)
		sysfatal("can't post: %r");

	close(fd);
	close(p[1]);

	if(dup(p[0], 0) < 0)
		sysfatal("dup: %r");

	close(p[0]);
	cmdfd = 0;
}

static void
initlog(void)
{
	int fd;

	if(file == nil){
		assert(service != nil && *service != 0);

		logdir = smprint("/tmp/%s", service);
		if(logdir == nil)
			sysfatal("smprint: %r");

		if(access(logdir, AEXIST) < 0){
			fd = create(logdir, OREAD, DMDIR | 0700);
			if(fd < 0)
				sysfatal("create: %r");
			close(fd);
		}
		file = smprint("%s/log", logdir);
		if(file == nil)
			sysfatal("smprint: %r");
	}
	fd = create(file, OWRITE, 0600 | DMAPPEND);
	if(fd < 0)
		sysfatal("create: %r");

	if(dup(fd, 1) < 0)
		sysfatal("dup: %r");

	close(fd);
	mainlog.fd = 1;
	mainlog.mday = 0;
}

static void
setnickuser(char *nickuser)
{
	char *p;

	if(p = strchr(nickuser, '!')){
		*p = 0;
		if(*(++p) != 0)
			user = p;
		else
			user = nickuser;
	} else
		user = nickuser;

	if(*nickuser != 0)
		mynick = nickuser;
	else
		sysfatal("empty nick");
}

void
threadmain(int argc, char *argv[])
{
	ARGBEGIN{
	case 'd':
		debug++;
		break;
	case 'e':
		usetls++;
		break;
	case 'f':
		file = EARGF(usage());
		userdb = 0;
		break;
	case 'p':
		passwd++;
		usetls++;
		break;
	case 'r':
		rawlog++;
		break;
	case 's':
		service = EARGF(usage());
		break;
	case 'T':
		timestamps = 0;
		break;
	case 'U':
		userdb = 0;
		break;
	default:
		usage();
	}ARGEND

	if(argc != 2)
		usage();

	setnickuser(argv[0]);
	addr = argv[1];

	fprint(2, "initializing, please wait\n");

	initlog();
	postpipe();
	threadnotify(note, 1);

	inic = chancreate(sizeof(ulong), Chanmsgs);	
	ctlc = chancreate(sizeof(ulong), Chanmsgs);

	if(inic == nil || ctlc == nil)
		sysfatal("chancreate");

	channels = triealloc();
	users = triealloc();

	procrfork(mainproc, nil, Stacksize, RFNOTEG);

	switch(recvul(inic)){
	case Ok:
		break;
	default:
		fprint(2, "init failed, exiting\n");
		cleanup();
		threadexitsall("no");
	}
	fprint(2, "init ok\n");
	chanclose(inic);
	threadexits(nil);
}