Skip to content

Commit 69b7e04

Browse files
Alec Solderalecsolder
authored andcommitted
Adding one unit test which works locally, fixing type to name translation issues
Signed-off-by: Alec Solder <alecs@fb.com>
1 parent e6ad3ae commit 69b7e04

File tree

5 files changed

+69
-25
lines changed

5 files changed

+69
-25
lines changed

tests/entrypoints/openai/test_response_api_with_harmony.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,24 @@ async def test_code_interpreter(client: OpenAI, model_name: str):
373373
assert response.status == "completed"
374374

375375

376+
@pytest.mark.asyncio
377+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
378+
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
379+
async def test_mcp_tool(client: OpenAI, model_name: str):
380+
response = await client.responses.create(
381+
model=model_name,
382+
input="Please multiply 123 and 456 using the available python tool.",
383+
tools=[{
384+
"type": "mcp",
385+
"server_label": "code_interpreter",
386+
# URL unused for DemoToolServer
387+
"server_url": "http://localhost:8888"
388+
}],
389+
)
390+
assert response is not None
391+
assert response.status == "completed"
392+
393+
376394
def get_weather(latitude, longitude):
377395
response = requests.get(
378396
f"https://api.open-meteo.com/v1/forecast?latitude={latitude}&longitude={longitude}&current=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m" # noqa

vllm/entrypoints/context.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from contextlib import AsyncExitStack
99
from typing import TYPE_CHECKING, Optional, Union
1010

11+
from openai.types.responses.tool import Mcp
1112
from openai_harmony import Author, Message, Role, StreamState, TextContent
1213

1314
from vllm.entrypoints.harmony_utils import (
@@ -22,6 +23,21 @@
2223
logger = logging.getLogger(__name__)
2324

2425

26+
# This is currently needed as the tool type doesn't 1:1 match the
27+
# tool namespace, which is what is used to look up the
28+
# connection to the tool server
29+
def _map_tool_name_to_tool_type(tool_name: str) -> str:
30+
if tool_name == "browser":
31+
return "web_search_preview"
32+
elif tool_name == "python":
33+
return "code_interpreter"
34+
elif tool_name == "container":
35+
return "container"
36+
else:
37+
raise ValueError(
38+
f"Built in tool name not defined in mapping: {tool_name}")
39+
40+
2541
class TurnTokens:
2642
"""Tracks token counts for a single conversation turn."""
2743

@@ -59,8 +75,8 @@ def render_for_completion(self) -> list[int]:
5975

6076
@abstractmethod
6177
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
62-
exit_stack: AsyncExitStack,
63-
request_id: str) -> None:
78+
exit_stack: AsyncExitStack, request_id: str,
79+
mcp_tools: dict[str, Mcp]) -> None:
6480
pass
6581

6682
@abstractmethod
@@ -96,8 +112,8 @@ def render_for_completion(self) -> list[int]:
96112
raise NotImplementedError("Should not be called.")
97113

98114
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
99-
exit_stack: AsyncExitStack,
100-
request_id: str) -> None:
115+
exit_stack: AsyncExitStack, request_id: str,
116+
mcp_tools: dict[str, Mcp]) -> None:
101117
pass
102118

103119
async def cleanup_session(self) -> None:
@@ -310,20 +326,18 @@ async def call_python_tool(self, tool_session: Union["ClientSession",
310326
recipient=Role.ASSISTANT)
311327
]
312328

