Skip to content

Commit a0f7c24

Browse files
authored
models - bedrock - threading (#411)
1 parent 98c5a37 commit a0f7c24

File tree

7 files changed

+96
-47
lines changed

7 files changed

+96
-47
lines changed

src/strands/models/bedrock.py

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
- Docs: https://aws.amazon.com/bedrock/
44
"""
55

6+
import asyncio
67
import json
78
import logging
89
import os
9-
from typing import Any, AsyncGenerator, Iterable, List, Literal, Optional, Type, TypeVar, Union, cast
10+
import threading
11+
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
1012

1113
import boto3
1214
from botocore.config import Config as BotocoreConfig
@@ -245,17 +247,6 @@ def format_request(
245247
),
246248
}
247249

248-
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
249-
"""Format the Bedrock response events into standardized message chunks.
250-
251-
Args:
252-
event: A response event from the Bedrock model.
253-
254-
Returns:
255-
The formatted chunk.
256-
"""
257-
return cast(StreamEvent, event)
258-
259250
def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
260251
"""Check if guardrail data contains any blocked policies.
261252
@@ -284,7 +275,7 @@ def _generate_redaction_events(self) -> list[StreamEvent]:
284275
Returns:
285276
List of redaction events to yield.
286277
"""
287-
events: List[StreamEvent] = []
278+
events: list[StreamEvent] = []
288279

289280
if self.config.get("guardrail_redact_input", True):
290281
logger.debug("Redacting user input due to guardrail.")
@@ -327,7 +318,55 @@ async def stream(
327318
system_prompt: System prompt to provide context to the model.
328319
329320
Yields:
330-
Formatted message chunks from the model.
321+
Model events.
322+
323+
Raises:
324+
ContextWindowOverflowException: If the input exceeds the model's context window.
325+
ModelThrottledException: If the model service is throttling requests.
326+
"""
327+
328+
def callback(event: Optional[StreamEvent] = None) -> None:
329+
loop.call_soon_threadsafe(queue.put_nowait, event)
330+
if event is None:
331+
return
332+
333+
signal.wait()
334+
signal.clear()
335+
336+
loop = asyncio.get_event_loop()
337+
queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue()
338+
signal = threading.Event()
339+
340+
thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt)
341+
task = asyncio.create_task(thread)
342+
343+
while True:
344+
event = await queue.get()
345+
if event is None:
346+
break
347+
348+
yield event
349+
signal.set()
350+
351+
await task
352+
353+
def _stream(
354+
self,
355+
callback: Callable[..., None],
356+
messages: Messages,
357+
tool_specs: Optional[list[ToolSpec]] = None,
358+
system_prompt: Optional[str] = None,
359+
) -> None:
360+
"""Stream conversation with the Bedrock model.
361+
362+
This method operates in a separate thread to avoid blocking the async event loop with the call to
363+
Bedrock's converse_stream.
364+
365+
Args:
366+
callback: Function to send events to the main thread.
367+
messages: List of message objects to be processed by the model.
368+
tool_specs: List of tool specifications to make available to the model.
369+
system_prompt: System prompt to provide context to the model.
331370
332371
Raises:
333372
ContextWindowOverflowException: If the input exceeds the model's context window.
@@ -343,7 +382,6 @@ async def stream(
343382
try:
344383
logger.debug("got response from model")
345384
if streaming:
346-
# Streaming implementation
347385
response = self.client.converse_stream(**request)
348386
for chunk in response["stream"]:
349387
if (
@@ -354,33 +392,29 @@ async def stream(
354392
guardrail_data = chunk["metadata"]["trace"]["guardrail"]
355393
if self._has_blocked_guardrail(guardrail_data):
356394
for event in self._generate_redaction_events():
357-
yield event
358-
yield self.format_chunk(chunk)
395+
callback(event)
396+
397+
callback(chunk)
398+
359399
else:
360-
# Non-streaming implementation
361400
response = self.client.converse(**request)
362-
363-
# Convert and yield from the response
364401
for event in self._convert_non_streaming_to_streaming(response):
365-
yield event
402+
callback(event)
366403

367-
# Check for guardrail triggers after yielding any events (same as streaming path)
368404
if (
369405
"trace" in response
370406
and "guardrail" in response["trace"]
371407
and self._has_blocked_guardrail(response["trace"]["guardrail"])
372408
):
373409
for event in self._generate_redaction_events():
374-
yield event
410+
callback(event)
375411

376412
except ClientError as e:
377413
error_message = str(e)
378414

379-
# Handle throttling error
380415
if e.response["Error"]["Code"] == "ThrottlingException":
381416
raise ModelThrottledException(error_message) from e
382417

383-
# Handle context window overflow
384418
if any(overflow_message in error_message for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES):
385419
logger.warning("bedrock threw context window overflow error")
386420
raise ContextWindowOverflowException(e) from e
@@ -411,10 +445,11 @@ async def stream(
411445
"https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#on-demand-throughput-isnt-supported"
412446
)
413447

414-
# Otherwise raise the error
415448
raise e
416449

417-
logger.debug("finished streaming response from model")
450+
finally:
451+
callback()
452+
logger.debug("finished streaming response from model")
418453

419454
def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
420455
"""Convert a non-streaming response to the streaming format.

