From 21814e71db756f39b69fb1a3e06350fa555a79b1 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Wed, 20 Sep 2023 18:29:51 -0700 Subject: [PATCH] quic: validate connection id transport parameters Validate the original_destination_connection_id and initial_source_connection_id transport parameters. RFC 9000, Section 7.3 For golang/go#58547 Change-Id: I8343fd53c5cc946f15d3410c632b3895205fd597 Reviewed-on: https://go-review.googlesource.com/c/net/+/530036 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/conn.go | 8 ++++++- internal/quic/conn_id.go | 44 +++++++++++++++++++++++++++++++---- internal/quic/conn_id_test.go | 38 ++++++++++++++++++++++++++++-- internal/quic/conn_test.go | 4 ++++ 4 files changed, 87 insertions(+), 7 deletions(-) diff --git a/internal/quic/conn.go b/internal/quic/conn.go index 60979125da..9db00fe092 100644 --- a/internal/quic/conn.go +++ b/internal/quic/conn.go @@ -86,6 +86,7 @@ func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip. // non-blocking operation. c.msgc = make(chan any, 1) + var originalDstConnID []byte if c.side == clientSide { if err := c.connIDState.initClient(c); err != nil { return nil, err @@ -95,6 +96,7 @@ func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip. if err := c.connIDState.initServer(c, initialConnID); err != nil { return nil, err } + originalDstConnID = initialConnID } // The smallest allowed maximum QUIC datagram size is 1200 bytes. @@ -105,9 +107,10 @@ func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip. c.streamsInit() c.lifetimeInit() - // TODO: initial_source_connection_id, retry_source_connection_id + // TODO: retry_source_connection_id if err := c.startTLS(now, initialConnID, transportParameters{ initialSrcConnID: c.connIDState.srcConnID(), + originalDstConnID: originalDstConnID, ackDelayExponent: ackDelayExponent, maxUDPPayloadSize: maxUDPPayloadSize, maxAckDelay: maxAckDelay, @@ -171,6 +174,9 @@ func (c *Conn) discardKeys(now time.Time, space numberSpace) { // receiveTransportParameters applies transport parameters sent by the peer. func (c *Conn) receiveTransportParameters(p transportParameters) error { + if err := c.connIDState.validateTransportParameters(c.side, p); err != nil { + return err + } c.streams.outflow.setMaxData(p.initialMaxData) c.streams.localLimit[bidiStream].setMax(p.initialMaxStreamsBidi) c.streams.localLimit[uniStream].setMax(p.initialMaxStreamsUni) diff --git a/internal/quic/conn_id.go b/internal/quic/conn_id.go index eb2f3ecc15..045e646ac1 100644 --- a/internal/quic/conn_id.go +++ b/internal/quic/conn_id.go @@ -161,6 +161,39 @@ func (s *connIDState) issueLocalIDs(c *Conn) error { return nil } +// validateTransportParameters verifies the original_destination_connection_id and +// initial_source_connection_id transport parameters match the expected values. +func (s *connIDState) validateTransportParameters(side connSide, p transportParameters) error { + // TODO: Consider returning more detailed errors, for debugging. + switch side { + case clientSide: + // Verify original_destination_connection_id matches + // the transient remote connection ID we chose. + if len(s.remote) == 0 || s.remote[0].seq != -1 { + return localTransportError(errInternal) + } + if !bytes.Equal(s.remote[0].cid, p.originalDstConnID) { + return localTransportError(errTransportParameter) + } + // Remove the transient remote connection ID. + // We have no further need for it. + s.remote = append(s.remote[:0], s.remote[1:]...) + case serverSide: + if p.originalDstConnID != nil { + // Clients do not send original_destination_connection_id. + return localTransportError(errTransportParameter) + } + } + // Verify initial_source_connection_id matches the first remote connection ID. + if len(s.remote) == 0 || s.remote[0].seq != 0 { + return localTransportError(errInternal) + } + if !bytes.Equal(p.initialSrcConnID, s.remote[0].cid) { + return localTransportError(errTransportParameter) + } + return nil +} + // handlePacket updates the connection ID state during the handshake // (Initial and Handshake packets). func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) { @@ -170,10 +203,13 @@ func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) // We're a client connection processing the first Initial packet // from the server. Replace the transient remote connection ID // with the Source Connection ID from the packet. - s.remote[0] = connID{ + // Leave the transient ID the list for now, since we'll need it when + // processing the transport parameters. + s.remote[0].retired = true + s.remote = append(s.remote, connID{ seq: 0, cid: cloneBytes(srcConnID), - } + }) } case ptype == packetTypeInitial && c.side == serverSide: if len(s.remote) == 0 { @@ -185,7 +221,7 @@ func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) }) } case ptype == packetTypeHandshake && c.side == serverSide: - if len(s.local) > 0 && s.local[0].seq == -1 { + if len(s.local) > 0 && s.local[0].seq == -1 && !s.local[0].retired { // We're a server connection processing the first Handshake packet from // the client. Discard the transient, client-chosen connection ID used // for Initial packets; the client will never send it again. @@ -213,7 +249,7 @@ func (s *connIDState) handleNewConnID(seq, retire int64, cid []byte, resetToken active := 0 for i := range s.remote { rcid := &s.remote[i] - if !rcid.retired && rcid.seq < s.retireRemotePriorTo { + if !rcid.retired && rcid.seq >= 0 && rcid.seq < s.retireRemotePriorTo { s.retireRemote(rcid) } if !rcid.retired { diff --git a/internal/quic/conn_id_test.go b/internal/quic/conn_id_test.go index c5289583d3..44755ecf45 100644 --- a/internal/quic/conn_id_test.go +++ b/internal/quic/conn_id_test.go @@ -48,6 +48,9 @@ func TestConnIDClientHandshake(t *testing.T) { t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal)) } wantRemote := []connID{{ + cid: testLocalConnID(-1), + seq: -1, + }, { cid: testPeerConnID(0), seq: 0, }} @@ -261,10 +264,12 @@ func TestConnIDPeerRetiresConnID(t *testing.T) { } func TestConnIDPeerWithZeroLengthConnIDSendsNewConnectionID(t *testing.T) { - // An endpoint that selects a zero-length connection ID during the handshake + // "An endpoint that selects a zero-length connection ID during the handshake // cannot issue a new connection ID." // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.1-8 - tc := newTestConn(t, clientSide) + tc := newTestConn(t, clientSide, func(p *transportParameters) { + p.initialSrcConnID = []byte{} + }) tc.peerConnID = []byte{} tc.ignoreFrame(frameTypeAck) tc.uncheckedHandshake() @@ -536,6 +541,7 @@ func TestConnIDPeerWithZeroLengthIDProvidesPreferredAddr(t *testing.T) { // Peer gives us more conn ids than our advertised limit, // including a conn id in the preferred address transport parameter. tc := newTestConn(t, serverSide, func(p *transportParameters) { + p.initialSrcConnID = []byte{} p.preferredAddrV4 = netip.MustParseAddrPort("0.0.0.0:0") p.preferredAddrV6 = netip.MustParseAddrPort("[::0]:0") p.preferredAddrConnID = testPeerConnID(1) @@ -552,3 +558,31 @@ func TestConnIDPeerWithZeroLengthIDProvidesPreferredAddr(t *testing.T) { code: errProtocolViolation, }) } + +func TestConnIDInitialSrcConnIDMismatch(t *testing.T) { + // "Endpoints MUST validate that received [initial_source_connection_id] + // parameters match received connection ID values." + // https://www.rfc-editor.org/rfc/rfc9000#section-7.3-3 + testSides(t, "", func(t *testing.T, side connSide) { + tc := newTestConn(t, side, func(p *transportParameters) { + p.initialSrcConnID = []byte("invalid") + }) + tc.ignoreFrame(frameTypeAck) + tc.ignoreFrame(frameTypeCrypto) + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + if side == clientSide { + // Server transport parameters are carried in the Handshake packet. + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake], + }) + } + tc.wantFrame("initial_source_connection_id transport parameter mismatch", + packetTypeInitial, debugFrameConnectionCloseTransport{ + code: errTransportParameter, + }) + }) +} diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index fd9e6e42e2..6a359e89a1 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -201,6 +201,10 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { TLSConfig: newTestTLSConfig(side), } peerProvidedParams := defaultTransportParameters() + peerProvidedParams.initialSrcConnID = testPeerConnID(0) + if side == clientSide { + peerProvidedParams.originalDstConnID = testLocalConnID(-1) + } for _, o := range opts { switch o := o.(type) { case func(*Config):