Skip to content

Commit 537fe8d

Browse files
authored
transport: Propagate status code on receiving RST_STREAM during message read (#8289) (#8317)
1 parent f32eab3 commit 537fe8d

File tree

3 files changed

+81
-3
lines changed

3 files changed

+81
-3
lines changed

internal/transport/http2_client.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -1242,7 +1242,8 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
12421242
statusCode = codes.DeadlineExceeded
12431243
}
12441244
}
1245-
t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode), nil, false)
1245+
st := status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode)
1246+
t.closeStream(s, st.Err(), false, http2.ErrCodeNo, st, nil, false)
12461247
}
12471248

12481249
func (t *http2Client) handleSettings(f *http2.SettingsFrame, isFirst bool) {

internal/transport/transport_test.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -919,8 +919,9 @@ func (s) TestLargeMessageSuspension(t *testing.T) {
919919
}
920920
// The server will send an RST stream frame on observing the deadline
921921
// expiration making the client stream fail with a DeadlineExceeded status.
922-
if _, err := s.readTo(make([]byte, 8)); err != io.EOF {
923-
t.Fatalf("Read got unexpected error: %v, want %v", err, io.EOF)
922+
_, err = s.readTo(make([]byte, 8))
923+
if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded {
924+
t.Fatalf("Read got unexpected error: %v, want status with code %v", err, codes.DeadlineExceeded)
924925
}
925926
if got, want := s.Status().Code(), codes.DeadlineExceeded; got != want {
926927
t.Fatalf("Read got status %v with code %v, want %v", s.Status(), got, want)

test/transport_test.go

+76
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,20 @@ package test
1919

2020
import (
2121
"context"
22+
"encoding/binary"
2223
"io"
2324
"net"
2425
"sync"
2526
"testing"
2627

28+
"golang.org/x/net/http2"
2729
"google.golang.org/grpc"
2830
"google.golang.org/grpc/codes"
2931
"google.golang.org/grpc/credentials"
32+
"google.golang.org/grpc/credentials/insecure"
3033
"google.golang.org/grpc/internal/grpcsync"
3134
"google.golang.org/grpc/internal/stubserver"
35+
"google.golang.org/grpc/internal/testutils"
3236
"google.golang.org/grpc/internal/transport"
3337
"google.golang.org/grpc/status"
3438

@@ -153,3 +157,75 @@ func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) {
153157
t.Fatal("Timeout expired when waiting for first client transport to close")
154158
}
155159
}
160+
161+
// Tests that an RST_STREAM frame that causes an io.ErrUnexpectedEOF while
162+
// reading a gRPC message is correctly converted to a gRPC status with code
163+
// CANCELLED. The test sends a data frame with a partial gRPC message, followed
164+
// by an RST_STREAM frame with HTTP/2 code CANCELLED. The test asserts the
165+
// client receives the correct status.
166+
func (s) TestRSTDuringMessageRead(t *testing.T) {
167+
lis, err := testutils.LocalTCPListener()
168+
if err != nil {
169+
t.Fatal(err)
170+
}
171+
defer lis.Close()
172+
173+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
174+
defer cancel()
175+
cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
176+
if err != nil {
177+
t.Fatalf("grpc.NewClient(%s) = %v", lis.Addr().String(), err)
178+
}
179+
defer cc.Close()
180+
181+
go func() {
182+
conn, err := lis.Accept()
183+
if err != nil {
184+
t.Errorf("lis.Accept() = %v", err)
185+
return
186+
}
187+
defer conn.Close()
188+
framer := http2.NewFramer(conn, conn)
189+
190+
if _, err := io.ReadFull(conn, make([]byte, len(clientPreface))); err != nil {
191+
t.Errorf("Error while reading client preface: %v", err)
192+
return
193+
}
194+
if err := framer.WriteSettings(); err != nil {
195+
t.Errorf("Error while writing settings: %v", err)
196+
return
197+
}
198+
if err := framer.WriteSettingsAck(); err != nil {
199+
t.Errorf("Error while writing settings: %v", err)
200+
return
201+
}
202+
for ctx.Err() == nil {
203+
frame, err := framer.ReadFrame()
204+
if err != nil {
205+
return
206+
}
207+
switch frame := frame.(type) {
208+
case *http2.HeadersFrame:
209+
// When the client creates a stream, write a partial gRPC
210+
// message followed by an RST_STREAM.
211+
const messageLen = 2048
212+
buf := make([]byte, messageLen/2)
213+
// Write the gRPC message length header.
214+
binary.BigEndian.PutUint32(buf[1:5], uint32(messageLen))
215+
if err := framer.WriteData(1, false, buf); err != nil {
216+
return
217+
}
218+
framer.WriteRSTStream(1, http2.ErrCodeCancel)
219+
default:
220+
t.Logf("Server received frame: %v", frame)
221+
}
222+
}
223+
}()
224+
225+
// The server will send a partial gRPC message before cancelling the stream.
226+
// The client should get a gRPC status with code CANCELLED.
227+
client := testgrpc.NewTestServiceClient(cc)
228+
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Canceled {
229+
t.Fatalf("client.EmptyCall() returned %v; want status with code %v", err, codes.Canceled)
230+
}
231+
}

0 commit comments

Comments
 (0)