shithub: tlsclient

ref: 2bbc75bddc6f2a07056ff017108e35f14061041b
dir: /third_party/boringssl/src/ssl/test/runner/hpke/hpke.go/

View raw version
// Copyright (c) 2020, Google Inc.
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

// Package hpke implements Hybrid Public Key Encryption (HPKE).
//
// See https://tools.ietf.org/html/draft-irtf-cfrg-hpke-12.
package hpke

import (
	"crypto"
	"crypto/aes"
	"crypto/cipher"
	"encoding/binary"
	"errors"
	"fmt"

	"golang.org/x/crypto/chacha20poly1305"
)

// KEM scheme IDs.
const (
	P256WithHKDFSHA256   uint16 = 0x0010
	X25519WithHKDFSHA256 uint16 = 0x0020
)

// HPKE AEAD IDs.
const (
	AES128GCM        uint16 = 0x0001
	AES256GCM        uint16 = 0x0002
	ChaCha20Poly1305 uint16 = 0x0003
)

// HPKE KDF IDs.
const (
	HKDFSHA256 uint16 = 0x0001
	HKDFSHA384 uint16 = 0x0002
	HKDFSHA512 uint16 = 0x0003
)

// Internal constants.
const (
	hpkeModeBase uint8 = 0
	hpkeModePSK  uint8 = 1
)

// GetHKDFHash returns the crypto.Hash that corresponds to kdf. If kdf is not
// one the supported KDF IDs, returns an error.
func GetHKDFHash(kdf uint16) (crypto.Hash, error) {
	switch kdf {
	case HKDFSHA256:
		return crypto.SHA256, nil
	case HKDFSHA384:
		return crypto.SHA384, nil
	case HKDFSHA512:
		return crypto.SHA512, nil
	}
	return 0, fmt.Errorf("unknown KDF: %d", kdf)
}

type GenerateKeyPairFunc func() (public []byte, secret []byte, e error)

// Context holds the HPKE state for a sender or a receiver.
type Context struct {
	kemID  uint16
	kdfID  uint16
	aeadID uint16

	aead cipher.AEAD

	key            []byte
	baseNonce      []byte
	seq            uint64
	exporterSecret []byte
}

// SetupBaseSenderX25519 corresponds to the spec's SetupBaseS(), but only
// supports X25519.
func SetupBaseSenderX25519(kdfID, aeadID uint16, publicKeyR, info []byte, ephemKeygen GenerateKeyPairFunc) (context *Context, enc []byte, err error) {
	sharedSecret, enc, err := x25519Encap(publicKeyR, ephemKeygen)
	if err != nil {
		return nil, nil, err
	}
	context, err = keySchedule(hpkeModeBase, X25519WithHKDFSHA256, kdfID, aeadID, sharedSecret, info, nil, nil)
	return
}

// SetupBaseReceiverX25519 corresponds to the spec's SetupBaseR(), but only
// supports X25519.
func SetupBaseReceiverX25519(kdfID, aeadID uint16, enc, secretKeyR, info []byte) (context *Context, err error) {
	sharedSecret, err := x25519Decap(enc, secretKeyR)
	if err != nil {
		return nil, err
	}
	return keySchedule(hpkeModeBase, X25519WithHKDFSHA256, kdfID, aeadID, sharedSecret, info, nil, nil)
}

// SetupPSKSenderX25519 corresponds to the spec's SetupPSKS(), but only supports
// X25519.
func SetupPSKSenderX25519(kdfID, aeadID uint16, publicKeyR, info, psk, pskID []byte, ephemKeygen GenerateKeyPairFunc) (context *Context, enc []byte, err error) {
	sharedSecret, enc, err := x25519Encap(publicKeyR, ephemKeygen)
	if err != nil {
		return nil, nil, err
	}
	context, err = keySchedule(hpkeModePSK, X25519WithHKDFSHA256, kdfID, aeadID, sharedSecret, info, psk, pskID)
	return
}

// SetupPSKReceiverX25519 corresponds to the spec's SetupPSKR(), but only
// supports X25519.
func SetupPSKReceiverX25519(kdfID, aeadID uint16, enc, secretKeyR, info, psk, pskID []byte) (context *Context, err error) {
	sharedSecret, err := x25519Decap(enc, secretKeyR)
	if err != nil {
		return nil, err
	}
	context, err = keySchedule(hpkeModePSK, X25519WithHKDFSHA256, kdfID, aeadID, sharedSecret, info, psk, pskID)
	if err != nil {
		return nil, err
	}
	return context, nil
}

func (c *Context) KEM() uint16 { return c.kemID }

func (c *Context) KDF() uint16 { return c.kdfID }

func (c *Context) AEAD() uint16 { return c.aeadID }

func (c *Context) Overhead() int { return c.aead.Overhead() }

