Skip to content

Commit d69d14a

Browse files
qandrewcharlifu
authored andcommitted
[gpt-oss][1][bugfix] fix streaming final output (vllm-project#24466)
Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: charlifu <charlifu@amd.com>
1 parent 36d901b commit d69d14a

File tree

3 files changed

+91
-5
lines changed

3 files changed

+91
-5
lines changed

tests/entrypoints/openai/test_response_api_with_harmony.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,8 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool):
364364
events.append(event)
365365

366366
assert len(events) > 0
367+
response_completed_event = events[-1]
368+
assert len(response_completed_event.response.output) > 0
367369

368370
if background:
369371
starting_after = 5

tests/entrypoints/test_context.py

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from unittest.mock import MagicMock, patch
55

66
import pytest
7-
from openai_harmony import StreamState
7+
from openai_harmony import Author, Message, Role, StreamState, TextContent
88

99
from vllm.entrypoints.context import HarmonyContext, StreamingHarmonyContext
1010
from vllm.outputs import CompletionOutput, RequestOutput
@@ -312,9 +312,9 @@ async def test_negative_tool_tokens_edge_case():
312312
@pytest.mark.asyncio
313313
async def test_streaming_multi_turn_token_counting(mock_parser):
314314
"""Test token counting for streaming multi-turn conversations.
315-
316-
This test focuses on how StreamingHarmonyContext counts tokens in a
317-
multi-turn conversation with streaming (token-by-token) outputs and
315+
316+
This test focuses on how StreamingHarmonyContext counts tokens in a
317+
multi-turn conversation with streaming (token-by-token) outputs and
318318
message boundaries.
319319
"""
320320
# Create a streaming context
@@ -423,3 +423,78 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
423423
additional_tool_tokens = 13 - 8 - 3 # = 2
424424
assert context.num_tool_output_tokens == expected_tool_tokens \
425425
+ additional_tool_tokens
426+
427+
428+
@pytest.mark.asyncio
429+
async def test_streaming_message_synchronization(mock_parser):
430+
"""Test message synchronization logic from lines 413-417 in context.py.
431+
432+
This test verifies that when parser.messages contains more messages than
433+
the context's _messages (minus initial messages), the context properly
434+
extends its message list with the new parser messages.
435+
"""
436+
437+
# Create a streaming context with some initial messages
438+
initial_messages = [
439+
Message(
440+
author=Author(role=Role.USER, name="user"),
441+
content=[TextContent(text="Hello")],
442+
recipient=Role.ASSISTANT,
443+
)
444+
]
445+
context = StreamingHarmonyContext(messages=initial_messages,
446+
available_tools=[])
447+
448+
# Verify initial state
449+
assert len(context._messages) == 1
450+
assert context.num_init_messages == 1
451+
452+
# Mock parser to have more messages than context
453+
# Simulate parser having processed 3 new messages
454+
mock_parser.messages = [
455+
Message(
456+
author=Author(role=Role.ASSISTANT, name="assistant"),
457+
content=[TextContent(text="Response 1")],
458+
recipient=Role.USER,
459+
),
460+
]
461+
462+
# This should trigger the message synchronization logic
463+
context.append_output(
464+
create_mock_request_output(prompt_token_ids=[1, 2, 3],
465+
output_token_ids=[101],
466+
finished=False))
467+
468+
# Verify that messages were synchronized
469+
assert len(context._messages) == 2
470+
471+
# Verify the new messages were added correctly
472+
assert context._messages[1].content[0].text == "Response 1"
473+
474+
# Test the specific condition from line 413-414:
475+
# len(self._messages) - self.num_init_messages < len(self.parser.messages)
476+
messages_minus_init = len(context._messages) - context.num_init_messages
477+
parser_messages_count = len(mock_parser.messages)
478+
479+
# After synchronization, they should be equal (no longer less than)
480+
assert messages_minus_init == parser_messages_count
481+
482+
# Test edge case: add one more parser message
483+
mock_parser.messages.append(
484+
Message(
485+
author=Author(role=Role.ASSISTANT, name="assistant"),
486+
content=[TextContent(text="Response 4")],
487+
recipient=Role.USER,
488+
))
489+
490+
# Create another output to trigger synchronization again
491+
mock_output2 = create_mock_request_output(prompt_token_ids=[1, 2, 3],
492+
output_token_ids=[102],
493+
finished=True)
494+
495+
context.append_output(mock_output2)
496+
497+
# Verify the fourth message was added, num_init_messages is still 1
498+
assert len(context._messages) == 3
499+
assert context.num_init_messages == 1
500+
assert context._messages[2].content[0].text == "Response 4"

vllm/entrypoints/context.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ def append_output(self, output: Union[RequestOutput,
151151
self._update_decode_token_usage(output)
152152
# Move current turn to previous turn for next turn's calculations
153153
self.previous_turn = self.current_turn.copy()
154+
# append_output is called only once before tool calling
155+
# in non-streaming case
156+
# so we can append all the parser messages to _messages
154157
output_msgs = self.parser.messages
155158
# The responses finish reason is set in the last message
156159
self.finish_reason = output.outputs[0].finish_reason
@@ -387,7 +390,7 @@ def __init__(self, *args, **kwargs):
387390

388391
@property
389392
def messages(self) -> list:
390-
return self.parser.messages
393+
return self._messages
391394

392395
def append_output(self, output: Union[RequestOutput,
393396
list[Message]]) -> None:
@@ -412,6 +415,11 @@ def append_output(self, output: Union[RequestOutput,
412415
# Check if the current token is part of reasoning content
413416
self._update_num_reasoning_tokens()
414417
self.last_tok = tok
418+
if len(self._messages) - self.num_init_messages < len(
419+
self.parser.messages):
420+
self._messages.extend(
421+
self.parser.messages[len(self._messages) -
422+
self.num_init_messages:])
415423
else:
416424
# Handle the case of tool output in direct message format
417425
assert len(output) == 1, "Tool output should be a single message"
@@ -424,6 +432,7 @@ def append_output(self, output: Union[RequestOutput,
424432
for tok in toks:
425433
self.parser.process(tok)
426434
self.last_tok = toks[-1]
435+
# TODO: add tool_output messages to self._messages
427436

428437
def is_expecting_start(self) -> bool:
429438
return self.parser.state == StreamState.EXPECT_START

0 commit comments

Comments
 (0)