diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go index 6347ddae85..19c43858c5 100644 --- a/internal/quic/conn_recv.go +++ b/internal/quic/conn_recv.go @@ -7,6 +7,9 @@ package quic import ( + "bytes" + "encoding/binary" + "errors" "time" ) @@ -31,6 +34,9 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) { n = c.handleLongHeader(now, ptype, handshakeSpace, c.keysHandshake.r, buf) case packetType1RTT: n = c.handle1RTT(now, buf) + case packetTypeVersionNegotiation: + c.handleVersionNegotiation(now, buf) + return default: return } @@ -59,6 +65,11 @@ func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpa c.abort(now, localTransportError(errProtocolViolation)) return -1 } + if p.version != quicVersion1 { + // The peer has changed versions on us mid-handshake? + c.abort(now, localTransportError(errProtocolViolation)) + return -1 + } if !c.acks[space].shouldProcess(p.num) { return n @@ -117,6 +128,42 @@ func (c *Conn) handle1RTT(now time.Time, buf []byte) int { return len(buf) } +var errVersionNegotiation = errors.New("server does not support QUIC version 1") + +func (c *Conn) handleVersionNegotiation(now time.Time, pkt []byte) { + if c.side != clientSide { + return // servers don't handle Version Negotiation packets + } + // "A client MUST discard any Version Negotiation packet if it has + // received and successfully processed any other packet [...]" + // https://www.rfc-editor.org/rfc/rfc9000#section-6.2-2 + if !c.keysInitial.canRead() { + return // discarded Initial keys, connection is already established + } + if c.acks[initialSpace].seen.numRanges() != 0 { + return // processed at least one packet + } + _, srcConnID, versions := parseVersionNegotiation(pkt) + if len(c.connIDState.remote) < 1 || !bytes.Equal(c.connIDState.remote[0].cid, srcConnID) { + return // Source Connection ID doesn't match what we sent + } + for len(versions) >= 4 { + ver := binary.BigEndian.Uint32(versions) + if ver == 1 { + // "A client MUST discard a Version Negotiation packet that lists + // the QUIC version selected by the client." + // https://www.rfc-editor.org/rfc/rfc9000#section-6.2-2 + return + } + versions = versions[4:] + } + // "A client that supports only this version of QUIC MUST + // abandon the current connection attempt if it receives + // a Version Negotiation packet, [with the two exceptions handled above]." + // https://www.rfc-editor.org/rfc/rfc9000#section-6.2-2 + c.abortImmediately(now, errVersionNegotiation) +} + func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, payload []byte) (ackEliciting bool) { if len(payload) == 0 { // "An endpoint MUST treat receipt of a packet containing no frames diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go index 63f65b5579..00b02c2a31 100644 --- a/internal/quic/conn_send.go +++ b/internal/quic/conn_send.go @@ -64,7 +64,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { pnum := c.loss.nextNumber(initialSpace) p := longPacket{ ptype: packetTypeInitial, - version: 1, + version: quicVersion1, num: pnum, dstConnID: dstConnID, srcConnID: c.connIDState.srcConnID(), @@ -91,7 +91,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { pnum := c.loss.nextNumber(handshakeSpace) p := longPacket{ ptype: packetTypeHandshake, - version: 1, + version: quicVersion1, num: pnum, dstConnID: dstConnID, srcConnID: c.connIDState.srcConnID(), diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index d75b2eb690..fd9e6e42e2 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -409,7 +409,7 @@ func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) { keyNumber: tc.sendKeyNumber, keyPhaseBit: tc.sendKeyPhaseBit, frames: frames, - version: 1, + version: quicVersion1, dstConnID: dstConnID, srcConnID: tc.peerConnID, }}, diff --git a/internal/quic/listener.go b/internal/quic/listener.go index 03d8ec65f4..96b1e45934 100644 --- a/internal/quic/listener.go +++ b/internal/quic/listener.go @@ -239,32 +239,15 @@ func (l *Listener) listen() { func (l *Listener) handleDatagram(m *datagram, conns map[string]*Conn) { dstConnID, ok := dstConnIDForDatagram(m.b) if !ok { + m.recycle() return } c := conns[string(dstConnID)] if c == nil { - if getPacketType(m.b) != packetTypeInitial { - // This packet isn't trying to create a new connection. - // It might be associated with some connection we've lost state for. - // TODO: Send a stateless reset when appropriate. - // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.3 - return - } - var now time.Time - if l.testHooks != nil { - now = l.testHooks.timeNow() - } else { - now = time.Now() - } - var err error - c, err = l.newConn(now, serverSide, dstConnID, m.addr) - if err != nil { - // The accept queue is probably full. - // We could send a CONNECTION_CLOSE to the peer to reject the connection. - // Currently, we just drop the datagram. - // https://www.rfc-editor.org/rfc/rfc9000.html#section-5.2.2-5 - return - } + // TODO: Move this branch into a separate goroutine to avoid blocking + // the listener while processing packets. + l.handleUnknownDestinationDatagram(m) + return } // TODO: This can block the listener while waiting for the conn to accept the dgram. @@ -272,6 +255,67 @@ func (l *Listener) handleDatagram(m *datagram, conns map[string]*Conn) { c.sendMsg(m) } +func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { + defer func() { + if m != nil { + m.recycle() + } + }() + if len(m.b) < minimumClientInitialDatagramSize { + return + } + p, ok := parseGenericLongHeaderPacket(m.b) + if !ok { + // Not a long header packet, or not parseable. + // Short header (1-RTT) packets don't contain enough information + // to do anything useful with if we don't recognize the + // connection ID. + return + } + + switch p.version { + case quicVersion1: + case 0: + // Version Negotiation for an unknown connection. + return + default: + // Unknown version. + l.sendVersionNegotiation(p, m.addr) + return + } + if getPacketType(m.b) != packetTypeInitial { + // This packet isn't trying to create a new connection. + // It might be associated with some connection we've lost state for. + // TODO: Send a stateless reset when appropriate. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.3 + return + } + var now time.Time + if l.testHooks != nil { + now = l.testHooks.timeNow() + } else { + now = time.Now() + } + var err error + c, err := l.newConn(now, serverSide, p.dstConnID, m.addr) + if err != nil { + // The accept queue is probably full. + // We could send a CONNECTION_CLOSE to the peer to reject the connection. + // Currently, we just drop the datagram. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-5.2.2-5 + return + } + c.sendMsg(m) + m = nil // don't recycle, sendMsg takes ownership +} + +func (l *Listener) sendVersionNegotiation(p genericLongPacket, addr netip.AddrPort) { + m := newDatagram() + m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1) + l.sendDatagram(m.b, addr) + m.recycle() +} + func (l *Listener) sendDatagram(p []byte, addr netip.AddrPort) error { _, err := l.udpConn.WriteToUDPAddrPort(p, addr) return err diff --git a/internal/quic/packet.go b/internal/quic/packet.go index 8242bd0a97..7d69f96d27 100644 --- a/internal/quic/packet.go +++ b/internal/quic/packet.go @@ -6,7 +6,10 @@ package quic -import "fmt" +import ( + "encoding/binary" + "fmt" +) // packetType is a QUIC packet type. // https://www.rfc-editor.org/rfc/rfc9000.html#section-17 @@ -157,6 +160,33 @@ func dstConnIDForDatagram(pkt []byte) (id []byte, ok bool) { return b[:n], true } +// parseVersionNegotiation parses a Version Negotiation packet. +// The returned versions is a slice of big-endian uint32s. +// It returns (nil, nil, nil) for an invalid packet. +func parseVersionNegotiation(pkt []byte) (dstConnID, srcConnID, versions []byte) { + p, ok := parseGenericLongHeaderPacket(pkt) + if !ok { + return nil, nil, nil + } + if len(p.data)%4 != 0 { + return nil, nil, nil + } + return p.dstConnID, p.srcConnID, p.data +} + +// appendVersionNegotiation appends a Version Negotiation packet to pkt, +// returning the result. +func appendVersionNegotiation(pkt, dstConnID, srcConnID []byte, versions ...uint32) []byte { + pkt = append(pkt, headerFormLong|fixedBit) // header byte + pkt = append(pkt, 0, 0, 0, 0) // Version (0 for Version Negotiation) + pkt = appendUint8Bytes(pkt, dstConnID) // Destination Connection ID + pkt = appendUint8Bytes(pkt, srcConnID) // Source Connection ID + for _, v := range versions { + pkt = binary.BigEndian.AppendUint32(pkt, v) // Supported Version + } + return pkt +} + // A longPacket is a long header packet. type longPacket struct { ptype packetType @@ -177,3 +207,42 @@ type shortPacket struct { num packetNumber payload []byte } + +// A genericLongPacket is a long header packet of an arbitrary QUIC version. +// https://www.rfc-editor.org/rfc/rfc8999#section-5.1 +type genericLongPacket struct { + version uint32 + dstConnID []byte + srcConnID []byte + data []byte +} + +func parseGenericLongHeaderPacket(b []byte) (p genericLongPacket, ok bool) { + if len(b) < 5 || !isLongHeader(b[0]) { + return genericLongPacket{}, false + } + b = b[1:] + // Version (32), + var n int + p.version, n = consumeUint32(b) + if n < 0 { + return genericLongPacket{}, false + } + b = b[n:] + // Destination Connection ID Length (8), + // Destination Connection ID (0..2048), + p.dstConnID, n = consumeUint8Bytes(b) + if n < 0 || len(p.dstConnID) > 2048/8 { + return genericLongPacket{}, false + } + b = b[n:] + // Source Connection ID Length (8), + // Source Connection ID (0..2048), + p.srcConnID, n = consumeUint8Bytes(b) + if n < 0 || len(p.dstConnID) > 2048/8 { + return genericLongPacket{}, false + } + b = b[n:] + p.data = b + return p, true +} diff --git a/internal/quic/packet_test.go b/internal/quic/packet_test.go index b13a587e54..58c584e162 100644 --- a/internal/quic/packet_test.go +++ b/internal/quic/packet_test.go @@ -8,7 +8,9 @@ package quic import ( "bytes" + "encoding/binary" "encoding/hex" + "reflect" "strings" "testing" ) @@ -112,6 +114,124 @@ func TestPacketHeader(t *testing.T) { } } +func TestEncodeDecodeVersionNegotiation(t *testing.T) { + dstConnID := []byte("this is a very long destination connection id") + srcConnID := []byte("this is a very long source connection id") + versions := []uint32{1, 0xffffffff} + got := appendVersionNegotiation([]byte{}, dstConnID, srcConnID, versions...) + want := bytes.Join([][]byte{{ + 0b1100_0000, // header byte + 0, 0, 0, 0, // Version + byte(len(dstConnID)), + }, dstConnID, { + byte(len(srcConnID)), + }, srcConnID, { + 0x00, 0x00, 0x00, 0x01, + 0xff, 0xff, 0xff, 0xff, + }}, nil) + if !bytes.Equal(got, want) { + t.Fatalf("appendVersionNegotiation(nil, %x, %x, %v):\ngot %x\nwant %x", + dstConnID, srcConnID, versions, got, want) + } + gotDst, gotSrc, gotVersionBytes := parseVersionNegotiation(got) + if got, want := gotDst, dstConnID; !bytes.Equal(got, want) { + t.Errorf("parseVersionNegotiation: got dstConnID = %x, want %x", got, want) + } + if got, want := gotSrc, srcConnID; !bytes.Equal(got, want) { + t.Errorf("parseVersionNegotiation: got srcConnID = %x, want %x", got, want) + } + var gotVersions []uint32 + for len(gotVersionBytes) >= 4 { + gotVersions = append(gotVersions, binary.BigEndian.Uint32(gotVersionBytes)) + gotVersionBytes = gotVersionBytes[4:] + } + if got, want := gotVersions, versions; !reflect.DeepEqual(got, want) { + t.Errorf("parseVersionNegotiation: got versions = %v, want %v", got, want) + } +} + +func TestParseGenericLongHeaderPacket(t *testing.T) { + for _, test := range []struct { + name string + packet []byte + version uint32 + dstConnID []byte + srcConnID []byte + data []byte + }{{ + name: "long header packet", + packet: unhex(` + 80 01020304 04a1a2a3a4 05b1b2b3b4b5 c1 + `), + version: 0x01020304, + dstConnID: unhex(`a1a2a3a4`), + srcConnID: unhex(`b1b2b3b4b5`), + data: unhex(`c1`), + }, { + name: "zero everything", + packet: unhex(` + 80 00000000 00 00 + `), + version: 0, + dstConnID: []byte{}, + srcConnID: []byte{}, + data: []byte{}, + }} { + t.Run(test.name, func(t *testing.T) { + p, ok := parseGenericLongHeaderPacket(test.packet) + if !ok { + t.Fatalf("parseGenericLongHeaderPacket() = _, false; want true") + } + if got, want := p.version, test.version; got != want { + t.Errorf("version = %v, want %v", got, want) + } + if got, want := p.dstConnID, test.dstConnID; !bytes.Equal(got, want) { + t.Errorf("Destination Connection ID = {%x}, want {%x}", got, want) + } + if got, want := p.srcConnID, test.srcConnID; !bytes.Equal(got, want) { + t.Errorf("Source Connection ID = {%x}, want {%x}", got, want) + } + if got, want := p.data, test.data; !bytes.Equal(got, want) { + t.Errorf("Data = {%x}, want {%x}", got, want) + } + }) + } +} + +func TestParseGenericLongHeaderPacketErrors(t *testing.T) { + for _, test := range []struct { + name string + packet []byte + }{{ + name: "short header packet", + packet: unhex(` + 00 01020304 04a1a2a3a4 05b1b2b3b4b5 c1 + `), + }, { + name: "packet too short", + packet: unhex(` + 80 000000 + `), + }, { + name: "destination id too long", + packet: unhex(` + 80 00000000 02 00 + `), + }, { + name: "source id too long", + packet: unhex(` + 80 00000000 00 01 + `), + }} { + t.Run(test.name, func(t *testing.T) { + _, ok := parseGenericLongHeaderPacket(test.packet) + if ok { + t.Fatalf("parseGenericLongHeaderPacket() = _, true; want false") + } + }) + } +} + func unhex(s string) []byte { b, err := hex.DecodeString(strings.Map(func(c rune) rune { switch c { diff --git a/internal/quic/quic.go b/internal/quic/quic.go index cf4137e810..9de97b6d88 100644 --- a/internal/quic/quic.go +++ b/internal/quic/quic.go @@ -10,6 +10,13 @@ import ( "time" ) +// QUIC versions. +// We only support v1 at this time. +const ( + quicVersion1 = 1 + quicVersion2 = 0x6b3343cf // https://www.rfc-editor.org/rfc/rfc9369 +) + // connIDLen is the length in bytes of connection IDs chosen by this package. // Since 1-RTT packets don't include a connection ID length field, // we use a consistent length for all our IDs. diff --git a/internal/quic/tls_test.go b/internal/quic/tls_test.go index 4167076889..81d17b8587 100644 --- a/internal/quic/tls_test.go +++ b/internal/quic/tls_test.go @@ -97,7 +97,7 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { packets: []*testPacket{{ ptype: packetTypeInitial, num: 0, - version: 1, + version: quicVersion1, srcConnID: clientConnIDs[0], dstConnID: transientConnID, frames: []debugFrame{ @@ -110,7 +110,7 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { packets: []*testPacket{{ ptype: packetTypeInitial, num: 0, - version: 1, + version: quicVersion1, srcConnID: serverConnIDs[0], dstConnID: clientConnIDs[0], frames: []debugFrame{ @@ -122,7 +122,7 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { }, { ptype: packetTypeHandshake, num: 0, - version: 1, + version: quicVersion1, srcConnID: serverConnIDs[0], dstConnID: clientConnIDs[0], frames: []debugFrame{ @@ -144,7 +144,7 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { packets: []*testPacket{{ ptype: packetTypeInitial, num: 1, - version: 1, + version: quicVersion1, srcConnID: clientConnIDs[0], dstConnID: serverConnIDs[0], frames: []debugFrame{ @@ -155,7 +155,7 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { }, { ptype: packetTypeHandshake, num: 0, - version: 1, + version: quicVersion1, srcConnID: clientConnIDs[0], dstConnID: serverConnIDs[0], frames: []debugFrame{ @@ -568,7 +568,7 @@ func TestConnAEADLimitReached(t *testing.T) { ptype: packetType1RTT, num: 1000, frames: []debugFrame{debugFramePing{}}, - version: 1, + version: quicVersion1, dstConnID: dstConnID, srcConnID: tc.peerConnID, }, 0) diff --git a/internal/quic/version_test.go b/internal/quic/version_test.go new file mode 100644 index 0000000000..cfb7ce4be7 --- /dev/null +++ b/internal/quic/version_test.go @@ -0,0 +1,110 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "bytes" + "context" + "crypto/tls" + "testing" +) + +func TestVersionNegotiationServerReceivesUnknownVersion(t *testing.T) { + config := &Config{ + TLSConfig: newTestTLSConfig(serverSide), + } + tl := newTestListener(t, config, nil) + + // Packet of unknown contents for some unrecognized QUIC version. + dstConnID := []byte{1, 2, 3, 4} + srcConnID := []byte{5, 6, 7, 8} + pkt := []byte{ + 0b1000_0000, + 0x00, 0x00, 0x00, 0x0f, + } + pkt = append(pkt, byte(len(dstConnID))) + pkt = append(pkt, dstConnID...) + pkt = append(pkt, byte(len(srcConnID))) + pkt = append(pkt, srcConnID...) + for len(pkt) < minimumClientInitialDatagramSize { + pkt = append(pkt, 0) + } + + tl.write(&datagram{ + b: pkt, + }) + gotPkt := tl.read() + if gotPkt == nil { + t.Fatalf("got no response; want Version Negotiaion") + } + if got := getPacketType(gotPkt); got != packetTypeVersionNegotiation { + t.Fatalf("got packet type %v; want Version Negotiaion", got) + } + gotDst, gotSrc, versions := parseVersionNegotiation(gotPkt) + if got, want := gotDst, srcConnID; !bytes.Equal(got, want) { + t.Errorf("got Destination Connection ID %x, want %x", got, want) + } + if got, want := gotSrc, dstConnID; !bytes.Equal(got, want) { + t.Errorf("got Source Connection ID %x, want %x", got, want) + } + if got, want := versions, []byte{0, 0, 0, 1}; !bytes.Equal(got, want) { + t.Errorf("got Supported Version %x, want %x", got, want) + } +} + +func TestVersionNegotiationClientAborts(t *testing.T) { + tc := newTestConn(t, clientSide) + p := tc.readPacket() // client Initial packet + tc.listener.write(&datagram{ + b: appendVersionNegotiation(nil, p.srcConnID, p.dstConnID, 10), + }) + tc.wantIdle("connection does not send a CONNECTION_CLOSE") + if err := tc.conn.waitReady(canceledContext()); err != errVersionNegotiation { + t.Errorf("conn.waitReady() = %v, want errVersionNegotiation", err) + } +} + +func TestVersionNegotiationClientIgnoresAfterProcessingPacket(t *testing.T) { + tc := newTestConn(t, clientSide) + tc.ignoreFrame(frameTypeAck) + p := tc.readPacket() // client Initial packet + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.listener.write(&datagram{ + b: appendVersionNegotiation(nil, p.srcConnID, p.dstConnID, 10), + }) + if err := tc.conn.waitReady(canceledContext()); err != context.Canceled { + t.Errorf("conn.waitReady() = %v, want context.Canceled", err) + } + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake], + }) + tc.wantFrameType("conn ignores Version Negotiation and continues with handshake", + packetTypeHandshake, debugFrameCrypto{}) +} + +func TestVersionNegotiationClientIgnoresMismatchingSourceConnID(t *testing.T) { + tc := newTestConn(t, clientSide) + tc.ignoreFrame(frameTypeAck) + p := tc.readPacket() // client Initial packet + tc.listener.write(&datagram{ + b: appendVersionNegotiation(nil, p.srcConnID, []byte("mismatch"), 10), + }) + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake], + }) + tc.wantFrameType("conn ignores Version Negotiation and continues with handshake", + packetTypeHandshake, debugFrameCrypto{}) +}