Skip to content

Commit 8cb8422

Browse files
author
Fede Kamelhar
committed
Modularizing Event Loop
1 parent 7042af1 commit 8cb8422

File tree

2 files changed

+135
-80
lines changed

2 files changed

+135
-80
lines changed

src/strands/event_loop/event_loop.py

Lines changed: 134 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,19 @@
2828

2929
logger = logging.getLogger(__name__)
3030

31+
MAX_ATTEMPTS = 6
32+
INITIAL_DELAY = 4
33+
MAX_DELAY = 240 # 4 minutes
3134

32-
def initialize_state(**kwargs: Any) -> Any:
35+
36+
def initialize_state(kwargs: Dict[str, Any]) -> Dict[str, Any]:
3337
"""Initialize the request state if not present.
3438
3539
Creates an empty request_state dictionary if one doesn't already exist in the
3640
provided keyword arguments.
3741
3842
Args:
39-
**kwargs: Keyword arguments that may contain a request_state.
43+
kwargs: Keyword arguments that may contain a request_state.
4044
4145
Returns:
4246
The updated kwargs dictionary with request_state initialized if needed.
@@ -51,7 +55,7 @@ def event_loop_cycle(
5155
system_prompt: Optional[str],
5256
messages: Messages,
5357
tool_config: Optional[ToolConfig],
54-
callback_handler: Any,
58+
callback_handler: Callable[..., Any],
5559
tool_handler: Optional[ToolHandler],
5660
tool_execution_handler: Optional[ParallelToolExecutorInterface] = None,
5761
**kwargs: Any,
@@ -103,7 +107,7 @@ def event_loop_cycle(
103107
event_loop_metrics: EventLoopMetrics = kwargs.get("event_loop_metrics", EventLoopMetrics())
104108

105109
# Initialize state and get cycle trace
106-
kwargs = initialize_state(**kwargs)
110+
kwargs = initialize_state(kwargs)
107111
cycle_start_time, cycle_trace = event_loop_metrics.start_cycle()
108112
kwargs["event_loop_cycle_trace"] = cycle_trace
109113

@@ -130,9 +134,9 @@ def event_loop_cycle(
130134
stop_reason: StopReason
131135
usage: Any
132136
metrics: Metrics
133-
max_attempts = 6
134-
initial_delay = 4
135-
max_delay = 240 # 4 minutes
137+
max_attempts = MAX_ATTEMPTS
138+
initial_delay = INITIAL_DELAY
139+
max_delay = MAX_DELAY
136140
current_delay = initial_delay
137141

138142
# Retry loop for handling throttling exceptions
@@ -204,80 +208,29 @@ def event_loop_cycle(
204208

205209
# If the model is requesting to use tools
206210
if stop_reason == "tool_use":
207-
tool_uses: List[ToolUse] = []
208-
tool_results: List[ToolResult] = []
209-
invalid_tool_use_ids: List[str] = []
210-
211-
# Extract and validate tools
212-
validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
213-
214-
# Check if tools are available for execution
215-
if tool_uses:
216-
if tool_handler is None:
217-
raise ValueError("toolUse present but tool handler not set")
218-
if tool_config is None:
219-
raise ValueError("toolUse present but tool config not set")
220-
221-
# Create the tool handler process callable
222-
tool_handler_process: Callable[[ToolUse], ToolResult] = partial(
223-
tool_handler.process,
224-
messages=messages,
225-
model=model,
226-
system_prompt=system_prompt,
227-
tool_config=tool_config,
228-
callback_handler=callback_handler,
229-
**kwargs,
230-
)
231-
232-
# Execute tools (parallel or sequential)
233-
run_tools(
234-
handler=tool_handler_process,
235-
tool_uses=tool_uses,
236-
event_loop_metrics=event_loop_metrics,
237-
request_state=cast(Any, kwargs["request_state"]),
238-
invalid_tool_use_ids=invalid_tool_use_ids,
239-
tool_results=tool_results,
240-
cycle_trace=cycle_trace,
241-
parent_span=cycle_span,
242-
parallel_tool_executor=tool_execution_handler,
211+
if not tool_handler:
212+
raise EventLoopException(
213+
"Model requested tool use but no tool handler provided",
214+
kwargs["request_state"],
243215
)
244216

245-
# Update state for the next cycle
246-
kwargs = prepare_next_cycle(kwargs, event_loop_metrics)
247-
248-
# Create the tool result message
249-
tool_result_message: Message = {
250-
"role": "user",
251-
"content": [{"toolResult": result} for result in tool_results],
252-
}
253-
messages.append(tool_result_message)
254-
callback_handler(message=tool_result_message)
255-
256-
if cycle_span:
257-
tracer.end_event_loop_cycle_span(
258-
span=cycle_span, message=message, tool_result_message=tool_result_message
259-
)
260-
261-
# Check if we should stop the event loop
262-
if kwargs["request_state"].get("stop_event_loop"):
263-
event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
264-
return (
265-
stop_reason,
266-
message,
267-
event_loop_metrics,
268-
kwargs["request_state"],
269-
)
270-
271-
# Recursive call to continue the conversation
272-
return recurse_event_loop(
273-
model=model,
274-
system_prompt=system_prompt,
275-
messages=messages,
276-
tool_config=tool_config,
277-
callback_handler=callback_handler,
278-
tool_handler=tool_handler,
279-
**kwargs,
280-
)
217+
# Handle tool execution
218+
return _handle_tool_execution(
219+
stop_reason,
220+
message,
221+
model,
222+
system_prompt,
223+
messages,
224+
tool_config,
225+
tool_handler,
226+
callback_handler,
227+
tool_execution_handler,
228+
event_loop_metrics,
229+
cycle_trace,
230+
cycle_span,
231+
cycle_start_time,
232+
kwargs,
233+
)
281234

282235
# End the cycle and return results
283236
event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
@@ -377,3 +330,105 @@ def prepare_next_cycle(kwargs: Dict[str, Any], event_loop_metrics: EventLoopMetr
377330
kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"]
378331

379332
return kwargs
333+
334+
335+
def _handle_tool_execution(
336+
stop_reason: StopReason,
337+
message: Message,
338+
model: Model,
339+
system_prompt: Optional[str],
340+
messages: Messages,
341+
tool_config: ToolConfig,
342+
tool_handler: ToolHandler,
343+
callback_handler: Callable[..., Any],
344+
tool_execution_handler: Optional[ParallelToolExecutorInterface],
345+
event_loop_metrics: EventLoopMetrics,
346+
cycle_trace: Trace,
347+
cycle_span: Any,
348+
cycle_start_time: float,
349+
kwargs: Dict[str, Any],
350+
) -> Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]:
351+
tool_uses: List[ToolUse] = []
352+
tool_results: List[ToolResult] = []
353+
invalid_tool_use_ids: List[str] = []
354+
355+
"""
356+
Handles the execution of tools requested by the model during an event loop cycle.
357+
358+
Args:
359+
stop_reason (StopReason): The reason the model stopped generating.
360+
message (Message): The message from the model that may contain tool use requests.
361+
model (Model): The model provider instance.
362+
system_prompt (Optional[str]): The system prompt instructions for the model.
363+
messages (Messages): The conversation history messages.
364+
tool_config (ToolConfig): Configuration for available tools.
365+
tool_handler (ToolHandler): Handler for tool execution.
366+
callback_handler (Callable[..., Any]): Callback for processing events as they happen.
367+
tool_execution_handler (Optional[ParallelToolExecutorInterface]): Optional handler for parallel tool execution.
368+
event_loop_metrics (EventLoopMetrics): Metrics tracking object for the event loop.
369+
cycle_trace (Trace): Trace object for the current event loop cycle.
370+
cycle_span (Any): Span object for tracing the cycle (type may vary).
371+
cycle_start_time (float): Start time of the current cycle.
372+
kwargs (Dict[str, Any]): Additional keyword arguments, including request state.
373+
374+
Returns:
375+
Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]:
376+
- The stop reason,
377+
- The updated message,
378+
- The updated event loop metrics,
379+
- The updated request state.
380+
"""
381+
validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
382+
383+
if not tool_uses:
384+
return stop_reason, message, event_loop_metrics, kwargs["request_state"]
385+
386+
tool_handler_process = partial(
387+
tool_handler.process,
388+
messages=messages,
389+
model=model,
390+
system_prompt=system_prompt,
391+
tool_config=tool_config,
392+
callback_handler=callback_handler,
393+
**kwargs,
394+
)
395+
396+
run_tools(
397+
handler=tool_handler_process,
398+
tool_uses=tool_uses,
399+
event_loop_metrics=event_loop_metrics,
400+
request_state=cast(Any, kwargs["request_state"]),
401+
invalid_tool_use_ids=invalid_tool_use_ids,
402+
tool_results=tool_results,
403+
cycle_trace=cycle_trace,
404+
parent_span=cycle_span,
405+
parallel_tool_executor=tool_execution_handler,
406+
)
407+
408+
kwargs = prepare_next_cycle(kwargs, event_loop_metrics)
409+
410+
tool_result_message: Message = {
411+
"role": "user",
412+
"content": [{"toolResult": result} for result in tool_results],
413+
}
414+
415+
messages.append(tool_result_message)
416+
callback_handler(message=tool_result_message)
417+
418+
if cycle_span:
419+
tracer = get_tracer()
420+
tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message)
421+
422+
if kwargs["request_state"].get("stop_event_loop", False):
423+
event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
424+
return stop_reason, message, event_loop_metrics, kwargs["request_state"]
425+
426+
return recurse_event_loop(
427+
model=model,
428+
system_prompt=system_prompt,
429+
messages=messages,
430+
tool_config=tool_config,
431+
callback_handler=callback_handler,
432+
tool_handler=tool_handler,
433+
**kwargs,
434+
)

tests/strands/event_loop/test_event_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def mock_tracer():
118118
],
119119
)
120120
def test_initialize_state(kwargs, exp_state):
121-
kwargs = strands.event_loop.event_loop.initialize_state(**kwargs)
121+
kwargs = strands.event_loop.event_loop.initialize_state(kwargs)
122122

123123
tru_state = kwargs["request_state"]
124124

0 commit comments

Comments
 (0)