|
| 1 | +"""Hook registry system for managing event callbacks in the Strands Agent SDK. |
| 2 | +
|
| 3 | +This module provides the core infrastructure for the typed hook system, enabling |
| 4 | +composable extension of agent functionality through strongly-typed event callbacks. |
| 5 | +The registry manages the mapping between event types and their associated callback |
| 6 | +functions, supporting both individual callback registration and bulk registration |
| 7 | +via hook provider objects. |
| 8 | +""" |
| 9 | + |
| 10 | +from dataclasses import dataclass |
| 11 | +from typing import TYPE_CHECKING, Callable, Generator, Generic, Protocol, Type, TypeVar |
| 12 | + |
| 13 | +if TYPE_CHECKING: |
| 14 | + from ...agent import Agent |
| 15 | + |
| 16 | + |
| 17 | +@dataclass |
| 18 | +class HookEvent: |
| 19 | + """Base class for all hook events. |
| 20 | +
|
| 21 | + Attributes: |
| 22 | + agent: The agent instance that triggered this event. |
| 23 | + """ |
| 24 | + |
| 25 | + agent: "Agent" |
| 26 | + |
| 27 | + @property |
| 28 | + def should_reverse_callbacks(self) -> bool: |
| 29 | + """Determine if callbacks for this event should be invoked in reverse order. |
| 30 | +
|
| 31 | + Returns: |
| 32 | + False by default. Override to return True for events that should |
| 33 | + invoke callbacks in reverse order (e.g., cleanup/teardown events). |
| 34 | + """ |
| 35 | + return False |
| 36 | + |
| 37 | + |
| 38 | +T = TypeVar("T", bound=Callable) |
| 39 | +TEvent = TypeVar("TEvent", bound=HookEvent, contravariant=True) |
| 40 | + |
| 41 | + |
| 42 | +class HookProvider(Protocol): |
| 43 | + """Protocol for objects that provide hook callbacks to an agent. |
| 44 | +
|
| 45 | + Hook providers offer a composable way to extend agent functionality by |
| 46 | + subscribing to various events in the agent lifecycle. This protocol enables |
| 47 | + building reusable components that can hook into agent events. |
| 48 | +
|
| 49 | + Example: |
| 50 | + ```python |
| 51 | + class MyHookProvider(HookProvider): |
| 52 | + def register_hooks(self, registry: HookRegistry) -> None: |
| 53 | + hooks.add_callback(StartRequestEvent, self.on_request_start) |
| 54 | + hooks.add_callback(EndRequestEvent, self.on_request_end) |
| 55 | +
|
| 56 | + agent = Agent(hooks=[MyHookProvider()]) |
| 57 | + ``` |
| 58 | + """ |
| 59 | + |
| 60 | + def register_hooks(self, registry: "HookRegistry") -> None: |
| 61 | + """Register callback functions for specific event types. |
| 62 | +
|
| 63 | + Args: |
| 64 | + registry: The hook registry to register callbacks with. |
| 65 | + """ |
| 66 | + ... |
| 67 | + |
| 68 | + |
| 69 | +class HookCallback(Protocol, Generic[TEvent]): |
| 70 | + """Protocol for callback functions that handle hook events. |
| 71 | +
|
| 72 | + Hook callbacks are functions that receive a single strongly-typed event |
| 73 | + argument and perform some action in response. They should not return |
| 74 | + values and any exceptions they raise will propagate to the caller. |
| 75 | +
|
| 76 | + Example: |
| 77 | + ```python |
| 78 | + def my_callback(event: StartRequestEvent) -> None: |
| 79 | + print(f"Request started for agent: {event.agent.name}") |
| 80 | + ``` |
| 81 | + """ |
| 82 | + |
| 83 | + def __call__(self, event: TEvent) -> None: |
| 84 | + """Handle a hook event. |
| 85 | +
|
| 86 | + Args: |
| 87 | + event: The strongly-typed event to handle. |
| 88 | + """ |
| 89 | + ... |
| 90 | + |
| 91 | + |
| 92 | +class HookRegistry: |
| 93 | + """Registry for managing hook callbacks associated with event types. |
| 94 | +
|
| 95 | + The HookRegistry maintains a mapping of event types to callback functions |
| 96 | + and provides methods for registering callbacks and invoking them when |
| 97 | + events occur. |
| 98 | +
|
| 99 | + The registry handles callback ordering, including reverse ordering for |
| 100 | + cleanup events, and provides type-safe event dispatching. |
| 101 | + """ |
| 102 | + |
| 103 | + def __init__(self) -> None: |
| 104 | + """Initialize an empty hook registry.""" |
| 105 | + self._registered_callbacks: dict[Type, list[HookCallback]] = {} |
| 106 | + |
| 107 | + def add_callback(self, event_type: Type[TEvent], callback: HookCallback[TEvent]) -> None: |
| 108 | + """Register a callback function for a specific event type. |
| 109 | +
|
| 110 | + Args: |
| 111 | + event_type: The class type of events this callback should handle. |
| 112 | + callback: The callback function to invoke when events of this type occur. |
| 113 | +
|
| 114 | + Example: |
| 115 | + ```python |
| 116 | + def my_handler(event: StartRequestEvent): |
| 117 | + print("Request started") |
| 118 | +
|
| 119 | + registry.add_callback(StartRequestEvent, my_handler) |
| 120 | + ``` |
| 121 | + """ |
| 122 | + callbacks = self._registered_callbacks.setdefault(event_type, []) |
| 123 | + callbacks.append(callback) |
| 124 | + |
| 125 | + def add_hook(self, hook: HookProvider) -> None: |
| 126 | + """Register all callbacks from a hook provider. |
| 127 | +
|
| 128 | + This method allows bulk registration of callbacks by delegating to |
| 129 | + the hook provider's register_hooks method. This is the preferred |
| 130 | + way to register multiple related callbacks. |
| 131 | +
|
| 132 | + Args: |
| 133 | + hook: The hook provider containing callbacks to register. |
| 134 | +
|
| 135 | + Example: |
| 136 | + ```python |
| 137 | + class MyHooks(HookProvider): |
| 138 | + def register_hooks(self, registry: HookRegistry): |
| 139 | + registry.add_callback(StartRequestEvent, self.on_start) |
| 140 | + registry.add_callback(EndRequestEvent, self.on_end) |
| 141 | +
|
| 142 | + registry.add_hook(MyHooks()) |
| 143 | + ``` |
| 144 | + """ |
| 145 | + hook.register_hooks(self) |
| 146 | + |
| 147 | + def invoke_callbacks(self, event: TEvent) -> None: |
| 148 | + """Invoke all registered callbacks for the given event. |
| 149 | +
|
| 150 | + This method finds all callbacks registered for the event's type and |
| 151 | + invokes them in the appropriate order. For events with is_after_callback=True, |
| 152 | + callbacks are invoked in reverse registration order. |
| 153 | +
|
| 154 | + Args: |
| 155 | + event: The event to dispatch to registered callbacks. |
| 156 | +
|
| 157 | + Raises: |
| 158 | + Any exceptions raised by callback functions will propagate to the caller. |
| 159 | +
|
| 160 | + Example: |
| 161 | + ```python |
| 162 | + event = StartRequestEvent(agent=my_agent) |
| 163 | + registry.invoke_callbacks(event) |
| 164 | + ``` |
| 165 | + """ |
| 166 | + for callback in self.get_callbacks_for(event): |
| 167 | + callback(event) |
| 168 | + |
| 169 | + def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], None, None]: |
| 170 | + """Get callbacks registered for the given event in the appropriate order. |
| 171 | +
|
| 172 | + This method returns callbacks in registration order for normal events, |
| 173 | + or reverse registration order for events that have is_after_callback=True. |
| 174 | + This enables proper cleanup ordering for teardown events. |
| 175 | +
|
| 176 | + Args: |
| 177 | + event: The event to get callbacks for. |
| 178 | +
|
| 179 | + Yields: |
| 180 | + Callback functions registered for this event type, in the appropriate order. |
| 181 | +
|
| 182 | + Example: |
| 183 | + ```python |
| 184 | + event = EndRequestEvent(agent=my_agent) |
| 185 | + for callback in registry.get_callbacks_for(event): |
| 186 | + callback(event) |
| 187 | + ``` |
| 188 | + """ |
| 189 | + event_type = type(event) |
| 190 | + |
| 191 | + callbacks = self._registered_callbacks.get(event_type, []) |
| 192 | + if event.should_reverse_callbacks: |
| 193 | + yield from reversed(callbacks) |
| 194 | + else: |
| 195 | + yield from callbacks |
0 commit comments