Skip to content

Commit

Permalink
Verify protocol version
Browse files Browse the repository at this point in the history
Check version in handshake message.
Silently discard records with unsupported version in
record layer header.
  • Loading branch information
at-wat committed Mar 9, 2020
1 parent 5b11374 commit f1595d5
Show file tree
Hide file tree
Showing 8 changed files with 304 additions and 2 deletions.
7 changes: 5 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -486,9 +486,9 @@ func (c *Conn) processHandshakePacket(p *packet, h *handshake) ([][]byte, error)
seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1

recordLayerHeader := &recordLayerHeader{
protocolVersion: p.record.recordLayerHeader.protocolVersion,
contentType: p.record.recordLayerHeader.contentType,
contentLen: uint16(len(handshakeFragment)),
protocolVersion: p.record.recordLayerHeader.protocolVersion,
epoch: p.record.recordLayerHeader.epoch,
sequenceNumber: seq,
}
Expand Down Expand Up @@ -644,7 +644,10 @@ func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert, err
// TODO: avoid separate unmarshal
h := &recordLayerHeader{}
if err := h.Unmarshal(buf); err != nil {
return false, &alert{alertLevelFatal, alertDecodeError}, err
// Decode error must be silently discarded
// [RFC6347 Section-4.1.2.7]
c.log.Debugf("discarded broken packet: %v", err)
return false, nil, nil
}

// Validate epoch
Expand Down
270 changes: 270 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"io"
"net"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -1308,3 +1309,272 @@ func TestServerTimeout(t *testing.T) {
default:
}
}

func TestProtocolVersionValidation(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()

// Check for leaking routines
report := test.CheckRoutines(t)
defer report()

cookie := make([]byte, 20)
if _, err := rand.Read(cookie); err != nil {
t.Fatal(err)
}

var rand [28]byte
random := handshakeRandom{time.Unix(500, 0), rand}

localKeypair, err := generateKeypair(namedCurveX25519)
if err != nil {
t.Fatal(err)
}

config := &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
FlightInterval: 100 * time.Millisecond,
}

t.Run("Server", func(t *testing.T) {
serverCases := map[string]struct {
records []*recordLayer
}{
"ClientHelloVersion": {
records: []*recordLayer{
{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageClientHello{
version: protocolVersion{0xfe, 0xff}, // try to downgrade
cookie: cookie,
random: random,
cipherSuites: []cipherSuite{&cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256{}},
compressionMethods: defaultCompressionMethods,
}},
},
},
},
"SecondsClientHelloVersion": {
records: []*recordLayer{
{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageClientHello{
version: protocolVersion1_2,
cookie: cookie,
random: random,
cipherSuites: []cipherSuite{&cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256{}},
compressionMethods: defaultCompressionMethods,
}},
},
{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
sequenceNumber: 1,
},
content: &handshake{
handshakeHeader: handshakeHeader{
messageSequence: 1,
},
handshakeMessage: &handshakeMessageClientHello{
version: protocolVersion{0xfe, 0xff}, // try to downgrade
cookie: cookie,
random: random,
cipherSuites: []cipherSuite{&cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256{}},
compressionMethods: defaultCompressionMethods,
}},
},
},
},
}
for name, c := range serverCases {
c := c
t.Run(name, func(t *testing.T) {
ca, cb := dpipe.Pipe()
defer func() {
err := ca.Close()
if err != nil {
t.Error(err)
}
}()

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()
go func() {
defer wg.Done()
if _, err := testServer(ctx, cb, config, true); err != errUnsupportedProtocolVersion {
t.Errorf("Client error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err)
}
}()

time.Sleep(50 * time.Millisecond)

resp := make([]byte, 1024)
for _, record := range c.records {
packet, err := record.Marshal()
if err != nil {
t.Fatal(err)
}
if _, werr := ca.Write(packet); werr != nil {
t.Fatal(werr)
}
n, rerr := ca.Read(resp[:cap(resp)])
if rerr != nil {
t.Fatal(rerr)
}
resp = resp[:n]
}

h := &recordLayerHeader{}
if err := h.Unmarshal(resp); err != nil {
t.Fatal("Failed to unmarshal response")
}
if h.contentType != contentTypeAlert {
t.Errorf("Peer must return alert to unsupported protocol version")
}
})
}
})

