Skip to content

Commit f8b0d4e

Browse files
committed
models - bedrock - threading
1 parent 98c5a37 commit f8b0d4e

File tree

3 files changed

+63
-34
lines changed

3 files changed

+63
-34
lines changed

src/strands/models/bedrock.py

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
- Docs: https://aws.amazon.com/bedrock/
44
"""
55

6+
import asyncio
67
import json
78
import logging
89
import os
9-
from typing import Any, AsyncGenerator, Iterable, List, Literal, Optional, Type, TypeVar, Union, cast
10+
import threading
11+
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
1012

1113
import boto3
1214
from botocore.config import Config as BotocoreConfig
@@ -245,17 +247,6 @@ def format_request(
245247
),
246248
}
247249

248-
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
249-
"""Format the Bedrock response events into standardized message chunks.
250-
251-
Args:
252-
event: A response event from the Bedrock model.
253-
254-
Returns:
255-
The formatted chunk.
256-
"""
257-
return cast(StreamEvent, event)
258-
259250
def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
260251
"""Check if guardrail data contains any blocked policies.
261252
@@ -284,7 +275,7 @@ def _generate_redaction_events(self) -> list[StreamEvent]:
284275
Returns:
285276
List of redaction events to yield.
286277
"""
287-
events: List[StreamEvent] = []
278+
events: list[StreamEvent] = []
288279

289280
if self.config.get("guardrail_redact_input", True):
290281
logger.debug("Redacting user input due to guardrail.")
@@ -327,7 +318,55 @@ async def stream(
327318
system_prompt: System prompt to provide context to the model.
328319
329320
Yields:
330-
Formatted message chunks from the model.
321+
Model events.
322+
323+
Raises:
324+
ContextWindowOverflowException: If the input exceeds the model's context window.
325+
ModelThrottledException: If the model service is throttling requests.
326+
"""
327+
328+
def callback(event: Optional[StreamEvent] = None) -> None:
329+
loop.call_soon_threadsafe(queue.put_nowait, event)
330+
if event is None:
331+
return
332+
333+
signal.wait()
334+
signal.clear()
335+
336+
loop = asyncio.get_event_loop()
337+
queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue()
338+
signal = threading.Event()
339+
340+
thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt)
341+
task = asyncio.create_task(thread)
342+
343+
while True:
344+
event = await queue.get()
345+
if event is None:
346+
break
347+
348+
yield event
349+
signal.set()
350+
351+
await task
352+
353+
def _stream(
354+
self,
355+
callback: Callable[..., None],
356+
messages: Messages,
357+
tool_specs: Optional[list[ToolSpec]] = None,
358+
system_prompt: Optional[str] = None,
359+
) -> None:
360+
"""Stream conversation with the Bedrock model.
361+
362+
This method operates in a separate thread to avoid blocking the async event loop with the call to
363+
Bedrock's converse_stream.
364+
365+
Args:
366+
callback: Function to send events to the main thread.
367+
messages: List of message objects to be processed by the model.
368+
tool_specs: List of tool specifications to make available to the model.
369+
system_prompt: System prompt to provide context to the model.
331370
332371
Raises:
333372
ContextWindowOverflowException: If the input exceeds the model's context window.
@@ -343,7 +382,6 @@ async def stream(
343382
try:
344383
logger.debug("got response from model")
345384
if streaming:
346-
# Streaming implementation
347385
response = self.client.converse_stream(**request)
348386
for chunk in response["stream"]:
349387
if (
@@ -354,33 +392,29 @@ async def stream(
354392
guardrail_data = chunk["metadata"]["trace"]["guardrail"]
355393
if self._has_blocked_guardrail(guardrail_data):
356394
for event in self._generate_redaction_events():
357-
yield event
358-
yield self.format_chunk(chunk)
395+
callback(event)
396+
397+
callback(chunk)
398+
359399
else:
360-
# Non-streaming implementation
361400
response = self.client.converse(**request)
362-
363-
# Convert and yield from the response
364401
for event in self._convert_non_streaming_to_streaming(response):
365-
yield event
402+
callback(event)
366403

367-
# Check for guardrail triggers after yielding any events (same as streaming path)
368404
if (
369405
"trace" in response
370406
and "guardrail" in response["trace"]
371407
and self._has_blocked_guardrail(response["trace"]["guardrail"])
372408
):
373409
for event in self._generate_redaction_events():
374-
yield event
410+
callback(event)
375411

376412
except ClientError as e:
377413
error_message = str(e)
378414

379-
# Handle throttling error
380415
if e.response["Error"]["Code"] == "ThrottlingException":
381416
raise ModelThrottledException(error_message) from e
382417

383-
# Handle context window overflow
384418
if any(overflow_message in error_message for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES):
385419
logger.warning("bedrock threw context window overflow error")
386420
raise ContextWindowOverflowException(e) from e
@@ -411,10 +445,11 @@ async def stream(
411445
"https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#on-demand-throughput-isnt-supported"
412446
)
413447

414-
# Otherwise raise the error
415448
raise e
416449

417-
logger.debug("finished streaming response from model")
450+
finally:
451+
callback()
452+
logger.debug("finished streaming response from model")
418453

419454
def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
420455
"""Convert a non-streaming response to the streaming format.

src/strands/tools/executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ async def work(
5858
async for event in handler(tool_use):
5959
worker_queue.put_nowait((worker_id, event))
6060
await worker_event.wait()
61+
worker_event.clear()
6162

6263
result = cast(ToolResult, event)
6364
finally:

tests/strands/models/test_bedrock.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -398,13 +398,6 @@ def test_format_request_cache(model, messages, model_id, tool_spec, cache_type):
398398
assert tru_request == exp_request
399399

400400

401-
def test_format_chunk(model):
402-
tru_chunk = model.format_chunk("event")
403-
exp_chunk = "event"
404-
405-
assert tru_chunk == exp_chunk
406-
407-
408401
@pytest.mark.asyncio
409402
async def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model, messages, alist):
410403
error_message = "Rate exceeded"

0 commit comments

Comments
 (0)