|
7 | 7 | via hook provider objects. |
8 | 8 | """ |
9 | 9 |
|
10 | | -from dataclasses import dataclass |
| 10 | +from dataclasses import dataclass, field |
11 | 11 | from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar |
12 | 12 |
|
13 | 13 | if TYPE_CHECKING: |
14 | 14 | from ..agent import Agent |
15 | 15 |
|
16 | | - |
17 | 16 | @dataclass |
18 | | -class HookEvent: |
19 | | - """Base class for all hook events. |
20 | | -
|
21 | | - Attributes: |
22 | | - agent: The agent instance that triggered this event. |
23 | | - """ |
24 | | - |
25 | | - agent: "Agent" |
26 | | - |
| 17 | +class HookEventBase: |
27 | 18 | @property |
28 | 19 | def should_reverse_callbacks(self) -> bool: |
29 | 20 | """Determine if callbacks for this event should be invoked in reverse order. |
@@ -65,11 +56,27 @@ def __setattr__(self, name: str, value: Any) -> None: |
65 | 56 |
|
66 | 57 | raise AttributeError(f"Property {name} is not writable") |
67 | 58 |
|
| 59 | +@dataclass |
| 60 | +class MultiAgentHookEvent(HookEventBase): |
| 61 | + pass |
| 62 | + |
| 63 | + |
| 64 | +@dataclass |
| 65 | +class HookEvent(HookEventBase): |
| 66 | + """Base class for all hook events. |
| 67 | +
|
| 68 | + Attributes: |
| 69 | + agent: The agent instance that triggered this event. |
| 70 | + """ |
| 71 | + |
| 72 | + # agent: "Agent | None" = field(default=None, kw_only=True) |
| 73 | + agent : "Agent" |
| 74 | + |
68 | 75 |
|
69 | | -TEvent = TypeVar("TEvent", bound=HookEvent, contravariant=True) |
| 76 | +TEvent = TypeVar("TEvent", bound=HookEventBase, contravariant=True) |
70 | 77 | """Generic for adding callback handlers - contravariant to allow adding handlers which take in base classes.""" |
71 | 78 |
|
72 | | -TInvokeEvent = TypeVar("TInvokeEvent", bound=HookEvent) |
| 79 | +TInvokeEvent = TypeVar("TInvokeEvent", bound=HookEventBase) |
73 | 80 | """Generic for invoking events - non-contravariant to enable returning events.""" |
74 | 81 |
|
75 | 82 |
|
|
0 commit comments