33- Docs: https://aws.amazon.com/bedrock/
44"""
55
6+ import asyncio
67import json
78import logging
89import os
9- from typing import Any , AsyncGenerator , Iterable , List , Literal , Optional , Type , TypeVar , Union , cast
10+ import threading
11+ from typing import Any , AsyncGenerator , Callable , Iterable , Literal , Optional , Type , TypeVar , Union
1012
1113import boto3
1214from botocore .config import Config as BotocoreConfig
@@ -245,17 +247,6 @@ def format_request(
245247 ),
246248 }
247249
248- def format_chunk (self , event : dict [str , Any ]) -> StreamEvent :
249- """Format the Bedrock response events into standardized message chunks.
250-
251- Args:
252- event: A response event from the Bedrock model.
253-
254- Returns:
255- The formatted chunk.
256- """
257- return cast (StreamEvent , event )
258-
259250 def _has_blocked_guardrail (self , guardrail_data : dict [str , Any ]) -> bool :
260251 """Check if guardrail data contains any blocked policies.
261252
@@ -284,7 +275,7 @@ def _generate_redaction_events(self) -> list[StreamEvent]:
284275 Returns:
285276 List of redaction events to yield.
286277 """
287- events : List [StreamEvent ] = []
278+ events : list [StreamEvent ] = []
288279
289280 if self .config .get ("guardrail_redact_input" , True ):
290281 logger .debug ("Redacting user input due to guardrail." )
@@ -327,7 +318,55 @@ async def stream(
327318 system_prompt: System prompt to provide context to the model.
328319
329320 Yields:
330- Formatted message chunks from the model.
321+ Model events.
322+
323+ Raises:
324+ ContextWindowOverflowException: If the input exceeds the model's context window.
325+ ModelThrottledException: If the model service is throttling requests.
326+ """
327+
328+ def callback (event : Optional [StreamEvent ] = None ) -> None :
329+ loop .call_soon_threadsafe (queue .put_nowait , event )
330+ if event is None :
331+ return
332+
333+ signal .wait ()
334+ signal .clear ()
335+
336+ loop = asyncio .get_event_loop ()
337+ queue : asyncio .Queue [Optional [StreamEvent ]] = asyncio .Queue ()
338+ signal = threading .Event ()
339+
340+ thread = asyncio .to_thread (self ._stream , callback , messages , tool_specs , system_prompt )
341+ task = asyncio .create_task (thread )
342+
343+ while True :
344+ event = await queue .get ()
345+ if event is None :
346+ break
347+
348+ yield event
349+ signal .set ()
350+
351+ await task
352+
353+ def _stream (
354+ self ,
355+ callback : Callable [..., None ],
356+ messages : Messages ,
357+ tool_specs : Optional [list [ToolSpec ]] = None ,
358+ system_prompt : Optional [str ] = None ,
359+ ) -> None :
360+ """Stream conversation with the Bedrock model.
361+
362+ This method operates in a separate thread to avoid blocking the async event loop with the call to
363+ Bedrock's converse_stream.
364+
365+ Args:
366+ callback: Function to send events to the main thread.
367+ messages: List of message objects to be processed by the model.
368+ tool_specs: List of tool specifications to make available to the model.
369+ system_prompt: System prompt to provide context to the model.
331370
332371 Raises:
333372 ContextWindowOverflowException: If the input exceeds the model's context window.
@@ -343,7 +382,6 @@ async def stream(
343382 try :
344383 logger .debug ("got response from model" )
345384 if streaming :
346- # Streaming implementation
347385 response = self .client .converse_stream (** request )
348386 for chunk in response ["stream" ]:
349387 if (
@@ -354,33 +392,29 @@ async def stream(
354392 guardrail_data = chunk ["metadata" ]["trace" ]["guardrail" ]
355393 if self ._has_blocked_guardrail (guardrail_data ):
356394 for event in self ._generate_redaction_events ():
357- yield event
358- yield self .format_chunk (chunk )
395+ callback (event )
396+
397+ callback (chunk )
398+
359399 else :
360- # Non-streaming implementation
361400 response = self .client .converse (** request )
362-
363- # Convert and yield from the response
364401 for event in self ._convert_non_streaming_to_streaming (response ):
365- yield event
402+ callback ( event )
366403
367- # Check for guardrail triggers after yielding any events (same as streaming path)
368404 if (
369405 "trace" in response
370406 and "guardrail" in response ["trace" ]
371407 and self ._has_blocked_guardrail (response ["trace" ]["guardrail" ])
372408 ):
373409 for event in self ._generate_redaction_events ():
374- yield event
410+ callback ( event )
375411
376412 except ClientError as e :
377413 error_message = str (e )
378414
379- # Handle throttling error
380415 if e .response ["Error" ]["Code" ] == "ThrottlingException" :
381416 raise ModelThrottledException (error_message ) from e
382417
383- # Handle context window overflow
384418 if any (overflow_message in error_message for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES ):
385419 logger .warning ("bedrock threw context window overflow error" )
386420 raise ContextWindowOverflowException (e ) from e
@@ -411,10 +445,11 @@ async def stream(
411445 "https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#on-demand-throughput-isnt-supported"
412446 )
413447
414- # Otherwise raise the error
415448 raise e
416449
417- logger .debug ("finished streaming response from model" )
450+ finally :
451+ callback ()
452+ logger .debug ("finished streaming response from model" )
418453
419454 def _convert_non_streaming_to_streaming (self , response : dict [str , Any ]) -> Iterable [StreamEvent ]:
420455 """Convert a non-streaming response to the streaming format.
0 commit comments