Skip to content

Commit 2be5556

Browse files
fix: implement real-time streaming for Agent.start() method
- Add streaming generator support to Agent.start() method - Implement _start_stream() method for streaming logic - Add _chat_stream() method to route streaming to appropriate handlers - Add _custom_llm_stream() for custom LLM streaming support - Add _openai_stream() for OpenAI client streaming support - Add chat_completion_with_tools_stream() method to OpenAI client - Maintain backward compatibility for existing code - Add proper error handling and chat history management Fixes #981 Co-authored-by: Mervin Praison <MervinPraison@users.noreply.github.com>
1 parent 63286ff commit 2be5556

File tree

5 files changed

+520
-1
lines changed

5 files changed

+520
-1
lines changed

src/praisonai-agents/praisonaiagents/agent/agent.py

Lines changed: 185 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1937,7 +1937,191 @@ def run(self):
19371937

19381938
def start(self, prompt: str, **kwargs):
19391939
"""Start the agent with a prompt. This is a convenience method that wraps chat()."""
1940-
return self.chat(prompt, **kwargs)
1940+
# Check if streaming is enabled and user wants streaming chunks
1941+
if self.stream and kwargs.get('stream', True):
1942+
return self._start_stream(prompt, **kwargs)
1943+
else:
1944+
return self.chat(prompt, **kwargs)
1945+
1946+
def _start_stream(self, prompt: str, **kwargs):
1947+
"""Generator method that yields streaming chunks from the agent."""
1948+
# Import here to avoid circular imports
1949+
from typing import Generator
1950+
1951+
# Reset the final display flag for each new conversation
1952+
self._final_display_shown = False
1953+
1954+
# Search for existing knowledge if any knowledge is provided
1955+
if self.knowledge:
1956+
search_results = self.knowledge.search(prompt, agent_id=self.agent_id)
1957+
if search_results:
1958+
# Check if search_results is a list of dictionaries or strings
1959+
if isinstance(search_results, dict) and 'results' in search_results:
1960+
# Extract memory content from the results
1961+
knowledge_content = "\n".join([result['memory'] for result in search_results['results']])
1962+
else:
1963+
# If search_results is a list of strings, join them directly
1964+
knowledge_content = "\n".join(search_results)
1965+
1966+
# Append found knowledge to the prompt
1967+
prompt = f"{prompt}\n\nKnowledge: {knowledge_content}"
1968+
1969+
# Get streaming response using the internal streaming method
1970+
for chunk in self._chat_stream(prompt, **kwargs):
1971+
yield chunk
1972+
1973+
def _chat_stream(self, prompt, temperature=0.2, tools=None, output_json=None, output_pydantic=None, reasoning_steps=False, **kwargs):
1974+
"""Internal streaming method that yields chunks from the LLM response."""
1975+
1976+
# Use the same logic as chat() but yield chunks instead of returning final response
1977+
if self._using_custom_llm:
1978+
# For custom LLM, yield chunks from the LLM instance
1979+
for chunk in self._custom_llm_stream(prompt, temperature, tools, output_json, output_pydantic, reasoning_steps, **kwargs):
1980+
yield chunk
1981+
else:
1982+
# For standard OpenAI client, yield chunks from the streaming response
1983+
for chunk in self._openai_stream(prompt, temperature, tools, output_json, output_pydantic, reasoning_steps, **kwargs):
1984+
yield chunk
1985+
1986+
def _custom_llm_stream(self, prompt, temperature=0.2, tools=None, output_json=None, output_pydantic=None, reasoning_steps=False, **kwargs):
1987+
"""Handle streaming for custom LLM instances."""
1988+
try:
1989+
# Special handling for MCP tools when using provider/model format
1990+
if tools is None or (isinstance(tools, list) and len(tools) == 0):
1991+
tool_param = self.tools
1992+
else:
1993+
tool_param = tools
1994+
1995+
# Convert MCP tool objects to OpenAI format if needed
1996+
if tool_param is not None:
1997+
from ..mcp.mcp import MCP
1998+
if isinstance(tool_param, MCP) and hasattr(tool_param, 'to_openai_tool'):
1999+
openai_tool = tool_param.to_openai_tool()
2000+
if openai_tool:
2001+
if isinstance(openai_tool, list):
2002+
tool_param = openai_tool
2003+
else:
2004+
tool_param = [openai_tool]
2005+
2006+
# Store chat history length for potential rollback
2007+
chat_history_length = len(self.chat_history)
2008+
2009+
# Normalize prompt content for consistent chat history storage
2010+
normalized_content = prompt
2011+
if isinstance(prompt, list):
2012+
normalized_content = next((item["text"] for item in prompt if item.get("type") == "text"), "")
2013+
2014+
# Prevent duplicate messages
2015+
if not (self.chat_history and
2016+
self.chat_history[-1].get("role") == "user" and
2017+
self.chat_history[-1].get("content") == normalized_content):
2018+
self.chat_history.append({"role": "user", "content": normalized_content})
2019+
2020+
# Get streaming response from LLM instance
2021+
if hasattr(self.llm_instance, 'get_response_stream'):
2022+
# Use streaming method if available
2023+
stream_response = self.llm_instance.get_response_stream(
2024+
prompt=prompt,
2025+
system_prompt=self._build_system_prompt(tools),
2026+
chat_history=self.chat_history,
2027+
temperature=temperature,
2028+
tools=tool_param,
2029+
output_json=output_json,
2030+
output_pydantic=output_pydantic,
2031+
verbose=self.verbose,
2032+
markdown=self.markdown,
2033+
console=self.console,
2034+
agent_name=self.name,
2035+
agent_role=self.role,
2036+
agent_tools=[t.__name__ if hasattr(t, '__name__') else str(t) for t in (tools if tools is not None else self.tools)],
2037+
reasoning_steps=reasoning_steps,
2038+
execute_tool_fn=self.execute_tool
2039+
)
2040+
2041+
accumulated_response = ""
2042+
for chunk in stream_response:
2043+
accumulated_response += chunk
2044+
yield chunk
2045+
2046+
# Add final response to chat history
2047+
self.chat_history.append({"role": "assistant", "content": accumulated_response})
2048+
2049+
else:
2050+
# Fallback to regular response if streaming not available
2051+
response_text = self.llm_instance.get_response(
2052+
prompt=prompt,
2053+
system_prompt=self._build_system_prompt(tools),
2054+
chat_history=self.chat_history,
2055+
temperature=temperature,
2056+
tools=tool_param,
2057+
output_json=output_json,
2058+
output_pydantic=output_pydantic,
2059+
verbose=self.verbose,
2060+
markdown=self.markdown,
2061+
console=self.console,
2062+
agent_name=self.name,
2063+
agent_role=self.role,
2064+
agent_tools=[t.__name__ if hasattr(t, '__name__') else str(t) for t in (tools if tools is not None else self.tools)],
2065+
reasoning_steps=reasoning_steps,
2066+
execute_tool_fn=self.execute_tool,
2067+
stream=True
2068+
)
2069+
2070+
self.chat_history.append({"role": "assistant", "content": response_text})
2071+
# Yield the complete response as a single chunk
2072+
yield response_text
2073+
2074+
except Exception as e:
2075+
# Rollback chat history on error
2076+
self.chat_history = self.chat_history[:chat_history_length]
2077+
yield f"Error: {str(e)}"
2078+
2079+
def _openai_stream(self, prompt, temperature=0.2, tools=None, output_json=None, output_pydantic=None, reasoning_steps=False, **kwargs):
2080+
"""Handle streaming for standard OpenAI client."""
2081+
try:
2082+
# Use the new _build_messages helper method
2083+
messages, original_prompt = self._build_messages(prompt, temperature, output_json, output_pydantic)
2084+
2085+
# Store chat history length for potential rollback
2086+
chat_history_length = len(self.chat_history)
2087+
2088+
# Normalize original_prompt for consistent chat history storage
2089+
normalized_content = original_prompt
2090+
if isinstance(original_prompt, list):
2091+
normalized_content = next((item["text"] for item in original_prompt if item.get("type") == "text"), "")
2092+
2093+
# Prevent duplicate messages
2094+
if not (self.chat_history and
2095+
self.chat_history[-1].get("role") == "user" and
2096+
self.chat_history[-1].get("content") == normalized_content):
2097+
self.chat_history.append({"role": "user", "content": normalized_content})
2098+
2099+
# Get streaming response from OpenAI client
2100+
if self._openai_client is None:
2101+
raise ValueError("OpenAI client is not initialized. Please provide OPENAI_API_KEY or use a custom LLM provider.")
2102+
2103+
# Stream the response using OpenAI client
2104+
accumulated_response = ""
2105+
for chunk in self._openai_client.chat_completion_with_tools_stream(
2106+
messages=messages,
2107+
model=self.llm,
2108+
temperature=temperature,
2109+
tools=self._format_tools_for_completion(tools),
2110+
execute_tool_fn=self.execute_tool,
2111+
reasoning_steps=reasoning_steps,
2112+
verbose=self.verbose,
2113+
max_iterations=10
2114+
):
2115+
accumulated_response += chunk
2116+
yield chunk
2117+
2118+
# Add the accumulated response to chat history
2119+
self.chat_history.append({"role": "assistant", "content": accumulated_response})
2120+
2121+
except Exception as e:
2122+
# Rollback chat history on error
2123+
self.chat_history = self.chat_history[:chat_history_length]
2124+
yield f"Error: {str(e)}"
19412125

