Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 35 additions & 27 deletions mcp/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ func writeEvent(w io.Writer, evt Event) (int, error) {
// TODO(rfindley): consider a different API here that makes failure modes more
// apparent.
func scanEvents(r io.Reader) iter.Seq2[Event, error] {
scanner := bufio.NewScanner(r)
const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size
scanner.Buffer(nil, maxTokenSize)
reader := bufio.NewReader(r)

// TODO: investigate proper behavior when events are out of order, or have
// non-standard names.
Expand All @@ -94,31 +92,43 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] {
evt Event
dataBuf *bytes.Buffer // if non-nil, preceding field was also data
)
flushData := func() {
yieldEvent := func() bool {
if dataBuf != nil {
evt.Data = dataBuf.Bytes()
dataBuf = nil
}
if evt.Empty() {
return true
}
if !yield(evt, nil) {
return false
}
evt = Event{}
return true
}
for scanner.Scan() {
line := scanner.Bytes()
for {
line, err := reader.ReadBytes('\n')
if err != nil && !errors.Is(err, io.EOF) {
yield(Event{}, fmt.Errorf("error reading event: %v", err))
return
}
line = bytes.TrimRight(line, "\r\n")
isEOF := errors.Is(err, io.EOF)

if len(line) == 0 {
flushData()
// \n\n is the record delimiter
if !evt.Empty() && !yield(evt, nil) {
if !yieldEvent() {
return
}
if isEOF {
return
}
evt = Event{}
continue
}
before, after, found := bytes.Cut(line, []byte{':'})
if !found {
yield(Event{}, fmt.Errorf("malformed line in SSE stream: %q", string(line)))
yield(Event{}, fmt.Errorf("%w: malformed line in SSE stream: %q", errMalformedEvent, string(line)))
return
}
if !bytes.Equal(before, dataKey) {
flushData()
}
switch {
case bytes.Equal(before, eventKey):
evt.Name = strings.TrimSpace(string(after))
Expand All @@ -128,27 +138,20 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] {
evt.Retry = strings.TrimSpace(string(after))
case bytes.Equal(before, dataKey):
data := bytes.TrimSpace(after)
if dataBuf != nil {
dataBuf.WriteByte('\n')
if dataBuf == nil {
dataBuf = new(bytes.Buffer)
dataBuf.Write(data)
} else {
dataBuf = new(bytes.Buffer)
dataBuf.WriteByte('\n')
dataBuf.Write(data)
}
}
}
if err := scanner.Err(); err != nil {
if errors.Is(err, bufio.ErrTooLong) {
err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize)
}
if !yield(Event{}, err) {

if isEOF {
yieldEvent()
return
}
}
flushData()
if !evt.Empty() {
yield(evt, nil)
}
}
}

Expand Down Expand Up @@ -310,6 +313,11 @@ func (s *MemoryEventStore) Append(_ context.Context, sessionID, streamID string,
// index is no longer available.
var ErrEventsPurged = errors.New("data purged")

// errMalformedEvent is returned when an SSE event cannot be parsed due to format violations.
// This is a hard error indicating corrupted data or protocol violations, as opposed to
// transient I/O errors which may be retryable.
var errMalformedEvent = errors.New("malformed event")

// After implements [EventStore.After].
func (s *MemoryEventStore) After(_ context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error] {
// Return the data items to yield.
Expand Down
72 changes: 72 additions & 0 deletions mcp/event_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,78 @@ func TestScanEvents(t *testing.T) {
input: "invalid line\n\n",
wantErr: "malformed line",
},
{
name: "message with 2 data lines and another event",
input: "event: message\ndata: hello\ndata: hello\ndata: hello\n\nevent:keepalive",
want: []Event{
{Name: "message", Data: []byte("hello\nhello\nhello")},
{Name: "keepalive"},
},
},
{
name: "event with multiple lines",
input: "event: message\ndata: hello\ndata: hello\ndata: hello\nid:1",
want: []Event{
{Name: "message", ID: "1", Data: []byte("hello\nhello\nhello")},
},
},
{
name: "multiple events, out of order keys",
input: strings.Join([]string{
"event:message",
"data: hello0",
"\n",
"data: hello1",
"data: hello1",
"id:1",
"event:message",
"\n",
"event:message",
"data: hello3",
"data: hello3",
"id:3",
"\n",
"data: hello4",
"data: hello4",
"id:4",
"event:message",
}, "\n"),
want: []Event{
{Name: "message", Data: []byte("hello0")},
{Name: "message", ID: "1", Data: []byte("hello1\nhello1")},
{Name: "message", ID: "3", Data: []byte("hello3\nhello3")},
{Name: "message", ID: "4", Data: []byte("hello4\nhello4")},
},
},
{
name: "non-continuous data items in the event",
input: "event: foo\ndata: 123\nretry: 5\ndata: 456",
want: []Event{
{Name: "foo", Data: []byte("123\n456"), Retry: "5"},
},
},
{
name: "no-data events",
input: "event: foo\n\nevent: bar",
want: []Event{
{Name: "foo"},
{Name: "bar"},
},
},
{
name: "empty data event",
input: "event: foo\ndata:\n\nevent: bar",
want: []Event{
{Name: "foo"},
{Name: "bar"},
},
},
{

name: "malformed data event",
input: "someline",
wantErr: "malformed event",
},
}

for _, tt := range tests {
Expand Down
3 changes: 2 additions & 1 deletion mcp/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

if req.Method != http.MethodGet {
http.Error(w, "invalid method", http.StatusMethodNotAllowed)
w.Header().Set("Allow", "GET, POST")
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return
}

Expand Down
12 changes: 11 additions & 1 deletion mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -1859,7 +1859,17 @@ func (c *streamableClientConn) processStream(ctx context.Context, requestSummary
if ctx.Err() != nil {
return "", 0, true // don't reconnect: client cancelled
}
break

// Malformed events are hard errors that indicate corrupted data or protocol // violations. These should fail the connection permanently.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Malformed events are hard errors that indicate corrupted data or protocol // violations. These should fail the connection permanently.
// Malformed events are hard errors that indicate corrupted data or protocol
// violations. These should fail the connection permanently.

if errors.Is(err, errMalformedEvent) {
c.fail(fmt.Errorf("%s: %v", requestSummary, err))
return "", 0, true
}

// Network/I/O errors during reading should trigger reconnection, not permanent failure.
// Return from processStream so handleSSE can attempt to reconnect.
c.logger.Debug(fmt.Sprintf("%s: stream read error (will attempt reconnect): %v", requestSummary, err))
return lastEventID, reconnectDelay, false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I think we actually need to be more subtle than this: a malformed event is a hard error, and different from a reader error.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did some adaptaptation. let me know WDYT

}

if evt.ID != "" {
Expand Down