shithub: riscv

Download patch

ref: 78eff200d85e0fc2f92622d221c0e3d81aaf9522
parent: 466cf20d3524b8e42edc333a6d2df2a01e99a95b
author: cinap_lenrek <cinap_lenrek@felloff.net>
date: Mon Oct 30 17:43:00 EDT 2023

ndb/dns: implement EDNS(0) extension (rfc6891)

To properly handle TCP fallback for servers,
we have to avoid sending responses too big
for the client to accept.

We used to accept up to 8K of UDP requests
(and responses when resolving).

Instead, we now advertise a UDP response size
of 1232 (assuming 1280 MTU) to the client
and take even smaller values into account
from clients (tho not smaller than 512).

This makes sure we truncate packets, signaling
the client that it must retry with TCP.

Note that we still accept up to 8K of
UDP data regardless (for lucky clients).

--- a/sys/src/cmd/ip/snoopy/dns.c
+++ b/sys/src/cmd/ip/snoopy/dns.c
@@ -92,7 +92,13 @@
 	if(rr == nil)
 		return;
 	*rrp = rr->next;
-
+	if(rr->type == Topt){
+		m->p = seprint(m->p, m->e, "opt eflags=%#lux udpsize=%d data=%.*H",
+			rr->eflags, rr->udpsize,
+			rr->opt->dlen, rr->opt->data);
+		rrfree(rr);
+		return;
+	}
 	m->p = seprint(m->p, m->e, "%s name=%s ttl=%lud",
 		rrtypestr(rr->type),
 		rr->owner->name, rr->ttl);
@@ -469,6 +475,10 @@
 		rp->null = emalloc(sizeof(*rp->null));
 		setmalloctag(rp->null, rp->pc);
 		break;
+	case Topt:
+		rp->opt = emalloc(sizeof(*rp->opt));
+		setmalloctag(rp->opt, rp->pc);
+		break;
 	default:
 		if(rrsupported(rp->type))
 			break;
@@ -542,6 +552,11 @@
 			memset(t, 0, sizeof *t);	/* cause trouble */
 			free(t);
 		}
+		break;
+	case Topt:
+		free(rp->opt->data);
+		memset(rp->opt, 0, sizeof *rp->opt);	/* cause trouble */
+		free(rp->opt);
 		break;
 	default:
 		if(rrsupported(rp->type))
--- a/sys/src/cmd/ndb/convDNS2M.c
+++ b/sys/src/cmd/ndb/convDNS2M.c
@@ -198,14 +198,17 @@
 
 	NAME(rp->owner->name);
 	USHORT(rp->type);
-	USHORT(rp->owner->class);
-
-	if(rp->db || (ttl = (long)(rp->expire - now)) > rp->ttl)
-		ttl = rp->ttl;
-	if(ttl < 0)
-		ttl = 0;
-	ULONG(ttl);
-
+	if(rp->type == Topt) {
+		USHORT(rp->udpsize);
+		ULONG(rp->eflags);
+	} else {
+		if(rp->db || (ttl = (long)(rp->expire - now)) > rp->ttl)
+			ttl = rp->ttl;
+		if(ttl < 0)
+			ttl = 0;
+		USHORT(rp->owner->class);
+		ULONG(ttl);
+	}
 	lp = p;			/* leave room for the rdata length */
 	p += 2;
 	data = p;
@@ -301,6 +304,13 @@
 		SYMBOL(rp->caa->tag->name);
 		BYTES(rp->caa->data, rp->caa->dlen);
 		break;
+	case Topt:
+		BYTES(rp->opt->data, rp->opt->dlen);
+		break;
+	default:
+		if(rrsupported(rp->type))
+			break;
+		BYTES(rp->unknown->data, rp->unknown->dlen);
 	}
 
 	/* stuff in the rdata section length */
@@ -361,7 +371,17 @@
 	p = rrloop(m->qd, &m->qdcount, p, ep, &d, 1);
 	p = rrloop(m->an, &m->ancount, p, ep, &d, 0);
 	p = rrloop(m->ns, &m->nscount, p, ep, &d, 0);
