Skip to content

Commit e90bb7e

Browse files
committed
fix: multiple tool calls in remote-vllm chat_completion
This fixes an issue in how we used the tool_call_buf from streaming tool calls in the remote-vllm provider where it would end up concatenating parameters from multiple different tool call results instead of aggregating the results from each tool call separately. It also fixes an issue found while digging into that where we were accidentally mixing the json string form of tool call parameters with the string representation of the python form, which mean we'd end up with single quotes in what should be double-quoted json strings. The following tests are now passing 100% for the remote-vllm provider, where some of the test_text_inference were failing before this change: ``` VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" LLAMA_STACK_CONFIG=remote-vllm python -m pytest -v tests/integration/inference/test_text_inference.py --text-model "RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" LLAMA_STACK_CONFIG=remote-vllm python -m pytest -v tests/integration/inference/test_vision_inference.py --vision-model "RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" ``` Many of the agent tests are passing, although some are failing due to bugs in vLLM's pythonic tool parser for Llama models. See the PR at vllm-project/vllm#17917 and a gist at https://gist.github.com/bbrowning/b5007709015cb2aabd85e0bd08e6d60f for changes needed there, which will have to get made upstream in vLLM. Agent tests: ``` VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" LLAMA_STACK_CONFIG=remote-vllm python -m pytest -v tests/integration/agents/test_agents.py --text-model "RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" ```` Signed-off-by: Ben Browning <bbrownin@redhat.com>
1 parent 43d4447 commit e90bb7e

File tree

3 files changed

+220
-39
lines changed

3 files changed

+220
-39
lines changed

