Skip to content

Commit

Permalink
crypto/tls: implement TLS 1.3 PSK authentication (server side)
Browse files Browse the repository at this point in the history
Added some assertions to testHandshake, but avoided checking the error
of one of the Close() because the one that would lose the race would
write the closeNotify to a connection closed on the other side which is
broken on js/wasm (#28650). Moved that Close() after the chan sync to
ensure it happens second.

Accepting a ticket with client certificates when NoClientCert is
configured is probably not a problem, and we could hide them to avoid
confusing the application, but the current behavior is to skip the
ticket, and I'd rather keep behavior changes to a minimum.

Updates #9671

Change-Id: I93b56e44ddfe3d48c2bef52c83285ba2f46f297a
Reviewed-on: https://go-review.googlesource.com/c/147445
Reviewed-by: Adam Langley <agl@golang.org>
  • Loading branch information
FiloSottile committed Nov 12, 2018
1 parent dc9021e commit 166c58b
Show file tree
Hide file tree
Showing 25 changed files with 1,960 additions and 955 deletions.
15 changes: 15 additions & 0 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,17 @@ const (
RequireAndVerifyClientCert
)

// requiresClientCert returns whether the ClientAuthType requires a client
// certificate to be provided.
func requiresClientCert(c ClientAuthType) bool {
switch c {
case RequireAnyClientCert, RequireAndVerifyClientCert:
return true
default:
return false
}
}

// ClientSessionState contains the state needed by clients to resume TLS
// sessions.
type ClientSessionState struct {
Expand Down Expand Up @@ -599,6 +610,10 @@ func ticketKeyFromBytes(b [32]byte) (key ticketKey) {
return key
}

// maxSessionTicketLifetime is the maximum allowed lifetime of a TLS 1.3 session
// ticket, and the lifetime we set for tickets we send.
const maxSessionTicketLifetime = 7 * 24 * time.Hour

// Clone returns a shallow clone of c. It is safe to clone a Config that is
// being used concurrently by a TLS client or server.
func (c *Config) Clone() *Config {
Expand Down
51 changes: 33 additions & 18 deletions handshake_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -869,11 +869,14 @@ func TestClientKeyUpdate(t *testing.T) {
runClientTestTLS13(t, test)
}

func TestClientResumption(t *testing.T) {
// TODO(filippo): update to test both TLS 1.3 and 1.2 once PSK are
// supported server-side.
func TestResumption(t *testing.T) {
t.Run("TLSv12", func(t *testing.T) { testResumption(t, VersionTLS12) })
t.Run("TLSv13", func(t *testing.T) { testResumption(t, VersionTLS13) })
}

func testResumption(t *testing.T, version uint16) {
serverConfig := &Config{
MaxVersion: version,
CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
Certificates: testConfig.Certificates,
}
Expand All @@ -887,6 +890,7 @@ func TestClientResumption(t *testing.T) {
rootCAs.AddCert(issuer)

clientConfig := &Config{
MaxVersion: version,
CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
ClientSessionCache: NewLRUClientSessionCache(32),
RootCAs: rootCAs,
Expand Down Expand Up @@ -924,9 +928,12 @@ func TestClientResumption(t *testing.T) {
testResumeState("Handshake", false)
ticket := getTicket()
testResumeState("Resume", true)
if !bytes.Equal(ticket, getTicket()) {
if !bytes.Equal(ticket, getTicket()) && version != VersionTLS13 {
t.Fatal("first ticket doesn't match ticket after resumption")
}
if bytes.Equal(ticket, getTicket()) && version == VersionTLS13 {
t.Fatal("ticket didn't change after resumption")
}

key1 := randomKey()
serverConfig.SetSessionTicketKeys([][32]byte{key1})
Expand All @@ -946,16 +953,21 @@ func TestClientResumption(t *testing.T) {
// Reset serverConfig to ensure that calling SetSessionTicketKeys
// before the serverConfig is used works.
serverConfig = &Config{
MaxVersion: version,
CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
Certificates: testConfig.Certificates,
}
serverConfig.SetSessionTicketKeys([][32]byte{key2})

testResumeState("FreshConfig", true)

clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
testResumeState("DifferentCipherSuite", false)
testResumeState("DifferentCipherSuiteRecovers", true)
// In TLS 1.3, cross-cipher suite resumption is allowed as long as the KDF
// hash matches. Also, Config.CipherSuites does not apply to TLS 1.3.
if version != VersionTLS13 {
clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
testResumeState("DifferentCipherSuite", false)
testResumeState("DifferentCipherSuiteRecovers", true)
}

deleteTicket()
testResumeState("WithoutSessionTicket", false)
Expand All @@ -966,18 +978,21 @@ func TestClientResumption(t *testing.T) {
serverConfig.ClientAuth = RequireAndVerifyClientCert
clientConfig.Certificates = serverConfig.Certificates
testResumeState("InitialHandshake", false)
testResumeState("WithClientCertificates", true)

// Tickets should be removed from the session cache on TLS handshake failure
farFuture := func() time.Time { return time.Unix(16725225600, 0) }
serverConfig.Time = farFuture
_, _, err = testHandshake(t, clientConfig, serverConfig)
if err == nil {
t.Fatalf("handshake did not fail after client certificate expiry")
if version != VersionTLS13 {
// TODO(filippo): reenable when client authentication is implemented
testResumeState("WithClientCertificates", true)

// Tickets should be removed from the session cache on TLS handshake failure
farFuture := func() time.Time { return time.Unix(16725225600, 0) }
serverConfig.Time = farFuture
_, _, err = testHandshake(t, clientConfig, serverConfig)
if err == nil {
t.Fatalf("handshake did not fail after client certificate expiry")
}
serverConfig.Time = nil
testResumeState("AfterHandshakeFailure", false)
serverConfig.ClientAuth = NoClientCert
}
serverConfig.Time = nil
testResumeState("AfterHandshakeFailure", false)
serverConfig.ClientAuth = NoClientCert

clientConfig.ClientSessionCache = nil
testResumeState("WithoutSessionCache", false)
Expand Down
2 changes: 1 addition & 1 deletion handshake_client_tls13.go
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
return nil
}
lifetime := time.Duration(msg.lifetime) * time.Second
if lifetime > 7*24*time.Hour {
if lifetime > maxSessionTicketLifetime {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: received a session ticket with invalid lifetime")
}
Expand Down
126 changes: 82 additions & 44 deletions handshake_messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,23 @@ func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) {
}))
}

// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder.
func addUint64(b *cryptobyte.Builder, v uint64) {
b.AddUint32(uint32(v >> 32))
b.AddUint32(uint32(v))
}

// readUint64 decodes a big-endian, 64-bit value into out and advances over it.
// It reports whether the read was successful.
func readUint64(s *cryptobyte.String, out *uint64) bool {
var hi, lo uint32
if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) {
return false
}
*out = uint64(hi)<<32 | uint64(lo)
return true
}

// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a
// []byte instead of a cryptobyte.String.
func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
Expand Down Expand Up @@ -1266,89 +1283,110 @@ func (m *certificateMsgTLS13) marshal() []byte {
b.AddUint8(typeCertificate)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint8(0) // certificate_request_context
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
for i, cert := range m.certificate.Certificate {
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(cert)
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
if i > 0 {
// This library only supports OCSP and SCT for leaf certificates.
return
}
if m.ocspStapling {
b.AddUint16(extensionStatusRequest)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint8(statusTypeOCSP)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.certificate.OCSPStaple)
})
})
}
if m.scts {
b.AddUint16(extensionSCT)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, sct := range m.certificate.SignedCertificateTimestamps {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(sct)
})
}
})
})
}
})
}
})

