Skip to content

Commit 01ac8b3

Browse files
committed
feat:add BaseHookEvent for multiagent use
1 parent 3406133 commit 01ac8b3

File tree

2 files changed

+13
-14
lines changed

2 files changed

+13
-14
lines changed

src/strands/hooks/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def log_end(self, event: AfterInvocationEvent) -> None:
3939
BeforeToolCallEvent,
4040
MessageAddedEvent,
4141
)
42-
from .registry import HookCallback, HookEvent, HookProvider, HookRegistry,MultiAgentHookEvent
42+
from .registry import BaseHookEvent, HookCallback, HookEvent, HookProvider, HookRegistry
4343

4444
__all__ = [
4545
"AgentInitializedEvent",
@@ -54,5 +54,6 @@ def log_end(self, event: AfterInvocationEvent) -> None:
5454
"HookProvider",
5555
"HookCallback",
5656
"HookRegistry",
57-
"MultiAgentHookEvent"
57+
"HookEvent",
58+
"BaseHookEvent",
5859
]

src/strands/hooks/registry.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@
77
via hook provider objects.
88
"""
99

10-
from dataclasses import dataclass, field
10+
from dataclasses import dataclass
1111
from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar
1212

1313
if TYPE_CHECKING:
1414
from ..agent import Agent
1515

16+
1617
@dataclass
17-
class HookEventBase:
18+
class BaseHookEvent:
19+
"""Base class for all hook events."""
20+
1821
@property
1922
def should_reverse_callbacks(self) -> bool:
2023
"""Determine if callbacks for this event should be invoked in reverse order.
@@ -56,27 +59,22 @@ def __setattr__(self, name: str, value: Any) -> None:
5659

5760
raise AttributeError(f"Property {name} is not writable")
5861

59-
@dataclass
60-
class MultiAgentHookEvent(HookEventBase):
61-
pass
62-
6362

6463
@dataclass
65-
class HookEvent(HookEventBase):
66-
"""Base class for all hook events.
64+
class HookEvent(BaseHookEvent):
65+
"""Base class for single agent hook events.
6766
6867
Attributes:
6968
agent: The agent instance that triggered this event.
7069
"""
7170

72-
# agent: "Agent | None" = field(default=None, kw_only=True)
73-
agent : "Agent"
71+
agent: "Agent"
7472

7573

76-
TEvent = TypeVar("TEvent", bound=HookEventBase, contravariant=True)
74+
TEvent = TypeVar("TEvent", bound=BaseHookEvent, contravariant=True)
7775
"""Generic for adding callback handlers - contravariant to allow adding handlers which take in base classes."""
7876

79-
TInvokeEvent = TypeVar("TInvokeEvent", bound=HookEventBase)
77+
TInvokeEvent = TypeVar("TInvokeEvent", bound=BaseHookEvent)
8078
"""Generic for invoking events - non-contravariant to enable returning events."""
8179

8280

0 commit comments

Comments
 (0)