Skip to content

feat: avoid consuming repeated agent message #6351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
import logging
import uuid
import warnings
from typing import (
Any,
Expand Down Expand Up @@ -824,6 +825,7 @@ async def on_messages_stream(

# STEP 3: Run the first inference
model_result = None
full_message_id = None
async for inference_output in self._call_llm(
model_client=model_client,
model_client_stream=model_client_stream,
Expand All @@ -837,6 +839,8 @@ async def on_messages_stream(
):
if isinstance(inference_output, CreateResult):
model_result = inference_output
elif isinstance(inference_output, tuple):
model_result, full_message_id = inference_output
else:
# Streaming chunk event
yield inference_output
Expand Down Expand Up @@ -875,6 +879,7 @@ async def on_messages_stream(
tool_call_summary_format=tool_call_summary_format,
output_content_type=output_content_type,
format_string=format_string,
full_message_id=full_message_id,
):
yield output_event

Expand Down Expand Up @@ -925,7 +930,7 @@ async def _call_llm(
agent_name: str,
cancellation_token: CancellationToken,
output_content_type: type[BaseModel] | None,
) -> AsyncGenerator[Union[CreateResult, ModelClientStreamingChunkEvent], None]:
) -> AsyncGenerator[Union[CreateResult, Tuple[CreateResult, str], ModelClientStreamingChunkEvent], None]:
"""
Perform a model inference and yield either streaming chunk events or the final CreateResult.
"""
Expand All @@ -935,6 +940,7 @@ async def _call_llm(
tools = (await workbench.list_tools()) + handoff_tools

if model_client_stream:
full_message_id = str(uuid.uuid4())
model_result: Optional[CreateResult] = None
async for chunk in model_client.create_stream(
llm_messages,
Expand All @@ -945,12 +951,14 @@ async def _call_llm(
if isinstance(chunk, CreateResult):
model_result = chunk
elif isinstance(chunk, str):
yield ModelClientStreamingChunkEvent(content=chunk, source=agent_name)
yield ModelClientStreamingChunkEvent(
content=chunk, source=agent_name, full_message_id=full_message_id
)
else:
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
if model_result is None:
raise RuntimeError("No final model result in streaming mode.")
yield model_result
yield (model_result, full_message_id)
else:
model_result = await model_client.create(
llm_messages,
Expand Down Expand Up @@ -978,6 +986,7 @@ async def _process_model_result(
tool_call_summary_format: str,
output_content_type: type[BaseModel] | None,
format_string: str | None = None,
full_message_id: str | None = None,
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
"""
Handle final or partial responses from model_result, including tool calls, handoffs,
Expand All @@ -998,8 +1007,10 @@ async def _process_model_result(
inner_messages=inner_messages,
)
else:
id = full_message_id if full_message_id else str(uuid.uuid4())
yield Response(
chat_message=TextMessage(
id=id,
content=model_result.content,
source=agent_name,
models_usage=model_result.usage,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
import re
import uuid
from typing import (
AsyncGenerator,
List,
Optional,
Sequence,
Tuple,
Union,
)

Expand Down Expand Up @@ -449,6 +451,8 @@ async def on_messages_stream(
):
if isinstance(inference_output, CreateResult):
model_result = inference_output
elif isinstance(inference_output, tuple):
model_result, _ = inference_output
else:
# Streaming chunk event
yield inference_output
Expand Down Expand Up @@ -646,27 +650,30 @@ async def _call_llm(
model_context: ChatCompletionContext,
agent_name: str,
cancellation_token: CancellationToken,
) -> AsyncGenerator[Union[CreateResult, ModelClientStreamingChunkEvent], None]:
) -> AsyncGenerator[Union[CreateResult, Tuple[CreateResult, str], ModelClientStreamingChunkEvent], None]:
"""
Perform a model inference and yield either streaming chunk events or the final CreateResult.
"""
all_messages = await model_context.get_messages()
llm_messages = cls._get_compatible_context(model_client=model_client, messages=system_messages + all_messages)

if model_client_stream:
full_message_id = str(uuid.uuid4())
model_result: Optional[CreateResult] = None
async for chunk in model_client.create_stream(
llm_messages, tools=[], cancellation_token=cancellation_token
):
if isinstance(chunk, CreateResult):
model_result = chunk
elif isinstance(chunk, str):
yield ModelClientStreamingChunkEvent(content=chunk, source=agent_name)
yield ModelClientStreamingChunkEvent(
content=chunk, source=agent_name, full_message_id=full_message_id
)
else:
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
if model_result is None:
raise RuntimeError("No final model result in streaming mode.")
yield model_result
yield (model_result, full_message_id)
else:
model_result = await model_client.create(llm_messages, tools=[], cancellation_token=cancellation_token)
yield model_result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class and includes specific fields relevant to the type of message being sent.

from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Literal, Mapping, Optional, Type, TypeVar
from uuid import uuid4

from autogen_core import Component, ComponentBase, FunctionCall, Image
from autogen_core.code_executor import CodeBlock, CodeResult
Expand Down Expand Up @@ -76,6 +77,9 @@ class BaseChatMessage(BaseMessage, ABC):
message using models and return a response as another :class:`BaseChatMessage`.
"""

id: str = Field(default_factory=lambda: str(uuid4()))
"""A unique identifier for the message."""

source: str
"""The name of the agent that sent this message."""

Expand Down Expand Up @@ -145,6 +149,9 @@ class BaseAgentEvent(BaseMessage, ABC):
a custom rendering of the content.
"""

id: str = Field(default_factory=lambda: str(uuid4()))
"""Unique identifier for the event."""

source: str
"""The name of the agent that sent this message."""

Expand Down Expand Up @@ -510,6 +517,9 @@ class ModelClientStreamingChunkEvent(BaseAgentEvent):
content: str
"""A string chunk from the model client."""

full_message_id: str | None = None
"""The ID of the full message that this chunk belongs to, if available."""

type: Literal["ModelClientStreamingChunkEvent"] = "ModelClientStreamingChunkEvent"

def to_text(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ async def Console(

streaming_chunks: List[str] = []

full_message_ids: set[str] = set()
async for message in stream:
if isinstance(message, TaskResult):
duration = time.time() - start_time
Expand All @@ -134,12 +135,17 @@ async def Console(
elif isinstance(message, Response):
duration = time.time() - start_time

message_id = ""
# Print final response.
if isinstance(message.chat_message, MultiModalMessage):
final_content = message.chat_message.to_text(iterm=render_image_iterm)
else:
message_id = message.chat_message.id
final_content = message.chat_message.to_text()
output = f"{'-' * 10} {message.chat_message.source} {'-' * 10}\n{final_content}\n"
# avoid printing this message as it is already printed in the streaming chunks
if message_id and message_id in full_message_ids:
output = ""
if message.chat_message.models_usage:
if output_stats:
output += f"[Prompt tokens: {message.chat_message.models_usage.prompt_tokens}, Completion tokens: {message.chat_message.models_usage.completion_tokens}]\n"
Expand Down Expand Up @@ -179,6 +185,8 @@ async def Console(
if isinstance(message, ModelClientStreamingChunkEvent):
await aprint(message.to_text(), end="", flush=True)
streaming_chunks.append(message.content)
if message.full_message_id:
full_message_ids.add(message.full_message_id)
else:
if streaming_chunks:
streaming_chunks.clear()
Expand Down
39 changes: 39 additions & 0 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,3 +1538,42 @@ async def test_tools_deserialize_aware() -> None:
assert result.messages[-1].content == "Hello, World!" # type: ignore
assert result.messages[-1].type == "ToolCallSummaryMessage" # type: ignore
assert isinstance(result.messages[-1], ToolCallSummaryMessage) # type: ignore


@pytest.mark.asyncio
async def test_full_message_id_consistency() -> None:
mock_client = ReplayChatCompletionClient(
[
"Mock Response to verify message_id consistency",
]
)
agent = AssistantAgent(
"test_agent",
model_client=mock_client,
model_client_stream=True,
)

chunks: List[ModelClientStreamingChunkEvent] = []
final_message: TextMessage | None = None

async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert isinstance(message.messages[-1], TextMessage)
final_message = message.messages[-1]
elif isinstance(message, ModelClientStreamingChunkEvent):
chunks.append(message)

assert len(chunks) > 0, "Expected at least one streaming chunk"

# Verify the final message exists
assert final_message is not None, "Expected a final TextMessage"

# Verify all chunks have the same full_message_id
full_message_id = chunks[0].full_message_id
assert full_message_id is not None, "Expected chunk to have a full_message_id"

for chunk in chunks:
assert chunk.full_message_id == full_message_id, "All chunks should have the same full_message_id"

# Verify the final message has the same ID as the chunks
assert final_message.id == full_message_id, "The final message id should match the chunks' full_message_id"