func (c *Context) Seal(plaintext, additionalData []byte) []byte {
	ciphertext := c.aead.Seal(nil, c.computeNonce(), plaintext, additionalData)
	c.incrementSeq()
	return ciphertext
}

func (c *Context) Open(ciphertext, additionalData []byte) ([]byte, error) {
	plaintext, err := c.aead.Open(nil, c.computeNonce(), ciphertext, additionalData)
	if err != nil {
		return nil, err
	}
	c.incrementSeq()
	return plaintext, nil
}

func (c *Context) Export(exporterContext []byte, length int) []byte {
	suiteID := buildSuiteID(c.kemID, c.kdfID, c.aeadID)
	kdfHash := getKDFHash(c.kdfID)
	return labeledExpand(kdfHash, c.exporterSecret, suiteID, []byte("sec"), exporterContext, length)
}

func buildSuiteID(kemID, kdfID, aeadID uint16) []byte {
	ret := make([]byte, 0, 10)
	ret = append(ret, "HPKE"...)
	ret = appendBigEndianUint16(ret, kemID)
	ret = appendBigEndianUint16(ret, kdfID)
	ret = appendBigEndianUint16(ret, aeadID)
	return ret
}

func newAEAD(aeadID uint16, key []byte) (cipher.AEAD, error) {
	if len(key) != expectedKeyLength(aeadID) {
		return nil, errors.New("wrong key length for specified AEAD")
	}
	switch aeadID {
	case AES128GCM, AES256GCM:
		block, err := aes.NewCipher(key)
		if err != nil {
			return nil, err
		}
		aead, err := cipher.NewGCM(block)
		if err != nil {
			return nil, err
		}
		return aead, nil
	case ChaCha20Poly1305:
		aead, err := chacha20poly1305.New(key)
		if err != nil {
			return nil, err
		}
		return aead, nil
	}
	return nil, errors.New("unsupported AEAD")
}

func keySchedule(mode uint8, kemID, kdfID, aeadID uint16, sharedSecret, info, psk, pskID []byte) (*Context, error) {
	// Verify the PSK inputs.
	switch mode {
	case hpkeModeBase:
		if len(psk) > 0 || len(pskID) > 0 {
			panic("unnecessary psk inputs were provided")
		}
	case hpkeModePSK:
		if len(psk) == 0 || len(pskID) == 0 {
			panic("missing psk inputs")
		}
	default:
		panic("unknown mode")
	}

	kdfHash := getKDFHash(kdfID)
	suiteID := buildSuiteID(kemID, kdfID, aeadID)
	pskIDHash := labeledExtract(kdfHash, nil, suiteID, []byte("psk_id_hash"), pskID)
	infoHash := labeledExtract(kdfHash, nil, suiteID, []byte("info_hash"), info)

	keyScheduleContext := make([]byte, 0)
	keyScheduleContext = append(keyScheduleContext, mode)
	keyScheduleContext = append(keyScheduleContext, pskIDHash...)
	keyScheduleContext = append(keyScheduleContext, infoHash...)

	secret := labeledExtract(kdfHash, sharedSecret, suiteID, []byte("secret"), psk)
	key := labeledExpand(kdfHash, secret, suiteID, []byte("key"), keyScheduleContext, expectedKeyLength(aeadID))

	aead, err := newAEAD(aeadID, key)
	if err != nil {
		return nil, err
	}

	baseNonce := labeledExpand(kdfHash, secret, suiteID, []byte("base_nonce"), keyScheduleContext, aead.NonceSize())
	exporterSecret := labeledExpand(kdfHash, secret, suiteID, []byte("exp"), keyScheduleContext, kdfHash.Size())

	return &Context{
		kemID:          kemID,
		kdfID:          kdfID,
		aeadID:         aeadID,
		aead:           aead,
		key:            key,
		baseNonce:      baseNonce,
		seq:            0,
		exporterSecret: exporterSecret,
	}, nil
}

func (c Context) computeNonce() []byte {
	nonce := make([]byte, len(c.baseNonce))
	// Write the big-endian |c.seq| value at the *end* of |baseNonce|.
	binary.BigEndian.PutUint64(nonce[len(nonce)-8:], c.seq)
	// XOR the big-endian |seq| with |c.baseNonce|.
	for i, b := range c.baseNonce {
		nonce[i] ^= b
	}
	return nonce
}

func (c *Context) incrementSeq() {
	c.seq++
	if c.seq == 0 {
		panic("sequence overflow")
	}
}

func expectedKeyLength(aeadID uint16) int {
	switch aeadID {
	case AES128GCM:
		return 128 / 8
	case AES256GCM:
		return 256 / 8
	case ChaCha20Poly1305:
		return chacha20poly1305.KeySize
	}
	panic("unsupported AEAD")
}

func appendBigEndianUint16(b []byte, v uint16) []byte {
	return append(b, byte(v>>8), byte(v))
}