@@ -48,6 +48,15 @@ def _get_streamable_server_base_url(transport: str) -> str:
48
48
return f"http://{ host_info ['host' ]} :{ host_info ['port' ]} "
49
49
50
50
51
+ def _get_server_base_url (transport : str ) -> str :
52
+ if transport == "sse" :
53
+ return _get_mcp_sse_server_base_url ()
54
+ elif transport .startswith ("streamable-" ):
55
+ return _get_streamable_server_base_url (transport )
56
+ else :
57
+ raise ValueError (f"Unknown transport: { transport } " )
58
+
59
+
51
60
def _get_headers (
52
61
server_base_url : str , project_name : str , push_to_explorer : bool = True
53
62
) -> dict [str , str ]:
@@ -58,6 +67,35 @@ def _get_headers(
58
67
}
59
68
60
69
70
+ async def _invoke_mcp_tool (
71
+ transport , gateway_url , project_name , tool_name , tool_args , whl = None , push = True
72
+ ):
73
+ if transport == "stdio" :
74
+ return await mcp_stdio_client_run (
75
+ whl ,
76
+ project_name ,
77
+ "resources/mcp/stdio/messenger_server/main.py" ,
78
+ push ,
79
+ tool_name ,
80
+ tool_args ,
81
+ )
82
+ elif transport == "sse" :
83
+ return await mcp_sse_client_run (
84
+ f"{ gateway_url } /api/v1/gateway/mcp/sse" ,
85
+ push ,
86
+ tool_name ,
87
+ tool_args ,
88
+ headers = _get_headers (_get_server_base_url (transport ), project_name , push ),
89
+ )
90
+ return await mcp_streamable_client_run (
91
+ f"{ gateway_url } /api/v1/gateway/mcp/streamable" ,
92
+ push ,
93
+ tool_name ,
94
+ tool_args ,
95
+ headers = _get_headers (_get_server_base_url (transport ), project_name , push ),
96
+ )
97
+
98
+
61
99
@pytest .mark .asyncio
62
100
@pytest .mark .timeout (30 )
63
101
@pytest .mark .parametrize (
@@ -81,34 +119,15 @@ async def test_mcp_with_gateway(
81
119
project_name = "test-mcp-" + str (uuid .uuid4 ())
82
120
83
121
# Run the MCP client and make the tool call.
84
- if transport == "sse" :
85
- result = await mcp_sse_client_run (
86
- gateway_url + "/api/v1/gateway/mcp/sse" ,
87
- push_to_explorer = True ,
88
- tool_name = "get_last_message_from_user" ,
89
- tool_args = {"username" : "Alice" },
90
- headers = _get_headers (_get_mcp_sse_server_base_url (), project_name , True ),
91
- )
92
- elif transport == "stdio" :
93
- result = await mcp_stdio_client_run (
94
- invariant_gateway_package_whl_file ,
95
- project_name ,
96
- server_script_path = "resources/mcp/stdio/messenger_server/main.py" ,
97
- push_to_explorer = True ,
98
- tool_name = "get_last_message_from_user" ,
99
- tool_args = {"username" : "Alice" },
100
- metadata_keys = {"my-custom-key" : "value1" , "my-custom-key-2" : "value2" },
101
- )
102
- else :
103
- result = await mcp_streamable_client_run (
104
- gateway_url + "/api/v1/gateway/mcp/streamable" ,
105
- push_to_explorer = True ,
106
- tool_name = "get_last_message_from_user" ,
107
- tool_args = {"username" : "Alice" },
108
- headers = _get_headers (
109
- _get_streamable_server_base_url (transport ), project_name , True
110
- ),
111
- )
122
+ result = await _invoke_mcp_tool (
123
+ transport ,
124
+ gateway_url ,
125
+ project_name ,
126
+ tool_name = "get_last_message_from_user" ,
127
+ tool_args = {"username" : "Alice" },
128
+ whl = invariant_gateway_package_whl_file ,
129
+ push = True ,
130
+ )
112
131
113
132
assert result .isError is False
114
133
assert (
@@ -203,33 +222,15 @@ async def test_mcp_with_gateway_and_logging_guardrails(
203
222
)
204
223
205
224
# Run the MCP client and make the tool call.
206
- if transport == "sse" :
207
- result = await mcp_sse_client_run (
208
- gateway_url + "/api/v1/gateway/mcp/sse" ,
209
- push_to_explorer = True ,
210
- tool_name = "get_last_message_from_user" ,
211
- tool_args = {"username" : "Alice" },
212
- headers = _get_headers (_get_mcp_sse_server_base_url (), project_name , True ),
213
- )
214
- elif transport == "stdio" :
215
- result = await mcp_stdio_client_run (
216
- invariant_gateway_package_whl_file ,
217
- project_name ,
218
- server_script_path = "resources/mcp/stdio/messenger_server/main.py" ,
219
- push_to_explorer = True ,
220
- tool_name = "get_last_message_from_user" ,
221
- tool_args = {"username" : "Alice" },
222
- )
223
- else :
224
- result = await mcp_streamable_client_run (
225
- gateway_url + "/api/v1/gateway/mcp/streamable" ,
226
- push_to_explorer = True ,
227
- tool_name = "get_last_message_from_user" ,
228
- tool_args = {"username" : "Alice" },
229
- headers = _get_headers (
230
- _get_streamable_server_base_url (transport ), project_name , True
231
- ),
232
- )
225
+ result = await _invoke_mcp_tool (
226
+ transport ,
227
+ gateway_url ,
228
+ project_name ,
229
+ tool_name = "get_last_message_from_user" ,
230
+ tool_args = {"username" : "Alice" },
231
+ whl = invariant_gateway_package_whl_file ,
232
+ push = True ,
233
+ )
233
234
234
235
assert result .isError is False
235
236
assert (
@@ -658,33 +659,16 @@ async def test_mcp_tool_list_blocking(
658
659
return
659
660
660
661
# Run the MCP client and make the tools/list call.
661
- if transport == "sse" :
662
- tools_result = await mcp_sse_client_run (
663
- gateway_url + "/api/v1/gateway/mcp/sse" ,
664
- push_to_explorer = True ,
665
- tool_name = "tools/list" ,
666
- tool_args = {},
667
- headers = _get_headers (_get_mcp_sse_server_base_url (), project_name , True ),
668
- )
669
- elif transport == "stdio" :
670
- tools_result = await mcp_stdio_client_run (
671
- invariant_gateway_package_whl_file ,
672
- project_name ,
673
- server_script_path = "resources/mcp/stdio/messenger_server/main.py" ,
674
- push_to_explorer = True ,
675
- tool_name = "tools/list" ,
676
- tool_args = {},
677
- )
678
- else :
679
- tools_result = await mcp_streamable_client_run (
680
- gateway_url + "/api/v1/gateway/mcp/streamable" ,
681
- push_to_explorer = True ,
682
- tool_name = "tools/list" ,
683
- tool_args = {},
684
- headers = _get_headers (
685
- _get_streamable_server_base_url (transport ), project_name , True
686
- ),
687
- )
662
+ # Run the MCP client and make the tool call.
663
+ tools_result = await _invoke_mcp_tool (
664
+ transport ,
665
+ gateway_url ,
666
+ project_name ,
667
+ tool_name = "tools/list" ,
668
+ tool_args = {},
669
+ whl = invariant_gateway_package_whl_file ,
670
+ push = True ,
671
+ )
688
672
assert "blocked_get_last_message_from_user" in str (tools_result ), (
689
673
"Expected the tool names to be renamed and blocked because of the blocking guardrail on the tools/list call. Instead got: "
690
674
+ str (tools_result )
0 commit comments