Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/entrypoints/openai/test_response_api_with_harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool):
events.append(event)

assert len(events) > 0
response_completed_event = events[-1]
assert len(response_completed_event.response.output) > 0

if background:
starting_after = 5
Expand Down
83 changes: 79 additions & 4 deletions tests/entrypoints/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from unittest.mock import MagicMock, patch

import pytest
from openai_harmony import StreamState
from openai_harmony import Author, Message, Role, StreamState, TextContent

from vllm.entrypoints.context import HarmonyContext, StreamingHarmonyContext
from vllm.outputs import CompletionOutput, RequestOutput
Expand Down Expand Up @@ -312,9 +312,9 @@ async def test_negative_tool_tokens_edge_case():
@pytest.mark.asyncio
async def test_streaming_multi_turn_token_counting(mock_parser):
"""Test token counting for streaming multi-turn conversations.
This test focuses on how StreamingHarmonyContext counts tokens in a
multi-turn conversation with streaming (token-by-token) outputs and

This test focuses on how StreamingHarmonyContext counts tokens in a
multi-turn conversation with streaming (token-by-token) outputs and
message boundaries.
"""
# Create a streaming context
Expand Down Expand Up @@ -423,3 +423,78 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
additional_tool_tokens = 13 - 8 - 3 # = 2
assert context.num_tool_output_tokens == expected_tool_tokens \
+ additional_tool_tokens


@pytest.mark.asyncio
async def test_streaming_message_synchronization(mock_parser):
"""Test message synchronization logic from lines 413-417 in context.py.

This test verifies that when parser.messages contains more messages than
the context's _messages (minus initial messages), the context properly
extends its message list with the new parser messages.
"""

# Create a streaming context with some initial messages
initial_messages = [
Message(
author=Author(role=Role.USER, name="user"),
content=[TextContent(text="Hello")],
recipient=Role.ASSISTANT,
)
]
context = StreamingHarmonyContext(messages=initial_messages,
available_tools=[])

# Verify initial state
assert len(context._messages) == 1
assert context.num_init_messages == 1

# Mock parser to have more messages than context
# Simulate parser having processed 3 new messages
mock_parser.messages = [
Message(
author=Author(role=Role.ASSISTANT, name="assistant"),
content=[TextContent(text="Response 1")],
recipient=Role.USER,
),
]

# This should trigger the message synchronization logic
context.append_output(
create_mock_request_output(prompt_token_ids=[1, 2, 3],
output_token_ids=[101],
finished=False))

# Verify that messages were synchronized
assert len(context._messages) == 2

# Verify the new messages were added correctly
assert context._messages[1].content[0].text == "Response 1"

# Test the specific condition from line 413-414:
# len(self._messages) - self.num_init_messages < len(self.parser.messages)
messages_minus_init = len(context._messages) - context.num_init_messages
parser_messages_count = len(mock_parser.messages)

# After synchronization, they should be equal (no longer less than)
assert messages_minus_init == parser_messages_count

# Test edge case: add one more parser message
mock_parser.messages.append(
Message(
author=Author(role=Role.ASSISTANT, name="assistant"),
content=[TextContent(text="Response 4")],
recipient=Role.USER,
))

# Create another output to trigger synchronization again
mock_output2 = create_mock_request_output(prompt_token_ids=[1, 2, 3],
output_token_ids=[102],
finished=True)

context.append_output(mock_output2)

# Verify the fourth message was added, num_init_messages is still 1
assert len(context._messages) == 3
assert context.num_init_messages == 1
assert context._messages[2].content[0].text == "Response 4"
11 changes: 10 additions & 1 deletion vllm/entrypoints/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ def append_output(self, output: Union[RequestOutput,
self._update_decode_token_usage(output)
# Move current turn to previous turn for next turn's calculations
self.previous_turn = self.current_turn.copy()
# append_output is called only once before tool calling
# in non-streaming case
# so we can append all the parser messages to _messages
output_msgs = self.parser.messages
# The responses finish reason is set in the last message
self.finish_reason = output.outputs[0].finish_reason
Expand Down Expand Up @@ -387,7 +390,7 @@ def __init__(self, *args, **kwargs):

@property
def messages(self) -> list:
return self.parser.messages
return self._messages

def append_output(self, output: Union[RequestOutput,
list[Message]]) -> None:
Expand All @@ -412,6 +415,11 @@ def append_output(self, output: Union[RequestOutput,
# Check if the current token is part of reasoning content
self._update_num_reasoning_tokens()
self.last_tok = tok
if len(self._messages) - self.num_init_messages < len(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also add a unit test covering this behavior. the test can be constructed similar to https://github.com/vllm-project/vllm/blob/main/tests/entrypoints/test_context.py#L313

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ty for the suggestion, just added

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ready for re-review @chaunceyjiang

self.parser.messages):
self._messages.extend(
self.parser.messages[len(self._messages) -
self.num_init_messages:])
else:
# Handle the case of tool output in direct message format
assert len(output) == 1, "Tool output should be a single message"
Expand All @@ -424,6 +432,7 @@ def append_output(self, output: Union[RequestOutput,
for tok in toks:
self.parser.process(tok)
self.last_tok = toks[-1]
# TODO: add tool_output messages to self._messages

def is_expecting_start(self) -> bool:
return self.parser.state == StreamState.EXPECT_START
Expand Down