Skip to content

Commit 8577e9f

Browse files
author
skyvanguard
committed
fix: send error to client when SSE stream disconnects without response
When the SSE stream disconnects before the server sends a response and no event ID has been received (making reconnection impossible), the client was left hanging indefinitely on read_stream.receive(). Now the transport sends a JSONRPCError to unblock the client, which surfaces as an McpError. Also handles the case where reconnection attempts are exhausted without receiving a response. Github-Issue: #1811 Reported-by: ivanbelenky
1 parent a7ddfda commit 8577e9f

File tree

2 files changed

+196
-2
lines changed

2 files changed

+196
-2
lines changed

src/mcp/client/streamable_http.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,13 @@ async def _handle_sse_response(
337337
logger.debug(f"SSE stream ended: {e}") # pragma: no cover
338338

339339
# Stream ended without response - reconnect if we received an event with ID
340-
if last_event_id is not None: # pragma: no branch
340+
if last_event_id is not None:
341341
logger.info("SSE stream disconnected, reconnecting...")
342342
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms)
343+
else:
344+
# No event ID received before disconnect - cannot reconnect,
345+
# send error to unblock the client
346+
await self._send_disconnect_error(ctx)
343347

344348
async def _handle_reconnection(
345349
self,
@@ -350,8 +354,9 @@ async def _handle_reconnection(
350354
) -> None:
351355
"""Reconnect with Last-Event-ID to resume stream after server disconnect."""
352356
# Bail if max retries exceeded
353-
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
357+
if attempt >= MAX_RECONNECTION_ATTEMPTS:
354358
logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
359+
await self._send_disconnect_error(ctx)
355360
return
356361

357362
# Always wait - use server value or default
@@ -417,6 +422,17 @@ async def _send_session_terminated_error(self, read_stream_writer: StreamWriter,
417422
session_message = SessionMessage(jsonrpc_error)
418423
await read_stream_writer.send(session_message)
419424

425+
async def _send_disconnect_error(self, ctx: RequestContext) -> None:
426+
"""Send a disconnect error to unblock the client waiting on the read stream."""
427+
if isinstance(ctx.session_message.message, JSONRPCRequest):
428+
request_id = ctx.session_message.message.id
429+
jsonrpc_error = JSONRPCError(
430+
jsonrpc="2.0",
431+
id=request_id,
432+
error=ErrorData(code=-32000, message="SSE stream disconnected before receiving a response"),
433+
)
434+
await ctx.read_stream_writer.send(SessionMessage(jsonrpc_error))
435+
420436
async def post_writer(
421437
self,
422438
client: httpx.AsyncClient,
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
"""Test for issue #1811 - client hangs after SSE disconnection.
2+
3+
When the SSE stream disconnects before the server sends a response (e.g., due to
4+
a read timeout), the client's read_stream_writer was never sent an error message,
5+
causing the client to hang indefinitely on .receive(). The fix sends a JSONRPCError
6+
when the stream disconnects without a resumable event ID.
7+
"""
8+
9+
import multiprocessing
10+
import socket
11+
from collections.abc import AsyncGenerator
12+
from contextlib import asynccontextmanager
13+
14+
import anyio
15+
import httpx
16+
import pytest
17+
from starlette.applications import Starlette
18+
from starlette.routing import Mount
19+
20+
from mcp.client.session import ClientSession
21+
from mcp.client.streamable_http import streamable_http_client
22+
from mcp.server import Server
23+
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
24+
from mcp.shared.exceptions import McpError
25+
from mcp.types import TextContent, Tool
26+
from tests.test_helpers import wait_for_server
27+
28+
SERVER_NAME = "test_sse_disconnect_server"
29+
30+
31+
def get_free_port() -> int:
32+
with socket.socket() as s:
33+
s.bind(("127.0.0.1", 0))
34+
return s.getsockname()[1]
35+
36+
37+
def create_slow_server_app() -> Starlette:
38+
"""Create a server with a tool that takes a long time to respond."""
39+
server = Server(SERVER_NAME)
40+
41+
@server.list_tools()
42+
async def handle_list_tools() -> list[Tool]:
43+
return [
44+
Tool(
45+
name="slow_tool",
46+
description="A tool that takes a long time",
47+
input_schema={"type": "object", "properties": {}},
48+
)
49+
]
50+
51+
@server.call_tool()
52+
async def handle_call_tool(name: str, arguments: dict[str, object]) -> list[TextContent]:
53+
# Sleep long enough that the client timeout fires first
54+
await anyio.sleep(30)
55+
return [TextContent(type="text", text="done")]
56+
57+
session_manager = StreamableHTTPSessionManager(app=server, stateless=True)
58+
59+
@asynccontextmanager
60+
async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
61+
async with session_manager.run():
62+
yield
63+
64+
return Starlette(
65+
routes=[Mount("/mcp", app=session_manager.handle_request)],
66+
lifespan=lifespan,
67+
)
68+
69+
70+
def create_fast_server_app() -> Starlette:
71+
"""Create a server with a fast tool for sanity testing."""
72+
server = Server(SERVER_NAME)
73+
74+
@server.list_tools()
75+
async def handle_list_tools() -> list[Tool]:
76+
return [
77+
Tool(
78+
name="fast_tool",
79+
description="A fast tool",
80+
input_schema={"type": "object", "properties": {}},
81+
)
82+
]
83+
84+
@server.call_tool()
85+
async def handle_call_tool(name: str, arguments: dict[str, object]) -> list[TextContent]:
86+
return [TextContent(type="text", text="fast result")]
87+
88+
session_manager = StreamableHTTPSessionManager(app=server, stateless=True)
89+
90+
@asynccontextmanager
91+
async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
92+
async with session_manager.run():
93+
yield
94+
95+
return Starlette(
96+
routes=[Mount("/mcp", app=session_manager.handle_request)],
97+
lifespan=lifespan,
98+
)
99+
100+
101+
def run_server(port: int, slow: bool = True) -> None:
102+
"""Run the server in a separate process."""
103+
import uvicorn
104+
105+
app = create_slow_server_app() if slow else create_fast_server_app()
106+
uvicorn.run(app, host="127.0.0.1", port=port, log_level="warning")
107+
108+
109+
@pytest.fixture
110+
def slow_server_url():
111+
"""Start the slow server and return its URL."""
112+
port = get_free_port()
113+
proc = multiprocessing.Process(target=run_server, args=(port, True), daemon=True)
114+
proc.start()
115+
wait_for_server(port)
116+
117+
yield f"http://127.0.0.1:{port}"
118+
119+
proc.kill()
120+
proc.join(timeout=2)
121+
122+
123+
@pytest.fixture
124+
def fast_server_url():
125+
"""Start the fast server and return its URL."""
126+
port = get_free_port()
127+
proc = multiprocessing.Process(target=run_server, args=(port, False), daemon=True)
128+
proc.start()
129+
wait_for_server(port)
130+
131+
yield f"http://127.0.0.1:{port}"
132+
133+
proc.kill()
134+
proc.join(timeout=2)
135+
136+
137+
@pytest.mark.anyio
138+
async def test_client_receives_error_on_sse_disconnect(slow_server_url: str):
139+
"""Client should receive an error instead of hanging when SSE stream disconnects.
140+
141+
When the read timeout fires before the server sends a response, the SSE stream
142+
is closed. Previously, if no event ID had been received, the client would hang
143+
forever. Now it should raise McpError with the disconnect message.
144+
"""
145+
# Use a short read timeout so the SSE stream disconnects quickly
146+
short_timeout_client = httpx.AsyncClient(
147+
timeout=httpx.Timeout(5.0, read=0.5),
148+
)
149+
150+
async with streamable_http_client(
151+
f"{slow_server_url}/mcp/",
152+
http_client=short_timeout_client,
153+
) as (read_stream, write_stream, _):
154+
async with ClientSession(read_stream, write_stream) as session:
155+
await session.initialize()
156+
157+
# Call the slow tool - the read timeout should fire
158+
# and the client should receive an error instead of hanging
159+
with pytest.raises(McpError, match="SSE stream disconnected"):
160+
await session.call_tool("slow_tool", {})
161+
162+
163+
@pytest.mark.anyio
164+
async def test_fast_tool_still_works_normally(fast_server_url: str):
165+
"""Ensure normal (fast) tool calls still work correctly after the fix."""
166+
client = httpx.AsyncClient(timeout=httpx.Timeout(5.0))
167+
168+
async with streamable_http_client(
169+
f"{fast_server_url}/mcp/",
170+
http_client=client,
171+
) as (read_stream, write_stream, _):
172+
async with ClientSession(read_stream, write_stream) as session:
173+
await session.initialize()
174+
175+
result = await session.call_tool("fast_tool", {})
176+
assert result.content[0].type == "text"
177+
assert isinstance(result.content[0], TextContent)
178+
assert result.content[0].text == "fast result"

0 commit comments

Comments
 (0)