Skip to content

Commit

Permalink
quic: include more detail in connection close errors
Browse files Browse the repository at this point in the history
When closing a connection with an error, include a reason
string in the CONNECTION_CLOSE frame as well as the
error code, when the code isn't sufficient to explain the error.

Change-Id: I055a4e11b222e87d1ff01d8c45fcb7cc17fe4196
Reviewed-on: https://go-review.googlesource.com/c/net/+/539342
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
  • Loading branch information
neild committed Nov 6, 2023
1 parent ec29a94 commit 434956a
Show file tree
Hide file tree
Showing 15 changed files with 178 additions and 61 deletions.
4 changes: 2 additions & 2 deletions internal/quic/conn_close.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func (c *Conn) enterDraining(err error) {
if c.isDraining() {
return
}
if e, ok := c.lifetime.localErr.(localTransportError); ok && transportError(e) != errNo {
if e, ok := c.lifetime.localErr.(localTransportError); ok && e.code != errNo {
// If we've terminated the connection due to a peer protocol violation,
// record the final error on the connection as our reason for termination.
c.lifetime.finalErr = c.lifetime.localErr
Expand Down Expand Up @@ -220,7 +220,7 @@ func (c *Conn) Wait(ctx context.Context) error {
// Otherwise, Abort sends a transport error of APPLICATION_ERROR with the error's text.
func (c *Conn) Abort(err error) {
if err == nil {
err = localTransportError(errNo)
err = localTransportError{code: errNo}
}
c.sendMsg(func(now time.Time, c *Conn) {
c.abort(now, err)
Expand Down
5 changes: 4 additions & 1 deletion internal/quic/conn_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ func (c *Conn) shouldUpdateFlowControl(credit int64) bool {
func (c *Conn) handleStreamBytesReceived(n int64) error {
c.streams.inflow.usedLimit += n
if c.streams.inflow.usedLimit > c.streams.inflow.sentLimit {
return localTransportError(errFlowControl)
return localTransportError{
code: errFlowControl,
reason: "stream exceeded flow control limit",
}
}
return nil
}
Expand Down
61 changes: 40 additions & 21 deletions internal/quic/conn_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,25 +210,40 @@ func (s *connIDState) validateTransportParameters(c *Conn, isRetry bool, p trans
// the transient remote connection ID we chose (client)
// or is empty (server).
if !bytes.Equal(s.originalDstConnID, p.originalDstConnID) {
return localTransportError(errTransportParameter)
return localTransportError{
code: errTransportParameter,
reason: "original_destination_connection_id mismatch",
}
}
s.originalDstConnID = nil // we have no further need for this
// Verify retry_source_connection_id matches the value from
// the server's Retry packet (when one was sent), or is empty.
if !bytes.Equal(p.retrySrcConnID, s.retrySrcConnID) {
return localTransportError(errTransportParameter)
return localTransportError{
code: errTransportParameter,
reason: "retry_source_connection_id mismatch",
}
}
s.retrySrcConnID = nil // we have no further need for this
// Verify initial_source_connection_id matches the first remote connection ID.
if len(s.remote) == 0 || s.remote[0].seq != 0 {
return localTransportError(errInternal)
return localTransportError{
code: errInternal,
reason: "remote connection id missing",
}
}
if !bytes.Equal(p.initialSrcConnID, s.remote[0].cid) {
return localTransportError(errTransportParameter)
return localTransportError{
code: errTransportParameter,
reason: "initial_source_connection_id mismatch",
}
}
if len(p.statelessResetToken) > 0 {
if c.side == serverSide {
return localTransportError(errTransportParameter)
return localTransportError{
code: errTransportParameter,
reason: "client sent stateless_reset_token",
}
}
token := statelessResetToken(p.statelessResetToken)
s.remote[0].resetToken = token
Expand All @@ -255,17 +270,6 @@ func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte)
},
}
}
case ptype == packetTypeInitial && c.side == serverSide:
if len(s.remote) == 0 {
// We're a server connection processing the first Initial packet
// from the client. Set the client's connection ID.
s.remote = append(s.remote, remoteConnID{
connID: connID{
seq: 0,
cid: cloneBytes(srcConnID),
},
})
}
case ptype == packetTypeHandshake && c.side == serverSide:
if len(s.local) > 0 && s.local[0].seq == -1 && !s.local[0].retired {
// We're a server connection processing the first Handshake packet from
Expand Down Expand Up @@ -294,7 +298,10 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re
// Destination Connection ID MUST treat receipt of a NEW_CONNECTION_ID
// frame as a connection error of type PROTOCOL_VIOLATION."
// https://www.rfc-editor.org/rfc/rfc9000.html#section-19.15-6
return localTransportError(errProtocolViolation)
return localTransportError{
code: errProtocolViolation,
reason: "NEW_CONNECTION_ID from peer with zero-length DCID",
}
}

if retire > s.retireRemotePriorTo {
Expand All @@ -316,7 +323,10 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re
}
if rcid.seq == seq {
if !bytes.Equal(rcid.cid, cid) {
return localTransportError(errProtocolViolation)
return localTransportError{
code: errProtocolViolation,
reason: "NEW_CONNECTION_ID does not match prior id",
}
}
have = true // yes, we've seen this sequence number
}
Expand Down Expand Up @@ -350,7 +360,10 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re
// Retired connection IDs (including newly-retired ones) do not count
// against the limit.
// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-5
return localTransportError(errConnectionIDLimit)
return localTransportError{
code: errConnectionIDLimit,
reason: "active_connection_id_limit exceeded",
}
}

// "An endpoint SHOULD limit the number of connection IDs it has retired locally
Expand All @@ -360,7 +373,10 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re
// Set a limit of four times the active_connection_id_limit for
// the total number of remote connection IDs we keep state for locally.
if len(s.remote) > 4*activeConnIDLimit {
return localTransportError(errConnectionIDLimit)
return localTransportError{
code: errConnectionIDLimit,
reason: "too many unacknowledged RETIRE_CONNECTION_ID frames",
}
}

return nil
Expand All @@ -375,7 +391,10 @@ func (s *connIDState) retireRemote(rcid *remoteConnID) {

func (s *connIDState) handleRetireConnID(c *Conn, seq int64) error {
if seq >= s.nextLocalSeq {
return localTransportError(errProtocolViolation)
return localTransportError{
code: errProtocolViolation,
reason: "RETIRE_CONNECTION_ID for unissued sequence number",
}
}
for i := range s.local {
if s.local[i].seq == seq {
Expand Down
40 changes: 32 additions & 8 deletions internal/quic/conn_recv.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,18 @@ func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpa
if buf[0]&reservedLongBits != 0 {
// Reserved header bits must be 0.
// https://www.rfc-editor.org/rfc/rfc9000#section-17.2-8.2.1
c.abort(now, localTransportError(errProtocolViolation))
c.abort(now, localTransportError{
code: errProtocolViolation,
reason: "reserved header bits are not zero",
})
return -1
}
if p.version != quicVersion1 {
// The peer has changed versions on us mid-handshake?
c.abort(now, localTransportError(errProtocolViolation))
c.abort(now, localTransportError{
code: errProtocolViolation,
reason: "protocol version changed during handshake",
})
return -1
}

Expand Down Expand Up @@ -129,7 +135,10 @@ func (c *Conn) handle1RTT(now time.Time, buf []byte) int {
if buf[0]&reserved1RTTBits != 0 {
// Reserved header bits must be 0.
// https://www.rfc-editor.org/rfc/rfc9000#section-17.3.1-4.8.1
c.abort(now, localTransportError(errProtocolViolation))
c.abort(now, localTransportError{
code: errProtocolViolation,
reason: "reserved header bits are not zero",
})
return -1
}

Expand Down Expand Up @@ -222,7 +231,10 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace,
// "An endpoint MUST treat receipt of a packet containing no frames
// as a connection error of type PROTOCOL_VIOLATION."
// https://www.rfc-editor.org/rfc/rfc9000#section-12.4-3
c.abort(now, localTransportError(errProtocolViolation))
c.abort(now, localTransportError{
code: errProtocolViolation,
reason: "packet contains no frames",
})
return false
}
// frameOK verifies that ptype is one of the packets in mask.
Expand All @@ -232,7 +244,10 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace,
// that is not permitted as a connection error of type
// PROTOCOL_VIOLATION."
// https://www.rfc-editor.org/rfc/rfc9000#section-12.4-3
c.abort(now, localTransportError(errProtocolViolation))
c.abort(now, localTransportError{
code: errProtocolViolation,
reason: "frame not allowed in packet",
})
return false
}
return true
Expand Down Expand Up @@ -347,7 +362,10 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace,
n = c.handleHandshakeDoneFrame(now, space, payload)
}
if n < 0 {
c.abort(now, localTransportError(errFrameEncoding))
c.abort(now, localTransportError{
code: errFrameEncoding,
reason: "frame encoding error",
})
return false
}
payload = payload[n:]
Expand All @@ -360,7 +378,10 @@ func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte)
largest, ackDelay, n := consumeAckFrame(payload, func(rangeIndex int, start, end packetNumber) {
if end > c.loss.nextNumber(space) {
// Acknowledgement of a packet we never sent.
c.abort(now, localTransportError(errProtocolViolation))
c.abort(now, localTransportError{
code: errProtocolViolation,
reason: "acknowledgement for unsent packet",
})
return
}
c.loss.receiveAckRange(now, space, rangeIndex, start, end, c.handleAckOrLoss)
Expand Down Expand Up @@ -521,7 +542,10 @@ func (c *Conn) handleHandshakeDoneFrame(now time.Time, space numberSpace, payloa
if c.side == serverSide {
// Clients should never send HANDSHAKE_DONE.
// https://www.rfc-editor.org/rfc/rfc9000#section-19.20-4
c.abort(now, localTransportError(errProtocolViolation))
c.abort(now, localTransportError{
code: errProtocolViolation,
reason: "client sent HANDSHAKE_DONE",
})
return -1
}
if !c.isClosingOrDraining() {
Expand Down
2 changes: 1 addition & 1 deletion internal/quic/conn_send.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ func (c *Conn) appendConnectionCloseFrame(now time.Time, space numberSpace, err
c.lifetime.connCloseSentTime = now
switch e := err.(type) {
case localTransportError:
c.w.appendConnectionCloseTransportFrame(transportError(e), 0, "")
c.w.appendConnectionCloseTransportFrame(e.code, 0, e.reason)
case *ApplicationError:
if space != appDataSpace {
// "CONNECTION_CLOSE frames signaling application errors (type 0x1d)
Expand Down
10 changes: 8 additions & 2 deletions internal/quic/conn_streams.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType)
if (id.initiator() == c.side) != (ftype == sendStream) {
// Received an invalid frame for unidirectional stream.
// For example, a RESET_STREAM frame for a send-only stream.
c.abort(now, localTransportError(errStreamState))
c.abort(now, localTransportError{
code: errStreamState,
reason: "invalid frame for unidirectional stream",
})
return nil
}
}
Expand All @@ -148,7 +151,10 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType)
}
// Received a frame for a stream that should be originated by us,
// but which we never created.
c.abort(now, localTransportError(errStreamState))
c.abort(now, localTransportError{
code: errStreamState,
reason: "received frame for unknown stream",
})
return nil
} else {
// if isOpen, this is a stream that was implicitly opened by a
Expand Down
44 changes: 43 additions & 1 deletion internal/quic/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,20 @@ func (tc *testConn) wantDatagram(expectation string, want *testDatagram) {
}
}

func datagramEqual(a, b *testDatagram) bool {
if a.paddedSize != b.paddedSize ||
a.addr != b.addr ||
len(a.packets) != len(b.packets) {
return false
}
for i := range a.packets {
if !packetEqual(a.packets[i], b.packets[i]) {
return false
}
}
return true
}

// wantPacket indicates that we expect the Conn to send a packet.
func (tc *testConn) wantPacket(expectation string, want *testPacket) {
tc.t.Helper()
Expand All @@ -603,6 +617,25 @@ func (tc *testConn) wantPacket(expectation string, want *testPacket) {
}
}

func packetEqual(a, b *testPacket) bool {
ac := *a
ac.frames = nil
bc := *b
bc.frames = nil
if !reflect.DeepEqual(ac, bc) {
return false
}
if len(a.frames) != len(b.frames) {
return false
}
for i := range a.frames {
if !frameEqual(a.frames[i], b.frames[i]) {
return false
}
}
return true
}

// wantFrame indicates that we expect the Conn to send a frame.
func (tc *testConn) wantFrame(expectation string, wantType packetType, want debugFrame) {
tc.t.Helper()
Expand All @@ -613,11 +646,20 @@ func (tc *testConn) wantFrame(expectation string, wantType packetType, want debu
if gotType != wantType {
tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got)
}
if !reflect.DeepEqual(got, want) {
if !frameEqual(got, want) {
tc.t.Fatalf("%v:\ngot frame: %v\nwant frame: %v", expectation, got, want)
}
}

func frameEqual(a, b debugFrame) bool {
switch af := a.(type) {
case debugFrameConnectionCloseTransport:
bf, ok := b.(debugFrameConnectionCloseTransport)
return ok && af.code == bf.code
}
return reflect.DeepEqual(a, b)
}

// wantFrameType indicates that we expect the Conn to send a frame,
// although we don't care about the contents.
func (tc *testConn) wantFrameType(expectation string, wantType packetType, want debugFrame) {
Expand Down
5 changes: 4 additions & 1 deletion internal/quic/crypto_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ type cryptoStream struct {
func (s *cryptoStream) handleCrypto(off int64, b []byte, f func([]byte) error) error {
end := off + int64(len(b))
if end-s.inset.min() > cryptoBufferSize {
return localTransportError(errCryptoBufferExceeded)
return localTransportError{
code: errCryptoBufferExceeded,
reason: "crypto buffer exceeded",
}
}
s.inset.add(off, end)
if off == s.in.start {
Expand Down
10 changes: 8 additions & 2 deletions internal/quic/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,16 @@ func (e transportError) String() string {
}

// A localTransportError is an error sent to the peer.
type localTransportError transportError
type localTransportError struct {
code transportError
reason string
}

func (e localTransportError) Error() string {
return "closed connection: " + transportError(e).String()
if e.reason == "" {
return fmt.Sprintf("closed connection: %v", e.code)
}
return fmt.Sprintf("closed connection: %v: %q", e.code, e.reason)
}

// A peerTransportError is an error received from the peer.
Expand Down
2 changes: 1 addition & 1 deletion internal/quic/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (l *Listener) Close(ctx context.Context) error {
if !l.closing {
l.closing = true
for c := range l.conns {
c.Abort(localTransportError(errNo))
c.Abort(localTransportError{code: errNo})
}
if len(l.conns) == 0 {
l.udpConn.Close()
Expand Down
2 changes: 1 addition & 1 deletion internal/quic/packet_protection.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ func (k *updatingKeyPair) unprotect(pkt []byte, pnumOff int, pnumMax packetNumbe
if err != nil {
k.authFailures++
if k.authFailures >= aeadIntegrityLimit(k.r.suite) {
return nil, 0, localTransportError(errAEADLimitReached)
return nil, 0, localTransportError{code: errAEADLimitReached}
}
return nil, 0, err
}
Expand Down
Loading

0 comments on commit 434956a

Please sign in to comment.