Skip to content

Send GOAWAY to server on Client Transport Shutdown #7015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions internal/transport/controlbuf.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ type goAway struct {
code http2.ErrCode
debugData []byte
headsUp bool
closeConn error // if set, loopyWriter will exit, resulting in conn closure
closeConn error // if set, loopyWriter will exit with this error
}

func (*goAway) isTransportResponseFrame() bool { return false }
Expand Down Expand Up @@ -495,21 +495,22 @@ type loopyWriter struct {
ssGoAwayHandler func(*goAway) (bool, error)
}

func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger) *loopyWriter {
func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger, goAwayHandler func(*goAway) (bool, error)) *loopyWriter {
var buf bytes.Buffer
l := &loopyWriter{
side: s,
cbuf: cbuf,
sendQuota: defaultWindowSize,
oiws: defaultWindowSize,
estdStreams: make(map[uint32]*outStream),
activeStreams: newOutStreamList(),
framer: fr,
hBuf: &buf,
hEnc: hpack.NewEncoder(&buf),
bdpEst: bdpEst,
conn: conn,
logger: logger,
side: s,
cbuf: cbuf,
sendQuota: defaultWindowSize,
oiws: defaultWindowSize,
estdStreams: make(map[uint32]*outStream),
activeStreams: newOutStreamList(),
framer: fr,
hBuf: &buf,
hEnc: hpack.NewEncoder(&buf),
bdpEst: bdpEst,
conn: conn,
logger: logger,
ssGoAwayHandler: goAwayHandler,
}
return l
}
Expand Down
30 changes: 24 additions & 6 deletions internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,10 +408,10 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
readerErrCh := make(chan error, 1)
go t.reader(readerErrCh)
defer func() {
if err == nil {
err = <-readerErrCh
}
if err != nil {
// writerDone should be closed since the loopy goroutine
// wouldn't have started in the case this function returns an error.
close(t.writerDone)
t.Close(err)
}
}()
Expand Down Expand Up @@ -458,8 +458,12 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
if err := t.framer.writer.Flush(); err != nil {
return nil, err
}
// Block until the server preface is received successfully or an error occurs.
if err = <-readerErrCh; err != nil {
return nil, err
}
go func() {
t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger)
t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger, t.outgoingGoAwayHandler)
if err := t.loopy.run(); !isIOError(err) {
// Immediately close the connection, as the loopy writer returns
// when there are no more active streams and we were draining (the
Expand Down Expand Up @@ -517,6 +521,17 @@ func (t *http2Client) getPeer() *peer.Peer {
}
}

// OutgoingGoAwayHandler writes a GOAWAY to the connection. Always returns (false, err) as we want the GoAway
// to be the last frame loopy writes to the transport.
func (t *http2Client) outgoingGoAwayHandler(g *goAway) (bool, error) {
t.mu.Lock()
defer t.mu.Unlock()
if err := t.framer.fr.WriteGoAway(t.nextID-2, http2.ErrCodeNo, g.debugData); err != nil {
return false, err
}
return false, g.closeConn
}

func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) ([]hpack.HeaderField, error) {
aud := t.createAudience(callHdr)
ri := credentials.RequestInfo{
Expand Down Expand Up @@ -966,7 +981,7 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.

// Close kicks off the shutdown process of the transport. This should be called
// only once on a transport. Once it is called, the transport should not be
// accessed any more.
// accessed anymore.
func (t *http2Client) Close(err error) {
t.mu.Lock()
// Make sure we only close once.
Expand All @@ -991,7 +1006,10 @@ func (t *http2Client) Close(err error) {
t.kpDormancyCond.Signal()
}
t.mu.Unlock()
t.controlBuf.finish()
// Per HTTP/2 spec, a GOAWAY frame must be sent before closing the
// connection. See https://httpwg.org/specs/rfc7540.html#GOAWAY.
t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte("client transport shutdown"), closeConn: err})
<-t.writerDone
t.cancel()
t.conn.Close()
channelz.RemoveEntry(t.channelz.ID)
Expand Down
3 changes: 1 addition & 2 deletions internal/transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
t.handleSettings(sf)

go func() {
t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger)
t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler
t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger, t.outgoingGoAwayHandler)
err := t.loopy.run()
close(t.loopyWriterDone)
if !isIOError(err) {
Expand Down
91 changes: 91 additions & 0 deletions internal/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2659,3 +2659,94 @@ func TestConnectionError_Unwrap(t *testing.T) {
t.Error("ConnectionError does not unwrap")
}
}