llama_stack/providers/remote/inference/vllm/vllm.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -161,51 +161,63 @@ def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
161161
async def _process_vllm_chat_completion_stream_response(
162162
stream: AsyncGenerator[OpenAIChatCompletionChunk, None],
163163
) -> AsyncGenerator:
164-
event_type = ChatCompletionResponseEventType.start
165-
tool_call_buf = UnparseableToolCall()
164+
yield ChatCompletionResponseStreamChunk(
165+
event=ChatCompletionResponseEvent(
166+
event_type=ChatCompletionResponseEventType.start,
167+
delta=TextDelta(text=""),
168+
)
169+
)
170+
event_type = ChatCompletionResponseEventType.progress
171+
tool_call_bufs: dict[str, UnparseableToolCall] = {}
166172
async for chunk in stream:
167173
if not chunk.choices:
168174
log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.")
169175
continue
170176
choice = chunk.choices[0]
171177
if choice.delta.tool_calls:
172-
tool_call = convert_tool_call(choice.delta.tool_calls[0])
173-
tool_call_buf.tool_name += str(tool_call.tool_name)
174-
tool_call_buf.call_id += tool_call.call_id
175-
# TODO: remove str() when dict type for 'arguments' is no longer allowed
176-
tool_call_buf.arguments += str(tool_call.arguments)
178+
for delta_tool_call in choice.delta.tool_calls:
179+
tool_call = convert_tool_call(delta_tool_call)
180+
if delta_tool_call.index not in tool_call_bufs:
181+
tool_call_bufs[delta_tool_call.index] = UnparseableToolCall()
182+
tool_call_buf = tool_call_bufs[delta_tool_call.index]
183+
tool_call_buf.tool_name += str(tool_call.tool_name)
184+
tool_call_buf.call_id += tool_call.call_id
185+
tool_call_buf.arguments += (
186+
tool_call.arguments if isinstance(tool_call.arguments, str) else json.dumps(tool_call.arguments)
187+
)
177188
if choice.finish_reason:
178-
args_str = tool_call_buf.arguments
179-
args = None
180-
try:
181-
args = {} if not args_str else json.loads(args_str)
182-
except Exception as e:
183-
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
184-
if args:
185-
yield ChatCompletionResponseStreamChunk(
186-
event=ChatCompletionResponseEvent(
187-
event_type=event_type,
188-
delta=ToolCallDelta(
189-
tool_call=ToolCall(
190-
call_id=tool_call_buf.call_id,
191-
tool_name=tool_call_buf.tool_name,
192-
arguments=args,
193-
arguments_json=args_str,
189+
for _index, tool_call_buf in sorted(tool_call_bufs.items()):
190+
args_str = tool_call_buf.arguments
191+
args = None
192+
try:
193+
args = {} if not args_str else json.loads(args_str)
194+
except Exception as e:
195+
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
196+
if args:
197+
yield ChatCompletionResponseStreamChunk(
198+
event=ChatCompletionResponseEvent(
199+
event_type=event_type,
200+
delta=ToolCallDelta(
201+
tool_call=ToolCall(
202+
call_id=tool_call_buf.call_id,
203+
tool_name=tool_call_buf.tool_name,
204+
arguments=args,
205+
arguments_json=args_str,
206+
),
207+
parse_status=ToolCallParseStatus.succeeded,
194208
),
195-
parse_status=ToolCallParseStatus.succeeded,
196-
),
209+
)
197210
)
198-
)
199-
elif args_str:
200-
yield ChatCompletionResponseStreamChunk(
201-
event=ChatCompletionResponseEvent(
202-
event_type=ChatCompletionResponseEventType.progress,
203-
delta=ToolCallDelta(
204-
tool_call=str(tool_call_buf),
205-
parse_status=ToolCallParseStatus.failed,
206-
),
211+
elif args_str:
212+
yield ChatCompletionResponseStreamChunk(
213+
event=ChatCompletionResponseEvent(
214+
event_type=ChatCompletionResponseEventType.progress,
215+
delta=ToolCallDelta(
216+
tool_call=str(tool_call_buf),
217+
parse_status=ToolCallParseStatus.failed,
218+
),
219+
)
207220
)
208-
)
209221
yield ChatCompletionResponseStreamChunk(
210222
event=ChatCompletionResponseEvent(
211223
event_type=ChatCompletionResponseEventType.complete,

llama_stack/providers/utils/inference/openai_compat.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,13 +531,19 @@ async def _convert_content(content) -> dict:
531531
tool_name = tc.tool_name
532532
if isinstance(tool_name, BuiltinTool):
533533
tool_name = tool_name.value
534+
535+
# arguments_json can be None, so attempt it first and fall back to arguments
536+
if hasattr(tc, "arguments_json") and tc.arguments_json:
537+
arguments = tc.arguments_json
538+
else:
539+
arguments = json.dumps(tc.arguments)
534540
result["tool_calls"].append(
535541
{
536542
"id": tc.call_id,
537543
"type": "function",
538544
"function": {
539545
"name": tool_name,
540-
"arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments),
546+
"arguments": arguments,
541547
},
542548
}
543549
)

tests/unit/providers/inference/test_remote_vllm.py

Lines changed: 166 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424
from openai.types.chat.chat_completion_chunk import (
2525
ChoiceDelta as OpenAIChoiceDelta,
2626
)
27+
from openai.types.chat.chat_completion_chunk import (
28+
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
29+
)
30+
from openai.types.chat.chat_completion_chunk import (
31+
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
32+
)
2733
from openai.types.model import Model as OpenAIModel
2834

2935
from llama_stack.apis.inference import (
@@ -206,8 +212,164 @@ async def mock_stream():
206212
yield chunk
207213

208214
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
209-
assert len(chunks) == 1
210-
assert chunks[0].event.stop_reason == StopReason.end_of_turn
215+
assert len(chunks) == 2
216+
assert chunks[0].event.event_type.value == "start"
217+
assert chunks[1].event.event_type.value == "complete"
218+
assert chunks[1].event.stop_reason == StopReason.end_of_turn
219+
220+
221+
@pytest.mark.asyncio
222+
async def test_tool_call_delta_streaming_arguments_dict():
223+
async def mock_stream():
224+
mock_chunk_1 = OpenAIChatCompletionChunk(
225+
id="chunk-1",
226+
created=1,
227+
model="foo",
228+
object="chat.completion.chunk",
229+
choices=[
230+
OpenAIChoice(
231+
delta=OpenAIChoiceDelta(
232+
content="",
233+
tool_calls=[
234+
OpenAIChoiceDeltaToolCall(
235+
id="tc_1",
236+
index=1,
237+
function=OpenAIChoiceDeltaToolCallFunction(
238+
name="power",
239+
arguments="",
240+
),
241+
)
242+
],
243+
),
244+
finish_reason=None,
245+
index=0,
246+
)
247+
],
248+
)
249+
mock_chunk_2 = OpenAIChatCompletionChunk(
250+
id="chunk-2",
251+
created=1,
252+
model="foo",
253+
object="chat.completion.chunk",
254+
choices=[
255+
OpenAIChoice(
256+
delta=OpenAIChoiceDelta(
257+
content="",
258+
tool_calls=[
259+
OpenAIChoiceDeltaToolCall(
260+
id="tc_1",
261+
index=1,
262+
function=OpenAIChoiceDeltaToolCallFunction(
263+
name="power",
264+
arguments='{"number": 28, "power": 3}',
265+
),
266+
)
267+
],
268+
),
269+
finish_reason=None,
270+
index=0,
271+
)
272+
],
273+
)
274+
mock_chunk_3 = OpenAIChatCompletionChunk(
275+
id="chunk-3",
276+
created=1,
277+
model="foo",
278+
object="chat.completion.chunk",
279+
choices=[
280+
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
281+
],
282+
)
283+
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
284+
yield chunk
285+
286+
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
287+
assert len(chunks) == 3
288+
assert chunks[0].event.event_type.value == "start"
289+
assert chunks[1].event.event_type.value == "progress"
290+
assert chunks[1].event.delta.type == "tool_call"
291+
assert chunks[1].event.delta.parse_status.value == "succeeded"
292+
assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}'
293+
assert chunks[2].event.event_type.value == "complete"
294+
295+
296+
@pytest.mark.asyncio
297+
async def test_multiple_tool_calls():
298+
async def mock_stream():
299+
mock_chunk_1 = OpenAIChatCompletionChunk(
300+
id="chunk-1",
301+
created=1,
302+
model="foo",
303+
object="chat.completion.chunk",
304+
choices=[
305+
OpenAIChoice(
306+
delta=OpenAIChoiceDelta(
307+
content="",
308+
tool_calls=[
309+
OpenAIChoiceDeltaToolCall(
310+
id="",
311+
index=1,
312+
function=OpenAIChoiceDeltaToolCallFunction(
313+
name="power",
314+
arguments='{"number": 28, "power": 3}',
315+
),
316+
),
317+
],
318+
),
319+
finish_reason=None,
320+
index=0,
321+
)
322+
],
323+
)
324+
mock_chunk_2 = OpenAIChatCompletionChunk(
325+
id="chunk-2",
326+
created=1,
327+
model="foo",
328+
object="chat.completion.chunk",
329+
choices=[
330+
OpenAIChoice(
331+
delta=OpenAIChoiceDelta(
332+
content="",
333+
tool_calls=[
334+
OpenAIChoiceDeltaToolCall(
335+
id="",
336+
index=2,
337+
function=OpenAIChoiceDeltaToolCallFunction(
338+
name="multiple",
339+
arguments='{"first_number": 4, "second_number": 7}',
340+
),
341+
),
342+
],
343+
),
344+
finish_reason=None,
345+
index=0,
346+
)
347+
],
348+
)
349+
mock_chunk_3 = OpenAIChatCompletionChunk(
350+
id="chunk-3",
351+
created=1,
352+
model="foo",
353+
object="chat.completion.chunk",
354+
choices=[
355+
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
356+
],
357+
)
358+
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
359+
yield chunk
360+
361+
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
362+
assert len(chunks) == 4
363+
assert chunks[0].event.event_type.value == "start"
364+
assert chunks[1].event.event_type.value == "progress"
365+
assert chunks[1].event.delta.type == "tool_call"
366+
assert chunks[1].event.delta.parse_status.value == "succeeded"
367+
assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}'
368+
assert chunks[2].event.event_type.value == "progress"
369+
assert chunks[2].event.delta.type == "tool_call"
370+
assert chunks[2].event.delta.parse_status.value == "succeeded"
371+
assert chunks[2].event.delta.tool_call.arguments_json == '{"first_number": 4, "second_number": 7}'
372+
assert chunks[3].event.event_type.value == "complete"
211373

212374

213375
@pytest.mark.asyncio
@@ -231,7 +393,8 @@ async def mock_stream():
231393
yield chunk
232394

233395
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
234-
assert len(chunks) == 0
396+
assert len(chunks) == 1
397+
assert chunks[0].event.event_type.value == "start"
235398

236399

237400
def test_chat_completion_doesnt_block_event_loop(caplog):

0 commit comments

Comments
 (0)