From 84d59c41945d751b6451fc806d2c2239b840eaaf Mon Sep 17 00:00:00 2001 From: Abhishek Ranjan Date: Tue, 13 Aug 2024 15:37:40 +0530 Subject: [PATCH] Make isGreetingDone part of hangingConn struct --- internal/transport/transport_test.go | 60 ++++++++++++++-------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 3342f817ab0a..c22d96df8e12 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -428,16 +428,8 @@ func setUpServerOnly(t *testing.T, port int, sc *ServerConfig, ht hType) *server return server } -// isGreetingDone verifies that client-server setup is complete -// for the test. -var isGreetingDone = atomic.Bool{} - -func setUp(t *testing.T, port int, ht hType, options ...ConnectOptions) (*server, *http2Client, func()) { - var copts = ConnectOptions{} - if len(options) > 0 { - copts = options[0] - } - return setUpWithOptions(t, port, &ServerConfig{}, ht, copts) +func setUp(t *testing.T, port int, ht hType) (*server, *http2Client, func()) { + return setUpWithOptions(t, port, &ServerConfig{}, ht, ConnectOptions{}) } func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) { @@ -451,7 +443,6 @@ func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts cancel() // Do not cancel in success path. t.Fatalf("failed to create transport: %v", connErr) } - isGreetingDone.Store(true) return server, ct.(*http2Client), cancel } @@ -2758,21 +2749,17 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) { } // hangingConn is a net.Conn wrapper for testing, simulating hanging connections -// after a GOAWAY frame is sent, of which Write operations pause until explicitly signaled -// or a timeout occurs. +// after a GOAWAY frame is sent, of which Write operations pause until explicitly +// signaled or a timeout occurs. type hangingConn struct { net.Conn - hangConn chan struct{} -} - -func (hc *hangingConn) Read(b []byte) (n int, err error) { - n, err = hc.Conn.Read(b) - return n, err + hangConn chan struct{} + isGreetingDone *atomic.Bool } func (hc *hangingConn) Write(b []byte) (n int, err error) { n, err = hc.Conn.Write(b) - if isGreetingDone.Load() == true { + if hc.isGreetingDone.Load() == true { // Hang the Write for more than goAwayLoopyWriterTimeout timer := time.NewTimer(time.Millisecond * 5) defer timer.Stop() @@ -2784,14 +2771,6 @@ func (hc *hangingConn) Write(b []byte) (n int, err error) { return n, err } -func hangingDialer(_ context.Context, addr string) (net.Conn, error) { - conn, err := net.Dial("tcp", addr) - if err != nil { - return nil, err - } - return &hangingConn{Conn: conn, hangConn: make(chan struct{})}, nil -} - // Tests the scenario where a client transport is closed and writing of the // GOAWAY frame as part of the close does not complete because of a network // hang. The test verifies that the client transport is closed without waiting @@ -2806,9 +2785,30 @@ func (s) TestClientCloseReturnsEarlyWhenGoAwayWriteHangs(t *testing.T) { goAwayLoopyWriterTimeout = origGoAwayLoopyTimeout }() - isGreetingDone.Store(false) + // Create the server set up. + connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) + server := setUpServerOnly(t, 0, &ServerConfig{}, normal) + addr := resolver.Address{Addr: "localhost:" + server.port} + isGreetingDone := &atomic.Bool{} + dialer := func(_ context.Context, addr string) (net.Conn, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + isGreetingDone.Store(false) + return &hangingConn{Conn: conn, hangConn: make(chan struct{}), isGreetingDone: isGreetingDone}, nil + } + copts := ConnectOptions{Dialer: dialer} + copts.ChannelzParent = channelzSubChannel(t) + + // Create client transport with custom dialer + ct, connErr := NewClientTransport(connectCtx, context.Background(), addr, copts, func(GoAwayReason) {}) + if connErr != nil { + cancel() // Do not cancel in success path. + t.Fatalf("failed to create transport: %v", connErr) + } + isGreetingDone.Store(true) - server, ct, cancel := setUp(t, 0, normal, ConnectOptions{Dialer: hangingDialer}) defer cancel() defer server.stop()