diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index f63fae5b8b5a..44574e703d82 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -59,6 +59,8 @@ import ( // atomically. var clientConnectionCounter uint64 +const GoAwayLoopyWriterTimeout = 5 * time.Millisecond + var metadataFromOutgoingContextRaw = internal.FromOutgoingContextRaw.(func(context.Context) (metadata.MD, [][]string, bool)) // http2Client implements the ClientTransport interface with HTTP2. @@ -1006,29 +1008,29 @@ func (t *http2Client) Close(err error) { t.kpDormancyCond.Signal() } t.mu.Unlock() - // Per HTTP/2 spec, a GOAWAY frame must be sent before closing the - // connection. See https://httpwg.org/specs/rfc7540.html#GOAWAY. + var st *status.Status + // Per HTTP/2 spec, a GOAWAY frame must be sent before closing the connection. + // See https://httpwg.org/specs/rfc7540.html#GOAWAY. It also waits for loopyWriter to + // be closed with a timer to avoid the indefinite blocking. t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte("client transport shutdown"), closeConn: err}) - timer := time.NewTimer(5 * time.Second) + timer := time.NewTimer(GoAwayLoopyWriterTimeout) select { case <-t.writerDone: + // Append info about previous goaway's if there were any, since this may be important + // for understanding the root cause for this connection to be closed. + _, goAwayDebugMessage := t.GetGoAwayReason() + if len(goAwayDebugMessage) > 0 { + st = status.Newf(codes.Unavailable, "closing transport due to: %v, received prior goaway: %v", err, goAwayDebugMessage) + err = st.Err() + } else { + st = status.New(codes.Unavailable, err.Error()) + } case <-timer.C: t.logger.Warningf("timeout waiting for the loopy writer to be closed.") } t.cancel() t.conn.Close() channelz.RemoveEntry(t.channelz.ID) - // Append info about previous goaways if there were any, since this may be important - // for understanding the root cause for this connection to be closed. - _, goAwayDebugMessage := t.GetGoAwayReason() - - var st *status.Status - if len(goAwayDebugMessage) > 0 { - st = status.Newf(codes.Unavailable, "closing transport due to: %v, received prior goaway: %v", err, goAwayDebugMessage) - err = st.Err() - } else { - st = status.New(codes.Unavailable, err.Error()) - } // Notify all active streams. for _, s := range streams { diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 23b2fbd284c7..b04cb7ce5451 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -90,9 +90,10 @@ const ( invalidHeaderField delayRead pingpong - goAwayFrameSize = 42 ) +const goAwayFrameSize = 42 + func (h *testStreamHandler) handleStreamAndNotify(s *Stream) { if h.notify == nil { return @@ -2661,94 +2662,15 @@ func TestConnectionError_Unwrap(t *testing.T) { // clientTransport.Close(), client sends a goaway to the server with the correct // error code and debug data. func (s) TestClientSendsAGoAwayFrame(t *testing.T) { - // Create a server. - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Error while listening: %v", err) - } - defer lis.Close() - // greetDone is used to notify when server is done greeting the client. - greetDone := make(chan struct{}) - // errorCh verifies that desired GOAWAY not received by server - errorCh := make(chan error) - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - // Launch the server. - go func() { - sconn, err := lis.Accept() - if err != nil { - t.Errorf("Error while accepting: %v", err) - } - defer sconn.Close() - if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil { - t.Errorf("Error while writing settings ack: %v", err) - return - } - sfr := http2.NewFramer(sconn, sconn) - if err := sfr.WriteSettings(); err != nil { - t.Errorf("Error while writing settings %v", err) - return - } - fr, _ := sfr.ReadFrame() - if _, ok := fr.(*http2.SettingsFrame); !ok { - t.Errorf("Expected settings frame, got %v", fr) - } - fr, _ = sfr.ReadFrame() - if fr, ok := fr.(*http2.SettingsFrame); !ok || !fr.IsAck() { - t.Errorf("Expected settings ACK frame, got %v", fr) - } - fr, _ = sfr.ReadFrame() - if fr, ok := fr.(*http2.HeadersFrame); !ok || !fr.Flags.Has(http2.FlagHeadersEndHeaders) { - t.Errorf("Expected Headers frame with END_HEADERS frame, got %v", fr) - } - close(greetDone) - - frame, err := sfr.ReadFrame() - if err != nil { - return - } - switch fr := frame.(type) { - case *http2.GoAwayFrame: - // Records that the server successfully received a GOAWAY frame. - goAwayFrame := fr - if goAwayFrame.ErrCode == http2.ErrCodeNo { - t.Logf("Received goAway frame from client") - close(errorCh) - } else { - errorCh <- fmt.Errorf("received unexpected goAway frame: %v", err) - close(errorCh) - } - return - default: - errorCh <- fmt.Errorf("server received a frame other than GOAWAY: %v", err) - close(errorCh) - return - } - }() - - ct, err := NewClientTransport(ctx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{}, func(GoAwayReason) {}) - if err != nil { - t.Fatalf("Error while creating client transport: %v", err) - } - _, err = ct.NewStream(ctx, &CallHdr{}) - if err != nil { - t.Fatalf("failed to open stream: %v", err) - } - // Wait until server receives the headers and settings frame as part of greet. - <-greetDone - ct.Close(errors.New("manually closed by client")) - select { - case err := <-errorCh: - if err != nil { - t.Errorf("Error receiving the GOAWAY frame: %v", err) - } - case <-ctx.Done(): - t.Errorf("Context timed out") - } + createClientServerConn(t, ConnectOptions{}) } +// writeHangSignal is used to hang the net.Conn Write for complete test duration. var writeHangSignal chan struct{} +// hangingConn is a net.Conn wrapper for testing, simulating hanging connections +// after a GOAWAY frame is sent, of which Write operations pause until explicitly signaled +// or a timeout occurs. type hangingConn struct { net.Conn } @@ -2761,8 +2683,11 @@ func (hc *hangingConn) Read(b []byte) (n int, err error) { func (hc *hangingConn) Write(b []byte) (n int, err error) { n, err = hc.Conn.Write(b) if n == goAwayFrameSize { // GOAWAY frame - writeHangSignal = make(chan struct{}) - time.Sleep(15 * time.Second) + timer := time.NewTimer(GoAwayLoopyWriterTimeout + 1) + select { + case <-writeHangSignal: + case <-timer.C: + } } return n, err } @@ -2801,8 +2726,20 @@ func hangingDialer(_ context.Context, addr string) (net.Conn, error) { // TestClientCloseTimeoutOnHang verifies that in the event of a graceful // client transport shutdown, i.e., clientTransport.Close(), if the conn hung -// forever, client should still be close itself and do not wait for long. +// for LoopyWriterTimeout, client should still be close itself and should +// not wait for long. func (s) TestClientCloseTimeoutOnHang(t *testing.T) { + writeHangSignal = make(chan struct{}) + ctx, _, _ := createClientServerConn(t, ConnectOptions{Dialer: hangingDialer}) + defer close(writeHangSignal) + select { + case <-writeHangSignal: + t.Errorf("error: channel closed too early.") + case <-ctx.Done(): + } +} + +func createClientServerConn(t *testing.T, connectOptions ConnectOptions) (context.Context, chan error, ClientTransport) { // Create a server. lis, err := net.Listen("tcp", "localhost:0") if err != nil { @@ -2868,7 +2805,7 @@ func (s) TestClientCloseTimeoutOnHang(t *testing.T) { } }() - ct, err := NewClientTransport(ctx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{Dialer: hangingDialer}, func(GoAwayReason) {}) + ct, err := NewClientTransport(ctx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, connectOptions, func(GoAwayReason) {}) if err != nil { t.Fatalf("Error while creating client transport: %v", err) } @@ -2879,11 +2816,13 @@ func (s) TestClientCloseTimeoutOnHang(t *testing.T) { // Wait until server receives the headers and settings frame as part of greet. <-greetDone ct.Close(errors.New("manually closed by client")) - defer close(writeHangSignal) select { - case <-writeHangSignal: - t.Errorf("error: channel closed too early.") + case err := <-errorCh: + if err != nil { + t.Errorf("Error receiving the GOAWAY frame: %v", err) + } case <-ctx.Done(): + t.Errorf("Context timed out") } - + return ctx, errorCh, ct }