-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Fix: Prevent session manager shutdown on individual session crash #841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
694cabc
a804442
4026edf
7671963
2a27e65
b12b370
8067bc9
ec85fa3
ab0c809
131a58d
6165a30
e25081b
61f3dc2
208f13b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,7 +52,6 @@ class StreamableHTTPSessionManager: | |
json_response: Whether to use JSON responses instead of SSE streams | ||
stateless: If True, creates a completely fresh transport for each request | ||
with no session tracking or state persistence between requests. | ||
|
||
""" | ||
|
||
def __init__( | ||
|
@@ -173,12 +172,15 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA | |
async with http_transport.connect() as streams: | ||
read_stream, write_stream = streams | ||
task_status.started() | ||
await self.app.run( | ||
read_stream, | ||
write_stream, | ||
self.app.create_initialization_options(), | ||
stateless=True, | ||
) | ||
try: | ||
await self.app.run( | ||
read_stream, | ||
write_stream, | ||
self.app.create_initialization_options(), | ||
stateless=True, | ||
) | ||
except Exception as e: | ||
logger.warning(f"Stateless session crashed: {e}", exc_info=True) | ||
|
||
# Assert task group is not None for type checking | ||
assert self._task_group is not None | ||
|
@@ -233,12 +235,33 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE | |
async with http_transport.connect() as streams: | ||
read_stream, write_stream = streams | ||
task_status.started() | ||
await self.app.run( | ||
read_stream, | ||
write_stream, | ||
self.app.create_initialization_options(), | ||
stateless=False, # Stateful mode | ||
) | ||
try: | ||
await self.app.run( | ||
read_stream, | ||
write_stream, | ||
self.app.create_initialization_options(), | ||
stateless=False, # Stateful mode | ||
) | ||
except Exception as e: | ||
logger.warning( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are you using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I originally logged it at error but it was lighting up my bug reporting system with uncaught exceptions related to trivial things like the client closing the tcp connection unexpectedly. "Be usable in the real world" has to take priority. |
||
f"Session {http_transport.mcp_session_id} crashed: {e}", | ||
exc_info=True, | ||
) | ||
finally: | ||
# Only remove from instances if not terminated | ||
if ( | ||
http_transport.mcp_session_id | ||
and http_transport.mcp_session_id in self._server_instances | ||
and not ( | ||
hasattr(http_transport, "_terminated") and http_transport._terminated # pyright: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To avoid having to override the linter, should we have
|
||
) | ||
): | ||
logger.info( | ||
"Cleaning up crashed session " | ||
f"{http_transport.mcp_session_id} from " | ||
"active instances." | ||
) | ||
del self._server_instances[http_transport.mcp_session_id] | ||
|
||
# Assert task group is not None for type checking | ||
assert self._task_group is not None | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,12 @@ | ||
"""Tests for StreamableHTTPSessionManager.""" | ||
|
||
from unittest.mock import AsyncMock | ||
|
||
import anyio | ||
import pytest | ||
|
||
from mcp.server.lowlevel import Server | ||
from mcp.server.streamable_http import MCP_SESSION_ID_HEADER | ||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager | ||
|
||
|
||
|
@@ -71,3 +74,124 @@ async def send(message): | |
await manager.handle_request(scope, receive, send) | ||
|
||
assert "Task group is not initialized. Make sure to use run()." in str(excinfo.value) | ||
|
||
|
||
class TestException(Exception): | ||
__test__ = False # Prevent pytest from collecting this as a test class | ||
pass | ||
|
||
|
||
@pytest.fixture | ||
async def running_manager(): | ||
app = Server("test-cleanup-server") | ||
# It's important that the app instance used by the manager is the one we can patch | ||
manager = StreamableHTTPSessionManager(app=app) | ||
async with manager.run(): | ||
# Patch app.run here if it's simpler, or patch it within the test | ||
yield manager, app | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_stateful_session_cleanup_on_graceful_exit(running_manager): | ||
manager, app = running_manager | ||
|
||
mock_mcp_run = AsyncMock(return_value=None) | ||
# This will be called by StreamableHTTPSessionManager's run_server -> self.app.run | ||
app.run = mock_mcp_run | ||
|
||
sent_messages = [] | ||
|
||
async def mock_send(message): | ||
sent_messages.append(message) | ||
|
||
scope = {"type": "http", "method": "POST", "path": "/mcp", "headers": []} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you need the headers here to make CI pass:
|
||
|
||
async def mock_receive(): | ||
return {"type": "http.request", "body": b"", "more_body": False} | ||
|
||
# Trigger session creation | ||
await manager.handle_request(scope, mock_receive, mock_send) | ||
|
||
# Extract session ID from response headers | ||
session_id = None | ||
for msg in sent_messages: | ||
if msg["type"] == "http.response.start": | ||
for header_name, header_value in msg.get("headers", []): | ||
if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower(): | ||
session_id = header_value.decode() | ||
break | ||
if session_id: # Break outer loop if session_id is found | ||
break | ||
|
||
assert session_id is not None, "Session ID not found in response headers" | ||
|
||
# Ensure MCPServer.run was called | ||
mock_mcp_run.assert_called_once() | ||
|
||
# At this point, mock_mcp_run has completed, and the finally block in | ||
# StreamableHTTPSessionManager's run_server should have executed. | ||
|
||
# To ensure the task spawned by handle_request finishes and cleanup occurs: | ||
# Give other tasks a chance to run. This is important for the finally block. | ||
await anyio.sleep(0.01) | ||
|
||
assert ( | ||
session_id not in manager._server_instances | ||
), "Session ID should be removed from _server_instances after graceful exit" | ||
assert not manager._server_instances, "No sessions should be tracked after the only session exits gracefully" | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_stateful_session_cleanup_on_exception(running_manager): | ||
manager, app = running_manager | ||
|
||
mock_mcp_run = AsyncMock(side_effect=TestException("Simulated crash")) | ||
app.run = mock_mcp_run | ||
|
||
sent_messages = [] | ||
|
||
async def mock_send(message): | ||
sent_messages.append(message) | ||
# If an exception occurs, the transport might try to send an error response | ||
# For this test, we mostly care that the session is established enough | ||
# to get an ID | ||
if message["type"] == "http.response.start" and message["status"] >= 500: | ||
pass # Expected if TestException propagates that far up the transport | ||
|
||
scope = {"type": "http", "method": "POST", "path": "/mcp", "headers": []} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you need the headers here to make CI pass:
|
||
|
||
async def mock_receive(): | ||
return {"type": "http.request", "body": b"", "more_body": False} | ||
|
||
# It's possible handle_request itself might raise an error if the TestException | ||
# isn't caught by the transport layer before propagating. | ||
# The key is that the session manager's internal task for MCPServer.run | ||
# encounters the exception. | ||
try: | ||
await manager.handle_request(scope, mock_receive, mock_send) | ||
except TestException: | ||
# This might be caught here if not handled by StreamableHTTPServerTransport's | ||
# error handling | ||
pass | ||
Comment on lines
+166
to
+175
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we actually need this try catch and explanation? I ran this test locally 50x without the Try/Catch and the test worked fine without it. Can we remove the try catch and comment? When would the exception not be caught by the app.run try-catch you added in the manager? |
||
|
||
session_id = None | ||
for msg in sent_messages: | ||
if msg["type"] == "http.response.start": | ||
for header_name, header_value in msg.get("headers", []): | ||
if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower(): | ||
session_id = header_value.decode() | ||
break | ||
if session_id: # Break outer loop if session_id is found | ||
break | ||
|
||
assert session_id is not None, "Session ID not found in response headers" | ||
|
||
mock_mcp_run.assert_called_once() | ||
|
||
# Give other tasks a chance to run to ensure the finally block executes | ||
await anyio.sleep(0.01) | ||
|
||
assert ( | ||
session_id not in manager._server_instances | ||
), "Session ID should be removed from _server_instances after an exception" | ||
assert not manager._server_instances, "No sessions should be tracked after the only session crashes" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did we look into
.run
function to see how to handle the error there?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NAVNAV221 That should also happen, but there should be no scenarios in which a per-session error is allowed to destabilize the entire server until reboot. This is a catch-all to make sure that the server as a whole survives any errors where proper error handling was missed.