+	if(m->edns) {
+		assert(m->edns->next == nil);
+		m->edns->next = m->ar;
+		m->ar = m->edns;
+	}
 	p = rrloop(m->ar, &m->arcount, p, ep, &d, 0);
+	if(m->edns) {
+		assert(m->edns == m->ar);
+		m->ar = m->edns->next;
+		m->edns->next = nil;
+	}
 	if(p > ep) {
 		trunc = Ftrunc;
 		dnslog("udp packet full; truncating my reply");
--- a/sys/src/cmd/ndb/convM2DNS.c
+++ b/sys/src/cmd/ndb/convM2DNS.c
@@ -338,10 +338,14 @@
 
 	type = mstypehack(sp, type, "convM2RR");
 	rp = rralloc(type);
-	rp->owner = dnlookup(dname, class, 1);
-	rp->type = type;
-
-	ULONG(rp->ttl);
+	if(type == Topt) {
+		rp->owner = dnlookup(dname, Cin, 1);
+		rp->udpsize = class;
+		ULONG(rp->eflags);
+	} else {
+		rp->owner = dnlookup(dname, class, 1);
+		ULONG(rp->ttl);
+	}
 	USHORT(len);			/* length of data following */
 	data = sp->p;
 	assert(data != nil);
@@ -465,6 +469,9 @@
 		SYMBOL(rp->caa->tag);
 		BYTES(rp->caa->data, rp->caa->dlen);
 		break;
+	case Topt:
+		BYTES(rp->opt->data, rp->opt->dlen);
+		break;
 	default:
 		if(rrsupported(type)){
 			sp->p = data + len;
@@ -592,6 +599,7 @@
 	if (sp->err)
 		err = strdup(sp->err);		/* live with bad ar's */
 	m->ar = rrloop(sp, "hints",	m->arcount, 0);
+	m->edns = nil;
 	if (sp->trunc)
 		m->flags |= Ftrunc;
 	if (sp->stop)
--- a/sys/src/cmd/ndb/dblookup.c
+++ b/sys/src/cmd/ndb/dblookup.c
@@ -112,7 +112,7 @@
 
 	/* so far only internet lookups are implemented */
 	if(class != Cin)
-		return 0;
+		return nil;
 
 	err = Rname;
 	rp = nil;
@@ -131,7 +131,7 @@
 	if(opendatabase() < 0)
 		goto out;
 	if(dp->rr)
-		err = 0;
+		err = Rok;
 
 	/* first try the given name */
 	if(cfg.cachedb)
@@ -146,7 +146,7 @@
 		snprint(buf, sizeof buf, "*%s", wild);
 		ndp = idnlookup(buf, class, 1);
 		if(ndp->rr)
-			err = 0;
+			err = Rok;
 		if(cfg.cachedb)
 			rp = rrlookup(ndp, type, NOneg);
 		else
--- a/sys/src/cmd/ndb/dn.c
+++ b/sys/src/cmd/ndb/dn.c
@@ -718,9 +718,10 @@
 	for(; rp; rp = next){
 		next = rp->next;
 		rp->next = nil;
-		/* avoid any outside spoofing */
-		if(cfg.cachedb && !rp->db && inmyarea(rp->owner->name)
-		|| !rrsupported(rp->type))
+		if(rp->type == Tall
+		|| rp->type == Topt
+		|| !rrsupported(rp->type)
+		|| cfg.cachedb && !rp->db && inmyarea(rp->owner->name))
 			rrfree(rp);
 		else
 			rrattach1(rp, auth);
@@ -1284,6 +1285,10 @@
 				rp->caa->flags, dnname(rp->caa->tag),
 				rp->caa->dlen, rp->caa->data);
 		break;
+	case Topt:
+		fmtprint(&fstr, "\t%#lux %d %.*H", rp->eflags, rp->udpsize,
+			rp->opt->dlen, rp->opt->data);
+		break;
 	default:
 		if(rrsupported(rp->type))
 			break;
@@ -1910,6 +1915,10 @@
 		rp->null = emalloc(sizeof(*rp->null));
 		setmalloctag(rp->null, rp->pc);
 		break;
+	case Topt:
+		rp->opt = emalloc(sizeof(*rp->opt));
+		setmalloctag(rp->opt, rp->pc);
+		break;
 	default:
 		if(rrsupported(type))
 			break;
@@ -1975,6 +1984,11 @@
 			memset(t, 0, sizeof *t);	/* cause trouble */
 			free(t);
 		}
+		break;
+	case Topt:
+		free(rp->opt->data);
+		memset(rp->opt, 0, sizeof *rp->opt);	/* cause trouble */
+		free(rp->opt);
 		break;
 	default:
 		if(rrsupported(rp->type))
--- a/sys/src/cmd/ndb/dnresolve.c
+++ b/sys/src/cmd/ndb/dnresolve.c
@@ -499,6 +499,41 @@
 		mp->qdcount = 1;
 }
 
