Skip to content

Commit e0324b1

Browse files
authored
fix: ensure websocket conns respect max duration (#156)
1 parent 1c61db6 commit e0324b1

File tree

4 files changed

+162
-0
lines changed

4 files changed

+162
-0
lines changed

httpbin/handlers.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,6 +1152,7 @@ func (h *HTTPBin) WebSocketEcho(w http.ResponseWriter, r *http.Request) {
11521152
}
11531153

11541154
ws := websocket.New(w, r, websocket.Limits{
1155+
MaxDuration: h.MaxDuration,
11551156
MaxFragmentSize: int(maxFragmentSize),
11561157
MaxMessageSize: int(maxMessageSize),
11571158
})

httpbin/websocket/websocket.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"io"
1313
"net/http"
1414
"strings"
15+
"time"
1516
"unicode/utf8"
1617
)
1718

@@ -80,6 +81,7 @@ var EchoHandler Handler = func(ctx context.Context, msg *Message) (*Message, err
8081

8182
// Limits define the limits imposed on a websocket connection.
8283
type Limits struct {
84+
MaxDuration time.Duration
8385
MaxFragmentSize int
8486
MaxMessageSize int
8587
}
@@ -88,6 +90,7 @@ type Limits struct {
8890
type WebSocket struct {
8991
w http.ResponseWriter
9092
r *http.Request
93+
maxDuration time.Duration
9194
maxFragmentSize int
9295
maxMessageSize int
9396
handshook bool
@@ -98,6 +101,7 @@ func New(w http.ResponseWriter, r *http.Request, limits Limits) *WebSocket {
98101
return &WebSocket{
99102
w: w,
100103
r: r,
104+
maxDuration: limits.MaxDuration,
101105
maxFragmentSize: limits.MaxFragmentSize,
102106
maxMessageSize: limits.MaxMessageSize,
103107
}
@@ -152,6 +156,10 @@ func (s *WebSocket) Serve(handler Handler) {
152156
}
153157
defer conn.Close()
154158

159+
// best effort attempt to ensure that our websocket conenctions do not
160+
// exceed the maximum request duration
161+
conn.SetDeadline(time.Now().Add(s.maxDuration))
162+
155163
// errors intentionally ignored here. it's serverLoop's responsibility to
156164
// properly close the websocket connection with a useful error message, and
157165
// any unexpected error returned from serverLoop is not actionable.

httpbin/websocket/websocket_autobahn_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ func TestWebSocketServer(t *testing.T) {
5555

5656
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
5757
ws := websocket.New(w, r, websocket.Limits{
58+
MaxDuration: 30 * time.Second,
5859
MaxFragmentSize: 1024 * 1024 * 16,
5960
MaxMessageSize: 1024 * 1024 * 16,
6061
})

httpbin/websocket/websocket_test.go

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@ package websocket_test
33
import (
44
"bufio"
55
"fmt"
6+
"io"
67
"net"
78
"net/http"
89
"net/http/httptest"
10+
"os"
11+
"strings"
12+
"sync"
913
"testing"
14+
"time"
1015

1116
"github.com/mccutchen/go-httpbin/v2/httpbin/websocket"
1217
"github.com/mccutchen/go-httpbin/v2/internal/testing/assert"
@@ -220,6 +225,153 @@ func TestHandshakeOrder(t *testing.T) {
220225
})
221226
}
222227

228+
func TestConnectionLimits(t *testing.T) {
229+
t.Run("maximum request duration is enforced", func(t *testing.T) {
230+
t.Parallel()
231+
232+
maxDuration := 500 * time.Millisecond
233+
234+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
235+
ws := websocket.New(w, r, websocket.Limits{
236+
MaxDuration: maxDuration,
237+
// TODO: test these limits as well
238+
MaxFragmentSize: 128,
239+
MaxMessageSize: 256,
240+
})
241+
if err := ws.Handshake(); err != nil {
242+
http.Error(w, err.Error(), http.StatusBadRequest)
243+
return
244+
}
245+
ws.Serve(websocket.EchoHandler)
246+
}))
247+
defer srv.Close()
248+
249+
conn, err := net.Dial("tcp", srv.Listener.Addr().String())
250+
assert.NilError(t, err)
251+
defer conn.Close()
252+
253+
reqParts := []string{
254+
"GET /websocket/echo HTTP/1.1",
255+
"Host: test",
256+
"Connection: upgrade",
257+
"Upgrade: websocket",
258+
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==",
259+
"Sec-WebSocket-Version: 13",
260+
}
261+
reqBytes := []byte(strings.Join(reqParts, "\r\n") + "\r\n\r\n")
262+
t.Logf("raw request:\n%q", reqBytes)
263+
264+
// first, we write the request line and headers, which should cause the
265+
// server to respond with a 101 Switching Protocols response.
266+
{
267+
n, err := conn.Write(reqBytes)
268+
assert.NilError(t, err)
269+
assert.Equal(t, n, len(reqBytes), "incorrect number of bytes written")
270+
271+
resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
272+
assert.NilError(t, err)
273+
assert.StatusCode(t, resp, http.StatusSwitchingProtocols)
274+
}
275+
276+
// next, we try to read from the connection, expecting the connection
277+
// to be closed after roughly maxDuration seconds
278+
{
279+
start := time.Now()
280+
_, err := conn.Read(make([]byte, 1))
281+
elapsed := time.Since(start)
282+
283+
assert.Error(t, err, io.EOF)
284+
assert.RoughDuration(t, elapsed, maxDuration, 25*time.Millisecond)
285+
}
286+
})
287+
288+
t.Run("client closing connection", func(t *testing.T) {
289+
t.Parallel()
290+
291+
// the client will close the connection well before the server closes
292+
// the connection. make sure the server properly handles the client
293+
// closure.
294+
var (
295+
clientTimeout = 100 * time.Millisecond
296+
serverTimeout = time.Hour // should never be reached
297+
elapsedClientTime time.Duration
298+
elapsedServerTime time.Duration
299+
wg sync.WaitGroup
300+
)
301+
302+
wg.Add(1)
303+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
304+
defer wg.Done()
305+
start := time.Now()
306+
ws := websocket.New(w, r, websocket.Limits{
307+
MaxDuration: serverTimeout,
308+
MaxFragmentSize: 128,
309+
MaxMessageSize: 256,
310+
})
311+
if err := ws.Handshake(); err != nil {
312+
http.Error(w, err.Error(), http.StatusBadRequest)
313+
return
314+
}
315+
ws.Serve(websocket.EchoHandler)
316+
elapsedServerTime = time.Since(start)
317+
}))
318+
defer srv.Close()
319+
320+
conn, err := net.Dial("tcp", srv.Listener.Addr().String())
321+
assert.NilError(t, err)
322+
defer conn.Close()
323+
324+
// should cause the client end of the connection to close well before
325+
// the max request time configured above
326+
conn.SetDeadline(time.Now().Add(clientTimeout))
327+
328+
reqParts := []string{
329+
"GET /websocket/echo HTTP/1.1",
330+
"Host: test",
331+
"Connection: upgrade",
332+
"Upgrade: websocket",
333+
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==",
334+
"Sec-WebSocket-Version: 13",
335+
}
336+
reqBytes := []byte(strings.Join(reqParts, "\r\n") + "\r\n\r\n")
337+
t.Logf("raw request:\n%q", reqBytes)
338+
339+
// first, we write the request line and headers, which should cause the
340+
// server to respond with a 101 Switching Protocols response.
341+
{
342+
n, err := conn.Write(reqBytes)
343+
assert.NilError(t, err)
344+
assert.Equal(t, n, len(reqBytes), "incorrect number of bytes written")
345+
346+
resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
347+
assert.NilError(t, err)
348+
assert.StatusCode(t, resp, http.StatusSwitchingProtocols)
349+
}
350+
351+
// next, we try to read from the connection, expecting the connection
352+
// to be closed after roughly clientTimeout seconds.
353+
//
354+
// the server should detect the closed connection and abort the
355+
// handler, also after roughly clientTimeout seconds.
356+
{
357+
start := time.Now()
358+
_, err := conn.Read(make([]byte, 1))
359+
elapsedClientTime = time.Since(start)
360+
361+
// close client connection, which should interrupt the server's
362+
// blocking read call on the connection
363+
conn.Close()
364+
365+
assert.Equal(t, os.IsTimeout(err), true, "expected timeout error")
366+
assert.RoughDuration(t, elapsedClientTime, clientTimeout, 10*time.Millisecond)
367+
368+
// wait for the server to finish
369+
wg.Wait()
370+
assert.RoughDuration(t, elapsedServerTime, clientTimeout, 10*time.Millisecond)
371+
}
372+
})
373+
}
374+
223375
// brokenHijackResponseWriter implements just enough to satisfy the
224376
// http.ResponseWriter and http.Hijacker interfaces and get through the
225377
// handshake before failing to actually hijack the connection.

0 commit comments

Comments
 (0)