Skip to content

feat(event_loop): make event loop settings configurable #288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ..tools.registry import ToolRegistry
from ..tools.watcher import ToolWatcher
from ..types.content import ContentBlock, Message, Messages
from ..types.event_loop import EventLoopConfig
from ..types.exceptions import ContextWindowOverflowException
from ..types.models import Model
from ..types.tools import ToolResult, ToolUse
Expand Down Expand Up @@ -198,6 +199,7 @@ def __init__(
record_direct_tool_call: bool = True,
load_tools_from_directory: bool = True,
trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
event_loop_config: Optional[EventLoopConfig] = None,
*,
name: Optional[str] = None,
description: Optional[str] = None,
Expand Down Expand Up @@ -232,6 +234,8 @@ def __init__(
load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory.
Defaults to True.
trace_attributes: Custom trace attributes to apply to the agent's trace span.
event_loop_config: Configuration for the event loop behavior.
If None, default values will be used.
name: name of the Agent
Defaults to None.
description: description of what the Agent does
Expand Down Expand Up @@ -281,6 +285,7 @@ def __init__(
self.tool_watcher = ToolWatcher(tool_registry=self.tool_registry)

self.event_loop_metrics = EventLoopMetrics()
self.event_loop_config = event_loop_config

# Initialize tracer instance (no-op if not configured)
self.tracer = get_tracer()
Expand Down
18 changes: 9 additions & 9 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ..telemetry.tracer import get_tracer
from ..tools.executor import run_tools, validate_and_prepare_tools
from ..types.content import Message
from ..types.event_loop import EventLoopConfig
from ..types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException
from ..types.streaming import Metrics, StopReason
from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse
Expand All @@ -36,10 +37,6 @@

logger = logging.getLogger(__name__)

MAX_ATTEMPTS = 6
INITIAL_DELAY = 4
MAX_DELAY = 240 # 4 minutes


async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
"""Execute a single cycle of the event loop.
Expand Down Expand Up @@ -108,9 +105,12 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener
usage: Any
metrics: Metrics

# Get event loop configuration or use defaults
event_loop_config = agent.event_loop_config or EventLoopConfig()

# Retry loop for handling throttling exceptions
current_delay = INITIAL_DELAY
for attempt in range(MAX_ATTEMPTS):
current_delay = event_loop_config.initial_delay
for attempt in range(event_loop_config.max_attempts):
model_id = agent.model.config.get("model_id") if hasattr(agent.model, "config") else None
model_invoke_span = tracer.start_model_invoke_span(
messages=agent.messages,
Expand Down Expand Up @@ -162,7 +162,7 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener
)

if isinstance(e, ModelThrottledException):
if attempt + 1 == MAX_ATTEMPTS:
if attempt + 1 == event_loop_config.max_attempts:
yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}}
raise e

Expand All @@ -171,11 +171,11 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener
"| throttling exception encountered "
"| delaying before next retry",
current_delay,
MAX_ATTEMPTS,
event_loop_config.max_attempts,
attempt + 1,
)
time.sleep(current_delay)
current_delay = min(current_delay * 2, MAX_DELAY)
current_delay = min(current_delay * 2, event_loop_config.max_delay)

yield {"callback": {"event_loop_throttled_delay": current_delay, **kwargs}}
else:
Expand Down
18 changes: 18 additions & 0 deletions src/strands/types/event_loop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Event loop-related type definitions for the SDK."""

from dataclasses import dataclass
from typing import Literal

from typing_extensions import TypedDict
Expand Down Expand Up @@ -46,3 +47,20 @@ class Metrics(TypedDict):
- "stop_sequence": Stop sequence encountered
- "tool_use": Model requested to use a tool
"""


@dataclass(frozen=True)
class EventLoopConfig:
"""Configuration for the event loop behavior.

This class defines the configuration parameters for the event loop's retry and throttling behavior.

Attributes:
max_attempts: Maximum number of retry attempts for throttled requests (default: 6)
initial_delay: Initial delay in seconds before retrying (default: 4)
max_delay: Maximum delay in seconds between retries (default: 240)
"""

max_attempts: int = 6
initial_delay: int = 4
max_delay: int = 240
43 changes: 43 additions & 0 deletions tests/strands/event_loop/test_event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from strands.telemetry.metrics import EventLoopMetrics
from strands.tools.registry import ToolRegistry
from strands.types.event_loop import EventLoopConfig
from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException
from tests.fixtures.mock_hook_provider import MockHookProvider

Expand Down Expand Up @@ -133,6 +134,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis
mock.tool_registry = tool_registry
mock.thread_pool = thread_pool
mock.event_loop_metrics = EventLoopMetrics()
mock.event_loop_config = EventLoopConfig()
mock._hooks = hook_registry

return mock
Expand Down Expand Up @@ -709,6 +711,47 @@ async def test_request_state_initialization(alist):
assert tru_request_state == initial_request_state


@pytest.mark.asyncio
async def test_event_loop_with_custom_config(mock_time, agent, model, agenerator, alist):
"""Test that the event loop uses custom configuration values when provided."""
# Set up the model to raise throttling exceptions multiple times before succeeding
model.stream.side_effect = [
ModelThrottledException("ThrottlingException"),
ModelThrottledException("ThrottlingException"),
agenerator(
[
{"contentBlockDelta": {"delta": {"text": "test text"}}},
{"contentBlockStop": {}},
]
),
]

# Create a custom config with different values
custom_config = EventLoopConfig(
max_attempts=3,
initial_delay=2,
max_delay=10,
)
agent.event_loop_config = custom_config

stream = strands.event_loop.event_loop.event_loop_cycle(
agent=agent,
kwargs={},
)
events = await alist(stream)
tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"]

# Verify the final response
assert tru_stop_reason == "end_turn"
assert tru_message == {"role": "assistant", "content": [{"text": "test text"}]}
assert tru_request_state == {}

# Verify that sleep was called with the custom delay values
# Initial delay is 2, then 4 (doubled but less than max_delay)
assert mock_time.sleep.call_count == 2
assert mock_time.sleep.call_args_list == [call(2), call(4)]


@pytest.mark.asyncio
async def test_prepare_next_cycle_in_tool_execution(agent, model, tool_stream, agenerator, alist):
"""Test that cycle ID and metrics are properly updated during tool execution."""
Expand Down