+RR*
+getednsopt(DNSmsg *mp)
+{
+	RR *rp;
+
+	rp = rrremtype(&mp->ar, Topt);
+	if(rp == nil)
+		return nil;
+	mp->arcount--;
+	if(rp->udpsize < 512)
+		rp->udpsize = 512;
+	return rp;
+}
+
+RR*
+mkednsopt(void)
+{
+	RR *rp;
+
+	rp = rralloc(Topt);
+	rp->owner = dnlookup("", Cin, 1);
+	rp->eflags = 0;
+
+	/*
+	 * Advertise a safe UDP response size
+	 * instead of Maxudp as that is just
+	 * the worst case we can accept.
+	 *
+	 * 1232 = MTU(1280)-IPv6(40)-UDP(8).
+	 */
+	rp->udpsize = 1232;
+
+	return rp;
+}
+
 /* generate a DNS UDP query packet, return size of request (without Udphdr) */
 int
 mkreq(DN *dp, int type, uchar *pkt, int flags, ushort id)
@@ -516,7 +551,9 @@
 	rp = rralloc(type);
 	rp->owner = dp;
 	initdnsmsg(&m, rp, flags, id);
+	m.edns = mkednsopt();
 	len = convDNS2M(&m, &pkt[Udphdrsize], Maxudp);
+	rrfreelist(m.edns);
 	rrfreelist(rp);
 	return len;
 }
@@ -925,13 +962,22 @@
 	Query nq;
 	DN *ndp;
 	RR *tp, *soarr;
-	int rv;
+	int rv, rcode;
 
 	if(mp->an == nil)
 		stats.negans++;
 
+	/* get the rcode */
+	rcode = mp->flags & Rmask;
+
+	/* get extended rcode from edns */
+	if((tp = getednsopt(mp)) != nil){
+		rcode = (rcode & 15) | (tp->eflags & Ercode) >> 20;
+		rrfreelist(tp);
+	}
+
 	/* ignore any error replies */
-	switch(mp->flags & Rmask){
+	switch(rcode){
 	case Rrefused:
 	case Rserver:
 		stats.negserver++;
@@ -1023,7 +1069,7 @@
 		 *  they can legitimately come from a cache.
 		 */
 		if( /* (mp->flags & Fauth) && */ mp->an == nil)
-			cacheneg(qp->dp, qp->type, (mp->flags & Rmask), soarr);
+			cacheneg(qp->dp, qp->type, rcode, soarr);
 		else
 			rrfreelist(soarr);
 		return 1;
@@ -1034,7 +1080,7 @@
 		 *  negative responses need not be authoritative:
 		 *  they can legitimately come from a cache.
 		 */
-		cacheneg(qp->dp, qp->type, (mp->flags & Rmask), soarr);
+		cacheneg(qp->dp, qp->type, rcode, soarr);
 		return 1;
 	}
 	stats.negnorname++;
@@ -1203,11 +1249,11 @@
 			/* exponential backoff of requests */
 			if((1UL<<p->nx) > ndest)
 				continue;
-			if(writenet(qp, Udp, fd, pkt, len, p) == 0)
-				n++;
 			p->nx++;
+			if(writenet(qp, Udp, fd, pkt, len, p) < 0)
+				continue;
+			n++;
 		}
-
 		/* nothing left to send to */
 		if (n == 0)
 			break;
