shithub: drawterm

Download patch

ref: 43866763e399e223a6e9b6f16e0c09bb16be19fa
parent: 15e68cc285cf082696ab68faa16f4662f50306c1
author: cinap_lenrek <cinap_lenrek@felloff.net>
date: Mon Apr 3 21:59:45 EDT 2017

tlshand: sync with 9front

--- a/libsec/tlshand.c
+++ b/libsec/tlshand.c
@@ -17,7 +17,6 @@
 	TLSFinishedLen = 12,
 	SSL3FinishedLen = MD5dlen+SHA1dlen,
 	MaxKeyData = 160,	// amount of secret we may need
-	MaxChunk = 1<<15,
 	MAXdlen = SHA2_512dlen,
 	RandomSize = 32,
 	MasterSecretSize = 48,
@@ -100,13 +99,8 @@
 	HandshakeHash	handhash;
 	Finished	finished;
 
-	// input buffer for handshake messages
-	uchar recvbuf[MaxChunk];
-	uchar *rp, *ep;
-
-	// output buffer
-	uchar sendbuf[MaxChunk];
-	uchar *sendp;
+	uchar *sendp, *recvp, *recvw;
+	uchar buf[1<<16];
 } TlsConnection;
 
 typedef struct Msg{
@@ -443,7 +437,7 @@
 static int get16(uchar *p);
 static Bytes* newbytes(int len);
 static Bytes* makebytes(uchar* buf, int len);
-static Bytes* mptobytes(mpint* big);
+static Bytes* mptobytes(mpint* big, int len);
 static mpint* bytestomp(Bytes* bytes);
 static void freebytes(Bytes* b);
 static Ints* newints(int len);
@@ -695,6 +689,8 @@
 	c->hand = hand;
 	c->trace = trace;
 	c->version = ProtocolVersion;
+	c->sendp = c->buf;
+	c->recvp = c->recvw = &c->buf[sizeof(c->buf)];
 
 	memset(&m, 0, sizeof(m));
 	if(!msgRecv(c, &m)){
@@ -894,6 +890,7 @@
 	DHstate *dh = &sec->dh;
 	mpint *G, *P, *Y, *K;
 	Bytes *Yc;
+	int n;
 
 	if(p == nil || g == nil || Ys == nil)
 		return nil;
@@ -906,7 +903,8 @@
 
 	if(dh_new(dh, P, nil, G) == nil)
 		goto Out;
-	Yc = mptobytes(dh->y);
+	n = (mpsignif(P)+7)/8;
+	Yc = mptobytes(dh->y, n);
 	K = dh_finish(dh, Y);	/* zeros dh */
 	if(K == nil){
 		freebytes(Yc);
@@ -913,7 +911,7 @@
 		Yc = nil;
 		goto Out;
 	}
-	setMasterSecret(sec, mptobytes(K));
+	setMasterSecret(sec, mptobytes(K, n));
 
 Out:
 	mpfree(K);
@@ -933,6 +931,7 @@
 	ECpub *pub;
 	ECpoint K;
 	Bytes *Yc;
+	int n;
 
 	if(Ys == nil)
 		return nil;
@@ -958,8 +957,10 @@
 
 	ecgen(dom, Q);
 	ecmul(dom, pub, Q->d, &K);
-	setMasterSecret(sec, mptobytes(K.x));
-	Yc = newbytes(1 + 2*((mpsignif(dom->p)+7)/8));
+
+	n = (mpsignif(dom->p)+7)/8;
+	setMasterSecret(sec, mptobytes(K.x, n));
+	Yc = newbytes(1 + 2*n);
 	Yc->len = ecencodepub(dom, (ECpub*)Q, Yc->data, Yc->len);
 
 	mpfree(K.x);
@@ -993,6 +994,8 @@
 	c->hand = hand;
 	c->trace = trace;
 	c->cert = nil;
+	c->sendp = c->buf;
+	c->recvp = c->recvw = &c->buf[sizeof(c->buf)];
 
 	c->version = ProtocolVersion;
 	tlsSecInitc(c->sec, c->version);
@@ -1256,14 +1259,13 @@
 static int
 msgSend(TlsConnection *c, Msg *m, int act)
 {
-	uchar *p; // sendp = start of new message;  p = write pointer
-	int nn, n, i;
+	uchar *p, *e; // sendp = start of new message;  p = write pointer; e = end pointer
+	int n, i;
 
-	if(c->sendp == nil)
-		c->sendp = c->sendbuf;
 	p = c->sendp;
+	e = c->recvp;
 	if(c->trace)
-		c->trace("send %s", msgPrint((char*)p, (sizeof(c->sendbuf)) - (p - c->sendbuf), m));
+		c->trace("send %s", msgPrint((char*)p, e - p, m));
 
 	p[0] = m->tag;	// header - fill in size later
 	p += 4;
@@ -1273,119 +1275,111 @@
 		tlsError(c, EInternalError, "can't encode a %d", m->tag);
 		goto Err;
 	case HClientHello:
-		// version
-		put16(p, m->u.clientHello.version);
-		p += 2;
-
-		// random
+		if(p+2+RandomSize > e)
+			goto Overflow;
+		put16(p, m->u.clientHello.version), p += 2;
 		memmove(p, m->u.clientHello.random, RandomSize);
 		p += RandomSize;
 
-		// sid
-		n = m->u.clientHello.sid->len;
-		p[0] = n;
-		memmove(p+1, m->u.clientHello.sid->data, n);
-		p += n+1;
+		if(p+1+(n = m->u.clientHello.sid->len) > e)
+			goto Overflow;
+		*p++ = n;
+		memmove(p, m->u.clientHello.sid->data, n);
+		p += n;
 
-		n = m->u.clientHello.ciphers->len;
-		put16(p, n*2);
-		p += 2;
-		for(i=0; i<n; i++) {
-			put16(p, m->u.clientHello.ciphers->data[i]);
-			p += 2;
-		}
+		if(p+2+(n = m->u.clientHello.ciphers->len) > e)
+			goto Overflow;
+		put16(p, n*2), p += 2;
+		for(i=0; i<n; i++)
+			put16(p, m->u.clientHello.ciphers->data[i]), p += 2;
 
-		n = m->u.clientHello.compressors->len;
-		p[0] = n;
-		memmove(p+1, m->u.clientHello.compressors->data, n);
-		p += n+1;
+		if(p+1+(n = m->u.clientHello.compressors->len) > e)
+			goto Overflow;
+		*p++ = n;
+		memmove(p, m->u.clientHello.compressors->data, n);
+		p += n;
 
-		if(m->u.clientHello.extensions == nil)
+		if(m->u.clientHello.extensions == nil
+		|| (n = m->u.clientHello.extensions->len) == 0)
 			break;
-		n = m->u.clientHello.extensions->len;
-		if(n == 0)
-			break;
-		put16(p, n);
-		memmove(p+2, m->u.clientHello.extensions->data, n);
-		p += n+2;
+		if(p+2+n > e)
+			goto Overflow;
+		put16(p, n), p += 2;
+		memmove(p, m->u.clientHello.extensions->data, n);
+		p += n;
 		break;
 	case HServerHello:
-		put16(p, m->u.serverHello.version);
-		p += 2;
-
-		// random
+		if(p+2+RandomSize > e)
+			goto Overflow;
+		put16(p, m->u.serverHello.version), p += 2;
 		memmove(p, m->u.serverHello.random, RandomSize);
 		p += RandomSize;
 
-		// sid
-		n = m->u.serverHello.sid->len;
-		p[0] = n;
-		memmove(p+1, m->u.serverHello.sid->data, n);
-		p += n+1;
+		if(p+1+(n = m->u.serverHello.sid->len) > e)
+			goto Overflow;
+		*p++ = n;
+		memmove(p, m->u.serverHello.sid->data, n);
+		p += n;
 
-		put16(p, m->u.serverHello.cipher);
-		p += 2;
-		p[0] = m->u.serverHello.compressor;
-		p += 1;
+		if(p+2+1 > e)
+			goto Overflow;
+		put16(p, m->u.serverHello.cipher), p += 2;
+		*p++ = m->u.serverHello.compressor;
 
-		if(m->u.serverHello.extensions == nil)
-			break;
-		n = m->u.serverHello.extensions->len;
-		if(n == 0)
+		if(m->u.serverHello.extensions == nil
+		|| (n = m->u.serverHello.extensions->len) == 0)
 			break;
-		put16(p, n);
-		memmove(p+2, m->u.serverHello.extensions->data, n);
-		p += n+2;
+		if(p+2+n > e)
+			goto Overflow;
+		put16(p, n), p += 2;
+		memmove(p, m->u.serverHello.extensions->data, n);
+		p += n;
 		break;
 	case HServerHelloDone:
 		break;
 	case HCertificate:
-		nn = 0;
+		n = 0;
 		for(i = 0; i < m->u.certificate.ncert; i++)
-			nn += 3 + m->u.certificate.certs[i]->len;
-		if(p + 3 + nn - c->sendbuf > sizeof(c->sendbuf)) {
-			tlsError(c, EInternalError, "output buffer too small for certificate");
-			goto Err;
-		}
-		put24(p, nn);
-		p += 3;
+			n += 3 + m->u.certificate.certs[i]->len;
+		if(p+3+n > e)
+			goto Overflow;
+		put24(p, n), p += 3;
 		for(i = 0; i < m->u.certificate.ncert; i++){
-			put24(p, m->u.certificate.certs[i]->len);
-			p += 3;
-			memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
-			p += m->u.certificate.certs[i]->len;
+			n = m->u.certificate.certs[i]->len;
+			put24(p, n), p += 3;
+			memmove(p, m->u.certificate.certs[i]->data, n);
+			p += n;
 		}
 		break;
 	case HCertificateVerify:
-		if(m->u.certificateVerify.sigalg != 0){
-			put16(p, m->u.certificateVerify.sigalg);
-			p += 2;
-		}
-		put16(p, m->u.certificateVerify.signature->len);
-		p += 2;
-		memmove(p, m->u.certificateVerify.signature->data, m->u.certificateVerify.signature->len);
-		p += m->u.certificateVerify.signature->len;
+		if(p+2+2+(n = m->u.certificateVerify.signature->len) > e)
+			goto Overflow;
+		if(m->u.certificateVerify.sigalg != 0)
+			put16(p, m->u.certificateVerify.sigalg), p += 2;
+		put16(p, n), p += 2;
+		memmove(p, m->u.certificateVerify.signature->data, n);
+		p += n;
 		break;
 	case HServerKeyExchange:
 		if(m->u.serverKeyExchange.pskid != nil){
-			n = m->u.serverKeyExchange.pskid->len;
-			put16(p, n);
-			p += 2;
+			if(p+2+(n = m->u.serverKeyExchange.pskid->len) > e)
+				goto Overflow;
+			put16(p, n), p += 2;
 			memmove(p, m->u.serverKeyExchange.pskid->data, n);
 			p += n;
 		}
 		if(m->u.serverKeyExchange.dh_parameters == nil)
 			break;
-		n = m->u.serverKeyExchange.dh_parameters->len;
+		if(p+(n = m->u.serverKeyExchange.dh_parameters->len) > e)
+			goto Overflow;
 		memmove(p, m->u.serverKeyExchange.dh_parameters->data, n);
 		p += n;
 		if(m->u.serverKeyExchange.dh_signature == nil)
 			break;
-		if(c->version >= TLS12Version){
-			put16(p, m->u.serverKeyExchange.sigalg);
-			p += 2;
-		}
-		n = m->u.serverKeyExchange.dh_signature->len;
+		if(p+2+2+(n = m->u.serverKeyExchange.dh_signature->len) > e)
+			goto Overflow;
+		if(c->version >= TLS12Version)
+			put16(p, m->u.serverKeyExchange.sigalg), p += 2;
 		put16(p, n), p += 2;
 		memmove(p, m->u.serverKeyExchange.dh_signature->data, n);
 		p += n;
@@ -1392,15 +1386,16 @@
 		break;
 	case HClientKeyExchange:
 		if(m->u.clientKeyExchange.pskid != nil){
-			n = m->u.clientKeyExchange.pskid->len;
-			put16(p, n);
-			p += 2;
+			if(p+2+(n = m->u.clientKeyExchange.pskid->len) > e)
+				goto Overflow;
+			put16(p, n), p += 2;
 			memmove(p, m->u.clientKeyExchange.pskid->data, n);
 			p += n;
 		}
 		if(m->u.clientKeyExchange.key == nil)
 			break;
-		n = m->u.clientKeyExchange.key->len;
+		if(p+2+(n = m->u.clientKeyExchange.key->len) > e)
+			goto Overflow;
 		if(isECDHE(c->cipher))
 			*p++ = n;
 		else if(isDHE(c->cipher) || c->version != SSL3Version)
@@ -1409,6 +1404,8 @@
 		p += n;
 		break;
 	case HFinished:
+		if(p+m->u.finished.n > e)
+			goto Overflow;
 		memmove(p, m->u.finished.verify, m->u.finished.n);
 		p += m->u.finished.n;
 		break;
@@ -1416,7 +1413,6 @@
 
 	// go back and fill in size
 	n = p - c->sendp;
-	assert(n <= sizeof(c->sendbuf));
 	put24(c->sendp+1, n-4);
 
 	// remember hash of Handshake messages
@@ -1425,8 +1421,8 @@
 
 	c->sendp = p;
 	if(act == AFlush){
-		c->sendp = c->sendbuf;
-		if(write(c->hand, c->sendbuf, p - c->sendbuf) < 0){
+		c->sendp = c->buf;
+		if(write(c->hand, c->buf, p - c->buf) < 0){
 			fprint(2, "write error: %r\n");
 			goto Err;
 		}
@@ -1433,6 +1429,8 @@
 	}
 	msgClear(m);
 	return 1;
+Overflow:
+	tlsError(c, EInternalError, "not enougth send buffer for message (%d)", m->tag);
 Err:
 	msgClear(m);
 	return 0;
@@ -1441,25 +1439,28 @@
 static uchar*
 tlsReadN(TlsConnection *c, int n)
 {
-	uchar *p;
-	int nn, nr;
+	uchar *p, *e;
 
-	nn = c->ep - c->rp;
-	if(nn < n){
-		if(c->rp != c->recvbuf){
-			memmove(c->recvbuf, c->rp, nn);
-			c->rp = c->recvbuf;
-			c->ep = &c->recvbuf[nn];
-		}
-		for(; nn < n; nn += nr) {
-			nr = read(c->hand, &c->rp[nn], n - nn);
-			if(nr <= 0)
-				return nil;
-			c->ep += nr;
-		}
+	p = c->recvp;
+	if(n <= c->recvw - p){
+		c->recvp += n;
+		return p;
 	}
-	p = c->rp;
-	c->rp += n;
+	e = &c->buf[sizeof(c->buf)];
+	c->recvp = e - n;
+	if(c->recvp < c->sendp || n > sizeof(c->buf)){
+		tlsError(c, EDecodeError, "handshake message too long %d", n);
+		return nil;
+	}
+	memmove(c->recvp, p, c->recvw - p);
+	c->recvw -= p - c->recvp;
+	p = c->recvp;
+	c->recvp += n;
+	while(c->recvw < c->recvp){
+		if((n = read(c->hand, c->recvw, e - c->recvw)) <= 0)
+			return nil;
+		c->recvw += n;
+	}
 	return p;
 }
 
@@ -1485,11 +1486,6 @@
 		}
 	}
 
-	if(n > sizeof(c->recvbuf)) {
-		tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->recvbuf));
-		return 0;
-	}
-
 	if(type == HSSL2ClientHello){
 		/* Cope with an SSL3 ClientHello expressed in SSL2 record format.
 			This is sent by some clients that we must interoperate
@@ -1512,10 +1508,8 @@
 		p += 6;
 		n -= 6;
 		if(nsid != 0 	/* no sid's, since shouldn't restart using ssl2 header */
-				|| nrandom < 16 || nn % 3)
+		|| nrandom < 16 || nn % 3 || n - nrandom < nn)
 			goto Err;
