@@ -3,10 +3,15 @@ package websocket_test
33import (
44 "bufio"
55 "fmt"
6+ "io"
67 "net"
78 "net/http"
89 "net/http/httptest"
10+ "os"
11+ "strings"
12+ "sync"
913 "testing"
14+ "time"
1015
1116 "github.com/mccutchen/go-httpbin/v2/httpbin/websocket"
1217 "github.com/mccutchen/go-httpbin/v2/internal/testing/assert"
@@ -220,6 +225,153 @@ func TestHandshakeOrder(t *testing.T) {
220225 })
221226}
222227
228+ func TestConnectionLimits (t * testing.T ) {
229+ t .Run ("maximum request duration is enforced" , func (t * testing.T ) {
230+ t .Parallel ()
231+
232+ maxDuration := 500 * time .Millisecond
233+
234+ srv := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
235+ ws := websocket .New (w , r , websocket.Limits {
236+ MaxDuration : maxDuration ,
237+ // TODO: test these limits as well
238+ MaxFragmentSize : 128 ,
239+ MaxMessageSize : 256 ,
240+ })
241+ if err := ws .Handshake (); err != nil {
242+ http .Error (w , err .Error (), http .StatusBadRequest )
243+ return
244+ }
245+ ws .Serve (websocket .EchoHandler )
246+ }))
247+ defer srv .Close ()
248+
249+ conn , err := net .Dial ("tcp" , srv .Listener .Addr ().String ())
250+ assert .NilError (t , err )
251+ defer conn .Close ()
252+
253+ reqParts := []string {
254+ "GET /websocket/echo HTTP/1.1" ,
255+ "Host: test" ,
256+ "Connection: upgrade" ,
257+ "Upgrade: websocket" ,
258+ "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==" ,
259+ "Sec-WebSocket-Version: 13" ,
260+ }
261+ reqBytes := []byte (strings .Join (reqParts , "\r \n " ) + "\r \n \r \n " )
262+ t .Logf ("raw request:\n %q" , reqBytes )
263+
264+ // first, we write the request line and headers, which should cause the
265+ // server to respond with a 101 Switching Protocols response.
266+ {
267+ n , err := conn .Write (reqBytes )
268+ assert .NilError (t , err )
269+ assert .Equal (t , n , len (reqBytes ), "incorrect number of bytes written" )
270+
271+ resp , err := http .ReadResponse (bufio .NewReader (conn ), nil )
272+ assert .NilError (t , err )
273+ assert .StatusCode (t , resp , http .StatusSwitchingProtocols )
274+ }
275+
276+ // next, we try to read from the connection, expecting the connection
277+ // to be closed after roughly maxDuration seconds
278+ {
279+ start := time .Now ()
280+ _ , err := conn .Read (make ([]byte , 1 ))
281+ elapsed := time .Since (start )
282+
283+ assert .Error (t , err , io .EOF )
284+ assert .RoughDuration (t , elapsed , maxDuration , 25 * time .Millisecond )
285+ }
286+ })
287+
288+ t .Run ("client closing connection" , func (t * testing.T ) {
289+ t .Parallel ()
290+
291+ // the client will close the connection well before the server closes
292+ // the connection. make sure the server properly handles the client
293+ // closure.
294+ var (
295+ clientTimeout = 100 * time .Millisecond
296+ serverTimeout = time .Hour // should never be reached
297+ elapsedClientTime time.Duration
298+ elapsedServerTime time.Duration
299+ wg sync.WaitGroup
300+ )
301+
302+ wg .Add (1 )
303+ srv := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
304+ defer wg .Done ()
305+ start := time .Now ()
306+ ws := websocket .New (w , r , websocket.Limits {
307+ MaxDuration : serverTimeout ,
308+ MaxFragmentSize : 128 ,
309+ MaxMessageSize : 256 ,
310+ })
311+ if err := ws .Handshake (); err != nil {
312+ http .Error (w , err .Error (), http .StatusBadRequest )
313+ return
314+ }
315+ ws .Serve (websocket .EchoHandler )
316+ elapsedServerTime = time .Since (start )
317+ }))
318+ defer srv .Close ()
319+
320+ conn , err := net .Dial ("tcp" , srv .Listener .Addr ().String ())
321+ assert .NilError (t , err )
322+ defer conn .Close ()
323+
324+ // should cause the client end of the connection to close well before
325+ // the max request time configured above
326+ conn .SetDeadline (time .Now ().Add (clientTimeout ))
327+
328+ reqParts := []string {
329+ "GET /websocket/echo HTTP/1.1" ,
330+ "Host: test" ,
331+ "Connection: upgrade" ,
332+ "Upgrade: websocket" ,
333+ "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==" ,
334+ "Sec-WebSocket-Version: 13" ,
335+ }
336+ reqBytes := []byte (strings .Join (reqParts , "\r \n " ) + "\r \n \r \n " )
337+ t .Logf ("raw request:\n %q" , reqBytes )
338+
339+ // first, we write the request line and headers, which should cause the
340+ // server to respond with a 101 Switching Protocols response.
341+ {
342+ n , err := conn .Write (reqBytes )
343+ assert .NilError (t , err )
344+ assert .Equal (t , n , len (reqBytes ), "incorrect number of bytes written" )
345+
346+ resp , err := http .ReadResponse (bufio .NewReader (conn ), nil )
347+ assert .NilError (t , err )
348+ assert .StatusCode (t , resp , http .StatusSwitchingProtocols )
349+ }
350+
351+ // next, we try to read from the connection, expecting the connection
352+ // to be closed after roughly clientTimeout seconds.
353+ //
354+ // the server should detect the closed connection and abort the
355+ // handler, also after roughly clientTimeout seconds.
356+ {
357+ start := time .Now ()
358+ _ , err := conn .Read (make ([]byte , 1 ))
359+ elapsedClientTime = time .Since (start )
360+
361+ // close client connection, which should interrupt the server's
362+ // blocking read call on the connection
363+ conn .Close ()
364+
365+ assert .Equal (t , os .IsTimeout (err ), true , "expected timeout error" )
366+ assert .RoughDuration (t , elapsedClientTime , clientTimeout , 10 * time .Millisecond )
367+
368+ // wait for the server to finish
369+ wg .Wait ()
370+ assert .RoughDuration (t , elapsedServerTime , clientTimeout , 10 * time .Millisecond )
371+ }
372+ })
373+ }
374+
223375// brokenHijackResponseWriter implements just enough to satisfy the
224376// http.ResponseWriter and http.Hijacker interfaces and get through the
225377// handshake before failing to actually hijack the connection.
0 commit comments