Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add /sse endpoint to test Server-Sent Events #160

Merged
merged 2 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions httpbin/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httputil"
"net/url"
Expand Down Expand Up @@ -1108,6 +1109,115 @@ func (h *HTTPBin) Hostname(w http.ResponseWriter, _ *http.Request) {
})
}

// SSE writes a stream of events over a duration after an optional
// initial delay.
func (h *HTTPBin) SSE(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
var (
count = h.DefaultParams.SSECount
duration = h.DefaultParams.SSEDuration
delay = h.DefaultParams.SSEDelay
err error
)

if userCount := q.Get("count"); userCount != "" {
count, err = strconv.Atoi(userCount)
if err != nil {
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid count: %w", err))
return
}
if count < 1 || int64(count) > h.maxSSECount {
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid count: must in range [1, %d]", h.maxSSECount))
return
}
}

if userDuration := q.Get("duration"); userDuration != "" {
duration, err = parseBoundedDuration(userDuration, 1, h.MaxDuration)
if err != nil {
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid duration: %w", err))
return
}
}

if userDelay := q.Get("delay"); userDelay != "" {
delay, err = parseBoundedDuration(userDelay, 0, h.MaxDuration)
if err != nil {
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid delay: %w", err))
return
}
}

if duration+delay > h.MaxDuration {
http.Error(w, "Too much time", http.StatusBadRequest)
return
}

pause := duration
if count > 1 {
// compensate for lack of pause after final write (i.e. if we're
// writing 10 events, we will only pause 9 times)
pause = duration / time.Duration(count-1)
}

// Initial delay before we send any response data
if delay > 0 {
select {
case <-time.After(delay):
// ok
case <-r.Context().Done():
w.WriteHeader(499) // "Client Closed Request" https://httpstatuses.com/499
return
}
}

w.Header().Set("Content-Type", sseContentType)
w.WriteHeader(http.StatusOK)

flusher := w.(http.Flusher)

// special case when we only have one event to write
if count == 1 {
writeServerSentEvent(w, 0, time.Now())
flusher.Flush()
return
}

ticker := time.NewTicker(pause)
defer ticker.Stop()

for i := 0; i < count; i++ {
writeServerSentEvent(w, i, time.Now())
flusher.Flush()

// don't pause after last byte
if i == count-1 {
return
}

select {
case <-ticker.C:
// ok
case <-r.Context().Done():
return
}
}
}

// writeServerSentEvent writes the bytes that constitute a single server-sent
// event message, including both the event type and data.
func writeServerSentEvent(dst io.Writer, id int, ts time.Time) {
dst.Write([]byte("event: ping\n"))
dst.Write([]byte("data: "))
json.NewEncoder(dst).Encode(serverSentEvent{
ID: id,
Timestamp: ts.UnixMilli(),
})
// each SSE ends with two newlines (\n\n), the first of which is written
// automatically by json.NewEncoder().Encode()
dst.Write([]byte("\n"))
}

// WebSocketEcho - simple websocket echo server, where the max fragment size
// and max message size can be controlled by clients.
func (h *HTTPBin) WebSocketEcho(w http.ResponseWriter, r *http.Request) {
Expand Down
244 changes: 244 additions & 0 deletions httpbin/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ func createApp(opts ...OptionFunc) *HTTPBin {
DripDelay: 0,
DripDuration: 100 * time.Millisecond,
DripNumBytes: 10,
SSECount: 10,
SSEDelay: 0,
SSEDuration: 100 * time.Millisecond,
}),
WithMaxBodySize(maxBodySize),
WithMaxDuration(maxDuration),
Expand Down Expand Up @@ -2957,6 +2960,246 @@ func TestHostname(t *testing.T) {
})
}