-		if(c->trace && (n - nrandom != nn))
-			c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
 		/* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
 		nciph = 0;
 		for(i = 0; i < nn; i += 3)
@@ -1805,15 +1799,11 @@
 		break;
 	}
 
-	if(type != HClientHello && type != HServerHello && n != 0)
+	if(n != 0 && type != HClientHello && type != HServerHello)
 		goto Short;
 Ok:
-	if(c->trace){
-		char *buf;
-		buf = emalloc(8000);
-		c->trace("recv %s", msgPrint(buf, 8000, m));
-		free(buf);
-	}
+	if(c->trace)
+		c->trace("recv %s", msgPrint((char*)c->sendp, c->recvp - c->sendp, m));
 	return 1;
 Short:
 	tlsError(c, EDecodeError, "handshake message (%d) has invalid length", type);
@@ -2623,8 +2613,9 @@
 	K.y = mpnew(0);
 
 	ecmul(dom, Y, Q->d, &K);
-	setMasterSecret(sec, mptobytes(K.x));
 
+	setMasterSecret(sec, mptobytes(K.x, (mpsignif(dom->p)+7)/8));
+
 	mpfree(K.x);
 	mpfree(K.y);
 
@@ -2857,7 +2848,7 @@
 	y = factotum_rsa_decrypt(sec->rpc, bytestomp(data));
 	if(y == nil)
 		return nil;
-	data = mptobytes(y);
+	data = mptobytes(y, (mpsignif(y)+7)/8);
 	if((data->len = pkcs1unpadbuf(data->data, data->len, sec->rsapub->n, 2)) < 0){
 		freebytes(data);
 		return nil;
@@ -2883,10 +2874,11 @@
 		werrstr("bad digest algorithm");
 		return nil;
 	}
+
 	signedMP = factotum_rsa_decrypt(sec->rpc, pkcs1padbuf(buf, digestlen, sec->rsapub->n, 1));
 	if(signedMP == nil)
 		return nil;
-	signature = mptobytes(signedMP);
+	signature = mptobytes(signedMP, (mpsignif(sec->rsapub->n)+7)/8);
 	mpfree(signedMP);
 	return signature;
 }
@@ -2998,14 +2990,12 @@
  * Convert mpint* to Bytes, putting high order byte first.
  */
 static Bytes*
-mptobytes(mpint* big)
+mptobytes(mpint *big, int len)
 {
 	Bytes* ans;
-	int n;
 
-	n = (mpsignif(big)+7)/8;
-	if(n == 0) n = 1;
-	ans = newbytes(n);
+	if(len == 0) len++;
+	ans = newbytes(len);
 	mptober(big, ans->data, ans->len);
 	return ans;
 }