@@ -3,10 +3,15 @@ package websocket_test
3
3
import (
4
4
"bufio"
5
5
"fmt"
6
+ "io"
6
7
"net"
7
8
"net/http"
8
9
"net/http/httptest"
10
+ "os"
11
+ "strings"
12
+ "sync"
9
13
"testing"
14
+ "time"
10
15
11
16
"github.com/mccutchen/go-httpbin/v2/httpbin/websocket"
12
17
"github.com/mccutchen/go-httpbin/v2/internal/testing/assert"
@@ -220,6 +225,153 @@ func TestHandshakeOrder(t *testing.T) {
220
225
})
221
226
}
222
227
228
+ func TestConnectionLimits (t * testing.T ) {
229
+ t .Run ("maximum request duration is enforced" , func (t * testing.T ) {
230
+ t .Parallel ()
231
+
232
+ maxDuration := 500 * time .Millisecond
233
+
234
+ srv := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
235
+ ws := websocket .New (w , r , websocket.Limits {
236
+ MaxDuration : maxDuration ,
237
+ // TODO: test these limits as well
238
+ MaxFragmentSize : 128 ,
239
+ MaxMessageSize : 256 ,
240
+ })
241
+ if err := ws .Handshake (); err != nil {
242
+ http .Error (w , err .Error (), http .StatusBadRequest )
243
+ return
244
+ }
245
+ ws .Serve (websocket .EchoHandler )
246
+ }))
247
+ defer srv .Close ()
248
+
249
+ conn , err := net .Dial ("tcp" , srv .Listener .Addr ().String ())
250
+ assert .NilError (t , err )
251
+ defer conn .Close ()
252
+
253
+ reqParts := []string {
254
+ "GET /websocket/echo HTTP/1.1" ,
255
+ "Host: test" ,
256
+ "Connection: upgrade" ,
257
+ "Upgrade: websocket" ,
258
+ "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==" ,
259
+ "Sec-WebSocket-Version: 13" ,
260
+ }
261
+ reqBytes := []byte (strings .Join (reqParts , "\r \n " ) + "\r \n \r \n " )
262
+ t .Logf ("raw request:\n %q" , reqBytes )
263
+
264
+ // first, we write the request line and headers, which should cause the
265
+ // server to respond with a 101 Switching Protocols response.
266
+ {
267
+ n , err := conn .Write (reqBytes )
268
+ assert .NilError (t , err )
269
+ assert .Equal (t , n , len (reqBytes ), "incorrect number of bytes written" )
270
+
271
+ resp , err := http .ReadResponse (bufio .NewReader (conn ), nil )
272
+ assert .NilError (t , err )
273
+ assert .StatusCode (t , resp , http .StatusSwitchingProtocols )
274
+ }
275
+
276
+ // next, we try to read from the connection, expecting the connection
277
+ // to be closed after roughly maxDuration seconds
278
+ {
279
+ start := time .Now ()
280
+ _ , err := conn .Read (make ([]byte , 1 ))
281
+ elapsed := time .Since (start )
282
+
283
+ assert .Error (t , err , io .EOF )
284
+ assert .RoughDuration (t , elapsed , maxDuration , 25 * time .Millisecond )
285
+ }
286
+ })
287
+
288
+ t .Run ("client closing connection" , func (t * testing.T ) {
289
+ t .Parallel ()
290
+
291
+ // the client will close the connection well before the server closes
292
+ // the connection. make sure the server properly handles the client
293
+ // closure.
294
+ var (
295
+ clientTimeout = 100 * time .Millisecond
296
+ serverTimeout = time .Hour // should never be reached
297
+ elapsedClientTime time.Duration
298
+ elapsedServerTime time.Duration
299
+ wg sync.WaitGroup
300
+ )
301
+
302
+ wg .Add (1 )
303
+ srv := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
304
+ defer wg .Done ()
305
+ start := time .Now ()
306
+ ws := websocket .New (w , r , websocket.Limits {
307
+ MaxDuration : serverTimeout ,
308
+ MaxFragmentSize : 128 ,
309
+ MaxMessageSize : 256 ,
310
+ })
311
+ if err := ws .Handshake (); err != nil {
312
+ http .Error (w , err .Error (), http .StatusBadRequest )
313
+ return
314
+ }
315
+ ws .Serve (websocket .EchoHandler )
316
+ elapsedServerTime = time .Since (start )
317
+ }))
318
+ defer srv .Close ()
319
+
320
+ conn , err := net .Dial ("tcp" , srv .Listener .Addr ().String ())
321
+ assert .NilError (t , err )
322
+ defer conn .Close ()
323
+
324
+ // should cause the client end of the connection to close well before
325
+ // the max request time configured above
326
+ conn .SetDeadline (time .Now ().Add (clientTimeout ))
327
+
328
+ reqParts := []string {
329
+ "GET /websocket/echo HTTP/1.1" ,
330
+ "Host: test" ,
331
+ "Connection: upgrade" ,
332
+ "Upgrade: websocket" ,
333
+ "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==" ,
334
+ "Sec-WebSocket-Version: 13" ,
335
+ }
336
+ reqBytes := []byte (strings .Join (reqParts , "\r \n " ) + "\r \n \r \n " )
337
+ t .Logf ("raw request:\n %q" , reqBytes )
338
+
339
+ // first, we write the request line and headers, which should cause the
340
+ // server to respond with a 101 Switching Protocols response.
341
+ {
342
+ n , err := conn .Write (reqBytes )
343
+ assert .NilError (t , err )
344
+ assert .Equal (t , n , len (reqBytes ), "incorrect number of bytes written" )
345
+
346
+ resp , err := http .ReadResponse (bufio .NewReader (conn ), nil )
347
+ assert .NilError (t , err )
348
+ assert .StatusCode (t , resp , http .StatusSwitchingProtocols )
349
+ }
350
+
351
+ // next, we try to read from the connection, expecting the connection
352
+ // to be closed after roughly clientTimeout seconds.
353
+ //
354
+ // the server should detect the closed connection and abort the
355
+ // handler, also after roughly clientTimeout seconds.
356
+ {
357
+ start := time .Now ()
358
+ _ , err := conn .Read (make ([]byte , 1 ))
359
+ elapsedClientTime = time .Since (start )
360
+
361
+ // close client connection, which should interrupt the server's
362
+ // blocking read call on the connection
363
+ conn .Close ()
364
+
365
+ assert .Equal (t , os .IsTimeout (err ), true , "expected timeout error" )
366
+ assert .RoughDuration (t , elapsedClientTime , clientTimeout , 10 * time .Millisecond )
367
+
368
+ // wait for the server to finish
369
+ wg .Wait ()
370
+ assert .RoughDuration (t , elapsedServerTime , clientTimeout , 10 * time .Millisecond )
371
+ }
372
+ })
373
+ }
374
+
223
375
// brokenHijackResponseWriter implements just enough to satisfy the
224
376
// http.ResponseWriter and http.Hijacker interfaces and get through the
225
377
// handshake before failing to actually hijack the connection.
0 commit comments