Skip to content

Commit dc457ff

Browse files
Implement SSE polling support (SEP-1699)
Implements the SSE polling behavior defined in SEP-1699: - Server sends priming event (empty data with event ID) on SSE stream - Server can call close_sse_stream() to trigger client reconnection - Client auto-reconnects using server-provided retryInterval or exponential backoff Github-Issue:#1654
1 parent 9cef4ff commit dc457ff

File tree

3 files changed

+96
-15
lines changed

3 files changed

+96
-15
lines changed

src/mcp/client/streamable_http.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,14 @@ def _get_next_reconnection_delay(self, attempt: int) -> float:
175175
Returns:
176176
Time to wait in seconds before next reconnection attempt
177177
"""
178-
# TODO: Implement proper delay calculation with server retry for SEP-1699
179-
return 1.0
178+
# Use server-provided retry value if available
179+
if self._server_retry_seconds is not None:
180+
return self._server_retry_seconds
181+
182+
# Fall back to exponential backoff
183+
opts = self.reconnection_options
184+
delay = opts.initial_reconnection_delay * (opts.reconnection_delay_grow_factor**attempt)
185+
return min(delay, opts.max_reconnection_delay)
180186

181187
async def _handle_sse_event(
182188
self,
@@ -376,25 +382,39 @@ async def _handle_sse_response(
376382
attempt: int = 0,
377383
) -> None:
378384
"""Handle SSE response from the server with automatic reconnection."""
385+
has_priming_event = False
386+
last_event_id: str | None = None
387+
is_complete = False
388+
379389
try:
380390
event_source = EventSource(response)
381391
async for sse in event_source.aiter_sse(): # pragma: no branch
382-
is_complete, _has_event_id = await self._handle_sse_event(
392+
is_complete, has_event_id = await self._handle_sse_event(
383393
sse,
384394
ctx.read_stream_writer,
385395
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
386396
is_initialization=is_initialization,
387397
)
388398

399+
# Track priming events
400+
if has_event_id:
401+
has_priming_event = True
402+
last_event_id = sse.id
403+
389404
# If the SSE event indicates completion, like returning response/error
390405
# break the loop
391406
if is_complete:
392407
await response.aclose()
393408
break
394409
except Exception as e: # pragma: no cover
395410
logger.exception("Error reading SSE stream:")
396-
await ctx.read_stream_writer.send(e)
397-
# TODO: Implement auto-reconnection for SEP-1699
411+
# Don't send exception if we can reconnect
412+
if not (has_priming_event and last_event_id):
413+
await ctx.read_stream_writer.send(e)
414+
415+
# Auto-reconnect if stream ended without completion and we have priming event
416+
if not is_complete and has_priming_event and last_event_id: # pragma: no cover
417+
await self._attempt_sse_reconnection(ctx, last_event_id, attempt)
398418

399419
async def _attempt_sse_reconnection( # pragma: no cover
400420
self,
@@ -407,8 +427,42 @@ async def _attempt_sse_reconnection( # pragma: no cover
407427
Called when SSE stream ends without receiving a response/error,
408428
but we have a priming event indicating resumability.
409429
"""
410-
# TODO: Implement SSE reconnection for SEP-1699
411-
pass
430+
max_retries = self.reconnection_options.max_retries
431+
432+
if attempt >= max_retries:
433+
error_msg = f"Max reconnection attempts ({max_retries}) exceeded"
434+
logger.error(error_msg)
435+
await ctx.read_stream_writer.send(StreamableHTTPError(error_msg))
436+
return
437+
438+
# Calculate delay (uses server retry if available, else exponential backoff)
439+
delay = self._get_next_reconnection_delay(attempt)
440+
logger.info(f"SSE stream closed, reconnecting in {delay:.1f}s (attempt {attempt + 1}/{max_retries})")
441+
442+
await anyio.sleep(delay)
443+
444+
# Build resumption context with last_event_id
445+
resumption_metadata = ClientMessageMetadata(
446+
resumption_token=last_event_id,
447+
on_resumption_token_update=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
448+
)
449+
450+
resumption_ctx = RequestContext(
451+
client=ctx.client,
452+
headers=ctx.headers,
453+
session_id=ctx.session_id,
454+
session_message=ctx.session_message,
455+
metadata=resumption_metadata,
456+
read_stream_writer=ctx.read_stream_writer,
457+
sse_read_timeout=ctx.sse_read_timeout,
458+
)
459+
460+
try:
461+
await self._handle_resumption_request(resumption_ctx)
462+
except Exception as e:
463+
logger.warning(f"Reconnection attempt {attempt + 1} failed: {e}")
464+
# Recursive retry with incremented attempt counter
465+
await self._attempt_sse_reconnection(ctx, last_event_id, attempt + 1)
412466

413467
async def _handle_unexpected_content_type(
414468
self,

src/mcp/server/streamable_http.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,25 @@ async def _create_priming_event(self, stream_id: str) -> dict[str, str | int] |
263263
Returns:
264264
Event data dictionary for the priming event, or None if no event store
265265
"""
266-
# TODO: Implement priming event creation for SEP-1699
267-
return None
266+
if self._event_store is None:
267+
return None
268+
269+
# Store an empty message to get an event ID
270+
# Using an empty dict as a placeholder - it won't be sent as actual data
271+
priming_event_id = await self._event_store.store_event(
272+
stream_id, JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": "_priming"})
273+
)
274+
275+
event_data: dict[str, str | int] = {
276+
"id": priming_event_id,
277+
"data": "", # Empty data for priming event
278+
}
279+
280+
# Add retry interval if configured (sse_starlette expects int, not str)
281+
if self._retry_interval is not None:
282+
event_data["retry"] = self._retry_interval
283+
284+
return event_data
268285

269286
async def _clean_up_memory_streams(self, request_id: RequestId) -> None: # pragma: no cover
270287
"""Clean up memory streams for a given request ID."""
@@ -739,8 +756,17 @@ async def close_sse_stream(self, request_id: RequestId) -> None:
739756
Args:
740757
request_id: The request ID (or stream key) of the stream to close
741758
"""
742-
# TODO: Implement stream closing for SEP-1699
743-
pass
759+
request_id_str = str(request_id)
760+
if request_id_str in self._request_streams:
761+
try:
762+
sender, receiver = self._request_streams[request_id_str]
763+
await sender.aclose()
764+
await receiver.aclose()
765+
except Exception: # pragma: no cover
766+
# Stream might already be closed
767+
logger.debug(f"Error closing SSE stream {request_id_str} - may already be closed")
768+
finally:
769+
self._request_streams.pop(request_id_str, None)
744770

745771
async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover
746772
"""Handle unsupported HTTP methods."""

tests/shared/test_streamable_http.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,11 +1663,12 @@ async def test_handle_sse_event_skips_empty_data():
16631663
write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1)
16641664

16651665
try:
1666-
# Call _handle_sse_event with empty data - should return False and not raise
1667-
result = await transport._handle_sse_event(mock_sse, write_stream)
1666+
# Call _handle_sse_event with empty data - should return (False, False) and not raise
1667+
is_complete, has_event_id = await transport._handle_sse_event(mock_sse, write_stream)
16681668

1669-
# Should return False (not complete) for empty data
1670-
assert result is False
1669+
# Should return False (not complete) for empty data, False for no event ID
1670+
assert is_complete is False
1671+
assert has_event_id is False
16711672

16721673
# Nothing should have been written to the stream
16731674
# Check buffer is empty (statistics().current_buffer_used returns buffer size)

0 commit comments

Comments
 (0)