Skip to content

Commit af8bb07

Browse files
committed
Add test case
1 parent b1f73cd commit af8bb07

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

test/transport_test.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,20 @@ package test
1919

2020
import (
2121
"context"
22+
"encoding/binary"
2223
"io"
2324
"net"
2425
"sync"
2526
"testing"
2627

28+
"golang.org/x/net/http2"
2729
"google.golang.org/grpc"
2830
"google.golang.org/grpc/codes"
2931
"google.golang.org/grpc/credentials"
32+
"google.golang.org/grpc/credentials/insecure"
3033
"google.golang.org/grpc/internal/grpcsync"
3134
"google.golang.org/grpc/internal/stubserver"
35+
"google.golang.org/grpc/internal/testutils"
3236
"google.golang.org/grpc/internal/transport"
3337
"google.golang.org/grpc/status"
3438

@@ -153,3 +157,98 @@ func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) {
153157
t.Fatal("Timeout expired when waiting for first client transport to close")
154158
}
155159
}
160+
161+
// Tests that an RST_STREAM frame that causes an io.ErrUnexpectedEOF while
162+
// reading a gRPC message is correctly converted to a gRPC status with code
163+
// CANCELLED. The test sends a data frame with a partial gRPC message, followed
164+
// by an RST_STREAM frame with HTTP/2 code CANCELLED. The test asserts the
165+
// client receives the correct status.
166+
func (s) TestRSTDuringMessageRead(t *testing.T) {
167+
lis, err := testutils.LocalTCPListener()
168+
if err != nil {
169+
t.Fatal(err)
170+
}
171+
defer lis.Close()
172+
173+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
174+
defer cancel()
175+
cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
176+
if err != nil {
177+
t.Fatalf("grpc.NewClient(%s) = %v", lis.Addr().String(), err)
178+
}
179+
defer cc.Close()
180+
181+
go func() {
182+
conn, err := lis.Accept()
183+
if err != nil {
184+
t.Errorf("lis.Accept() = %v", err)
185+
return
186+
}
187+
defer conn.Close()
188+
framer := http2.NewFramer(conn, conn)
189+
190+
if _, err := io.ReadFull(conn, make([]byte, len(clientPreface))); err != nil {
191+
t.Errorf("Error while reading client preface: %v", err)
192+
return
193+
}
194+
if err := framer.WriteSettings(); err != nil {
195+
t.Errorf("Error while writing settings: %v", err)
196+
return
197+
}
198+
if err := framer.WriteSettingsAck(); err != nil {
199+
t.Errorf("Error while writing settings: %v", err)
200+
return
201+
}
202+
var mu sync.Mutex
203+
for {
204+
frame, err := framer.ReadFrame()
205+
if err != nil {
206+
return
207+
}
208+
switch frame := frame.(type) {
209+
case *http2.HeadersFrame:
210+
// When the client creates a stream, write a partial gRPC
211+
// message followed by an RST_STREAM.
212+
go func() {
213+
buf := make([]byte, 1024)
214+
// Write the gRPC message length header.
215+
binary.BigEndian.PutUint32(buf[1:5], 2048)
216+
mu.Lock()
217+
if err := framer.WriteData(1, false, buf); err != nil {
218+
mu.Unlock()
219+
return
220+
}
221+
framer.WriteRSTStream(1, http2.ErrCodeCancel)
222+
mu.Unlock()
223+
}()
224+
// Wait for the client to close before closing the connection.
225+
<-ctx.Done()
226+
return
227+
case *http2.RSTStreamFrame:
228+
if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeFlowControl {
229+
t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeFlowControl", frame.Header().StreamID, http2.ErrCode(frame.ErrCode))
230+
}
231+
return
232+
case *http2.PingFrame:
233+
mu.Lock()
234+
framer.WritePing(true, frame.Data)
235+
mu.Unlock()
236+
default:
237+
t.Logf("Server received frame: %v", frame)
238+
}
239+
}
240+
}()
241+
242+
// The server will send a partial gRPC message before cancelling the stream.
243+
// The client should get a gRPC status with code CANCELLED.
244+
client := testgrpc.NewTestServiceClient(cc)
245+
_, err = client.EmptyCall(ctx, &testpb.Empty{})
246+
247+
s, ok := status.FromError(err)
248+
if !ok {
249+
t.Fatalf("client.EmptyCall() returned non-status error: %v", err)
250+
}
251+
if s.Code() != codes.Canceled {
252+
t.Fatalf("client.EmptyCall() returned status %v with code %v, want %v", s, s.Code(), codes.Canceled)
253+
}
254+
}

0 commit comments

Comments
 (0)