|
1 | 1 | import anyio |
2 | 2 | import pytest |
3 | | -from pydantic import AnyUrl |
4 | 3 |
|
5 | 4 | from mcp.server.fastmcp import FastMCP |
6 | | -from mcp.shared.memory import ( |
7 | | - create_connected_server_and_client_session as create_session, |
8 | | -) |
9 | | - |
10 | | -_sleep_time_seconds = 0.01 |
11 | | -_resource_name = "slow://slow_resource" |
| 5 | +from mcp.shared.memory import create_connected_server_and_client_session as create_session |
12 | 6 |
|
13 | 7 |
|
14 | 8 | @pytest.mark.anyio |
15 | 9 | async def test_messages_are_executed_concurrently(): |
16 | 10 | server = FastMCP("test") |
17 | | - call_timestamps = [] |
| 11 | + event = anyio.Event() |
| 12 | + tool_started = anyio.Event() |
| 13 | + call_order = [] |
18 | 14 |
|
19 | 15 | @server.tool("sleep") |
20 | 16 | async def sleep_tool(): |
21 | | - call_timestamps.append(("tool_start_time", anyio.current_time())) |
22 | | - await anyio.sleep(_sleep_time_seconds) |
23 | | - call_timestamps.append(("tool_end_time", anyio.current_time())) |
| 17 | + call_order.append("waiting_for_event") |
| 18 | + tool_started.set() |
| 19 | + await event.wait() |
| 20 | + call_order.append("tool_end") |
24 | 21 | return "done" |
25 | 22 |
|
26 | | - @server.resource(_resource_name) |
27 | | - async def slow_resource(): |
28 | | - call_timestamps.append(("resource_start_time", anyio.current_time())) |
29 | | - await anyio.sleep(_sleep_time_seconds) |
30 | | - call_timestamps.append(("resource_end_time", anyio.current_time())) |
| 23 | + @server.tool("trigger") |
| 24 | + async def trigger(): |
| 25 | + # Wait for tool to start before setting the event |
| 26 | + await tool_started.wait() |
| 27 | + call_order.append("trigger_started") |
| 28 | + event.set() |
| 29 | + call_order.append("trigger_end") |
31 | 30 | return "slow" |
32 | 31 |
|
33 | 32 | async with create_session(server._mcp_server) as client_session: |
| 33 | + # First tool will wait on event, second will set it |
34 | 34 | async with anyio.create_task_group() as tg: |
35 | | - for _ in range(10): |
36 | | - tg.start_soon(client_session.call_tool, "sleep") |
37 | | - tg.start_soon(client_session.read_resource, AnyUrl(_resource_name)) |
38 | | - |
39 | | - active_calls = 0 |
40 | | - max_concurrent_calls = 0 |
41 | | - for call_type, _ in sorted(call_timestamps, key=lambda x: x[1]): |
42 | | - if "start" in call_type: |
43 | | - active_calls += 1 |
44 | | - max_concurrent_calls = max(max_concurrent_calls, active_calls) |
45 | | - else: |
46 | | - active_calls -= 1 |
47 | | - print(f"Max concurrent calls: {max_concurrent_calls}") |
48 | | - assert max_concurrent_calls > 1, "No concurrent calls were executed" |
49 | | - |
50 | | - |
51 | | -def main(): |
52 | | - anyio.run(test_messages_are_executed_concurrently) |
53 | | - |
54 | | - |
55 | | -if __name__ == "__main__": |
56 | | - import logging |
57 | | - |
58 | | - logging.basicConfig(level=logging.DEBUG) |
59 | | - |
60 | | - main() |
| 35 | + # Start the tool first (it will wait on event) |
| 36 | + tg.start_soon(client_session.call_tool, "sleep") |
| 37 | + # Then the trigger tool will set the event to allow the first tool to continue |
| 38 | + await client_session.call_tool("trigger") |
| 39 | + |
| 40 | + # Verify that both ran concurrently |
| 41 | + assert call_order == [ |
| 42 | + "waiting_for_event", |
| 43 | + "trigger_started", |
| 44 | + "trigger_end", |
| 45 | + "tool_end", |
| 46 | + ], f"Expected concurrent execution, but got: {call_order}" |
0 commit comments