Skip to content

Commit 8c6f24e

Browse files
committed
add a base class example
1 parent 8122453 commit 8c6f24e

File tree

2 files changed

+22
-14
lines changed

2 files changed

+22
-14
lines changed

src/strands/hooks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def log_end(self, event: AfterInvocationEvent) -> None:
3535
BeforeInvocationEvent,
3636
MessageAddedEvent,
3737
)
38-
from .registry import HookCallback, HookEvent, HookProvider, HookRegistry
38+
from .registry import HookCallback, HookEvent, HookProvider, HookRegistry,MultiAgentHookEvent
3939

4040
__all__ = [
4141
"AgentInitializedEvent",
@@ -46,4 +46,5 @@ def log_end(self, event: AfterInvocationEvent) -> None:
4646
"HookProvider",
4747
"HookCallback",
4848
"HookRegistry",
49+
"MultiAgentHookEvent"
4950
]

src/strands/hooks/registry.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,14 @@
77
via hook provider objects.
88
"""
99

10-
from dataclasses import dataclass
10+
from dataclasses import dataclass, field
1111
from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar
1212

1313
if TYPE_CHECKING:
1414
from ..agent import Agent
1515

16-
1716
@dataclass
18-
class HookEvent:
19-
"""Base class for all hook events.
20-
21-
Attributes:
22-
agent: The agent instance that triggered this event.
23-
"""
24-
25-
agent: "Agent"
26-
17+
class HookEventBase:
2718
@property
2819
def should_reverse_callbacks(self) -> bool:
2920
"""Determine if callbacks for this event should be invoked in reverse order.
@@ -65,11 +56,27 @@ def __setattr__(self, name: str, value: Any) -> None:
6556

6657
raise AttributeError(f"Property {name} is not writable")
6758

59+
@dataclass
60+
class MultiAgentHookEvent(HookEventBase):
61+
pass
62+
63+
64+
@dataclass
65+
class HookEvent(HookEventBase):
66+
"""Base class for all hook events.
67+
68+
Attributes:
69+
agent: The agent instance that triggered this event.
70+
"""
71+
72+
# agent: "Agent | None" = field(default=None, kw_only=True)
73+
agent : "Agent"
74+
6875

69-
TEvent = TypeVar("TEvent", bound=HookEvent, contravariant=True)
76+
TEvent = TypeVar("TEvent", bound=HookEventBase, contravariant=True)
7077
"""Generic for adding callback handlers - contravariant to allow adding handlers which take in base classes."""
7178

72-
TInvokeEvent = TypeVar("TInvokeEvent", bound=HookEvent)
79+
TInvokeEvent = TypeVar("TInvokeEvent", bound=HookEventBase)
7380
"""Generic for invoking events - non-contravariant to enable returning events."""
7481

7582

0 commit comments

Comments
 (0)