--- a/sys/src/cmd/ndb/dns.h
+++ b/sys/src/cmd/ndb/dns.h
@@ -115,6 +115,11 @@
 	Frecurse=	1<<8,	/* request recursion */
 	Fcanrec=	1<<7,	/* server can recurse */
 
+	/* EDNS flags (eflags) */
+	Ercode=		0xff<<24,
+	Evers=		0xff<<16,
+	Ednssecok=	1<<15,
+
 	Domlen=		256,	/* max domain name length (with NULL) */
 	Labellen=	64,	/* max domain label length (with NULL) */
 	Strlen=		256,	/* max string length (with NULL) */
@@ -163,6 +168,7 @@
 typedef struct Txt	Txt;
 typedef struct Caa	Caa;
 typedef struct Unknown	Unknown;
+typedef struct Opt	Opt;
 
 /*
  *  a structure to track a request and any slave process handling it
@@ -237,6 +243,10 @@
 {
 	Block;
 };
+struct Opt
+{
+	Block;
+};
 
 /*
  *  text strings
@@ -272,6 +282,7 @@
 		DN	*mb;	/* mailbox - mg, minfo */
 		DN	*ip;	/* ip address - a, aaaa */
 		DN	*rp;	/* rp arg - rp */
+		ulong	eflags;	/* EDNS(0) flags - opt */
 		uintptr	arg0;	/* arg[01] are compared to find dups in dn.c */
 	};
 	union {			/* discriminated by negative & type */
@@ -282,6 +293,7 @@
 		ulong	pref;	/* preference value - mx */
 		ulong	local;	/* ns served from local database - ns */
 		ushort	port;	/* - srv */
+		ushort	udpsize;/* requester's UDP payload size - opt */
 		uintptr	arg1;	/* arg[01] are compared to find dups in dn.c */
 	};
 	union {			/* discriminated by type */
@@ -294,6 +306,7 @@
 		Null	*null;
 		Txt	*txt;
 		Unknown	*unknown;
+		Opt	*opt;
 	};
 };
 
@@ -330,13 +343,6 @@
 	ushort	weight;
 };
 
-typedef struct Rrlist Rrlist;
-struct Rrlist
-{
-	int	count;
-	RR	*rrs;
-};
-
 /*
  *  domain messages
  */
@@ -352,6 +358,7 @@
 	RR	*ns;
 	int	arcount;	/* hints */
 	RR	*ar;
+	RR	*edns;		/* edns option */
 };
 
 /*
@@ -503,7 +510,9 @@
 /* dnresolve.c */
 RR*	dnresolve(char*, int, int, Request*, RR**, int, int, int, int*);
 int	udpport(char *);
-int	mkreq(DN *dp, int type, uchar *pkt, int flags, ushort reqno);
+int	mkreq(DN*, int type, uchar *pkt, int flags, ushort);
+RR*	mkednsopt(void);
+RR*	getednsopt(DNSmsg*);
 
 /* dnserver.c */
 void	dnserver(DNSmsg*, DNSmsg*, Request*, uchar *, int);
--- a/sys/src/cmd/ndb/dnserver.c
+++ b/sys/src/cmd/ndb/dnserver.c
@@ -6,6 +6,18 @@
 static RR*	doextquery(DNSmsg*, Request*, int);
 static void	hint(RR**, RR*);
 
+static void
+setflags(DNSmsg *repp, int rcode, int flags)
+{
+	if(repp->edns){
+		repp->edns->eflags = (rcode >> 4) << 24;
+		rcode &= 15;
+	}
+	rcode &= Rmask;
+	flags &= ~Rmask;
+	repp->flags |= rcode | flags;
+}
+
 /*
  *  answer a dns request
  */
@@ -20,7 +32,6 @@
 	RR *tp, *neg, *rp;
 
 	recursionflag = cfg.nonrecursive? 0: Fcanrec;
-	memset(repp, 0, sizeof(*repp));
 	repp->id = reqp->id;
 	repp->flags = Fresp | recursionflag | Oquery;
 
@@ -37,14 +48,14 @@
 		dnslog("%d: server: response code 0%o (%s), req from %I",
 			req->id, rcode, errmsg, srcip);
 		/* provide feedback to clients who send us trash */
