Skip to content

Commit 96826fa

Browse files
committed
Cleanup some code in test_mcp.py
1 parent bfb5702 commit 96826fa

File tree

1 file changed

+66
-82
lines changed

1 file changed

+66
-82
lines changed

tests/integration/mcp/test_mcp.py

Lines changed: 66 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ def _get_streamable_server_base_url(transport: str) -> str:
4848
return f"http://{host_info['host']}:{host_info['port']}"
4949

5050

51+
def _get_server_base_url(transport: str) -> str:
52+
if transport == "sse":
53+
return _get_mcp_sse_server_base_url()
54+
elif transport.startswith("streamable-"):
55+
return _get_streamable_server_base_url(transport)
56+
else:
57+
raise ValueError(f"Unknown transport: {transport}")
58+
59+
5160
def _get_headers(
5261
server_base_url: str, project_name: str, push_to_explorer: bool = True
5362
) -> dict[str, str]:
@@ -58,6 +67,35 @@ def _get_headers(
5867
}
5968

6069

70+
async def _invoke_mcp_tool(
71+
transport, gateway_url, project_name, tool_name, tool_args, whl=None, push=True
72+
):
73+
if transport == "stdio":
74+
return await mcp_stdio_client_run(
75+
whl,
76+
project_name,
77+
"resources/mcp/stdio/messenger_server/main.py",
78+
push,
79+
tool_name,
80+
tool_args,
81+
)
82+
elif transport == "sse":
83+
return await mcp_sse_client_run(
84+
f"{gateway_url}/api/v1/gateway/mcp/sse",
85+
push,
86+
tool_name,
87+
tool_args,
88+
headers=_get_headers(_get_server_base_url(transport), project_name, push),
89+
)
90+
return await mcp_streamable_client_run(
91+
f"{gateway_url}/api/v1/gateway/mcp/streamable",
92+
push,
93+
tool_name,
94+
tool_args,
95+
headers=_get_headers(_get_server_base_url(transport), project_name, push),
96+
)
97+
98+
6199
@pytest.mark.asyncio
62100
@pytest.mark.timeout(30)
63101
@pytest.mark.parametrize(
@@ -81,34 +119,15 @@ async def test_mcp_with_gateway(
81119
project_name = "test-mcp-" + str(uuid.uuid4())
82120

83121
# Run the MCP client and make the tool call.
84-
if transport == "sse":
85-
result = await mcp_sse_client_run(
86-
gateway_url + "/api/v1/gateway/mcp/sse",
87-
push_to_explorer=True,
88-
tool_name="get_last_message_from_user",
89-
tool_args={"username": "Alice"},
90-
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
91-
)
92-
elif transport == "stdio":
93-
result = await mcp_stdio_client_run(
94-
invariant_gateway_package_whl_file,
95-
project_name,
96-
server_script_path="resources/mcp/stdio/messenger_server/main.py",
97-
push_to_explorer=True,
98-
tool_name="get_last_message_from_user",
99-
tool_args={"username": "Alice"},
100-
metadata_keys={"my-custom-key": "value1", "my-custom-key-2": "value2"},
101-
)
102-
else:
103-
result = await mcp_streamable_client_run(
104-
gateway_url + "/api/v1/gateway/mcp/streamable",
105-
push_to_explorer=True,
106-
tool_name="get_last_message_from_user",
107-
tool_args={"username": "Alice"},
108-
headers=_get_headers(
109-
_get_streamable_server_base_url(transport), project_name, True
110-
),
111-
)
122+
result = await _invoke_mcp_tool(
123+
transport,
124+
gateway_url,
125+
project_name,
126+
tool_name="get_last_message_from_user",
127+
tool_args={"username": "Alice"},
128+
whl=invariant_gateway_package_whl_file,
129+
push=True,
130+
)
112131

113132
assert result.isError is False
114133
assert (
@@ -203,33 +222,15 @@ async def test_mcp_with_gateway_and_logging_guardrails(
203222
)
204223

205224
# Run the MCP client and make the tool call.
206-
if transport == "sse":
207-
result = await mcp_sse_client_run(
208-
gateway_url + "/api/v1/gateway/mcp/sse",
209-
push_to_explorer=True,
210-
tool_name="get_last_message_from_user",
211-
tool_args={"username": "Alice"},
212-
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
213-
)
214-
elif transport == "stdio":
215-
result = await mcp_stdio_client_run(
216-
invariant_gateway_package_whl_file,
217-
project_name,
218-
server_script_path="resources/mcp/stdio/messenger_server/main.py",
219-
push_to_explorer=True,
220-
tool_name="get_last_message_from_user",
221-
tool_args={"username": "Alice"},
222-
)
223-
else:
224-
result = await mcp_streamable_client_run(
225-
gateway_url + "/api/v1/gateway/mcp/streamable",
226-
push_to_explorer=True,
227-
tool_name="get_last_message_from_user",
228-
tool_args={"username": "Alice"},
229-
headers=_get_headers(
230-
_get_streamable_server_base_url(transport), project_name, True
231-
),
232-
)
225+
result = await _invoke_mcp_tool(
226+
transport,
227+
gateway_url,
228+
project_name,
229+
tool_name="get_last_message_from_user",
230+
tool_args={"username": "Alice"},
231+
whl=invariant_gateway_package_whl_file,
232+
push=True,
233+
)
233234

234235
assert result.isError is False
235236
assert (
@@ -658,33 +659,16 @@ async def test_mcp_tool_list_blocking(
658659
return
659660

660661
# Run the MCP client and make the tools/list call.
661-
if transport == "sse":
662-
tools_result = await mcp_sse_client_run(
663-
gateway_url + "/api/v1/gateway/mcp/sse",
664-
push_to_explorer=True,
665-
tool_name="tools/list",
666-
tool_args={},
667-
headers=_get_headers(_get_mcp_sse_server_base_url(), project_name, True),
668-
)
669-
elif transport == "stdio":
670-
tools_result = await mcp_stdio_client_run(
671-
invariant_gateway_package_whl_file,
672-
project_name,
673-
server_script_path="resources/mcp/stdio/messenger_server/main.py",
674-
push_to_explorer=True,
675-
tool_name="tools/list",
676-
tool_args={},
677-
)
678-
else:
679-
tools_result = await mcp_streamable_client_run(
680-
gateway_url + "/api/v1/gateway/mcp/streamable",
681-
push_to_explorer=True,
682-
tool_name="tools/list",
683-
tool_args={},
684-
headers=_get_headers(
685-
_get_streamable_server_base_url(transport), project_name, True
686-
),
687-
)
662+
# Run the MCP client and make the tool call.
663+
tools_result = await _invoke_mcp_tool(
664+
transport,
665+
gateway_url,
666+
project_name,
667+
tool_name="tools/list",
668+
tool_args={},
669+
whl=invariant_gateway_package_whl_file,
670+
push=True,
671+
)
688672
assert "blocked_get_last_message_from_user" in str(tools_result), (
689673
"Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: "
690674
+ str(tools_result)

0 commit comments

Comments
 (0)