Skip to content

Commit ac97195

Browse files
fix: address review feedback and add receive middleware test
- Move 'import inspect' to top of file - Pre-compute whether middleware is async using inspect.iscoroutinefunction() instead of checking on every message - Add test for receive_middleware to fix coverage
1 parent 90e8720 commit ac97195

File tree

2 files changed

+78
-11
lines changed

2 files changed

+78
-11
lines changed

src/mcp/shared/session.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import logging
23
from collections.abc import Awaitable, Callable
34
from contextlib import AsyncExitStack
@@ -209,22 +210,24 @@ def __init__(
209210
self._progress_callbacks = {}
210211
self._response_routers = []
211212
self._exit_stack = AsyncExitStack()
212-
self._send_middleware = send_middleware or []
213-
self._receive_middleware = receive_middleware or []
213+
# Pre-compute whether each middleware is async to avoid checking on every message
214+
self._send_middleware: list[tuple[MessageMiddleware, bool]] = [
215+
(m, inspect.iscoroutinefunction(m)) for m in (send_middleware or [])
216+
]
217+
self._receive_middleware: list[tuple[MessageMiddleware, bool]] = [
218+
(m, inspect.iscoroutinefunction(m)) for m in (receive_middleware or [])
219+
]
214220

215221
async def _apply_middleware(
216-
self, message: JSONRPCMessage, middleware_list: list[MessageMiddleware]
222+
self, message: JSONRPCMessage, middleware_list: list[tuple[MessageMiddleware, bool]]
217223
) -> JSONRPCMessage:
218224
"""Apply a list of middleware functions to a message."""
219-
import inspect
220-
221-
for middleware in middleware_list:
225+
for middleware, is_async in middleware_list:
222226
result = middleware(message)
223-
if inspect.isawaitable(result):
224-
message = await result
225-
else:
226-
message = result # type: ignore[assignment]
227-
return message
227+
if is_async:
228+
result = await result # type: ignore[misc]
229+
message = result # type: ignore[assignment]
230+
return message # type: ignore[return-value]
228231

229232
def add_response_router(self, router: ResponseRouter) -> None:
230233
"""

tests/client/test_session.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,3 +909,67 @@ async def mock_server():
909909
await session.initialize()
910910

911911
assert middleware_called
912+
913+
914+
@pytest.mark.anyio
915+
async def test_client_session_receive_middleware():
916+
"""Test that receive middleware can transform incoming messages."""
917+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
918+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
919+
920+
middleware_called = False
921+
received_response = None
922+
923+
def receive_transform(message: JSONRPCMessage) -> JSONRPCMessage:
924+
"""Middleware that observes incoming messages."""
925+
nonlocal middleware_called, received_response
926+
middleware_called = True
927+
if isinstance(message.root, JSONRPCResponse):
928+
received_response = message.root
929+
return message
930+
931+
async def mock_server():
932+
session_message = await client_to_server_receive.receive()
933+
jsonrpc_request = session_message.message
934+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
935+
936+
result = ServerResult(
937+
InitializeResult(
938+
protocolVersion=LATEST_PROTOCOL_VERSION,
939+
capabilities=ServerCapabilities(),
940+
serverInfo=Implementation(name="mock-server", version="0.1.0"),
941+
)
942+
)
943+
944+
async with server_to_client_send:
945+
await server_to_client_send.send(
946+
SessionMessage(
947+
JSONRPCMessage(
948+
JSONRPCResponse(
949+
jsonrpc="2.0",
950+
id=jsonrpc_request.root.id,
951+
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
952+
)
953+
)
954+
)
955+
)
956+
await client_to_server_receive.receive()
957+
958+
async with (
959+
ClientSession(
960+
server_to_client_receive,
961+
client_to_server_send,
962+
receive_middleware=[receive_transform],
963+
) as session,
964+
anyio.create_task_group() as tg,
965+
client_to_server_send,
966+
client_to_server_receive,
967+
server_to_client_send,
968+
server_to_client_receive,
969+
):
970+
tg.start_soon(mock_server)
971+
await session.initialize()
972+
973+
# Verify receive middleware was called and saw the response
974+
assert middleware_called
975+
assert received_response is not None

0 commit comments

Comments
 (0)