313-
async def init_tool_sessions(
314-
self,
315-
tool_server: Optional[ToolServer],
316-
exit_stack: AsyncExitStack,
317-
request_id: str,
318-
mcp_tool_headers: dict[str, dict[str, str]] = None) -> None:
329+
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
330+
exit_stack: AsyncExitStack, request_id: str,
331+
mcp_tools: dict[str, Mcp]):
319332
if tool_server:
320333
for tool_name in self.available_tools:
321334
if tool_name not in self._tool_sessions:
335+
tool_type = _map_tool_name_to_tool_type(tool_name)
336+
headers = mcp_tools[
337+
tool_type].headers if tool_type in mcp_tools else None
322338
tool_session = await exit_stack.enter_async_context(
323-
tool_server.new_session(
324-
tool_name, request_id,
325-
mcp_tool_headers.get(tool_name)
326-
if tool_name in mcp_tool_headers else {}))
339+
tool_server.new_session(tool_name, request_id,
340+
headers))
327341
logger.info("Created new session for %s", tool_name)
328342
self._tool_sessions[tool_name] = tool_session
329343
exit_stack.push_async_exit(self.cleanup_session)

vllm/entrypoints/harmony_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,10 @@ def get_developer_message(
126126
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
127127
for tool in tools:
128128
if tool.type in ("web_search_preview", "code_interpreter",
129-
"container"):
129+
"container", "mcp"):
130130
# These are built-in tools that are added to the system message.
131+
# Adding in MCP for now until we support MCP tools executed
132+
# server side
131133
pass
132134

133135
elif tool.type == "function":

vllm/entrypoints/openai/serving_responses.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,8 +451,12 @@ async def responses_full_generator(
451451

452452
async with AsyncExitStack() as exit_stack:
453453
try:
454+
mcp_tools = {
455+
tool.server_label: tool
456+
for tool in request.tools if tool.type == "mcp"
457+
}
454458
await context.init_tool_sessions(self.tool_server, exit_stack,
455-
request.request_id)
459+
request.request_id, mcp_tools)
456460
async for _ in result_generator:
457461
pass
458462
except asyncio.CancelledError:
@@ -1635,8 +1639,12 @@ def _send_event(event: BaseModel):
16351639
async with AsyncExitStack() as exit_stack:
16361640
processer = None
16371641
if self.use_harmony:
1642+
mcp_tools = {
1643+
tool.server_label: tool
1644+
for tool in request.tools if tool.type == "mcp"
1645+
}
16381646
await context.init_tool_sessions(self.tool_server, exit_stack,
1639-
request.request_id)
1647+
request.request_id, mcp_tools)
16401648
processer = self._process_harmony_streaming_events
16411649
else:
16421650
processer = self._process_simple_streaming_events

vllm/entrypoints/tool_server.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
async def list_server_and_tools(server_url: str):
1919
from mcp import ClientSession
2020
from mcp.client.sse import sse_client
21-
2221
async with sse_client(url=server_url) as streams, ClientSession(
2322
*streams) as session:
2423
initialize_response = await session.initialize()
@@ -155,14 +154,14 @@ async def new_session(self,
155154
from mcp import ClientSession
156155
from mcp.client.sse import sse_client
157156
url = self.urls.get(tool_name)
158-
if headers is None:
159-
headers = {}
160-
headers = {"x-session-id": session_id, **headers}
157+
request_headers = {"x-session-id": session_id}
158+
if headers is not None:
159+
request_headers.update(headers)
161160
if not url:
162161
raise KeyError(f"Tool '{tool_name}' is not supported")
163-
async with sse_client(url=url,
164-
headers=headers) as streams, ClientSession(
165-
*streams) as session:
162+
async with sse_client(
163+
url=url, headers=request_headers) as streams, ClientSession(
164+
*streams) as session:
166165
await session.initialize()
167166
yield session
168167

@@ -198,7 +197,10 @@ def get_tool_description(self,
198197
raise ValueError(f"Unknown tool {tool_name}")
199198

200199
@asynccontextmanager
201-
async def new_session(self, tool_name: str, session_id: str):
200+
async def new_session(self,
201+
tool_name: str,
202+
session_id: str,
203+
headers: Optional[dict[str, str]] = None):
202204
if tool_name not in self.tools:
203205
raise KeyError(f"Tool '{tool_name}' is not supported")
204206
yield self.tools[tool_name]

0 commit comments

Comments
 (0)