Skip to content

Commit 33f3401

Browse files
authored
Feat: Expose MCP message handler configuration (#1795)
1 parent c96a40a commit 33f3401

File tree

2 files changed

+149
-0
lines changed

2 files changed

+149
-0
lines changed

src/agents/mcp/server.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1313
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
14+
from mcp.client.session import MessageHandlerFnT
1415
from mcp.client.sse import sse_client
1516
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
1617
from mcp.shared.message import SessionMessage
@@ -103,6 +104,7 @@ def __init__(
103104
use_structured_content: bool = False,
104105
max_retry_attempts: int = 0,
105106
retry_backoff_seconds_base: float = 1.0,
107+
message_handler: MessageHandlerFnT | None = None,
106108
):
107109
"""
108110
Args:
@@ -124,6 +126,8 @@ def __init__(
124126
Defaults to no retries.
125127
retry_backoff_seconds_base: The base delay, in seconds, used for exponential
126128
backoff between retries.
129+
message_handler: Optional handler invoked for session messages as delivered by the
130+
ClientSession.
127131
"""
128132
super().__init__(use_structured_content=use_structured_content)
129133
self.session: ClientSession | None = None
@@ -135,6 +139,7 @@ def __init__(
135139
self.client_session_timeout_seconds = client_session_timeout_seconds
136140
self.max_retry_attempts = max_retry_attempts
137141
self.retry_backoff_seconds_base = retry_backoff_seconds_base
142+
self.message_handler = message_handler
138143

139144
# The cache is always dirty at startup, so that we fetch tools at least once
140145
self._cache_dirty = True
@@ -272,6 +277,7 @@ async def connect(self):
272277
timedelta(seconds=self.client_session_timeout_seconds)
273278
if self.client_session_timeout_seconds
274279
else None,
280+
message_handler=self.message_handler,
275281
)
276282
)
277283
server_result = await session.initialize()
@@ -394,6 +400,7 @@ def __init__(
394400
use_structured_content: bool = False,
395401
max_retry_attempts: int = 0,
396402
retry_backoff_seconds_base: float = 1.0,
403+
message_handler: MessageHandlerFnT | None = None,
397404
):
398405
"""Create a new MCP server based on the stdio transport.
399406
@@ -421,6 +428,8 @@ def __init__(
421428
Defaults to no retries.
422429
retry_backoff_seconds_base: The base delay, in seconds, for exponential
423430
backoff between retries.
431+
message_handler: Optional handler invoked for session messages as delivered by the
432+
ClientSession.
424433
"""
425434
super().__init__(
426435
cache_tools_list,
@@ -429,6 +438,7 @@ def __init__(
429438
use_structured_content,
430439
max_retry_attempts,
431440
retry_backoff_seconds_base,
441+
message_handler=message_handler,
432442
)
433443

434444
self.params = StdioServerParameters(
@@ -492,6 +502,7 @@ def __init__(
492502
use_structured_content: bool = False,
493503
max_retry_attempts: int = 0,
494504
retry_backoff_seconds_base: float = 1.0,
505+
message_handler: MessageHandlerFnT | None = None,
495506
):
496507
"""Create a new MCP server based on the HTTP with SSE transport.
497508
@@ -521,6 +532,8 @@ def __init__(
521532
Defaults to no retries.
522533
retry_backoff_seconds_base: The base delay, in seconds, for exponential
523534
backoff between retries.
535+
message_handler: Optional handler invoked for session messages as delivered by the
536+
ClientSession.
524537
"""
525538
super().__init__(
526539
cache_tools_list,
@@ -529,6 +542,7 @@ def __init__(
529542
use_structured_content,
530543
max_retry_attempts,
531544
retry_backoff_seconds_base,
545+
message_handler=message_handler,
532546
)
533547

534548
self.params = params
@@ -595,6 +609,7 @@ def __init__(
595609
use_structured_content: bool = False,
596610
max_retry_attempts: int = 0,
597611
retry_backoff_seconds_base: float = 1.0,
612+
message_handler: MessageHandlerFnT | None = None,
598613
):
599614
"""Create a new MCP server based on the Streamable HTTP transport.
600615
@@ -625,6 +640,8 @@ def __init__(
625640
Defaults to no retries.
626641
retry_backoff_seconds_base: The base delay, in seconds, for exponential
627642
backoff between retries.
643+
message_handler: Optional handler invoked for session messages as delivered by the
644+
ClientSession.
628645
"""
629646
super().__init__(
630647
cache_tools_list,
@@ -633,6 +650,7 @@ def __init__(
633650
use_structured_content,
634651
max_retry_attempts,
635652
retry_backoff_seconds_base,
653+
message_handler=message_handler,
636654
)
637655

638656
self.params = params

tests/mcp/test_message_handler.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import contextlib
2+
3+
import anyio
4+
import pytest
5+
from mcp.client.session import MessageHandlerFnT
6+
from mcp.shared.message import SessionMessage
7+
from mcp.shared.session import RequestResponder
8+
from mcp.types import (
9+
ClientResult,
10+
Implementation,
11+
InitializeResult,
12+
ServerCapabilities,
13+
ServerNotification,
14+
ServerRequest,
15+
)
16+
17+
from agents.mcp.server import (
18+
MCPServerSse,
19+
MCPServerStdio,
20+
MCPServerStreamableHttp,
21+
_MCPServerWithClientSession,
22+
)
23+
24+
HandlerMessage = (
25+
RequestResponder[ServerRequest, ClientResult]
26+
| ServerNotification
27+
| Exception
28+
)
29+
30+
31+
class _StubClientSession:
32+
"""Stub ClientSession that records the configured message handler."""
33+
34+
def __init__(
35+
self,
36+
read_stream,
37+
write_stream,
38+
read_timeout_seconds,
39+
*,
40+
message_handler=None,
41+
**_: object,
42+
) -> None:
43+
self.message_handler = message_handler
44+
45+
async def __aenter__(self):
46+
return self
47+
48+
async def __aexit__(self, exc_type, exc, tb):
49+
return False
50+
51+
async def initialize(self) -> InitializeResult:
52+
capabilities = ServerCapabilities.model_construct()
53+
server_info = Implementation.model_construct(name="stub", version="1.0")
54+
return InitializeResult(
55+
protocolVersion="2024-11-05",
56+
capabilities=capabilities,
57+
serverInfo=server_info,
58+
)
59+
60+
61+
class _MessageHandlerTestServer(_MCPServerWithClientSession):
62+
def __init__(self, handler: MessageHandlerFnT | None):
63+
super().__init__(
64+
cache_tools_list=False,
65+
client_session_timeout_seconds=None,
66+
message_handler=handler,
67+
)
68+
69+
def create_streams(self):
70+
@contextlib.asynccontextmanager
71+
async def _streams():
72+
send_stream, recv_stream = (
73+
anyio.create_memory_object_stream[SessionMessage | Exception](1))
74+
try:
75+
yield recv_stream, send_stream, None
76+
finally:
77+
await recv_stream.aclose()
78+
await send_stream.aclose()
79+
80+
return _streams()
81+
82+
@property
83+
def name(self) -> str:
84+
return "test-server"
85+
86+
87+
@pytest.mark.asyncio
88+
async def test_client_session_receives_message_handler(monkeypatch):
89+
captured: dict[str, object] = {}
90+
91+
def _recording_client_session(*args, **kwargs):
92+
session = _StubClientSession(*args, **kwargs)
93+
captured["message_handler"] = session.message_handler
94+
return session
95+
96+
monkeypatch.setattr("agents.mcp.server.ClientSession", _recording_client_session)
97+
98+
class _AsyncHandler:
99+
async def __call__(self, message: HandlerMessage) -> None:
100+
del message
101+
102+
handler: MessageHandlerFnT = _AsyncHandler()
103+
104+
server = _MessageHandlerTestServer(handler)
105+
106+
try:
107+
await server.connect()
108+
finally:
109+
await server.cleanup()
110+
111+
assert captured["message_handler"] is handler
112+
113+
114+
@pytest.mark.parametrize(
115+
"server_cls, params",
116+
[
117+
(MCPServerSse, {"url": "https://example.com"}),
118+
(MCPServerStreamableHttp, {"url": "https://example.com"}),
119+
(MCPServerStdio, {"command": "python"}),
120+
],
121+
)
122+
def test_message_handler_propagates_to_server_base(server_cls, params):
123+
class _AsyncHandler:
124+
async def __call__(self, message: HandlerMessage) -> None:
125+
del message
126+
127+
handler: MessageHandlerFnT = _AsyncHandler()
128+
129+
server = server_cls(params, message_handler=handler)
130+
131+
assert server.message_handler is handler

0 commit comments

Comments
 (0)