Skip to content

Commit 1421aad

Browse files
authored
feat: Implement the core system of typed hooks & callbacks (#304)
Relates to #231 Add the HookRegistry and a small subset of events (AgentInitializedEvent, StartRequestEvent, EndRequestEvent) as a POC for how hooks will work.
1 parent d601615 commit 1421aad

File tree

10 files changed

+581
-12
lines changed

10 files changed

+581
-12
lines changed

src/strands/agent/agent.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pydantic import BaseModel
2121

2222
from ..event_loop.event_loop import event_loop_cycle
23+
from ..experimental.hooks import AgentInitializedEvent, EndRequestEvent, HookRegistry, StartRequestEvent
2324
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
2425
from ..handlers.tool_handler import AgentToolHandler
2526
from ..models.bedrock import BedrockModel
@@ -308,6 +309,10 @@ def __init__(
308309
self.name = name
309310
self.description = description
310311

312+
self._hooks = HookRegistry()
313+
# Register built-in hook providers (like ConversationManager) here
314+
self._hooks.invoke_callbacks(AgentInitializedEvent(agent=self))
315+
311316
@property
312317
def tool(self) -> ToolCaller:
313318
"""Call tool as a function.
@@ -405,21 +410,26 @@ def structured_output(self, output_model: Type[T], prompt: Optional[str] = None)
405410
that the agent will use when responding.
406411
prompt: The prompt to use for the agent.
407412
"""
408-
messages = self.messages
409-
if not messages and not prompt:
410-
raise ValueError("No conversation history or prompt provided")
413+
self._hooks.invoke_callbacks(StartRequestEvent(agent=self))
411414

412-
# add the prompt as the last message
413-
if prompt:
414-
messages.append({"role": "user", "content": [{"text": prompt}]})
415+
try:
416+
messages = self.messages
417+
if not messages and not prompt:
418+
raise ValueError("No conversation history or prompt provided")
415419

416-
# get the structured output from the model
417-
events = self.model.structured_output(output_model, messages)
418-
for event in events:
419-
if "callback" in event:
420-
self.callback_handler(**cast(dict, event["callback"]))
420+
# add the prompt as the last message
421+
if prompt:
422+
messages.append({"role": "user", "content": [{"text": prompt}]})
421423

422-
return event["output"]
424+
# get the structured output from the model
425+
events = self.model.structured_output(output_model, messages)
426+
for event in events:
427+
if "callback" in event:
428+
self.callback_handler(**cast(dict, event["callback"]))
429+
430+
return event["output"]
431+
finally:
432+
self._hooks.invoke_callbacks(EndRequestEvent(agent=self))
423433

424434
async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
425435
"""Process a natural language prompt and yield events as an async iterator.
@@ -473,6 +483,8 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
473483

474484
def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> Generator[dict[str, Any], None, None]:
475485
"""Execute the agent's event loop with the given prompt and parameters."""
486+
self._hooks.invoke_callbacks(StartRequestEvent(agent=self))
487+
476488
try:
477489
# Extract key parameters
478490
yield {"callback": {"init_event_loop": True, **kwargs}}
@@ -487,6 +499,7 @@ def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> Generator[dict[str,
487499

488500
finally:
489501
self.conversation_manager.apply_management(self)
502+
self._hooks.invoke_callbacks(EndRequestEvent(agent=self))
490503

491504
def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> Generator[dict[str, Any], None, None]:
492505
"""Execute the event loop cycle with retry logic for context window limits.

src/strands/experimental/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""Experimental features.
2+
3+
This module implements experimental features that are subject to change in future revisions without notice.
4+
"""
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Typed hook system for extending agent functionality.
2+
3+
This module provides a composable mechanism for building objects that can hook
4+
into specific events during the agent lifecycle. The hook system enables both
5+
built-in SDK components and user code to react to or modify agent behavior
6+
through strongly-typed event callbacks.
7+
8+
Example Usage:
9+
```python
10+
from strands.hooks import HookProvider, HookRegistry
11+
from strands.hooks.events import StartRequestEvent, EndRequestEvent
12+
13+
class LoggingHooks(HookProvider):
14+
def register_hooks(self, registry: HookRegistry) -> None:
15+
registry.add_callback(StartRequestEvent, self.log_start)
16+
registry.add_callback(EndRequestEvent, self.log_end)
17+
18+
def log_start(self, event: StartRequestEvent) -> None:
19+
print(f"Request started for {event.agent.name}")
20+
21+
def log_end(self, event: EndRequestEvent) -> None:
22+
print(f"Request completed for {event.agent.name}")
23+
24+
# Use with agent
25+
agent = Agent(hooks=[LoggingHooks()])
26+
```
27+
28+
This replaces the older callback_handler approach with a more composable,
29+
type-safe system that supports multiple subscribers per event type.
30+
"""
31+
32+
from .events import AgentInitializedEvent, EndRequestEvent, StartRequestEvent
33+
from .registry import HookCallback, HookEvent, HookProvider, HookRegistry
34+
35+
__all__ = [
36+
"AgentInitializedEvent",
37+
"StartRequestEvent",
38+
"EndRequestEvent",
39+
"HookEvent",
40+
"HookProvider",
41+
"HookCallback",
42+
"HookRegistry",
43+
]
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Hook events emitted as part of invoking Agents.
2+
3+
This module defines the events that are emitted as Agents run through the lifecycle of a request.
4+
"""
5+
6+
from dataclasses import dataclass
7+
8+
from .registry import HookEvent
9+
10+
11+
@dataclass
12+
class AgentInitializedEvent(HookEvent):
13+
"""Event triggered when an agent has finished initialization.
14+
15+
This event is fired after the agent has been fully constructed and all
16+
built-in components have been initialized. Hook providers can use this
17+
event to perform setup tasks that require a fully initialized agent.
18+
"""
19+
20+
pass
21+
22+
23+
@dataclass
24+
class StartRequestEvent(HookEvent):
25+
"""Event triggered at the beginning of a new agent request.
26+
27+
This event is fired when the agent begins processing a new user request,
28+
before any model inference or tool execution occurs. Hook providers can
29+
use this event to perform request-level setup, logging, or validation.
30+
31+
This event is triggered at the beginning of the following api calls:
32+
- Agent.__call__
33+
- Agent.stream_async
34+
- Agent.structured_output
35+
"""
36+
37+
pass
38+
39+
40+
@dataclass
41+
class EndRequestEvent(HookEvent):
42+
"""Event triggered at the end of an agent request.
43+
44+
This event is fired after the agent has completed processing a request,
45+
regardless of whether it completed successfully or encountered an error.
46+
Hook providers can use this event for cleanup, logging, or state persistence.
47+
48+
Note: This event uses reverse callback ordering, meaning callbacks registered
49+
later will be invoked first during cleanup.
50+
51+
This event is triggered at the end of the following api calls:
52+
- Agent.__call__
53+
- Agent.stream_async
54+
- Agent.structured_output
55+
"""
56+
57+
@property
58+
def should_reverse_callbacks(self) -> bool:
59+
"""Return True to invoke callbacks in reverse order for proper cleanup.
60+
61+
Returns:
62+
True, indicating callbacks should be invoked in reverse order.
63+
"""
64+
return True
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
"""Hook registry system for managing event callbacks in the Strands Agent SDK.
2+
3+
This module provides the core infrastructure for the typed hook system, enabling
4+
composable extension of agent functionality through strongly-typed event callbacks.
5+
The registry manages the mapping between event types and their associated callback
6+
functions, supporting both individual callback registration and bulk registration
7+
via hook provider objects.
8+
"""
9+
10+
from dataclasses import dataclass
11+
from typing import TYPE_CHECKING, Callable, Generator, Generic, Protocol, Type, TypeVar
12+
13+
if TYPE_CHECKING:
14+
from ...agent import Agent
15+
16+
17+
@dataclass
18+
class HookEvent:
19+
"""Base class for all hook events.
20+
21+
Attributes:
22+
agent: The agent instance that triggered this event.
23+
"""
24+
25+
agent: "Agent"
26+
27+
@property
28+
def should_reverse_callbacks(self) -> bool:
29+
"""Determine if callbacks for this event should be invoked in reverse order.
30+
31+
Returns:
32+
False by default. Override to return True for events that should
33+
invoke callbacks in reverse order (e.g., cleanup/teardown events).
34+
"""
35+
return False
36+
37+
38+
T = TypeVar("T", bound=Callable)
39+
TEvent = TypeVar("TEvent", bound=HookEvent, contravariant=True)
40+
41+
42+
class HookProvider(Protocol):
43+
"""Protocol for objects that provide hook callbacks to an agent.
44+
45+
Hook providers offer a composable way to extend agent functionality by
46+
subscribing to various events in the agent lifecycle. This protocol enables
47+
building reusable components that can hook into agent events.
48+
49+
Example:
50+
```python
51+
class MyHookProvider(HookProvider):
52+
def register_hooks(self, registry: HookRegistry) -> None:
53+
hooks.add_callback(StartRequestEvent, self.on_request_start)
54+
hooks.add_callback(EndRequestEvent, self.on_request_end)
55+
56+
agent = Agent(hooks=[MyHookProvider()])
57+
```
58+
"""
59+
60+
def register_hooks(self, registry: "HookRegistry") -> None:
61+
"""Register callback functions for specific event types.
62+
63+
Args:
64+
registry: The hook registry to register callbacks with.
65+
"""
66+
...
67+
68+
69+
class HookCallback(Protocol, Generic[TEvent]):
70+
"""Protocol for callback functions that handle hook events.
71+
72+
Hook callbacks are functions that receive a single strongly-typed event
73+
argument and perform some action in response. They should not return
74+
values and any exceptions they raise will propagate to the caller.
75+
76+
Example:
77+
```python
78+
def my_callback(event: StartRequestEvent) -> None:
79+
print(f"Request started for agent: {event.agent.name}")
80+
```
81+
"""
82+
83+
def __call__(self, event: TEvent) -> None:
84+
"""Handle a hook event.
85+
86+
Args:
87+
event: The strongly-typed event to handle.
88+
"""
89+
...
90+
91+
92+
class HookRegistry:
93+
"""Registry for managing hook callbacks associated with event types.
94+
95+
The HookRegistry maintains a mapping of event types to callback functions
96+
and provides methods for registering callbacks and invoking them when
97+
events occur.
98+
99+
The registry handles callback ordering, including reverse ordering for
100+
cleanup events, and provides type-safe event dispatching.
101+
"""
102+
103+
def __init__(self) -> None:
104+
"""Initialize an empty hook registry."""
105+
self._registered_callbacks: dict[Type, list[HookCallback]] = {}
106+
107+
def add_callback(self, event_type: Type[TEvent], callback: HookCallback[TEvent]) -> None:
108+
"""Register a callback function for a specific event type.
109+
110+
Args:
111+
event_type: The class type of events this callback should handle.
112+
callback: The callback function to invoke when events of this type occur.
113+
114+
Example:
115+
```python
116+
def my_handler(event: StartRequestEvent):
117+
print("Request started")
118+
119+
registry.add_callback(StartRequestEvent, my_handler)
120+
```
121+
"""
122+
callbacks = self._registered_callbacks.setdefault(event_type, [])
123+
callbacks.append(callback)
124+
125+
def add_hook(self, hook: HookProvider) -> None:
126+
"""Register all callbacks from a hook provider.
127+
128+
This method allows bulk registration of callbacks by delegating to
129+
the hook provider's register_hooks method. This is the preferred
130+
way to register multiple related callbacks.
131+
132+
Args:
133+
hook: The hook provider containing callbacks to register.
134+
135+
Example:
136+
```python
137+
class MyHooks(HookProvider):
138+
def register_hooks(self, registry: HookRegistry):
139+
registry.add_callback(StartRequestEvent, self.on_start)
140+
registry.add_callback(EndRequestEvent, self.on_end)
141+
142+
registry.add_hook(MyHooks())
143+
```
144+
"""
145+
hook.register_hooks(self)
146+
147+
def invoke_callbacks(self, event: TEvent) -> None:
148+
"""Invoke all registered callbacks for the given event.
149+
150+
This method finds all callbacks registered for the event's type and
151+
invokes them in the appropriate order. For events with is_after_callback=True,
152+
callbacks are invoked in reverse registration order.
153+
154+
Args:
155+
event: The event to dispatch to registered callbacks.
156+
157+
Raises:
158+
Any exceptions raised by callback functions will propagate to the caller.
159+
160+
Example:
161+
```python
162+
event = StartRequestEvent(agent=my_agent)
163+
registry.invoke_callbacks(event)
164+
```
165+
"""
166+
for callback in self.get_callbacks_for(event):
167+
callback(event)
168+
169+
def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], None, None]:
170+
"""Get callbacks registered for the given event in the appropriate order.
171+
172+
This method returns callbacks in registration order for normal events,
173+
or reverse registration order for events that have is_after_callback=True.
174+
This enables proper cleanup ordering for teardown events.
175+
176+
Args:
177+
event: The event to get callbacks for.
178+
179+
Yields:
180+
Callback functions registered for this event type, in the appropriate order.
181+
182+
Example:
183+
```python
184+
event = EndRequestEvent(agent=my_agent)
185+
for callback in registry.get_callbacks_for(event):
186+
callback(event)
187+
```
188+
"""
189+
event_type = type(event)
190+
191+
callbacks = self._registered_callbacks.get(event_type, [])
192+
if event.should_reverse_callbacks:
193+
yield from reversed(callbacks)
194+
else:
195+
yield from callbacks

0 commit comments

Comments
 (0)