-		repp->flags = (rcode&Rmask) | Fresp | Fcanrec | Oquery;
+		setflags(repp, rcode, Fresp | Fcanrec | Oquery);
 		return;
 	}
-	if(!rrsupported(repp->qd->type)){
+	if(repp->qd->type == Topt || !rrsupported(repp->qd->type)){
 		if(debug)
 			dnslog("%d: server: unsupported request %s from %I",
 				req->id, rrname(repp->qd->type, tname, sizeof tname), srcip);
-		repp->flags = Runimplimented | Fresp | Fcanrec | Oquery;
+		setflags(repp, Runimplimented, Fresp | Fcanrec | Oquery);
 		return;
 	}
 
@@ -52,7 +63,7 @@
 		if(debug)
 			dnslog("%d: server: unsupported class %d from %I",
 				req->id, repp->qd->owner->class, srcip);
-		repp->flags = Runimplimented | Fresp | Fcanrec | Oquery;
+		setflags(repp, Runimplimented, Fresp | Fcanrec | Oquery);
 		return;
 	}
 
@@ -63,13 +74,13 @@
 				dnslog("%d: server: unsupported xfr request %s for %s from %I",
 					req->id, rrname(repp->qd->type, tname, sizeof tname),
 					repp->qd->owner->name, srcip);
