1
+ import json
2
+ import asyncio
1
3
from typing import Annotated , Any
2
4
from uuid import UUID
3
5
4
- from fastapi import BackgroundTasks , Depends , Header
6
+ from fastapi import BackgroundTasks , Depends , Header , HTTPException , status
7
+ from starlette .background import BackgroundTask
8
+ from fastapi .responses import StreamingResponse
5
9
from starlette .status import HTTP_201_CREATED
6
10
from uuid_extensions import uuid7
7
11
23
27
from .metrics import total_tokens_per_user
24
28
from .render import render_chat_input
25
29
from .router import router
30
+ from .render import render_chat_input
26
31
27
32
COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"
28
33
29
34
35
+ async def wait_for_tasks (tasks : list [asyncio .Task ]) -> None :
36
+ """Wait for all background tasks to complete."""
37
+ if tasks :
38
+ await asyncio .gather (* tasks , return_exceptions = True )
39
+
40
+
30
41
@router .post (
31
42
"/sessions/{session_id}/chat" ,
32
43
status_code = HTTP_201_CREATED ,
33
44
tags = ["sessions" , "chat" ],
45
+ response_model = None ,
34
46
)
35
47
async def chat (
36
48
developer : Annotated [Developer , Depends (get_developer_data )],
@@ -39,7 +51,7 @@ async def chat(
39
51
background_tasks : BackgroundTasks ,
40
52
x_custom_api_key : str | None = Header (None , alias = "X-Custom-Api-Key" ),
41
53
connection_pool : Any = None , # FIXME: Placeholder that should be removed
42
- ) -> ChatResponse :
54
+ ) -> ChatResponse | StreamingResponse : # FIXME: Update type to include StreamingResponse
43
55
"""
44
56
Initiates a chat session.
45
57
@@ -73,16 +85,31 @@ async def chat(
73
85
"user" : str (developer .id ),
74
86
"tags" : developer .tags ,
75
87
"custom_api_key" : x_custom_api_key ,
88
+ "stream" : chat_input .stream , # Enable streaming if requested
76
89
}
77
- evaluator = ToolCallsEvaluator (
78
- tool_types = {"system" }, developer_id = developer .id , completion_func = litellm .acompletion
79
- )
80
- model_response = await evaluator .completion (** {
81
- ** settings ,
82
- ** params ,
83
- })
90
+ payload = {** settings , ** params }
91
+
92
+ try :
93
+ evaluator = ToolCallsEvaluator (
94
+ tool_types = {"system" }, developer_id = developer .id , completion_func = litellm .acompletion
95
+ )
96
+ model_response = await evaluator .completion (** payload )
97
+ except Exception as e :
98
+ import logging
99
+
100
+ logging .error (f"LLM completion error: { e !s} " )
101
+ # Create basic error response
102
+ return ChatResponse (
103
+ id = uuid7 (),
104
+ created_at = utcnow (),
105
+ jobs = [],
106
+ docs = doc_references ,
107
+ usage = None ,
108
+ choices = [],
109
+ error = f"Error getting model completion: { e !s} " ,
110
+ )
84
111
85
- # Save the input and the response to the session history
112
+ # Save the input messages to the session history
86
113
if chat_input .save :
87
114
new_entries = [
88
115
CreateEntryRequest .from_model_input (
@@ -93,21 +120,33 @@ async def chat(
93
120
for msg in new_messages
94
121
]
95
122
96
- # Add the response to the new entries
97
- # FIXME: We need to save all the choices
98
- new_entries .append (
99
- CreateEntryRequest .from_model_input (
100
- model = settings ["model" ],
101
- ** model_response .choices [0 ].model_dump ()["message" ],
102
- source = "api_response" ,
103
- ),
104
- )
105
- background_tasks .add_task (
106
- create_entries ,
107
- developer_id = developer .id ,
108
- session_id = session_id ,
109
- data = new_entries ,
110
- )
123
+ # For non-streaming responses, add the response to the new entries immediately
124
+ if not chat_input .stream :
125
+ # FIXME: We need to save all the choices
126
+ new_entries .append (
127
+ CreateEntryRequest .from_model_input (
128
+ model = settings ["model" ],
129
+ ** model_response .choices [0 ].model_dump ()["message" ],
130
+ source = "api_response" ,
131
+ ),
132
+ )
133
+ background_tasks .add_task (
134
+ create_entries ,
135
+ developer_id = developer .id ,
136
+ session_id = session_id ,
137
+ data = new_entries ,
138
+ )
139
+ else :
140
+ # For streaming, we need to collect all chunks and save at the end
141
+ # For now, just save the input messages and handle response separately
142
+ background_tasks .add_task (
143
+ create_entries ,
144
+ developer_id = developer .id ,
145
+ session_id = session_id ,
146
+ data = new_entries ,
147
+ )
148
+ # The complete streamed response will be saved in the stream_chat_response function
149
+ # using a separate background task to avoid blocking the stream
111
150
112
151
# Adaptive context handling
113
152
jobs = []
@@ -120,9 +159,147 @@ async def chat(
120
159
raise NotImplementedError (msg )
121
160
122
161
# Return the response
123
- # FIXME: Implement streaming for chat
124
- chat_response_class = ChunkChatResponse if chat_input .stream else MessageChatResponse
162
+ # Handle streaming response if requested
163
+ stream_tasks : list [asyncio .Task ] = []
164
+
165
+ if chat_input .stream :
166
+ # For streaming, we'll use an async generator to yield chunks
167
+ async def stream_chat_response ():
168
+ """Stream chat response chunks to the client."""
169
+ # Create initial response with metadata
170
+ response_id = uuid7 ()
171
+ created_at = utcnow ()
172
+
173
+ # Collect full response for metrics and optional saving
174
+ content_so_far = ""
175
+ final_usage = None
176
+ has_content = False
177
+
178
+ nonlocal stream_tasks
179
+
180
+ try :
181
+ # Stream chunks from the model_response (CustomStreamWrapper from litellm)
182
+ async for chunk in model_response :
183
+ # Process a single chunk of the streaming response
184
+ try :
185
+ # Extract usage metrics if available
186
+ if hasattr (chunk , "usage" ) and chunk .usage :
187
+ final_usage = chunk .usage .model_dump ()
125
188
189
+ # Check if chunk has valid choices
190
+ has_choices = (
191
+ hasattr (chunk , "choices" )
192
+ and chunk .choices
193
+ and len (chunk .choices ) > 0
194
+ )
195
+
196
+ # Update metrics when we detect the final chunk
197
+ if final_usage and has_choices and chunk .choices [0 ].finish_reason :
198
+ # This is the last chunk with the finish reason
199
+ total_tokens = final_usage .get ("total_tokens" , 0 )
200
+ total_tokens_per_user .labels (str (developer .id )).inc (
201
+ amount = total_tokens
202
+ )
203
+
204
+ # Collect content for the full response
205
+ if has_choices and hasattr (chunk .choices [0 ], "delta" ):
206
+ delta = chunk .choices [0 ].delta
207
+ if hasattr (delta , "content" ) and delta .content :
208
+ content_so_far += delta .content
209
+ has_content = True
210
+
211
+ # Prepare the response chunk
212
+ choices_to_send = []
213
+ if has_choices :
214
+ chunk_data = chunk .choices [0 ].model_dump ()
215
+
216
+ # Ensure delta always contains a role field
217
+ if "delta" in chunk_data and "role" not in chunk_data ["delta" ]:
218
+ chunk_data ["delta" ]["role" ] = "assistant"
219
+
220
+ choices_to_send = [chunk_data ]
221
+
222
+ # Create and send the chunk response
223
+ chunk_response = ChunkChatResponse (
224
+ id = response_id ,
225
+ created_at = created_at ,
226
+ jobs = jobs ,
227
+ docs = doc_references ,
228
+ usage = final_usage ,
229
+ choices = choices_to_send ,
230
+ )
231
+ yield chunk_response .model_dump_json () + "\n "
232
+
233
+ except Exception as e :
234
+ # Log error details for debugging but send a generic message to client
235
+ import logging
236
+
237
+ logging .error (f"Error processing chunk: { e !s} " )
238
+
239
+ error_response = {
240
+ "id" : str (response_id ),
241
+ "created_at" : created_at .isoformat (),
242
+ "error" : "An error occurred while processing the response chunk." ,
243
+ }
244
+ yield json .dumps (error_response ) + "\n "
245
+ # Continue processing remaining chunks
246
+ continue
247
+
248
+ # Save complete response to history if needed
249
+ if chat_input .save and has_content :
250
+ try :
251
+ # Create entry for the complete response
252
+ complete_entry = CreateEntryRequest .from_model_input (
253
+ model = settings ["model" ],
254
+ role = "assistant" ,
255
+ content = content_so_far ,
256
+ source = "api_response" ,
257
+ )
258
+ # Create a task to save the entry without blocking the stream
259
+ ref = asyncio .create_task (
260
+ create_entries (
261
+ developer_id = developer .id ,
262
+ session_id = session_id ,
263
+ data = [complete_entry ],
264
+ )
265
+ )
266
+ stream_tasks .append (ref )
267
+ except Exception as e :
268
+ # Log the full error for debugging purposes
269
+ import logging
270
+
271
+ logging .error (f"Failed to save streamed response: { e !s} " )
272
+
273
+ # Send a minimal error message to the client
274
+ error_response = {
275
+ "id" : str (response_id ),
276
+ "created_at" : created_at .isoformat (),
277
+ "error" : "Failed to save response history." ,
278
+ }
279
+ yield json .dumps (error_response ) + "\n "
280
+ except Exception as e :
281
+ # Log the detailed error for system debugging
282
+ import logging
283
+
284
+ logging .error (f"Streaming error: { e !s} " )
285
+
286
+ # Send a user-friendly error message to the client
287
+ error_response = {
288
+ "id" : str (response_id ),
289
+ "created_at" : created_at .isoformat (),
290
+ "error" : "An error occurred during the streaming response." ,
291
+ }
292
+ yield json .dumps (error_response ) + "\n "
293
+
294
+ # Return a streaming response with a background task to wait for all entry saving tasks
295
+ return StreamingResponse (
296
+ stream_chat_response (),
297
+ media_type = "application/json" ,
298
+ background = BackgroundTask (wait_for_tasks , stream_tasks ),
299
+ )
300
+
301
+ # For non-streaming, return the complete response
302
+ chat_response_class = MessageChatResponse
126
303
chat_response : ChatResponse = chat_response_class (
127
304
id = uuid7 (),
128
305
created_at = utcnow (),
@@ -132,8 +309,13 @@ async def chat(
132
309
choices = [choice .model_dump () for choice in model_response .choices ],
133
310
)
134
311
135
- total_tokens_per_user .labels (str (developer .id )).inc (
136
- amount = chat_response .usage .total_tokens if chat_response .usage is not None else 0 ,
137
- )
312
+ # For non-streaming responses, update metrics and return the response
313
+ if not chat_input .stream :
314
+ total_tokens_per_user .labels (str (developer .id )).inc (
315
+ amount = chat_response .usage .total_tokens if chat_response .usage is not None else 0 ,
316
+ )
317
+ return chat_response
138
318
139
- return chat_response
319
+ # Note: For streaming responses, we've already returned the StreamingResponse above
320
+ # This code is unreachable for streaming responses
321
+ return None
0 commit comments