func TestSSE(t *testing.T) {
t.Parallel()

parseServerSentEvent := func(t *testing.T, buf *bufio.Reader) (serverSentEvent, error) {
t.Helper()

// match "event: ping" line
eventLine, err := buf.ReadBytes('\n')
if err != nil {
return serverSentEvent{}, err
}
_, eventType, _ := bytes.Cut(eventLine, []byte(":"))
assert.Equal(t, string(bytes.TrimSpace(eventType)), "ping", "unexpected event type")

// match "data: {...}" line
dataLine, err := buf.ReadBytes('\n')
if err != nil {
return serverSentEvent{}, err
}
_, data, _ := bytes.Cut(dataLine, []byte(":"))
var event serverSentEvent
assert.NilError(t, json.Unmarshal(data, &event))

// match newline after event data
b, err := buf.ReadByte()
if err != nil && err != io.EOF {
assert.NilError(t, err)
}
if b != '\n' {
t.Fatalf("expected newline after event data, got %q", b)
}

return event, nil
}

parseServerSentEventStream := func(t *testing.T, resp *http.Response) []serverSentEvent {
t.Helper()
buf := bufio.NewReader(resp.Body)
var events []serverSentEvent
for {
event, err := parseServerSentEvent(t, buf)
if err == io.EOF {
break
}
assert.NilError(t, err)
events = append(events, event)
}
return events
}

okTests := []struct {
params *url.Values
duration time.Duration
count int
}{
// there are useful defaults for all values
{&url.Values{}, 0, 10},

// go-style durations are accepted
{&url.Values{"duration": {"5ms"}}, 5 * time.Millisecond, 10},
{&url.Values{"duration": {"10ns"}}, 0, 10},
{&url.Values{"delay": {"5ms"}}, 5 * time.Millisecond, 10},
{&url.Values{"delay": {"0h"}}, 0, 10},

// or floating point seconds
{&url.Values{"duration": {"0.25"}}, 250 * time.Millisecond, 10},
{&url.Values{"duration": {"1"}}, 1 * time.Second, 10},
{&url.Values{"delay": {"0.25"}}, 250 * time.Millisecond, 10},
{&url.Values{"delay": {"0"}}, 0, 10},

{&url.Values{"count": {"1"}}, 0, 1},
{&url.Values{"count": {"011"}}, 0, 11},
{&url.Values{"count": {fmt.Sprintf("%d", app.maxSSECount)}}, 0, int(app.maxSSECount)},

{&url.Values{"duration": {"250ms"}, "delay": {"250ms"}}, 500 * time.Millisecond, 10},
{&url.Values{"duration": {"250ms"}, "delay": {"0.25s"}}, 500 * time.Millisecond, 10},
}
for _, test := range okTests {
test := test
t.Run(fmt.Sprintf("ok/%s", test.params.Encode()), func(t *testing.T) {
t.Parallel()

url := "/sse?" + test.params.Encode()

start := time.Now()
req := newTestRequest(t, "GET", url)
resp := must.DoReq(t, client, req)
assert.StatusCode(t, resp, http.StatusOK)
events := parseServerSentEventStream(t, resp)

if elapsed := time.Since(start); elapsed < test.duration {
t.Fatalf("expected minimum duration of %s, request took %s", test.duration, elapsed)
}
assert.ContentType(t, resp, sseContentType)
assert.DeepEqual(t, resp.TransferEncoding, []string{"chunked"}, "unexpected Transfer-Encoding header")
assert.Equal(t, len(events), test.count, "unexpected number of events")
})
}

badTests := []struct {
params *url.Values
code int
}{
{&url.Values{"duration": {"0"}}, http.StatusBadRequest},
{&url.Values{"duration": {"0s"}}, http.StatusBadRequest},
{&url.Values{"duration": {"1m"}}, http.StatusBadRequest},
{&url.Values{"duration": {"-1ms"}}, http.StatusBadRequest},
{&url.Values{"duration": {"1001"}}, http.StatusBadRequest},
{&url.Values{"duration": {"-1"}}, http.StatusBadRequest},
{&url.Values{"duration": {"foo"}}, http.StatusBadRequest},

{&url.Values{"delay": {"1m"}}, http.StatusBadRequest},
{&url.Values{"delay": {"-1ms"}}, http.StatusBadRequest},
{&url.Values{"delay": {"1001"}}, http.StatusBadRequest},
{&url.Values{"delay": {"-1"}}, http.StatusBadRequest},
{&url.Values{"delay": {"foo"}}, http.StatusBadRequest},

{&url.Values{"count": {"foo"}}, http.StatusBadRequest},
{&url.Values{"count": {"0"}}, http.StatusBadRequest},
{&url.Values{"count": {"-1"}}, http.StatusBadRequest},
{&url.Values{"count": {"0xff"}}, http.StatusBadRequest},
{&url.Values{"count": {fmt.Sprintf("%d", app.maxSSECount+1)}}, http.StatusBadRequest},

// request would take too long
{&url.Values{"duration": {"750ms"}, "delay": {"500ms"}}, http.StatusBadRequest},
}
for _, test := range badTests {
test := test
t.Run(fmt.Sprintf("bad/%s", test.params.Encode()), func(t *testing.T) {
t.Parallel()
url := "/sse?" + test.params.Encode()
req := newTestRequest(t, "GET", url)
resp := must.DoReq(t, client, req)
defer consumeAndCloseBody(resp)
assert.StatusCode(t, resp, test.code)
})
}

t.Run("writes are actually incremmental", func(t *testing.T) {
t.Parallel()

var (
duration = 100 * time.Millisecond
count = 3
endpoint = fmt.Sprintf("/sse?duration=%s&count=%d", duration, count)

// Match server logic for calculating the delay between writes
wantPauseBetweenWrites = duration / time.Duration(count-1)
)

req := newTestRequest(t, "GET", endpoint)
resp := must.DoReq(t, client, req)
buf := bufio.NewReader(resp.Body)
eventCount := 0

// Here we read from the response one byte at a time, and ensure that
// at least the expected delay occurs for each read.
//
// The request above includes an initial delay equal to the expected
// wait between writes so that even the first iteration of this loop
// expects to wait the same amount of time for a read.
for i := 0; ; i++ {
start := time.Now()
event, err := parseServerSentEvent(t, buf)
if err == io.EOF {
break
}
assert.NilError(t, err)
gotPause := time.Since(start)

// We expect to read exactly one byte on each iteration. On the
// last iteration, we expct to hit EOF after reading the final
// byte, because the server does not pause after the last write.
assert.Equal(t, event.ID, i, "unexpected SSE event ID")

// only ensure that we pause for the expected time between writes
// (allowing for minor mismatch in local timers and server timers)
// after the first byte.
if i > 0 {
assert.RoughDuration(t, gotPause, wantPauseBetweenWrites, 3*time.Millisecond)
}

eventCount++
}

assert.Equal(t, eventCount, count, "unexpected number of events")
})

t.Run("handle cancelation during initial delay", func(t *testing.T) {
t.Parallel()

// For this test, we expect the client to time out and cancel the
// request after 10ms. The handler should still be in its intitial
// delay period, so this will result in a request error since no status
// code will be written before the cancelation.
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond)
defer cancel()

req := newTestRequest(t, "GET", "/sse?duration=500ms&delay=500ms").WithContext(ctx)
if _, err := client.Do(req); !os.IsTimeout(err) {
t.Fatalf("expected timeout error, got %s", err)
}
})

