Skip to content

Commit 20b65d9

Browse files
feat: Add streaming support
1 parent 0849f71 commit 20b65d9

File tree

1 file changed

+213
-31
lines changed
  • agents-api/agents_api/routers/sessions

1 file changed

+213
-31
lines changed

agents-api/agents_api/routers/sessions/chat.py

+213-31
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import json
2+
import asyncio
13
from typing import Annotated, Any
24
from uuid import UUID
35

4-
from fastapi import BackgroundTasks, Depends, Header
6+
from fastapi import BackgroundTasks, Depends, Header, HTTPException, status
7+
from starlette.background import BackgroundTask
8+
from fastapi.responses import StreamingResponse
59
from starlette.status import HTTP_201_CREATED
610
from uuid_extensions import uuid7
711

@@ -23,14 +27,22 @@
2327
from .metrics import total_tokens_per_user
2428
from .render import render_chat_input
2529
from .router import router
30+
from .render import render_chat_input
2631

2732
COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"
2833

2934

35+
async def wait_for_tasks(tasks: list[asyncio.Task]) -> None:
36+
"""Wait for all background tasks to complete."""
37+
if tasks:
38+
await asyncio.gather(*tasks, return_exceptions=True)
39+
40+
3041
@router.post(
3142
"/sessions/{session_id}/chat",
3243
status_code=HTTP_201_CREATED,
3344
tags=["sessions", "chat"],
45+
response_model=None,
3446
)
3547
async def chat(
3648
developer: Annotated[Developer, Depends(get_developer_data)],
@@ -39,7 +51,7 @@ async def chat(
3951
background_tasks: BackgroundTasks,
4052
x_custom_api_key: str | None = Header(None, alias="X-Custom-Api-Key"),
4153
connection_pool: Any = None, # FIXME: Placeholder that should be removed
42-
) -> ChatResponse:
54+
) -> ChatResponse | StreamingResponse: # FIXME: Update type to include StreamingResponse
4355
"""
4456
Initiates a chat session.
4557
@@ -73,16 +85,31 @@ async def chat(
7385
"user": str(developer.id),
7486
"tags": developer.tags,
7587
"custom_api_key": x_custom_api_key,
88+
"stream": chat_input.stream, # Enable streaming if requested
7689
}
77-
evaluator = ToolCallsEvaluator(
78-
tool_types={"system"}, developer_id=developer.id, completion_func=litellm.acompletion
79-
)
80-
model_response = await evaluator.completion(**{
81-
**settings,
82-
**params,
83-
})
90+
payload = {**settings, **params}
91+
92+
try:
93+
evaluator = ToolCallsEvaluator(
94+
tool_types={"system"}, developer_id=developer.id, completion_func=litellm.acompletion
95+
)
96+
model_response = await evaluator.completion(**payload)
97+
except Exception as e:
98+
import logging
99+
100+
logging.error(f"LLM completion error: {e!s}")
101+
# Create basic error response
102+
return ChatResponse(
103+
id=uuid7(),
104+
created_at=utcnow(),
105+
jobs=[],
106+
docs=doc_references,
107+
usage=None,
108+
choices=[],
109+
error=f"Error getting model completion: {e!s}",
110+
)
84111

85-
# Save the input and the response to the session history
112+
# Save the input messages to the session history
86113
if chat_input.save:
87114
new_entries = [
88115
CreateEntryRequest.from_model_input(
@@ -93,21 +120,33 @@ async def chat(
93120
for msg in new_messages
94121
]
95122

96-
# Add the response to the new entries
97-
# FIXME: We need to save all the choices
98-
new_entries.append(
99-
CreateEntryRequest.from_model_input(
100-
model=settings["model"],
101-
**model_response.choices[0].model_dump()["message"],
102-
source="api_response",
103-
),
104-
)
105-
background_tasks.add_task(
106-
create_entries,
107-
developer_id=developer.id,
108-
session_id=session_id,
109-
data=new_entries,
110-
)
123+
# For non-streaming responses, add the response to the new entries immediately
124+
if not chat_input.stream:
125+
# FIXME: We need to save all the choices
126+
new_entries.append(
127+
CreateEntryRequest.from_model_input(
128+
model=settings["model"],
129+
**model_response.choices[0].model_dump()["message"],
130+
source="api_response",
131+
),
132+
)
133+
background_tasks.add_task(
134+
create_entries,
135+
developer_id=developer.id,
136+
session_id=session_id,
137+
data=new_entries,
138+
)
139+
else:
140+
# For streaming, we need to collect all chunks and save at the end
141+
# For now, just save the input messages and handle response separately
142+
background_tasks.add_task(
143+
create_entries,
144+
developer_id=developer.id,
145+
session_id=session_id,
146+
data=new_entries,
147+
)
148+
# The complete streamed response will be saved in the stream_chat_response function
149+
# using a separate background task to avoid blocking the stream
111150

112151
# Adaptive context handling
113152
jobs = []
@@ -120,9 +159,147 @@ async def chat(
120159
raise NotImplementedError(msg)
121160

122161
# Return the response
123-
# FIXME: Implement streaming for chat
124-
chat_response_class = ChunkChatResponse if chat_input.stream else MessageChatResponse
162+
# Handle streaming response if requested
163+
stream_tasks: list[asyncio.Task] = []
164+
165+
if chat_input.stream:
166+
# For streaming, we'll use an async generator to yield chunks
167+
async def stream_chat_response():
168+
"""Stream chat response chunks to the client."""
169+
# Create initial response with metadata
170+
response_id = uuid7()
171+
created_at = utcnow()
172+
173+
# Collect full response for metrics and optional saving
174+
content_so_far = ""
175+
final_usage = None
176+
has_content = False
177+
178+
nonlocal stream_tasks
179+
180+
try:
181+
# Stream chunks from the model_response (CustomStreamWrapper from litellm)
182+
async for chunk in model_response:
183+
# Process a single chunk of the streaming response
184+
try:
185+
# Extract usage metrics if available
186+
if hasattr(chunk, "usage") and chunk.usage:
187+
final_usage = chunk.usage.model_dump()
125188

189+
# Check if chunk has valid choices
190+
has_choices = (
191+
hasattr(chunk, "choices")
192+
and chunk.choices
193+
and len(chunk.choices) > 0
194+
)
195+
196+
# Update metrics when we detect the final chunk
197+
if final_usage and has_choices and chunk.choices[0].finish_reason:
198+
# This is the last chunk with the finish reason
199+
total_tokens = final_usage.get("total_tokens", 0)
200+
total_tokens_per_user.labels(str(developer.id)).inc(
201+
amount=total_tokens
202+
)
203+
204+
# Collect content for the full response
205+
if has_choices and hasattr(chunk.choices[0], "delta"):
206+
delta = chunk.choices[0].delta
207+
if hasattr(delta, "content") and delta.content:
208+
content_so_far += delta.content
209+
has_content = True
210+
211+
# Prepare the response chunk
212+
choices_to_send = []
213+
if has_choices:
214+
chunk_data = chunk.choices[0].model_dump()
215+
216+
# Ensure delta always contains a role field
217+
if "delta" in chunk_data and "role" not in chunk_data["delta"]:
218+
chunk_data["delta"]["role"] = "assistant"
219+
220+
choices_to_send = [chunk_data]
221+
222+
# Create and send the chunk response
223+
chunk_response = ChunkChatResponse(
224+
id=response_id,
225+
created_at=created_at,
226+
jobs=jobs,
227+
docs=doc_references,
228+
usage=final_usage,
229+
choices=choices_to_send,
230+
)
231+
yield chunk_response.model_dump_json() + "\n"
232+
233+
except Exception as e:
234+
# Log error details for debugging but send a generic message to client
235+
import logging
236+
237+
logging.error(f"Error processing chunk: {e!s}")
238+
239+
error_response = {
240+
"id": str(response_id),
241+
"created_at": created_at.isoformat(),
242+
"error": "An error occurred while processing the response chunk.",
243+
}
244+
yield json.dumps(error_response) + "\n"
245+
# Continue processing remaining chunks
246+
continue
247+
248+
# Save complete response to history if needed
249+
if chat_input.save and has_content:
250+
try:
251+
# Create entry for the complete response
252+
complete_entry = CreateEntryRequest.from_model_input(
253+
model=settings["model"],
254+
role="assistant",
255+
content=content_so_far,
256+
source="api_response",
257+
)
258+
# Create a task to save the entry without blocking the stream
259+
ref = asyncio.create_task(
260+
create_entries(
261+
developer_id=developer.id,
262+
session_id=session_id,
263+
data=[complete_entry],
264+
)
265+
)
266+
stream_tasks.append(ref)
267+
except Exception as e:
268+
# Log the full error for debugging purposes
269+
import logging
270+
271+
logging.error(f"Failed to save streamed response: {e!s}")
272+
273+
# Send a minimal error message to the client
274+
error_response = {
275+
"id": str(response_id),
276+
"created_at": created_at.isoformat(),
277+
"error": "Failed to save response history.",
278+
}
279+
yield json.dumps(error_response) + "\n"
280+
except Exception as e:
281+
# Log the detailed error for system debugging
282+
import logging
283+
284+
logging.error(f"Streaming error: {e!s}")
285+
286+
# Send a user-friendly error message to the client
287+
error_response = {
288+
"id": str(response_id),
289+
"created_at": created_at.isoformat(),
290+
"error": "An error occurred during the streaming response.",
291+
}
292+
yield json.dumps(error_response) + "\n"
293+
294+
# Return a streaming response with a background task to wait for all entry saving tasks
295+
return StreamingResponse(
296+
stream_chat_response(),
297+
media_type="application/json",
298+
background=BackgroundTask(wait_for_tasks, stream_tasks),
299+
)
300+
301+
# For non-streaming, return the complete response
302+
chat_response_class = MessageChatResponse
126303
chat_response: ChatResponse = chat_response_class(
127304
id=uuid7(),
128305
created_at=utcnow(),
@@ -132,8 +309,13 @@ async def chat(
132309
choices=[choice.model_dump() for choice in model_response.choices],
133310
)
134311

135-
total_tokens_per_user.labels(str(developer.id)).inc(
136-
amount=chat_response.usage.total_tokens if chat_response.usage is not None else 0,
137-
)
312+
# For non-streaming responses, update metrics and return the response
313+
if not chat_input.stream:
314+
total_tokens_per_user.labels(str(developer.id)).inc(
315+
amount=chat_response.usage.total_tokens if chat_response.usage is not None else 0,
316+
)
317+
return chat_response
138318

139-
return chat_response
319+
# Note: For streaming responses, we've already returned the StreamingResponse above
320+
# This code is unreachable for streaming responses
321+
return None

0 commit comments

Comments
 (0)