Skip to content

chore: Inline event loop helper functions #222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 5 additions & 39 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,6 @@
MAX_DELAY = 240 # 4 minutes


def initialize_state(**kwargs: Any) -> Any:
"""Initialize the request state if not present.

Creates an empty request_state dictionary if one doesn't already exist in the
provided keyword arguments.

Args:
**kwargs: Keyword arguments that may contain a request_state.

Returns:
The updated kwargs dictionary with request_state initialized if needed.
"""
if "request_state" not in kwargs:
kwargs["request_state"] = {}
return kwargs


def event_loop_cycle(
model: Model,
system_prompt: Optional[str],
Expand Down Expand Up @@ -107,7 +90,8 @@ def event_loop_cycle(
event_loop_metrics: EventLoopMetrics = kwargs.get("event_loop_metrics", EventLoopMetrics())

# Initialize state and get cycle trace
kwargs = initialize_state(**kwargs)
if "request_state" not in kwargs:
kwargs["request_state"] = {}
cycle_start_time, cycle_trace = event_loop_metrics.start_cycle()
kwargs["event_loop_cycle_trace"] = cycle_trace

Expand Down Expand Up @@ -309,26 +293,6 @@ def recurse_event_loop(
)


def prepare_next_cycle(kwargs: Dict[str, Any], event_loop_metrics: EventLoopMetrics) -> Dict[str, Any]:
"""Prepare state for the next event loop cycle.

Updates the keyword arguments with the current event loop metrics and stores the current cycle ID as the parent
cycle ID for the next cycle. This maintains the parent-child relationship between cycles for tracing and metrics.

Args:
kwargs: Current keyword arguments containing event loop state.
event_loop_metrics: The metrics object tracking event loop execution.

Returns:
Updated keyword arguments ready for the next cycle.
"""
# Store parent cycle ID
kwargs["event_loop_metrics"] = event_loop_metrics
kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"]

return kwargs


def _handle_tool_execution(
stop_reason: StopReason,
message: Message,
Expand Down Expand Up @@ -402,7 +366,9 @@ def _handle_tool_execution(
parallel_tool_executor=tool_execution_handler,
)

kwargs = prepare_next_cycle(kwargs, event_loop_metrics)
# Store parent cycle ID for the next cycle
kwargs["event_loop_metrics"] = event_loop_metrics
kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"]

tool_result_message: Message = {
"role": "user",
Expand Down
107 changes: 73 additions & 34 deletions tests/strands/event_loop/test_event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,27 +104,6 @@ def mock_tracer():
return tracer


@pytest.mark.parametrize(
("kwargs", "exp_state"),
[
(
{"request_state": {"key1": "value1"}},
{"key1": "value1"},
),
(
{},
{},
),
],
)
def test_initialize_state(kwargs, exp_state):
kwargs = strands.event_loop.event_loop.initialize_state(**kwargs)

tru_state = kwargs["request_state"]

assert tru_state == exp_state


def test_event_loop_cycle_text_response(
model,
model_id,
Expand Down Expand Up @@ -411,19 +390,6 @@ def test_event_loop_cycle_stop(
assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state


def test_prepare_next_cycle():
kwargs = {"event_loop_cycle_id": "c1"}
event_loop_metrics = strands.telemetry.metrics.EventLoopMetrics()
tru_result = strands.event_loop.event_loop.prepare_next_cycle(kwargs, event_loop_metrics)
exp_result = {
"event_loop_cycle_id": "c1",
"event_loop_parent_cycle_id": "c1",
"event_loop_metrics": event_loop_metrics,
}

assert tru_result == exp_result


def test_cycle_exception(
model,
system_prompt,
Expand Down Expand Up @@ -679,3 +645,76 @@ def test_event_loop_cycle_with_parent_span(
mock_tracer.start_event_loop_cycle_span.assert_called_once_with(
event_loop_kwargs=unittest.mock.ANY, parent_span=parent_span, messages=messages
)


def test_request_state_initialization():
# Call without providing request_state
tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle(
model=MagicMock(),
model_id=MagicMock(),
system_prompt=MagicMock(),
messages=MagicMock(),
tool_config=MagicMock(),
callback_handler=MagicMock(),
tool_handler=MagicMock(),
tool_execution_handler=MagicMock(),
)

# Verify request_state was initialized to empty dict
assert tru_request_state == {}

# Call with pre-existing request_state
initial_request_state = {"key": "value"}
tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle(
model=MagicMock(),
model_id=MagicMock(),
system_prompt=MagicMock(),
messages=MagicMock(),
tool_config=MagicMock(),
callback_handler=MagicMock(),
tool_handler=MagicMock(),
request_state=initial_request_state,
)

# Verify existing request_state was preserved
assert tru_request_state == initial_request_state


def test_prepare_next_cycle_in_tool_execution(model, tool_stream):
"""Test that cycle ID and metrics are properly updated during tool execution."""
model.converse.side_effect = [
tool_stream,
[
{"contentBlockStop": {}},
],
]

# Create a mock for recurse_event_loop to capture the kwargs passed to it
with unittest.mock.patch.object(strands.event_loop.event_loop, "recurse_event_loop") as mock_recurse:
# Set up mock to return a valid response
mock_recurse.return_value = (
"end_turn",
{"role": "assistant", "content": [{"text": "test text"}]},
strands.telemetry.metrics.EventLoopMetrics(),
{},
)

# Call event_loop_cycle which should execute a tool and then call recurse_event_loop
strands.event_loop.event_loop.event_loop_cycle(
model=model,
model_id=MagicMock(),
system_prompt=MagicMock(),
messages=MagicMock(),
tool_config=MagicMock(),
callback_handler=MagicMock(),
tool_handler=MagicMock(),
tool_execution_handler=MagicMock(),
)

assert mock_recurse.called

# Verify required properties are present
recursive_kwargs = mock_recurse.call_args[1]
assert "event_loop_metrics" in recursive_kwargs
assert "event_loop_parent_cycle_id" in recursive_kwargs
assert recursive_kwargs["event_loop_parent_cycle_id"] == recursive_kwargs["event_loop_cycle_id"]