Skip to content

Optimize and standardize the readSSEStream function handling #204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
.idea
.opencode
.claude

.vscode
60 changes: 10 additions & 50 deletions client/transport/sse.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
package transport

import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -104,7 +102,7 @@ func (c *SSE) Start(ctx context.Context) error {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

go c.readSSE(resp.Body)
go c.readSSE(ctx, resp.Body)

// Wait for the endpoint to be received
timeout := time.NewTimer(30 * time.Second)
Expand All @@ -125,56 +123,18 @@ func (c *SSE) Start(ctx context.Context) error {

// readSSE continuously reads the SSE stream and processes events.
// It runs until the connection is closed or an error occurs.
func (c *SSE) readSSE(reader io.ReadCloser) {
defer reader.Close()

br := bufio.NewReader(reader)
var event, data string

for {
// when close or start's ctx cancel, the reader will be closed
// and the for loop will break.
line, err := br.ReadString('\n')
if err != nil {
if err == io.EOF {
// Process any pending event before exit
if event != "" && data != "" {
c.handleSSEEvent(event, data)
}
break
}
if !c.closed.Load() {
fmt.Printf("SSE stream error: %v\n", err)
}
return
}

// Remove only newline markers
line = strings.TrimRight(line, "\r\n")
if line == "" {
// Empty line means end of event
if event != "" && data != "" {
c.handleSSEEvent(event, data)
event = ""
data = ""
}
continue
}

if strings.HasPrefix(line, "event:") {
event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
} else if strings.HasPrefix(line, "data:") {
data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
}
func (c *SSE) readSSE(ctx context.Context, reader io.ReadCloser) {
if err := ReadSSEStream(ctx, reader, c.handleSSEEvent); err != nil && !c.closed.Load() {
fmt.Printf("SSE stream error: %v\n", err)
}
}

// handleSSEEvent processes SSE events based on their type.
// Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication.
func (c *SSE) handleSSEEvent(event, data string) {
switch event {
func (c *SSE) handleSSEEvent(event sseEvent) {
switch event.event {
case "endpoint":
endpoint, err := c.baseURL.Parse(data)
endpoint, err := c.baseURL.Parse(event.data)
if err != nil {
fmt.Printf("Error parsing endpoint URL: %v\n", err)
return
Expand All @@ -188,15 +148,15 @@ func (c *SSE) handleSSEEvent(event, data string) {

case "message":
var baseMessage JSONRPCResponse
if err := json.Unmarshal([]byte(data), &baseMessage); err != nil {
if err := json.Unmarshal([]byte(event.data), &baseMessage); err != nil {
fmt.Printf("Error unmarshaling message: %v\n", err)
return
}

// Handle notification
if baseMessage.ID == nil {
var notification mcp.JSONRPCNotification
if err := json.Unmarshal([]byte(data), &notification); err != nil {
if err := json.Unmarshal([]byte(event.data), &notification); err != nil {
return
}
c.notifyMu.RLock()
Expand Down Expand Up @@ -340,7 +300,7 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti

req, err := http.NewRequestWithContext(
ctx,
"POST",
http.MethodPost,
c.endpoint.String(),
bytes.NewReader(notificationBytes),
)
Expand Down
68 changes: 68 additions & 0 deletions client/transport/sse_helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package transport

import (
"bufio"
"context"
"fmt"
"io"
"strings"
)

type sseEvent struct {
event string
data string
}

// ReadSSEStream continuously reads the SSE stream and processes events.
func ReadSSEStream(ctx context.Context, reader io.ReadCloser, onEvent func(event sseEvent)) error {
defer func(reader io.ReadCloser) {
err := reader.Close()
if err != nil {
fmt.Printf("Error closing reader: %v\n", err)
}
}(reader)

br := bufio.NewReader(reader)
var event, data strings.Builder

processEvent := func() {
if event.Len() > 0 || data.Len() > 0 {
onEvent(sseEvent{event: event.String(), data: data.String()})
event.Reset()
data.Reset()
}
}

for {
select {
case <-ctx.Done():
return nil
default:
line, err := br.ReadString('\n')
if err != nil {
if err == io.EOF {
// Handle last event when EOF
processEvent()
return nil
}
return fmt.Errorf("error reading SSE stream: %w", err)
}

// Remove only newline markers
line = strings.TrimRight(line, "\r\n")

switch {
case strings.HasPrefix(line, "event:"):
event.Reset()
event.WriteString(strings.TrimSpace(strings.TrimPrefix(line, "event:")))
case strings.HasPrefix(line, "data:"):
if data.Len() > 0 {
data.WriteString("\n")
}
data.WriteString(strings.TrimSpace(strings.TrimPrefix(line, "data:")))
case line == "":
processEvent()
}
}
}
}
54 changes: 7 additions & 47 deletions client/transport/streamable_http.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package transport

import (
"bufio"
"bytes"
"context"
"encoding/json"
Expand All @@ -10,7 +9,6 @@ import (
"mime"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -247,20 +245,20 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
// only close responseChan after readingSSE()
defer close(responseChan)

c.readSSE(ctx, reader, func(event, data string) {
c.readSSE(ctx, reader, func(event sseEvent) {

// (unsupported: batching)

var message JSONRPCResponse
if err := json.Unmarshal([]byte(data), &message); err != nil {
if err := json.Unmarshal([]byte(event.data), &message); err != nil {
fmt.Printf("failed to unmarshal message: %v\n", err)
return
}

// Handle notification
if message.ID == nil {
var notification mcp.JSONRPCNotification
if err := json.Unmarshal([]byte(data), &notification); err != nil {
if err := json.Unmarshal([]byte(event.data), &notification); err != nil {
fmt.Printf("failed to unmarshal notification: %v\n", err)
return
}
Expand Down Expand Up @@ -290,52 +288,14 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl

// readSSE reads the SSE stream(reader) and calls the handler for each event and data pair.
// It will end when the reader is closed (or the context is done).
func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, handler func(event, data string)) {
defer reader.Close()

br := bufio.NewReader(reader)
var event, data string

for {
func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, handler func(event sseEvent)) {
if err := ReadSSEStream(ctx, reader, handler); err != nil {
select {
case <-ctx.Done():
return
default:
line, err := br.ReadString('\n')
if err != nil {
if err == io.EOF {
// Process any pending event before exit
if event != "" && data != "" {
handler(event, data)
}
return
}
select {
case <-ctx.Done():
return
default:
fmt.Printf("SSE stream error: %v\n", err)
return
}
}

// Remove only newline markers
line = strings.TrimRight(line, "\r\n")
if line == "" {
// Empty line means end of event
if event != "" && data != "" {
handler(event, data)
event = ""
data = ""
}
continue
}

if strings.HasPrefix(line, "event:") {
event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
} else if strings.HasPrefix(line, "data:") {
data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
}
fmt.Printf("SSE stream error: %v\n", err)
return
}
}
}
Expand Down