19422126
def execute(self, task, context=None):
19432127
"""Execute a task synchronously - backward compatibility method"""

src/praisonai-agents/praisonaiagents/llm/openai_client.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,6 +1049,142 @@ async def achat_completion_with_tools(
10491049
break
10501050

10511051
return final_response
1052+
1053+
def chat_completion_with_tools_stream(
1054+
self,
1055+
messages: List[Dict[str, Any]],
1056+
model: str = "gpt-4o",
1057+
temperature: float = 0.7,
1058+
tools: Optional[List[Any]] = None,
1059+
execute_tool_fn: Optional[Callable] = None,
1060+
reasoning_steps: bool = False,
1061+
verbose: bool = True,
1062+
max_iterations: int = 10,
1063+
**kwargs
1064+
):
1065+
"""
1066+
Create a streaming chat completion with tool support.
1067+
1068+
This method yields chunks of the response as they are generated,
1069+
enabling real-time streaming to the user.
1070+
1071+
Args:
1072+
messages: List of message dictionaries
1073+
model: Model to use
1074+
temperature: Temperature for generation
1075+
tools: List of tools (can be callables, dicts, or strings)
1076+
execute_tool_fn: Function to execute tools
1077+
reasoning_steps: Whether to show reasoning
1078+
verbose: Whether to show verbose output
1079+
max_iterations: Maximum tool calling iterations
1080+
**kwargs: Additional API parameters
1081+
1082+
Yields:
1083+
String chunks of the response as they are generated
1084+
"""
1085+
start_time = time.time()
1086+
1087+
# Format tools for OpenAI API
1088+
formatted_tools = self.format_tools(tools)
1089+
1090+
# Continue tool execution loop until no more tool calls are needed
1091+
iteration_count = 0
1092+
1093+
while iteration_count < max_iterations:
1094+
try:
1095+
# Create streaming response
1096+
response_stream = self._sync_client.chat.completions.create(
1097+
model=model,
1098+
messages=messages,
1099+
temperature=temperature,
1100+
tools=formatted_tools if formatted_tools else None,
1101+
stream=True,
1102+
**kwargs
1103+
)
1104+
1105+
full_response_text = ""
1106+
reasoning_content = ""
1107+
chunks = []
1108+
1109+
# Stream the response chunk by chunk
1110+
for chunk in response_stream:
1111+
chunks.append(chunk)
1112+
if chunk.choices[0].delta.content:
1113+
content = chunk.choices[0].delta.content
1114+
full_response_text += content
1115+
yield content
1116+
1117+
# Handle reasoning content if enabled
1118+
if reasoning_steps and hasattr(chunk.choices[0].delta, "reasoning_content"):
1119+
rc = chunk.choices[0].delta.reasoning_content
1120+
if rc:
1121+
reasoning_content += rc
1122+
yield f"[Reasoning: {rc}]"
1123+
1124+
# Process the complete response to check for tool calls
1125+
final_response = process_stream_chunks(chunks)
1126+
1127+
if not final_response:
1128+
return
1129+
1130+
# Check for tool calls
1131+
tool_calls = getattr(final_response.choices[0].message, 'tool_calls', None)
1132+
1133+
if tool_calls and execute_tool_fn:
1134+
# Convert ToolCall dataclass objects to dict for JSON serialization
1135+
serializable_tool_calls = []
1136+
for tc in tool_calls:
1137+
if isinstance(tc, ToolCall):
1138+
# Convert dataclass to dict
1139+
serializable_tool_calls.append({
1140+
"id": tc.id,
1141+
"type": tc.type,
1142+
"function": tc.function
1143+
})
1144+
else:
1145+
# Already an OpenAI object, keep as is
1146+
serializable_tool_calls.append(tc)
1147+
1148+
messages.append({
1149+
"role": "assistant",
1150+
"content": final_response.choices[0].message.content,
1151+
"tool_calls": serializable_tool_calls
1152+
})
1153+
1154+
for tool_call in tool_calls:
1155+
# Handle both ToolCall dataclass and OpenAI object
1156+
if isinstance(tool_call, ToolCall):
1157+
function_name = tool_call.function["name"]
1158+
arguments = json.loads(tool_call.function["arguments"])
1159+
else:
1160+
function_name = tool_call.function.name
1161+
arguments = json.loads(tool_call.function.arguments)
1162+
1163+
if verbose:
1164+
yield f"\n[Calling function: {function_name}]"
1165+
1166+
# Execute the tool
1167+
tool_result = execute_tool_fn(function_name, arguments)
1168+
results_str = json.dumps(tool_result) if tool_result else "Function returned an empty output"
1169+
1170+
if verbose:
1171+
yield f"\n[Function result: {results_str}]"
1172+
1173+
messages.append({
1174+
"role": "tool",
1175+
"tool_call_id": tool_call.id if hasattr(tool_call, 'id') else tool_call['id'],
1176+
"content": results_str
1177+
})
1178+
1179+
# Continue the loop to allow more tool calls
1180+
iteration_count += 1
1181+
else:
1182+
# No tool calls, we're done
1183+
break
1184+
1185+
except Exception as e:
1186+
yield f"Error: {str(e)}"
1187+
break
10521188

10531189
def parse_structured_output(
10541190
self,

0 commit comments

Comments
 (0)