-			repp->flags = Runimplimented | Fresp | recursionflag | Oquery;
+			setflags(repp, Runimplimented, Fresp | recursionflag | Oquery);
 			return;
 		}
 	}
 	if(myarea == nil && cfg.nonrecursive) {
 		/* we don't recurse and we're not authoritative */
-		repp->flags = Rok | Fresp | Oquery;
+		setflags(repp, Rok, Fresp | Oquery);
 		neg = nil;
 	} else {
 		/*
@@ -89,7 +100,7 @@
 			dp = dnlookup(repp->qd->owner->name, repp->qd->owner->class, 0);
 			if(dp->rr == nil)
 				if(reqp->flags & Frecurse)
-					repp->flags |= dp->respcode | Fauth;
+					setflags(repp, dp->respcode, Fauth);
 		}
 	}
 
@@ -145,7 +156,7 @@
 				tp = rrlookup(neg->negsoaowner, Tsoa, NOneg);
 				rrcat(&repp->ns, tp);
 			}
-			repp->flags |= neg->negrcode;
+			setflags(repp, neg->negrcode, repp->flags);
 		}
 	}
 
--- a/sys/src/cmd/ndb/dntcpserver.c
+++ b/sys/src/cmd/ndb/dntcpserver.c
@@ -23,6 +23,7 @@
 	volatile uchar pkt[Maxpkt], callip[IPaddrlen];
 	volatile DNSmsg reqmsg, repmsg;
 	volatile Request req;
+	volatile RR *edns;
 	char *volatile err;
 
 	/*
@@ -54,10 +55,11 @@
 	/* loop on requests */
 	for(;; putactivity(&req)){
 		memset(&reqmsg, 0, sizeof reqmsg);
+		edns = nil;
 
 		ms = (long)(req.aborttime - nowms);
 		if(ms < Minreqtm){
-		noreq:
+		hangup:
 			close(fd);
 			_exits(0);
 		}
@@ -64,12 +66,12 @@
 		alarm(ms);
 		if(readn(fd, pkt, 2) != 2){
 			alarm(0);
-			goto noreq;
+			goto hangup;
 		}
 		len = pkt[0]<<8 | pkt[1];
 		if(len <= 0 || len > Maxtcp || readn(fd, pkt+2, len) != len){
 			alarm(0);
-			goto noreq;
+			goto hangup;
 		}
 		alarm(0);
 
@@ -111,10 +113,17 @@
 		logrequest(req.id, 0, "rcvd", callip, caller,
 			reqmsg.qd->owner->name, reqmsg.qd->type);
 
+		if((reqmsg.edns = getednsopt(&reqmsg)) != nil){
+			if(reqmsg.edns->eflags & Evers)
+				rcode = Rbadvers;
+			edns = mkednsopt();
+		}
+
 		/* loop through each question */
 		while(reqmsg.qd){
 			memset(&repmsg, 0, sizeof(repmsg));
-			if(reqmsg.qd->type == Taxfr)
+			repmsg.edns = edns;
+			if(rcode == Rok && reqmsg.qd->type == Taxfr)
 				rv = dnzone(fd, pkt, &reqmsg, &repmsg, &req, callip);
 			else {
 				dnserver(&reqmsg, &repmsg, &req, callip, rcode);
@@ -124,10 +133,14 @@
 			if(rv < 0)
 				goto out;
 		}
+		rrfreelist(edns);
+		rrfreelist(reqmsg.edns);
 		freeanswers(&reqmsg);
 	}
 out:
 	close(fd);
+	rrfreelist(edns);
+	rrfreelist(reqmsg.edns);
 	freeanswers(&reqmsg);
 	putactivity(&req);
 	_exits(0);
--- a/sys/src/cmd/ndb/dnudpserver.c
+++ b/sys/src/cmd/ndb/dnudpserver.c
@@ -4,7 +4,7 @@
 #include "dns.h"
 
 static int	udpannounce(char*, char*);
-static void	reply(int, uchar*, DNSmsg*, Request*);
+static void	reply(int, uchar*, int, DNSmsg*, Request*);
 
 typedef struct Inprogress Inprogress;
 struct Inprogress
@@ -65,6 +65,7 @@
 	volatile uchar pkt[Udphdrsize + Maxudp];
 	volatile DNSmsg reqmsg, repmsg;
 	Inprogress *volatile p;
+	volatile RR *edns;
 	volatile Request req;
 	Udphdr *volatile uh;
 
@@ -98,6 +99,8 @@
 	/* loop on requests */
 	for(;; putactivity(&req)){
 		memset(&reqmsg, 0, sizeof reqmsg);
+		edns = nil;
+
 		procsetname("%s: udp server %s: served %d", mntpt, addr, served);
 
 		len = read(fd, pkt, sizeof pkt);
@@ -156,24 +159,35 @@
 		logrequest(req.id, 0, "rcvd", uh->raddr, caller,
 			reqmsg.qd->owner->name, reqmsg.qd->type);
 
+		/* determine response size */
+		len = 512;	/* default */
+		if((reqmsg.edns = getednsopt(&reqmsg)) != nil){
+			if(reqmsg.edns->eflags & Evers)
+				rcode = Rbadvers;
+			edns = mkednsopt();
+			len = Maxudp;
+			if(edns->udpsize < len)
+				len = edns->udpsize;
+			if(reqmsg.edns->udpsize < len)
+				len = reqmsg.edns->udpsize;
+		}
+
 		/* loop through each question */
 		while(reqmsg.qd){
 			memset(&repmsg, 0, sizeof repmsg);
-			switch(op){
-			case Oquery:
-				dnserver(&reqmsg, &repmsg, &req, uh->raddr, rcode);
-				break;
-			case Onotify:
+			repmsg.edns = edns;
+			if(rcode == Rok && op == Onotify)
 				dnnotify(&reqmsg, &repmsg, &req);
-				break;
-			}
-			/* send reply on fd to address in pkt's udp hdr */
-			reply(fd, pkt, &repmsg, &req);
+			else
+				dnserver(&reqmsg, &repmsg, &req, uh->raddr, rcode);
+			reply(fd, pkt, len, &repmsg, &req);
 			freeanswers(&repmsg);
 		}
+		rrfreelist(edns);
 
 		p->inuse = 0;
 freereq:
+		rrfreelist(reqmsg.edns);
 		freeanswers(&reqmsg);
 		if(req.isslave){
 			putactivity(&req);
@@ -183,13 +197,11 @@
 }
 
 static void
-reply(int fd, uchar *pkt, DNSmsg *rep, Request *req)
+reply(int fd, uchar *pkt, int len, DNSmsg *rep, Request *req)
 {
-	int len;
-
 	logreply(req->id, "send", pkt, rep);
 
-	len = convDNS2M(rep, &pkt[Udphdrsize], Maxudp);
+	len = convDNS2M(rep, &pkt[Udphdrsize], len);
 	len += Udphdrsize;
 	if(write(fd, pkt, len) != len)
 		dnslog("%d: error sending reply to %I: %r",