t.Run("handle cancelation during stream", func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

req := newTestRequest(t, "GET", "/sse?duration=900ms&delay=0&count=2").WithContext(ctx)
resp := must.DoReq(t, client, req)
defer consumeAndCloseBody(resp)

// In this test, the server should have started an OK response before
// our client timeout cancels the request, so we should get an OK here.
assert.StatusCode(t, resp, http.StatusOK)

// But, we should time out while trying to read the whole response
// body.
body, err := io.ReadAll(resp.Body)
if !os.IsTimeout(err) {
t.Fatalf("expected timeout reading body, got %s", err)
}

// partial read should include the first whole event
event, err := parseServerSentEvent(t, bufio.NewReader(bytes.NewReader(body)))
assert.NilError(t, err)
assert.Equal(t, event.ID, 0, "unexpected SSE event ID")
})

t.Run("ensure HEAD request works with streaming responses", func(t *testing.T) {
t.Parallel()
req := newTestRequest(t, "HEAD", "/sse?duration=900ms&delay=100ms")
resp := must.DoReq(t, client, req)
assert.StatusCode(t, resp, http.StatusOK)
assert.BodySize(t, resp, 0)
})
}

func TestWebSocketEcho(t *testing.T) {
// ========================================================================
// Note: Here we only test input validation for the websocket endpoint.
Expand Down Expand Up @@ -3028,6 +3271,7 @@ func TestWebSocketEcho(t *testing.T) {
})
}
}

func newTestServer(handler http.Handler) (*httptest.Server, *http.Client) {
srv := httptest.NewServer(handler)
client := srv.Client()
Expand Down
Loading
Loading