diff --git a/httpbin/handlers.go b/httpbin/handlers.go index 7ecb7bc3..a189a07f 100644 --- a/httpbin/handlers.go +++ b/httpbin/handlers.go @@ -1152,6 +1152,7 @@ func (h *HTTPBin) WebSocketEcho(w http.ResponseWriter, r *http.Request) { } ws := websocket.New(w, r, websocket.Limits{ + MaxDuration: h.MaxDuration, MaxFragmentSize: int(maxFragmentSize), MaxMessageSize: int(maxMessageSize), }) diff --git a/httpbin/websocket/websocket.go b/httpbin/websocket/websocket.go index e85f79a4..ebac941e 100644 --- a/httpbin/websocket/websocket.go +++ b/httpbin/websocket/websocket.go @@ -12,6 +12,7 @@ import ( "io" "net/http" "strings" + "time" "unicode/utf8" ) @@ -80,6 +81,7 @@ var EchoHandler Handler = func(ctx context.Context, msg *Message) (*Message, err // Limits define the limits imposed on a websocket connection. type Limits struct { + MaxDuration time.Duration MaxFragmentSize int MaxMessageSize int } @@ -88,6 +90,7 @@ type Limits struct { type WebSocket struct { w http.ResponseWriter r *http.Request + maxDuration time.Duration maxFragmentSize int maxMessageSize int handshook bool @@ -98,6 +101,7 @@ func New(w http.ResponseWriter, r *http.Request, limits Limits) *WebSocket { return &WebSocket{ w: w, r: r, + maxDuration: limits.MaxDuration, maxFragmentSize: limits.MaxFragmentSize, maxMessageSize: limits.MaxMessageSize, } @@ -152,6 +156,10 @@ func (s *WebSocket) Serve(handler Handler) { } defer conn.Close() + // best effort attempt to ensure that our websocket conenctions do not + // exceed the maximum request duration + conn.SetDeadline(time.Now().Add(s.maxDuration)) + // errors intentionally ignored here. it's serverLoop's responsibility to // properly close the websocket connection with a useful error message, and // any unexpected error returned from serverLoop is not actionable. diff --git a/httpbin/websocket/websocket_autobahn_test.go b/httpbin/websocket/websocket_autobahn_test.go index c1ae0c95..1ddde8c4 100644 --- a/httpbin/websocket/websocket_autobahn_test.go +++ b/httpbin/websocket/websocket_autobahn_test.go @@ -55,6 +55,7 @@ func TestWebSocketServer(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ws := websocket.New(w, r, websocket.Limits{ + MaxDuration: 30 * time.Second, MaxFragmentSize: 1024 * 1024 * 16, MaxMessageSize: 1024 * 1024 * 16, }) diff --git a/httpbin/websocket/websocket_test.go b/httpbin/websocket/websocket_test.go index cb1084e0..d6c02d10 100644 --- a/httpbin/websocket/websocket_test.go +++ b/httpbin/websocket/websocket_test.go @@ -3,10 +3,15 @@ package websocket_test import ( "bufio" "fmt" + "io" "net" "net/http" "net/http/httptest" + "os" + "strings" + "sync" "testing" + "time" "github.com/mccutchen/go-httpbin/v2/httpbin/websocket" "github.com/mccutchen/go-httpbin/v2/internal/testing/assert" @@ -220,6 +225,153 @@ func TestHandshakeOrder(t *testing.T) { }) } +func TestConnectionLimits(t *testing.T) { + t.Run("maximum request duration is enforced", func(t *testing.T) { + t.Parallel() + + maxDuration := 500 * time.Millisecond + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws := websocket.New(w, r, websocket.Limits{ + MaxDuration: maxDuration, + // TODO: test these limits as well + MaxFragmentSize: 128, + MaxMessageSize: 256, + }) + if err := ws.Handshake(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + ws.Serve(websocket.EchoHandler) + })) + defer srv.Close() + + conn, err := net.Dial("tcp", srv.Listener.Addr().String()) + assert.NilError(t, err) + defer conn.Close() + + reqParts := []string{ + "GET /websocket/echo HTTP/1.1", + "Host: test", + "Connection: upgrade", + "Upgrade: websocket", + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version: 13", + } + reqBytes := []byte(strings.Join(reqParts, "\r\n") + "\r\n\r\n") + t.Logf("raw request:\n%q", reqBytes) + + // first, we write the request line and headers, which should cause the + // server to respond with a 101 Switching Protocols response. + { + n, err := conn.Write(reqBytes) + assert.NilError(t, err) + assert.Equal(t, n, len(reqBytes), "incorrect number of bytes written") + + resp, err := http.ReadResponse(bufio.NewReader(conn), nil) + assert.NilError(t, err) + assert.StatusCode(t, resp, http.StatusSwitchingProtocols) + } + + // next, we try to read from the connection, expecting the connection + // to be closed after roughly maxDuration seconds + { + start := time.Now() + _, err := conn.Read(make([]byte, 1)) + elapsed := time.Since(start) + + assert.Error(t, err, io.EOF) + assert.RoughDuration(t, elapsed, maxDuration, 25*time.Millisecond) + } + }) + + t.Run("client closing connection", func(t *testing.T) { + t.Parallel() + + // the client will close the connection well before the server closes + // the connection. make sure the server properly handles the client + // closure. + var ( + clientTimeout = 100 * time.Millisecond + serverTimeout = time.Hour // should never be reached + elapsedClientTime time.Duration + elapsedServerTime time.Duration + wg sync.WaitGroup + ) + + wg.Add(1) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer wg.Done() + start := time.Now() + ws := websocket.New(w, r, websocket.Limits{ + MaxDuration: serverTimeout, + MaxFragmentSize: 128, + MaxMessageSize: 256, + }) + if err := ws.Handshake(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + ws.Serve(websocket.EchoHandler) + elapsedServerTime = time.Since(start) + })) + defer srv.Close() + + conn, err := net.Dial("tcp", srv.Listener.Addr().String()) + assert.NilError(t, err) + defer conn.Close() + + // should cause the client end of the connection to close well before + // the max request time configured above + conn.SetDeadline(time.Now().Add(clientTimeout)) + + reqParts := []string{ + "GET /websocket/echo HTTP/1.1", + "Host: test", + "Connection: upgrade", + "Upgrade: websocket", + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version: 13", + } + reqBytes := []byte(strings.Join(reqParts, "\r\n") + "\r\n\r\n") + t.Logf("raw request:\n%q", reqBytes) + + // first, we write the request line and headers, which should cause the + // server to respond with a 101 Switching Protocols response. + { + n, err := conn.Write(reqBytes) + assert.NilError(t, err) + assert.Equal(t, n, len(reqBytes), "incorrect number of bytes written") + + resp, err := http.ReadResponse(bufio.NewReader(conn), nil) + assert.NilError(t, err) + assert.StatusCode(t, resp, http.StatusSwitchingProtocols) + } + + // next, we try to read from the connection, expecting the connection + // to be closed after roughly clientTimeout seconds. + // + // the server should detect the closed connection and abort the + // handler, also after roughly clientTimeout seconds. + { + start := time.Now() + _, err := conn.Read(make([]byte, 1)) + elapsedClientTime = time.Since(start) + + // close client connection, which should interrupt the server's + // blocking read call on the connection + conn.Close() + + assert.Equal(t, os.IsTimeout(err), true, "expected timeout error") + assert.RoughDuration(t, elapsedClientTime, clientTimeout, 10*time.Millisecond) + + // wait for the server to finish + wg.Wait() + assert.RoughDuration(t, elapsedServerTime, clientTimeout, 10*time.Millisecond) + } + }) +} + // brokenHijackResponseWriter implements just enough to satisfy the // http.ResponseWriter and http.Hijacker interfaces and get through the // handshake before failing to actually hijack the connection.