3
3
- Docs: https://aws.amazon.com/bedrock/
4
4
"""
5
5
6
+ import asyncio
6
7
import json
7
8
import logging
8
9
import os
9
- from typing import Any , AsyncGenerator , Iterable , List , Literal , Optional , Type , TypeVar , Union , cast
10
+ import threading
11
+ from typing import Any , AsyncGenerator , Callable , Iterable , Literal , Optional , Type , TypeVar , Union
10
12
11
13
import boto3
12
14
from botocore .config import Config as BotocoreConfig
@@ -245,17 +247,6 @@ def format_request(
245
247
),
246
248
}
247
249
248
- def format_chunk (self , event : dict [str , Any ]) -> StreamEvent :
249
- """Format the Bedrock response events into standardized message chunks.
250
-
251
- Args:
252
- event: A response event from the Bedrock model.
253
-
254
- Returns:
255
- The formatted chunk.
256
- """
257
- return cast (StreamEvent , event )
258
-
259
250
def _has_blocked_guardrail (self , guardrail_data : dict [str , Any ]) -> bool :
260
251
"""Check if guardrail data contains any blocked policies.
261
252
@@ -284,7 +275,7 @@ def _generate_redaction_events(self) -> list[StreamEvent]:
284
275
Returns:
285
276
List of redaction events to yield.
286
277
"""
287
- events : List [StreamEvent ] = []
278
+ events : list [StreamEvent ] = []
288
279
289
280
if self .config .get ("guardrail_redact_input" , True ):
290
281
logger .debug ("Redacting user input due to guardrail." )
@@ -327,7 +318,55 @@ async def stream(
327
318
system_prompt: System prompt to provide context to the model.
328
319
329
320
Yields:
330
- Formatted message chunks from the model.
321
+ Model events.
322
+
323
+ Raises:
324
+ ContextWindowOverflowException: If the input exceeds the model's context window.
325
+ ModelThrottledException: If the model service is throttling requests.
326
+ """
327
+
328
+ def callback (event : Optional [StreamEvent ] = None ) -> None :
329
+ loop .call_soon_threadsafe (queue .put_nowait , event )
330
+ if event is None :
331
+ return
332
+
333
+ signal .wait ()
334
+ signal .clear ()
335
+
336
+ loop = asyncio .get_event_loop ()
337
+ queue : asyncio .Queue [Optional [StreamEvent ]] = asyncio .Queue ()
338
+ signal = threading .Event ()
339
+
340
+ thread = asyncio .to_thread (self ._stream , callback , messages , tool_specs , system_prompt )
341
+ task = asyncio .create_task (thread )
342
+
343
+ while True :
344
+ event = await queue .get ()
345
+ if event is None :
346
+ break
347
+
348
+ yield event
349
+ signal .set ()
350
+
351
+ await task
352
+
353
+ def _stream (
354
+ self ,
355
+ callback : Callable [..., None ],
356
+ messages : Messages ,
357
+ tool_specs : Optional [list [ToolSpec ]] = None ,
358
+ system_prompt : Optional [str ] = None ,
359
+ ) -> None :
360
+ """Stream conversation with the Bedrock model.
361
+
362
+ This method operates in a separate thread to avoid blocking the async event loop with the call to
363
+ Bedrock's converse_stream.
364
+
365
+ Args:
366
+ callback: Function to send events to the main thread.
367
+ messages: List of message objects to be processed by the model.
368
+ tool_specs: List of tool specifications to make available to the model.
369
+ system_prompt: System prompt to provide context to the model.
331
370
332
371
Raises:
333
372
ContextWindowOverflowException: If the input exceeds the model's context window.
@@ -343,7 +382,6 @@ async def stream(
343
382
try :
344
383
logger .debug ("got response from model" )
345
384
if streaming :
346
- # Streaming implementation
347
385
response = self .client .converse_stream (** request )
348
386
for chunk in response ["stream" ]:
349
387
if (
@@ -354,33 +392,29 @@ async def stream(
354
392
guardrail_data = chunk ["metadata" ]["trace" ]["guardrail" ]
355
393
if self ._has_blocked_guardrail (guardrail_data ):
356
394
for event in self ._generate_redaction_events ():
357
- yield event
358
- yield self .format_chunk (chunk )
395
+ callback (event )
396
+
397
+ callback (chunk )
398
+
359
399
else :
360
- # Non-streaming implementation
361
400
response = self .client .converse (** request )
362
-
363
- # Convert and yield from the response
364
401
for event in self ._convert_non_streaming_to_streaming (response ):
365
- yield event
402
+ callback ( event )
366
403
367
- # Check for guardrail triggers after yielding any events (same as streaming path)
368
404
if (
369
405
"trace" in response
370
406
and "guardrail" in response ["trace" ]
371
407
and self ._has_blocked_guardrail (response ["trace" ]["guardrail" ])
372
408
):
373
409
for event in self ._generate_redaction_events ():
374
- yield event
410
+ callback ( event )
375
411
376
412
except ClientError as e :
377
413
error_message = str (e )
378
414
379
- # Handle throttling error
380
415
if e .response ["Error" ]["Code" ] == "ThrottlingException" :
381
416
raise ModelThrottledException (error_message ) from e
382
417
383
- # Handle context window overflow
384
418
if any (overflow_message in error_message for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES ):
385
419
logger .warning ("bedrock threw context window overflow error" )
386
420
raise ContextWindowOverflowException (e ) from e
@@ -411,10 +445,11 @@ async def stream(
411
445
"https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#on-demand-throughput-isnt-supported"
412
446
)
413
447
414
- # Otherwise raise the error
415
448
raise e
416
449
417
- logger .debug ("finished streaming response from model" )
450
+ finally :
451
+ callback ()
452
+ logger .debug ("finished streaming response from model" )
418
453
419
454
def _convert_non_streaming_to_streaming (self , response : dict [str , Any ]) -> Iterable [StreamEvent ]:
420
455
"""Convert a non-streaming response to the streaming format.
0 commit comments