src/strands/tools/executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ async def work(
5858
async for event in handler(tool_use):
5959
worker_queue.put_nowait((worker_id, event))
6060
await worker_event.wait()
61+
worker_event.clear()
6162

6263
result = cast(ToolResult, event)
6364
finally:

tests/strands/models/test_bedrock.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -398,13 +398,6 @@ def test_format_request_cache(model, messages, model_id, tool_spec, cache_type):
398398
assert tru_request == exp_request
399399

400400

401-
def test_format_chunk(model):
402-
tru_chunk = model.format_chunk("event")
403-
exp_chunk = "event"
404-
405-
assert tru_chunk == exp_chunk
406-
407-
408401
@pytest.mark.asyncio
409402
async def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model, messages, alist):
410403
error_message = "Rate exceeded"

tests_integ/models/test_model_anthropic.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22

3+
import pydantic
34
import pytest
4-
from pydantic import BaseModel
55

66
import strands
77
from strands import Agent
@@ -48,7 +48,7 @@ def agent(model, tools, system_prompt):
4848

4949
@pytest.fixture
5050
def weather():
51-
class Weather(BaseModel):
51+
class Weather(pydantic.BaseModel):
5252
"""Extracts the time and weather from the user's message with the exact strings."""
5353

5454
time: str
@@ -59,11 +59,16 @@ class Weather(BaseModel):
5959

6060
@pytest.fixture
6161
def yellow_color():
62-
class Color(BaseModel):
62+
class Color(pydantic.BaseModel):
6363
"""Describes a color."""
6464

6565
name: str
6666

67+
@pydantic.field_validator("name", mode="after")
68+
@classmethod
69+
def lower(_, value):
70+
return value.lower()
71+
6772
return Color(name="yellow")
6873

6974

tests_integ/models/test_model_bedrock.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
import pydantic
12
import pytest
2-
from pydantic import BaseModel
33

44
import strands
55
from strands import Agent
@@ -39,11 +39,16 @@ def non_streaming_agent(non_streaming_model, system_prompt):
3939

4040
@pytest.fixture
4141
def yellow_color():
42-
class Color(BaseModel):
42+
class Color(pydantic.BaseModel):
4343
"""Describes a color."""
4444

4545
name: str
4646

47+
@pydantic.field_validator("name", mode="after")
48+
@classmethod
49+
def lower(_, value):
50+
return value.lower()
51+
4752
return Color(name="yellow")
4853

4954

@@ -136,7 +141,7 @@ def calculator(expression: str) -> float:
136141
def test_structured_output_streaming(streaming_model):
137142
"""Test structured output with streaming model."""
138143

139-
class Weather(BaseModel):
144+
class Weather(pydantic.BaseModel):
140145
time: str
141146
weather: str
142147

@@ -151,7 +156,7 @@ class Weather(BaseModel):
151156
def test_structured_output_non_streaming(non_streaming_model):
152157
"""Test structured output with non-streaming model."""
153158

154-
class Weather(BaseModel):
159+
class Weather(pydantic.BaseModel):
155160
time: str
156161
weather: str
157162

tests_integ/models/test_model_litellm.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
import pydantic
12
import pytest
2-
from pydantic import BaseModel
33

44
import strands
55
from strands import Agent
@@ -31,11 +31,16 @@ def agent(model, tools):
3131

3232
@pytest.fixture
3333
def yellow_color():
34-
class Color(BaseModel):
34+
class Color(pydantic.BaseModel):
3535
"""Describes a color."""
3636

3737
name: str
3838

39+
@pydantic.field_validator("name", mode="after")
40+
@classmethod
41+
def lower(_, value):
42+
return value.lower()
43+
3944
return Color(name="yellow")
4045

4146

@@ -47,7 +52,7 @@ def test_agent(agent):
4752

4853

4954
def test_structured_output(model):
50-
class Weather(BaseModel):
55+
class Weather(pydantic.BaseModel):
5156
time: str
5257
weather: str
5358

tests_integ/models/test_model_openai.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22

3+
import pydantic
34
import pytest
4-
from pydantic import BaseModel
55

66
import strands
77
from strands import Agent, tool
@@ -42,7 +42,7 @@ def agent(model, tools):
4242

4343
@pytest.fixture
4444
def weather():
45-
class Weather(BaseModel):
45+
class Weather(pydantic.BaseModel):
4646
"""Extracts the time and weather from the user's message with the exact strings."""
4747

4848
time: str
@@ -53,11 +53,16 @@ class Weather(BaseModel):
5353

5454
@pytest.fixture
5555
def yellow_color():
56-
class Color(BaseModel):
56+
class Color(pydantic.BaseModel):
5757
"""Describes a color."""
5858

5959
name: str
6060

61+
@pydantic.field_validator("name", mode="after")
62+
@classmethod
63+
def lower(_, value):
64+
return value.lower()
65+
6166
return Color(name="yellow")
6267

6368

0 commit comments

Comments
 (0)