Skip to content

Commit

Permalink
quic: compute pnum len from max ack received, not sent
Browse files Browse the repository at this point in the history
QUIC packet numbers are truncated to include only the least
significant bits of the packet number. The number of bits
which must be retained is computed based on the largest
packet number known to have been received by the peer.
See RFC 9000, section 17.1.

We were incorrectly using the largest packet number
we have received *from* the peer. Oops.

(Test infrastructure change: Include the header byte
in the testPacket structure, so we can see how many
bytes the packet number was encoded with. Ignore this
byte when comparing packets.)

Change-Id: Iec17c69f007f8b39d14d24b0ca216c6a0018ae22
Reviewed-on: https://go-review.googlesource.com/c/net/+/545575
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
  • Loading branch information
neild committed Dec 18, 2023
1 parent b952594 commit b0eb4d6
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 9 deletions.
6 changes: 3 additions & 3 deletions internal/quic/conn_send.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {
pad := false
var sentInitial *sentPacket
if c.keysInitial.canWrite() {
pnumMaxAcked := c.acks[initialSpace].largestSeen()
pnumMaxAcked := c.loss.spaces[initialSpace].maxAcked
pnum := c.loss.nextNumber(initialSpace)
p := longPacket{
ptype: packetTypeInitial,
Expand Down Expand Up @@ -93,7 +93,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {

// Handshake packet.
if c.keysHandshake.canWrite() {
pnumMaxAcked := c.acks[handshakeSpace].largestSeen()
pnumMaxAcked := c.loss.spaces[handshakeSpace].maxAcked
pnum := c.loss.nextNumber(handshakeSpace)
p := longPacket{
ptype: packetTypeHandshake,
Expand Down Expand Up @@ -124,7 +124,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) {

// 1-RTT packet.
if c.keysAppData.canWrite() {
pnumMaxAcked := c.acks[appDataSpace].largestSeen()
pnumMaxAcked := c.loss.spaces[appDataSpace].maxAcked
pnum := c.loss.nextNumber(appDataSpace)
c.w.start1RTTPacket(pnum, pnumMaxAcked, dstConnID)
c.appendFrames(now, appDataSpace, pnum, limit)
Expand Down
43 changes: 43 additions & 0 deletions internal/quic/conn_send_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,46 @@ func TestAckElicitingAck(t *testing.T) {
}
t.Errorf("after sending %v PINGs, got no ack-eliciting response", count)
}

func TestSendPacketNumberSize(t *testing.T) {
tc := newTestConn(t, clientSide, permissiveTransportParameters)
tc.handshake()

recvPing := func() *testPacket {
t.Helper()
tc.conn.ping(appDataSpace)
p := tc.readPacket()
if p == nil {
t.Fatalf("want packet containing PING, got none")
}
return p
}

// Desynchronize the packet numbers the conn is sending and the ones it is receiving,
// by having the conn send a number of unacked packets.
for i := 0; i < 16; i++ {
recvPing()
}

// Establish the maximum packet number the conn has received an ACK for.
maxAcked := recvPing().num
tc.writeAckForAll()

// Make the conn send a sequence of packets.
// Check that the packet number is encoded with two bytes once the difference between the
// current packet and the max acked one is sufficiently large.
for want := maxAcked + 1; want < maxAcked+0x100; want++ {
p := recvPing()
if p.num != want {
t.Fatalf("received packet number %v, want %v", p.num, want)
}
gotPnumLen := int(p.header&0x03) + 1
wantPnumLen := 1
if p.num-maxAcked >= 0x80 {
wantPnumLen = 2
}
if gotPnumLen != wantPnumLen {
t.Fatalf("packet number 0x%x encoded with %v bytes, want %v (max acked = %v)", p.num, gotPnumLen, wantPnumLen, maxAcked)
}
}
}
15 changes: 13 additions & 2 deletions internal/quic/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func (d testDatagram) String() string {

type testPacket struct {
ptype packetType
header byte
version uint32
num packetNumber
keyPhaseBit bool
Expand Down Expand Up @@ -599,12 +600,18 @@ func (tc *testConn) readFrame() (debugFrame, packetType) {
func (tc *testConn) wantDatagram(expectation string, want *testDatagram) {
tc.t.Helper()
got := tc.readDatagram()
if !reflect.DeepEqual(got, want) {
if !datagramEqual(got, want) {
tc.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want)
}
}

func datagramEqual(a, b *testDatagram) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
if a.paddedSize != b.paddedSize ||
a.addr != b.addr ||
len(a.packets) != len(b.packets) {
Expand All @@ -622,16 +629,18 @@ func datagramEqual(a, b *testDatagram) bool {
func (tc *testConn) wantPacket(expectation string, want *testPacket) {
tc.t.Helper()
got := tc.readPacket()
if !reflect.DeepEqual(got, want) {
if !packetEqual(got, want) {
tc.t.Fatalf("%v:\ngot packet: %v\nwant packet: %v", expectation, got, want)
}
}

func packetEqual(a, b *testPacket) bool {
ac := *a
ac.frames = nil
ac.header = 0
bc := *b
bc.frames = nil
bc.header = 0
if !reflect.DeepEqual(ac, bc) {
return false
}
Expand Down Expand Up @@ -839,6 +848,7 @@ func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte)
}
d.packets = append(d.packets, &testPacket{
ptype: p.ptype,
header: buf[0],
version: p.version,
num: p.num,
dstConnID: p.dstConnID,
Expand Down Expand Up @@ -880,6 +890,7 @@ func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte)
}
d.packets = append(d.packets, &testPacket{
ptype: packetType1RTT,
header: hdr[0],
num: pnum,
dstConnID: hdr[1:][:len(tc.peerConnID)],
keyPhaseBit: hdr[0]&keyPhaseBit != 0,
Expand Down
3 changes: 1 addition & 2 deletions internal/quic/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"io"
"net"
"net/netip"
"reflect"
"testing"
"time"
)
Expand Down Expand Up @@ -242,7 +241,7 @@ func (te *testEndpoint) readDatagram() *testDatagram {
func (te *testEndpoint) wantDatagram(expectation string, want *testDatagram) {
te.t.Helper()
got := te.readDatagram()
if !reflect.DeepEqual(got, want) {
if !datagramEqual(got, want) {
te.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want)
}
}
Expand Down
3 changes: 1 addition & 2 deletions internal/quic/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
"reflect"
"testing"
"time"
)
Expand Down Expand Up @@ -56,7 +55,7 @@ func (tc *testConn) handshake() {
fillCryptoFrames(want, tc.cryptoDataOut)
i++
}
if !reflect.DeepEqual(got, want) {
if !datagramEqual(got, want) {
t.Fatalf("dgram %v:\ngot %v\n\nwant %v", i, got, want)
}
if i >= len(dgrams) {
Expand Down

0 comments on commit b0eb4d6

Please sign in to comment.