Skip to content

Commit 52603e8

Browse files
Add SSE polling support (SEP-1699)
Implements SEP-1699 which enables servers to disconnect SSE connections at will by sending priming events and retry fields. This enables more efficient resource management on the server side while maintaining resumability. Key changes: - Server sends priming event (empty data with event ID) on SSE stream - Server can call close_sse_stream() to close stream while gathering events - Client auto-reconnects using server-provided retryInterval or exponential backoff - Added e2e integration tests and example server/client Github-Issue:#1654
1 parent 5983a65 commit 52603e8

File tree

12 files changed

+1130
-52
lines changed

12 files changed

+1130
-52
lines changed

examples/servers/everything-server/mcp_everything_server/server.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,29 @@
1414
from mcp.server.fastmcp import Context, FastMCP
1515
from mcp.server.fastmcp.prompts.base import UserMessage
1616
from mcp.server.session import ServerSession
17+
from mcp.server.streamable_http import (
18+
EventCallback,
19+
EventId,
20+
EventMessage,
21+
EventStore,
22+
StreamId,
23+
)
1724
from mcp.types import (
1825
AudioContent,
1926
Completion,
2027
CompletionArgument,
2128
CompletionContext,
2229
EmbeddedResource,
2330
ImageContent,
31+
JSONRPCMessage,
2432
PromptReference,
2533
ResourceTemplateReference,
2634
SamplingMessage,
2735
TextContent,
2836
TextResourceContents,
2937
)
3038
from pydantic import AnyUrl, BaseModel, Field
39+
from starlette.requests import Request
3140

3241
logger = logging.getLogger(__name__)
3342

@@ -39,8 +48,47 @@
3948
resource_subscriptions: set[str] = set()
4049
watched_resource_content = "Watched resource content"
4150

51+
52+
# Simple in-memory event store for SSE polling resumability (SEP-1699)
53+
class SimpleEventStore(EventStore):
54+
"""Simple in-memory event store for testing resumability."""
55+
56+
def __init__(self) -> None:
57+
self._events: list[tuple[StreamId, EventId, JSONRPCMessage]] = []
58+
self._event_id_counter = 0
59+
60+
async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId:
61+
"""Store an event and return its ID."""
62+
self._event_id_counter += 1
63+
event_id = str(self._event_id_counter)
64+
self._events.append((stream_id, event_id, message))
65+
return event_id
66+
67+
async def replay_events_after(
68+
self,
69+
last_event_id: EventId,
70+
send_callback: EventCallback,
71+
) -> StreamId | None:
72+
"""Replay events after the specified ID."""
73+
target_stream_id = None
74+
found = False
75+
for stream_id, event_id, message in self._events:
76+
if event_id == last_event_id:
77+
target_stream_id = stream_id
78+
found = True
79+
continue
80+
if found and stream_id == target_stream_id:
81+
await send_callback(EventMessage(message=message, event_id=event_id))
82+
return target_stream_id
83+
84+
85+
# Create event store for resumability
86+
event_store = SimpleEventStore()
87+
4288
mcp = FastMCP(
4389
name="mcp-conformance-test-server",
90+
event_store=event_store,
91+
sse_retry_interval=3000, # 3 seconds
4492
)
4593

4694

@@ -257,6 +305,33 @@ async def test_elicitation_sep1330_enums(ctx: Context[ServerSession, None]) -> s
257305
return f"Elicitation not supported or error: {str(e)}"
258306

259307