t.Run("Client", func(t *testing.T) {
clientCases := map[string]struct {
records []*recordLayer
}{
"ServerHelloVersion": {
records: []*recordLayer{
{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
},
content: &handshake{
handshakeMessage: &handshakeMessageHelloVerifyRequest{
version: protocolVersion1_2,
cookie: cookie,
}},
},
{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
sequenceNumber: 1,
},
content: &handshake{
handshakeHeader: handshakeHeader{
messageSequence: 1,
},
handshakeMessage: &handshakeMessageServerHello{
version: protocolVersion{0xfe, 0xff}, // try to downgrade
random: random,
cipherSuite: &cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256{},
compressionMethod: defaultCompressionMethods[0],
},
}},
{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
sequenceNumber: 2,
},
content: &handshake{
handshakeHeader: handshakeHeader{
messageSequence: 2,
},
handshakeMessage: &handshakeMessageCertificate{},
}},
{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
sequenceNumber: 3,
},
content: &handshake{
handshakeHeader: handshakeHeader{
messageSequence: 3,
},
handshakeMessage: &handshakeMessageServerKeyExchange{
ellipticCurveType: ellipticCurveTypeNamedCurve,
namedCurve: namedCurveX25519,
publicKey: localKeypair.publicKey,
hashAlgorithm: hashAlgorithmSHA256,
signatureAlgorithm: signatureAlgorithmECDSA,
signature: make([]byte, 64),
},
}},
{
recordLayerHeader: recordLayerHeader{
protocolVersion: protocolVersion1_2,
sequenceNumber: 4,
},
content: &handshake{
handshakeHeader: handshakeHeader{
messageSequence: 4,
},
handshakeMessage: &handshakeMessageServerHelloDone{},
}},
},
},
}
for name, c := range clientCases {
c := c
t.Run(name, func(t *testing.T) {
ca, cb := dpipe.Pipe()
defer func() {
err := ca.Close()
if err != nil {
t.Error(err)
}
}()

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()
go func() {
defer wg.Done()
if _, err := testClient(ctx, cb, config, true); err != errUnsupportedProtocolVersion {
t.Errorf("Server error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err)
}
}()

time.Sleep(50 * time.Millisecond)

for _, record := range c.records {
if _, err := ca.Read(make([]byte, 1024)); err != nil {
t.Fatal(err)
}

packet, err := record.Marshal()
if err != nil {
t.Fatal(err)
}
if _, err := ca.Write(packet); err != nil {
t.Fatal(err)
}
}
resp := make([]byte, 1024)
n, err := ca.Read(resp)
if err != nil {
t.Fatal(err)
}
resp = resp[:n]

h := &recordLayerHeader{}
if err := h.Unmarshal(resp); err != nil {
t.Fatal("Failed to unmarshal response")
}
if h.contentType != contentTypeAlert {
t.Errorf("Peer must return alert to unsupported protocol version")
}
})
}
})
}
1 change: 1 addition & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ var (
errNoCertificates = &ErrFatal{errors.New("no certificates configured")}
errNoConfigProvided = &ErrFatal{errors.New("no config provided")}
errNoSupportedEllipticCurves = &ErrFatal{errors.New("client requested zero or more elliptic curves that are not supported by the server")}
errUnsupportedProtocolVersion = &ErrFatal{errors.New("unsupported protocol version")}
errPSKAndCertificate = &ErrFatal{errors.New("Certificate and PSK provided")} // nolint:stylecheck
errPSKAndIdentityMustBeSetForClient = &ErrFatal{errors.New("PSK and PSK Identity Hint must both be set for client")}
errRequestedButNoSRTPExtension = &ErrFatal{errors.New("SRTP support was requested but server did not respond with use_srtp extension")}
Expand Down
4 changes: 4 additions & 0 deletions flight0handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ func flight0Parse(ctx context.Context, c flightConn, state *State, cache *handsh
return 0, &alert{alertLevelFatal, alertInternalError}, nil
}

if !clientHello.version.Equal(protocolVersion1_2) {
return 0, &alert{alertLevelFatal, alertProtocolVersion}, errUnsupportedProtocolVersion
}

state.remoteRandom = clientHello.random

if _, ok := findMatchingCipherSuite(clientHello.cipherSuites, cfg.localCipherSuites); !ok {
Expand Down
5 changes: 5 additions & 0 deletions flight1handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ func flight1Parse(ctx context.Context, c flightConn, state *State, cache *handsh
state.handshakeRecvSequence = seq

if h, ok := msgs[handshakeTypeHelloVerifyRequest].(*handshakeMessageHelloVerifyRequest); ok {
// DTLS 1.2 clients must not assume that the server will use the protocol version
// specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
if !h.version.Equal(protocolVersion1_0) && !h.version.Equal(protocolVersion1_2) {
return 0, &alert{alertLevelFatal, alertProtocolVersion}, errUnsupportedProtocolVersion
}
state.cookie = append([]byte{}, h.cookie...)
return flight3, nil, nil
}
Expand Down
4 changes: 4 additions & 0 deletions flight2handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ func flight2Parse(ctx context.Context, c flightConn, state *State, cache *handsh
return 0, &alert{alertLevelFatal, alertInternalError}, nil
}

if !clientHello.version.Equal(protocolVersion1_2) {
return 0, &alert{alertLevelFatal, alertProtocolVersion}, errUnsupportedProtocolVersion
}

if len(clientHello.cookie) == 0 {
return 0, nil, nil
}
Expand Down
3 changes: 3 additions & 0 deletions flight3handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh
state.handshakeRecvSequence = seq

if h, ok := msgs[handshakeTypeServerHello].(*handshakeMessageServerHello); ok {
if !h.version.Equal(protocolVersion1_2) {
return 0, &alert{alertLevelFatal, alertProtocolVersion}, errUnsupportedProtocolVersion
}
for _, extension := range h.extensions {
switch e := extension.(type) {
case *extensionUseSRTP:
Expand Down
12 changes: 12 additions & 0 deletions record_layer_header.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,26 @@ const (
dtls1_2Major = 0xfe
dtls1_2Minor = 0xfd

dtls1_0Major = 0xfe
dtls1_0Minor = 0xff

// VersionDTLS12 is the DTLS version in the same style as
// VersionTLSXX from crypto/tls
VersionDTLS12 = 0xfefd
)

var protocolVersion1_0 = protocolVersion{dtls1_0Major, dtls1_0Minor}
var protocolVersion1_2 = protocolVersion{dtls1_2Major, dtls1_2Minor}

// https://tools.ietf.org/html/rfc4346#section-6.2.1
type protocolVersion struct {
major, minor uint8
}

func (v protocolVersion) Equal(x protocolVersion) bool {
return v.major == x.major && v.minor == x.minor
}

func (r *recordLayerHeader) Marshal() ([]byte, error) {
if r.sequenceNumber > maxSequenceNumber {
return nil, errSequenceNumberOverflow
Expand Down Expand Up @@ -58,5 +66,9 @@ func (r *recordLayerHeader) Unmarshal(data []byte) error {
copy(seqCopy[2:], data[5:11])
r.sequenceNumber = binary.BigEndian.Uint64(seqCopy)

if !r.protocolVersion.Equal(protocolVersion1_0) && !r.protocolVersion.Equal(protocolVersion1_2) {
return errUnsupportedProtocolVersion
}

return nil
}

0 comments on commit f1595d5

Please sign in to comment.