Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion autogen/events/base_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from abc import ABC
from collections.abc import Callable
from typing import Annotated, Any, Literal, Union
from typing import Annotated, Any, ClassVar, Literal, Union
from uuid import UUID, uuid4

from pydantic import BaseModel, Field, create_model
Expand All @@ -31,6 +31,19 @@ def print(self, f: Callable[..., Any] | None = None) -> None:
"""
...

_hooks: ClassVar[dict[type["BaseEvent"], list[Callable]]] = {}

@classmethod
def register_hook(cls, event_type, func):
cls._hooks.setdefault(event_type, []).append(func)

@classmethod
def trigger_hook(cls, event):
for base, funcs in cls._hooks.items():
if isinstance(event, base):
for f in funcs:
f(event.content)


def camel2snake(name: str) -> str:
return "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_")
Expand Down
59 changes: 59 additions & 0 deletions test/events/test_base_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,62 @@ class TestSingleContentParameterEvent(BaseEvent):

model = TestSingleContentParameterEvent(**expected)
assert model.model_dump() == expected


class TestBaseEventHooks:
def test_register_and_trigger_hook(self, TestEvent: type[BaseEvent], uuid: UUID) -> None:
captured = []

def hook(event: TestEvent) -> None:
captured.append(event.content)

BaseEvent.register_hook(TestEvent, hook)
event = TestEvent(uuid=uuid, sender="alice", receiver="bob", content="hello")

BaseEvent.trigger_hook(event)

assert captured == ["hello"]

def test_multiple_hooks(self, TestEvent: type[BaseEvent], uuid: UUID) -> None:
captured = []

def hook1(event: TestEvent) -> None:
captured.append("hook1:" + event.content)

def hook2(event: TestEvent) -> None:
captured.append("hook2:" + event.content)

BaseEvent.register_hook(TestEvent, hook1)
BaseEvent.register_hook(TestEvent, hook2)

event = TestEvent(uuid=uuid, sender="alice", receiver="bob", content="hello")
BaseEvent.trigger_hook(event)

assert captured == ["hook1:hello", "hook2:hello"]

def test_hooks_are_isolated_by_event_type(self, uuid: UUID) -> None:
captured_a = []
captured_b = []

@wrap_event
class EventAEvent(BaseEvent):
content: str

@wrap_event
class EventBEvent(BaseEvent):
content: str

def hook_a(event: EventAEvent) -> None:
captured_a.append("A:" + event.content)

def hook_b(event: EventBEvent) -> None:
captured_b.append("B:" + event.content)

BaseEvent.register_hook(EventAEvent, hook_a)
BaseEvent.register_hook(EventBEvent, hook_b)

BaseEvent.trigger_hook(EventAEvent(uuid=uuid, content="foo"))
BaseEvent.trigger_hook(EventBEvent(uuid=uuid, content="bar"))

assert captured_a == ["A:foo"]
assert captured_b == ["B:bar"]
Loading