308+
@mcp.tool()
309+
async def test_reconnection(ctx: Context[ServerSession, None]) -> str:
310+
"""Tests SSE polling via server-initiated disconnect (SEP-1699)
311+
312+
This tool closes the SSE stream mid-call, requiring the client to reconnect
313+
with Last-Event-ID to receive the remaining events.
314+
"""
315+
# Send notification before disconnect
316+
await ctx.info("Notification before disconnect")
317+
318+
# Get session_id from request headers
319+
request = ctx.request_context.request
320+
if isinstance(request, Request):
321+
session_id = request.headers.get("mcp-session-id")
322+
if session_id:
323+
# Trigger server-initiated SSE disconnect
324+
await mcp.session_manager.close_sse_stream(session_id, ctx.request_id)
325+
326+
# Wait for client to reconnect
327+
await asyncio.sleep(0.2)
328+
329+
# Send notification after disconnect (will be replayed via event store)
330+
await ctx.info("Notification after disconnect")
331+
332+
return "Reconnection test completed successfully"
333+
334+
260335
@mcp.tool()
261336
def test_error_handling() -> str:
262337
"""Tests error response handling"""
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""
2+
SSE Polling Example Client
3+
4+
Demonstrates client-side behavior during server-initiated SSE disconnect.
5+
6+
Key features:
7+
- Automatic reconnection when server closes SSE stream
8+
- Event replay via Last-Event-ID header (handled internally by the transport)
9+
- Progress notifications via logging callback
10+
11+
This client connects to the SSE polling server and calls the `long-task` tool.
12+
The server disconnects at 50% progress, and the client automatically reconnects
13+
to receive remaining progress updates.
14+
15+
Run:
16+
# First start the server:
17+
uv run examples/snippets/servers/sse_polling_server.py
18+
19+
# Then run this client:
20+
uv run examples/snippets/clients/sse_polling_client.py
21+
"""
22+
23+
import asyncio
24+
import logging
25+
26+
from mcp import ClientSession
27+
from mcp.client.streamable_http import StreamableHTTPReconnectionOptions, streamablehttp_client
28+
from mcp.types import LoggingMessageNotificationParams, TextContent
29+
30+
logging.basicConfig(
31+
level=logging.INFO,
32+
format="%(asctime)s - %(levelname)s - %(message)s",
33+
)
34+
logger = logging.getLogger(__name__)
35+
36+
37+
async def main() -> None:
38+
print("SSE Polling Example Client")
39+
print("=" * 50)
40+
print()
41+
42+
# Track notifications received via the logging callback
43+
notifications_received: list[str] = []
44+
45+
async def logging_callback(params: LoggingMessageNotificationParams) -> None:
46+
"""Called when a log message notification is received from the server."""
47+
data = params.data
48+
if data:
49+
data_str = str(data)
50+
notifications_received.append(data_str)
51+
print(f"[Progress] {data_str}")
52+
53+
# Configure reconnection behavior
54+
reconnection_options = StreamableHTTPReconnectionOptions(
55+
initial_reconnection_delay=1.0, # Start with 1 second
56+
max_reconnection_delay=30.0, # Cap at 30 seconds
57+
reconnection_delay_grow_factor=1.5, # Exponential backoff
58+
max_retries=5, # Try up to 5 times
59+
)
60+
61+
print("[Client] Connecting to server...")
62+
63+
async with streamablehttp_client(
64+
"http://localhost:3001/mcp",
65+
reconnection_options=reconnection_options,
66+
) as (read_stream, write_stream, get_session_id):
67+
# Create session with logging callback to receive progress notifications
68+
async with ClientSession(
69+
read_stream,
70+
write_stream,
71+
logging_callback=logging_callback,
72+
) as session:
73+
# Initialize the session
74+
await session.initialize()
75+
session_id = get_session_id()
76+
print(f"[Client] Connected! Session ID: {session_id}")
77+
78+
# List available tools
79+
tools = await session.list_tools()
80+
tool_names = [t.name for t in tools.tools]
81+
print(f"[Client] Available tools: {tool_names}")
82+
print()
83+
84+
# Call the long-running task
85+
print("[Client] Calling long-task tool...")
86+
print("[Client] The server will disconnect at 50% and we'll auto-reconnect")
87+
print()
88+
89+
# Call the tool
90+
result = await session.call_tool("long-task", {})
91+
92+
print()
93+
print("[Client] Task completed!")
94+
if result.content and isinstance(result.content[0], TextContent):
95+
print(f"[Result] {result.content[0].text}")
96+
else:
97+
print("[Result] No content")
98+
print()
99+
print(f"[Summary] Received {len(notifications_received)} progress notifications")
100+
101+
102+
if __name__ == "__main__":
103+
asyncio.run(main())

0 commit comments

Comments
 (0)