// Test that in the event of a graceful client transport shutdown, i.e.,
// clientTransport.Close(), client sends a goaway to the server with the correct
// error code and debug data.
func (s) TestClientSendsAGoAwayFrame(t *testing.T) {
// Create a server.
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening: %v", err)
}
defer lis.Close()
// greetDone is used to notify when server is done greeting the client.
greetDone := make(chan struct{})
// errorCh verifies that desired GOAWAY not received by server
errorCh := make(chan error)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// Launch the server.
go func() {
sconn, err := lis.Accept()
if err != nil {
t.Errorf("Error while accepting: %v", err)
}
defer sconn.Close()
if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil {
t.Errorf("Error while writing settings ack: %v", err)
return
}
sfr := http2.NewFramer(sconn, sconn)
if err := sfr.WriteSettings(); err != nil {
t.Errorf("Error while writing settings %v", err)
return
}
fr, _ := sfr.ReadFrame()
if _, ok := fr.(*http2.SettingsFrame); !ok {
t.Errorf("Expected settings frame, got %v", fr)
}
fr, _ = sfr.ReadFrame()
if fr, ok := fr.(*http2.SettingsFrame); !ok && fr.IsAck() {
t.Errorf("Expected settings ACK frame, got %v", fr)
}
fr, _ = sfr.ReadFrame()
if fr, ok := fr.(*http2.HeadersFrame); !ok && fr.Flags.Has(http2.FlagHeadersEndStream) {
t.Errorf("Expected Headers frame with END_HEADERS frame, got %v", fr)
}
close(greetDone)

frame, err := sfr.ReadFrame()
if err != nil {
return
}
switch fr := frame.(type) {
case *http2.GoAwayFrame:
// Records that the server successfully received a GOAWAY frame.
goAwayFrame := fr
if goAwayFrame.ErrCode == http2.ErrCodeNo {
t.Logf("Received goAway frame from client")
close(errorCh)
} else {
errorCh <- fmt.Errorf("received unexpected goAway frame: %v", err)
close(errorCh)
}
return
default:
errorCh <- fmt.Errorf("server received a frame other than GOAWAY: %v", err)
close(errorCh)
return
}
}()

ct, err := NewClientTransport(ctx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{}, func(GoAwayReason) {})
if err != nil {
t.Fatalf("Error while creating client transport: %v", err)
}
_, err = ct.NewStream(ctx, &CallHdr{})
if err != nil {
t.Fatalf("failed to open stream: %v", err)
}
// Wait until server receives the headers and settings frame as part of greet.
<-greetDone
ct.Close(errors.New("manually closed by client"))
t.Logf("Closed the client connection")
select {
case err := <-errorCh:
if err != nil {
t.Errorf("Error receiving the GOAWAY frame: %v", err)
}
case <-ctx.Done():
t.Errorf("Context timed out")
}
}
65 changes: 65 additions & 0 deletions test/goaway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package test

import (
"context"
"fmt"
"io"
"net"
"strings"
Expand Down Expand Up @@ -761,3 +762,67 @@ func (s) TestTwoGoAwayPingFrames(t *testing.T) {
t.Fatalf("Error waiting for graceful shutdown of the server: %v", err)
}
}

// TestClientSendsAGoAway tests the scenario where you get a go away ping
// frames from the client during graceful shutdown.
func (s) TestClientSendsAGoAway(t *testing.T) {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("error listening: %v", err)
}
ctCh := testutils.NewChannel()
go func() {
conn, err := lis.Accept()
if err != nil {
t.Errorf("error in lis.Accept(): %v", err)
}
ct := newClientTester(t, conn)
ctCh.Send(ct)
}()
defer lis.Close()

cc, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("error dialing: %v", err)
}

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

val, err := ctCh.Receive(ctx)
if err != nil {
t.Fatalf("timeout waiting for client transport (should be given after http2 creation)")
}
ct := val.(*clientTester)
goAwayReceived := make(chan struct{})
errCh := make(chan error)
go func() {
for {
f, err := ct.fr.ReadFrame()
if err != nil {
return
}
switch fr := f.(type) {
case *http2.GoAwayFrame:
fr = f.(*http2.GoAwayFrame)
if fr.ErrCode == http2.ErrCodeNo {
t.Logf("GoAway received from client")
close(goAwayReceived)
}
default:
t.Errorf("server tester received unexpected frame type %T", f)
errCh <- fmt.Errorf("server tester received unexpected frame type %T", f)
close(errCh)
}
}
}()
cc.Close()
defer ct.conn.Close()
select {
case <-goAwayReceived:
case err := <-errCh:
t.Errorf("Error receiving the goAway: %v", err)
case <-ctx.Done():
t.Errorf("Context timed out")
}
}