certificate := m.certificate
if !m.ocspStapling {
certificate.OCSPStaple = nil
}
if !m.scts {
certificate.SignedCertificateTimestamps = nil
}
marshalCertificate(b, certificate)
})

m.raw = b.BytesOrPanic()
return m.raw
}

func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
for i, cert := range certificate.Certificate {
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(cert)
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
if i > 0 {
// This library only supports OCSP and SCT for leaf certificates.
return
}
if certificate.OCSPStaple != nil {
b.AddUint16(extensionStatusRequest)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint8(statusTypeOCSP)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(certificate.OCSPStaple)
})
})
}
if certificate.SignedCertificateTimestamps != nil {
b.AddUint16(extensionSCT)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, sct := range certificate.SignedCertificateTimestamps {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(sct)
})
}
})
})
}
})
}
})
}

func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
*m = certificateMsgTLS13{raw: data}
s := cryptobyte.String(data)

var context, certList cryptobyte.String
var context cryptobyte.String
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
!s.ReadUint24LengthPrefixed(&certList) ||
!unmarshalCertificate(&s, &m.certificate) ||
!s.Empty() {
return false
}

m.scts = m.certificate.SignedCertificateTimestamps != nil
m.ocspStapling = m.certificate.OCSPStaple != nil

return true
}

