From ec29a9498a02f880ede985f4671b24c62016f936 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 3 Nov 2023 16:37:26 -0700 Subject: [PATCH] quic: provide source conn ID when creating server conns New server-side conns need to know a variety of connection IDs, such as the Initial DCID used to create Initial encryption keys. We've been providing these as an ever-growing list of []byte parameters to newConn. Bundle them all up into a struct. Add the client's SCID to the set of IDs we pass to newConn. Up until now, we've been setting this when processing the first Initial packet from the client. Passing it to newConn will makes it available when logging the connection_started event. Update some test infrastructure to deal with the fact that we need to know the peer's SCID earlier in the test now. Change-Id: I760ee94af36125acf21c5bf135f1168830ba1ab8 Reviewed-on: https://go-review.googlesource.com/c/net/+/539341 Reviewed-by: Jonathan Amsterdam LUCI-TryBot-Result: Go LUCI --- internal/quic/conn.go | 22 ++++++++++++------ internal/quic/conn_id.go | 12 ++++++++-- internal/quic/conn_id_test.go | 5 +++- internal/quic/conn_test.go | 24 ++++++++++++------- internal/quic/listener.go | 19 ++++++++------- internal/quic/listener_test.go | 42 ++++++++-------------------------- 6 files changed, 65 insertions(+), 59 deletions(-) diff --git a/internal/quic/conn.go b/internal/quic/conn.go index b3d6feabc7..1292f2b20e 100644 --- a/internal/quic/conn.go +++ b/internal/quic/conn.go @@ -86,7 +86,15 @@ type connTestHooks interface { timeNow() time.Time } -func newConn(now time.Time, side connSide, originalDstConnID, retrySrcConnID []byte, peerAddr netip.AddrPort, config *Config, l *Listener) (*Conn, error) { +// newServerConnIDs is connection IDs associated with a new server connection. +type newServerConnIDs struct { + srcConnID []byte // source from client's current Initial + dstConnID []byte // destination from client's current Initial + originalDstConnID []byte // destination from client's first Initial + retrySrcConnID []byte // source from server's Retry +} + +func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort, config *Config, l *Listener) (*Conn, error) { c := &Conn{ side: side, listener: l, @@ -115,11 +123,11 @@ func newConn(now time.Time, side connSide, originalDstConnID, retrySrcConnID []b } initialConnID, _ = c.connIDState.dstConnID() } else { - initialConnID = originalDstConnID - if retrySrcConnID != nil { - initialConnID = retrySrcConnID + initialConnID = cids.originalDstConnID + if cids.retrySrcConnID != nil { + initialConnID = cids.retrySrcConnID } - if err := c.connIDState.initServer(c, initialConnID); err != nil { + if err := c.connIDState.initServer(c, cids); err != nil { return nil, err } } @@ -134,8 +142,8 @@ func newConn(now time.Time, side connSide, originalDstConnID, retrySrcConnID []b if err := c.startTLS(now, initialConnID, transportParameters{ initialSrcConnID: c.connIDState.srcConnID(), - originalDstConnID: originalDstConnID, - retrySrcConnID: retrySrcConnID, + originalDstConnID: cids.originalDstConnID, + retrySrcConnID: cids.retrySrcConnID, ackDelayExponent: ackDelayExponent, maxUDPPayloadSize: maxUDPPayloadSize, maxAckDelay: maxAckDelay, diff --git a/internal/quic/conn_id.go b/internal/quic/conn_id.go index 91ccaade14..b77ad8edf2 100644 --- a/internal/quic/conn_id.go +++ b/internal/quic/conn_id.go @@ -96,8 +96,8 @@ func (s *connIDState) initClient(c *Conn) error { return nil } -func (s *connIDState) initServer(c *Conn, dstConnID []byte) error { - dstConnID = cloneBytes(dstConnID) +func (s *connIDState) initServer(c *Conn, cids newServerConnIDs) error { + dstConnID := cloneBytes(cids.dstConnID) // Client-chosen, transient connection ID received in the first Initial packet. // The server will not use this as the Source Connection ID of packets it sends, // but remembers it because it may receive packets sent to this destination. @@ -121,6 +121,14 @@ func (s *connIDState) initServer(c *Conn, dstConnID []byte) error { conns.addConnID(c, dstConnID) conns.addConnID(c, locid) }) + + // Client chose its own connection ID. + s.remote = append(s.remote, remoteConnID{ + connID: connID{ + seq: 0, + cid: cloneBytes(cids.srcConnID), + }, + }) return nil } diff --git a/internal/quic/conn_id_test.go b/internal/quic/conn_id_test.go index 63feec992e..314a6b3845 100644 --- a/internal/quic/conn_id_test.go +++ b/internal/quic/conn_id_test.go @@ -578,8 +578,11 @@ func TestConnIDPeerWithZeroLengthIDProvidesPreferredAddr(t *testing.T) { p.preferredAddrV6 = netip.MustParseAddrPort("[::0]:0") p.preferredAddrConnID = testPeerConnID(1) p.preferredAddrResetToken = make([]byte, 16) + }, func(cids *newServerConnIDs) { + cids.srcConnID = []byte{} + }, func(tc *testConn) { + tc.peerConnID = []byte{} }) - tc.peerConnID = []byte{} tc.writeFrames(packetTypeInitial, debugFrameCrypto{ diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index df28907f44..248be9641b 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -193,33 +193,38 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { TLSConfig: newTestTLSConfig(side), StatelessResetKey: testStatelessResetKey, } + var cids newServerConnIDs + if side == serverSide { + // The initial connection ID for the server is chosen by the client. + cids.srcConnID = testPeerConnID(0) + cids.dstConnID = testPeerConnID(-1) + } var configTransportParams []func(*transportParameters) + var configTestConn []func(*testConn) for _, o := range opts { switch o := o.(type) { case func(*Config): o(config) case func(*tls.Config): o(config.TLSConfig) + case func(cids *newServerConnIDs): + o(&cids) case func(p *transportParameters): configTransportParams = append(configTransportParams, o) + case func(p *testConn): + configTestConn = append(configTestConn, o) default: t.Fatalf("unknown newTestConn option %T", o) } } - var initialConnID []byte - if side == serverSide { - // The initial connection ID for the server is chosen by the client. - initialConnID = testPeerConnID(-1) - } - listener := newTestListener(t, config) listener.configTransportParams = configTransportParams + listener.configTestConn = configTestConn conn, err := listener.l.newConn( listener.now, side, - initialConnID, - nil, + cids, netip.MustParseAddrPort("127.0.0.1:443")) if err != nil { t.Fatal(err) @@ -244,6 +249,9 @@ func newTestConnForConn(t *testing.T, listener *testListener, conn *Conn) *testC recvDatagram: make(chan *datagram), } t.Cleanup(tc.cleanup) + for _, f := range listener.configTestConn { + f(tc) + } conn.testHooks = (*testConnHooks)(tc) if listener.peerTLSConn != nil { diff --git a/internal/quic/listener.go b/internal/quic/listener.go index 08f011092a..24484eb6f2 100644 --- a/internal/quic/listener.go +++ b/internal/quic/listener.go @@ -140,7 +140,7 @@ func (l *Listener) Dial(ctx context.Context, network, address string) (*Conn, er } addr := u.AddrPort() addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port()) - c, err := l.newConn(time.Now(), clientSide, nil, nil, addr) + c, err := l.newConn(time.Now(), clientSide, newServerConnIDs{}, addr) if err != nil { return nil, err } @@ -151,13 +151,13 @@ func (l *Listener) Dial(ctx context.Context, network, address string) (*Conn, er return c, nil } -func (l *Listener) newConn(now time.Time, side connSide, originalDstConnID, retrySrcConnID []byte, peerAddr netip.AddrPort) (*Conn, error) { +func (l *Listener) newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort) (*Conn, error) { l.connsMu.Lock() defer l.connsMu.Unlock() if l.closing { return nil, errors.New("listener closed") } - c, err := newConn(now, side, originalDstConnID, retrySrcConnID, peerAddr, l.config, l) + c, err := newConn(now, side, cids, peerAddr, l.config, l) if err != nil { return nil, err } @@ -296,19 +296,22 @@ func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { } else { now = time.Now() } - var originalDstConnID, retrySrcConnID []byte + cids := newServerConnIDs{ + srcConnID: p.srcConnID, + dstConnID: p.dstConnID, + } if l.config.RequireAddressValidation { var ok bool - retrySrcConnID = p.dstConnID - originalDstConnID, ok = l.validateInitialAddress(now, p, m.addr) + cids.retrySrcConnID = p.dstConnID + cids.originalDstConnID, ok = l.validateInitialAddress(now, p, m.addr) if !ok { return } } else { - originalDstConnID = p.dstConnID + cids.originalDstConnID = p.dstConnID } var err error - c, err := l.newConn(now, serverSide, originalDstConnID, retrySrcConnID, m.addr) + c, err := l.newConn(now, serverSide, cids, m.addr) if err != nil { // The accept queue is probably full. // We could send a CONNECTION_CLOSE to the peer to reject the connection. diff --git a/internal/quic/listener_test.go b/internal/quic/listener_test.go index 21717e2516..674d4e4a16 100644 --- a/internal/quic/listener_test.go +++ b/internal/quic/listener_test.go @@ -19,12 +19,12 @@ import ( ) func TestConnect(t *testing.T) { - newLocalConnPair(t, &Config{}, &Config{}) + NewLocalConnPair(t, &Config{}, &Config{}) } func TestStreamTransfer(t *testing.T) { ctx := context.Background() - cli, srv := newLocalConnPair(t, &Config{}, &Config{}) + cli, srv := NewLocalConnPair(t, &Config{}, &Config{}) data := makeTestData(1 << 20) srvdone := make(chan struct{}) @@ -61,11 +61,11 @@ func TestStreamTransfer(t *testing.T) { } } -func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) { +func NewLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) { t.Helper() ctx := context.Background() - l1 := newLocalListener(t, serverSide, conf1) - l2 := newLocalListener(t, clientSide, conf2) + l1 := NewLocalListener(t, serverSide, conf1) + l2 := NewLocalListener(t, clientSide, conf2) c2, err := l2.Dial(ctx, "udp", l1.LocalAddr().String()) if err != nil { t.Fatal(err) @@ -77,9 +77,11 @@ func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverCon return c2, c1 } -func newLocalListener(t *testing.T, side connSide, conf *Config) *Listener { +func NewLocalListener(t *testing.T, side connSide, conf *Config) *Listener { t.Helper() if conf.TLSConfig == nil { + newConf := *conf + conf = &newConf conf.TLSConfig = newTestTLSConfig(side) } l, err := Listen("udp", "127.0.0.1:0", conf) @@ -101,6 +103,7 @@ type testListener struct { conns map[*Conn]*testConn acceptQueue []*testConn configTransportParams []func(*transportParameters) + configTestConn []func(*testConn) sentDatagrams [][]byte peerTLSConn *tls.QUICConn lastInitialDstConnID []byte // for parsing Retry packets @@ -251,33 +254,6 @@ func (tl *testListener) wantIdle(expectation string) { } } -func (tl *testListener) newClientTLS(srcConnID, dstConnID []byte) []byte { - peerProvidedParams := defaultTransportParameters() - peerProvidedParams.initialSrcConnID = srcConnID - peerProvidedParams.originalDstConnID = dstConnID - for _, f := range tl.configTransportParams { - f(&peerProvidedParams) - } - - config := &tls.QUICConfig{TLSConfig: newTestTLSConfig(clientSide)} - tl.peerTLSConn = tls.QUICClient(config) - tl.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams)) - tl.peerTLSConn.Start(context.Background()) - var data []byte - for { - e := tl.peerTLSConn.NextEvent() - switch e.Kind { - case tls.QUICNoEvent: - return data - case tls.QUICWriteData: - if e.Level != tls.QUICEncryptionLevelInitial { - tl.t.Fatal("initial data at unexpected level") - } - data = append(data, e.Data...) - } - } -} - // advance causes time to pass. func (tl *testListener) advance(d time.Duration) { tl.t.Helper()