@@ -19,16 +19,20 @@ package test
19
19
20
20
import (
21
21
"context"
22
+ "encoding/binary"
22
23
"io"
23
24
"net"
24
25
"sync"
25
26
"testing"
26
27
28
+ "golang.org/x/net/http2"
27
29
"google.golang.org/grpc"
28
30
"google.golang.org/grpc/codes"
29
31
"google.golang.org/grpc/credentials"
32
+ "google.golang.org/grpc/credentials/insecure"
30
33
"google.golang.org/grpc/internal/grpcsync"
31
34
"google.golang.org/grpc/internal/stubserver"
35
+ "google.golang.org/grpc/internal/testutils"
32
36
"google.golang.org/grpc/internal/transport"
33
37
"google.golang.org/grpc/status"
34
38
@@ -153,3 +157,75 @@ func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) {
153
157
t .Fatal ("Timeout expired when waiting for first client transport to close" )
154
158
}
155
159
}
160
+
161
+ // Tests that an RST_STREAM frame that causes an io.ErrUnexpectedEOF while
162
+ // reading a gRPC message is correctly converted to a gRPC status with code
163
+ // CANCELLED. The test sends a data frame with a partial gRPC message, followed
164
+ // by an RST_STREAM frame with HTTP/2 code CANCELLED. The test asserts the
165
+ // client receives the correct status.
166
+ func (s ) TestRSTDuringMessageRead (t * testing.T ) {
167
+ lis , err := testutils .LocalTCPListener ()
168
+ if err != nil {
169
+ t .Fatal (err )
170
+ }
171
+ defer lis .Close ()
172
+
173
+ ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
174
+ defer cancel ()
175
+ cc , err := grpc .NewClient (lis .Addr ().String (), grpc .WithTransportCredentials (insecure .NewCredentials ()))
176
+ if err != nil {
177
+ t .Fatalf ("grpc.NewClient(%s) = %v" , lis .Addr ().String (), err )
178
+ }
179
+ defer cc .Close ()
180
+
181
+ go func () {
182
+ conn , err := lis .Accept ()
183
+ if err != nil {
184
+ t .Errorf ("lis.Accept() = %v" , err )
185
+ return
186
+ }
187
+ defer conn .Close ()
188
+ framer := http2 .NewFramer (conn , conn )
189
+
190
+ if _ , err := io .ReadFull (conn , make ([]byte , len (clientPreface ))); err != nil {
191
+ t .Errorf ("Error while reading client preface: %v" , err )
192
+ return
193
+ }
194
+ if err := framer .WriteSettings (); err != nil {
195
+ t .Errorf ("Error while writing settings: %v" , err )
196
+ return
197
+ }
198
+ if err := framer .WriteSettingsAck (); err != nil {
199
+ t .Errorf ("Error while writing settings: %v" , err )
200
+ return
201
+ }
202
+ for ctx .Err () == nil {
203
+ frame , err := framer .ReadFrame ()
204
+ if err != nil {
205
+ return
206
+ }
207
+ switch frame := frame .(type ) {
208
+ case * http2.HeadersFrame :
209
+ // When the client creates a stream, write a partial gRPC
210
+ // message followed by an RST_STREAM.
211
+ const messageLen = 2048
212
+ buf := make ([]byte , messageLen / 2 )
213
+ // Write the gRPC message length header.
214
+ binary .BigEndian .PutUint32 (buf [1 :5 ], uint32 (messageLen ))
215
+ if err := framer .WriteData (1 , false , buf ); err != nil {
216
+ return
217
+ }
218
+ framer .WriteRSTStream (1 , http2 .ErrCodeCancel )
219
+ default :
220
+ t .Logf ("Server received frame: %v" , frame )
221
+ }
222
+ }
223
+ }()
224
+
225
+ // The server will send a partial gRPC message before cancelling the stream.
226
+ // The client should get a gRPC status with code CANCELLED.
227
+ client := testgrpc .NewTestServiceClient (cc )
228
+ if _ , err := client .EmptyCall (ctx , & testpb.Empty {}); status .Code (err ) != codes .Canceled {
229
+ t .Fatalf ("client.EmptyCall() returned %v; want status with code %v" , err , codes .Canceled )
230
+ }
231
+ }
0 commit comments