func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool {
var certList cryptobyte.String
if !s.ReadUint24LengthPrefixed(&certList) {
return false
}
for !certList.Empty() {
var cert []byte
var extensions cryptobyte.String
if !readUint24LengthPrefixed(&certList, &cert) ||
!certList.ReadUint16LengthPrefixed(&extensions) {
return false
}
m.certificate.Certificate = append(m.certificate.Certificate, cert)
certificate.Certificate = append(certificate.Certificate, cert)
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
if len(m.certificate.Certificate) > 1 {
if len(certificate.Certificate) > 1 {
// This library only supports OCSP and SCT for leaf certificates.
continue
}

switch extension {
case extensionStatusRequest:
m.ocspStapling = true
var statusType uint8
if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
!readUint24LengthPrefixed(&extData, &m.certificate.OCSPStaple) ||
len(m.certificate.OCSPStaple) == 0 {
!readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) ||
len(certificate.OCSPStaple) == 0 {
return false
}
case extensionSCT:
m.scts = true
var sctList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
return false
Expand All @@ -1359,8 +1397,8 @@ func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
len(sct) == 0 {
return false
}
m.certificate.SignedCertificateTimestamps = append(
m.certificate.SignedCertificateTimestamps, sct)
certificate.SignedCertificateTimestamps = append(
certificate.SignedCertificateTimestamps, sct)
}
default:
// Ignore unknown extensions.
Expand Down
22 changes: 22 additions & 0 deletions handshake_messages_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ var tests = []interface{}{
&nextProtoMsg{},
&newSessionTicketMsg{},
&sessionState{},
&sessionStateTLS13{},
&encryptedExtensionsMsg{},
&endOfEarlyDataMsg{},
&keyUpdateMsg{},
Expand Down Expand Up @@ -332,6 +333,27 @@ func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
return reflect.ValueOf(s)
}

func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
s := &sessionStateTLS13{}
s.cipherSuite = uint16(rand.Intn(10000))
s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand)
s.createdAt = uint64(rand.Int63())
for i := 0; i < rand.Intn(2)+1; i++ {
s.certificate.Certificate = append(
s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
}
if rand.Intn(10) > 5 {
s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
}
if rand.Intn(10) > 5 {
for i := 0; i < rand.Intn(2)+1; i++ {
s.certificate.SignedCertificateTimestamps = append(
s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
}
}
return reflect.ValueOf(s)
}

func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &endOfEarlyDataMsg{}
return reflect.ValueOf(m)
Expand Down
14 changes: 9 additions & 5 deletions handshake_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,13 @@ func (hs *serverHandshakeState) checkForResumption() bool {
return false
}

var ok bool
var sessionTicket = append([]uint8{}, hs.clientHello.sessionTicket...)
if hs.sessionState, ok = c.decryptTicket(sessionTicket); !ok {
plaintext, usedOldKey := c.decryptTicket(hs.clientHello.sessionTicket)
if plaintext == nil {
return false
}
hs.sessionState = &sessionState{usedOldKey: usedOldKey}
ok := hs.sessionState.unmarshal(plaintext)
if !ok {
return false
}

Expand All @@ -352,7 +356,7 @@ func (hs *serverHandshakeState) checkForResumption() bool {
}

sessionHasClientCerts := len(hs.sessionState.certificates) != 0
needClientCerts := c.config.ClientAuth == RequireAnyClientCert || c.config.ClientAuth == RequireAndVerifyClientCert
needClientCerts := requiresClientCert(c.config.ClientAuth)
if needClientCerts && !sessionHasClientCerts {
return false
}
Expand Down Expand Up @@ -657,7 +661,7 @@ func (hs *serverHandshakeState) sendSessionTicket() error {
masterSecret: hs.masterSecret,
certificates: hs.certsFromClient,
}
m.ticket, err = c.encryptTicket(&state)
m.ticket, err = c.encryptTicket(state.marshal())
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 166c58b

Please sign in to comment.