1- import inspect
21import os
32import sys
43from abc import ABC , abstractmethod
4+ from collections import defaultdict
55from functools import wraps
66from logging import getLogger
77from typing import ( # noqa: WPS235
88 TYPE_CHECKING ,
99 Any ,
10+ Awaitable ,
1011 Callable ,
1112 Coroutine ,
13+ DefaultDict ,
1214 Dict ,
1315 List ,
16+ Literal ,
1417 Optional ,
1518 TypeVar ,
1619 Union ,
1720 overload ,
1821)
1922from uuid import uuid4
2023
21- from typing_extensions import ParamSpec
24+ from typing_extensions import ParamSpec , TypeAlias
2225
26+ from taskiq .abc .middleware import TaskiqMiddleware
2327from taskiq .decor import AsyncTaskiqDecoratedTask
28+ from taskiq .events import TaskiqEvents
2429from taskiq .formatters .json_formatter import JSONFormatter
2530from taskiq .message import BrokerMessage
2631from taskiq .result_backends .dummy import DummyResultBackend
32+ from taskiq .state import TaskiqState
33+ from taskiq .utils import maybe_awaitable
2734
2835if TYPE_CHECKING :
2936 from taskiq .abc .formatter import TaskiqFormatter
30- from taskiq .abc .middleware import TaskiqMiddleware
3137 from taskiq .abc .result_backend import AsyncResultBackend
3238
3339_T = TypeVar ("_T" ) # noqa: WPS111
3440_FuncParams = ParamSpec ("_FuncParams" )
3541_ReturnType = TypeVar ("_ReturnType" )
3642
43+ EventHandler : TypeAlias = Callable [[TaskiqState ], Optional [Awaitable [None ]]]
44+
3745logger = getLogger ("taskiq" )
3846
3947
@@ -49,7 +57,7 @@ def default_id_generator() -> str:
4957 return uuid4 ().hex
5058
5159
52- class AsyncBroker (ABC ):
60+ class AsyncBroker (ABC ): # noqa: WPS230
5361 """
5462 Async broker.
5563
@@ -75,8 +83,16 @@ def __init__(
7583 self .decorator_class = AsyncTaskiqDecoratedTask
7684 self .formatter : "TaskiqFormatter" = JSONFormatter ()
7785 self .id_generator = task_id_generator
78-
79- def add_middlewares (self , middlewares : "List[TaskiqMiddleware]" ) -> None :
86+ # Every event has a list of handlers.
87+ # Every handler is a function which takes state as a first argument.
88+ # And handler can be either sync or async.
89+ self .event_handlers : DefaultDict [ # noqa: WPS234
90+ TaskiqEvents ,
91+ List [Callable [[TaskiqState ], Optional [Awaitable [None ]]]],
92+ ] = defaultdict (list )
93+ self .state = TaskiqState ()
94+
95+ def add_middlewares (self , * middlewares : "TaskiqMiddleware" ) -> None :
8096 """
8197 Add a list of middlewares.
8298
@@ -86,11 +102,23 @@ def add_middlewares(self, middlewares: "List[TaskiqMiddleware]") -> None:
86102 :param middlewares: list of middlewares.
87103 """
88104 for middleware in middlewares :
105+ if not isinstance (middleware , TaskiqMiddleware ):
106+ logger .warning (
107+ f"Middleware { middleware } is not an instance of TaskiqMiddleware. "
108+ "Skipping..." ,
109+ )
110+ continue
89111 middleware .set_broker (self )
90112 self .middlewares .append (middleware )
91113
92114 async def startup (self ) -> None :
93115 """Do something when starting broker."""
116+ event = TaskiqEvents .CLIENT_STARTUP
117+ if self .is_worker_process :
118+ event = TaskiqEvents .WORKER_STARTUP
119+
120+ for handler in self .event_handlers [event ]:
121+ await maybe_awaitable (handler (self .state ))
94122
95123 async def shutdown (self ) -> None :
96124 """
@@ -99,11 +127,13 @@ async def shutdown(self) -> None:
99127 This method is called,
100128 when broker is closig.
101129 """
102- for middleware in self .middlewares :
103- middleware_shutdown = middleware .shutdown ()
104- if inspect .isawaitable (middleware_shutdown ):
105- await middleware_shutdown
106- await self .result_backend .shutdown ()
130+ event = TaskiqEvents .CLIENT_SHUTDOWN
131+ if self .is_worker_process :
132+ event = TaskiqEvents .WORKER_SHUTDOWN
133+
134+ # Call all shutdown events.
135+ for handler in self .event_handlers [event ]:
136+ await maybe_awaitable (handler (self .state ))
107137
108138 @abstractmethod
109139 async def kick (
@@ -232,3 +262,43 @@ def inner(
232262 inner_task_name = task_name ,
233263 inner_labels = labels or {},
234264 )
265+
266+ def on_event (self , * events : TaskiqEvents ) -> Callable [[EventHandler ], EventHandler ]:
267+ """
268+ Adds event handler.
269+
270+ This function adds function to call when event occurs.
271+
272+ :param events: events to react to.
273+ :return: a decorator function.
274+ """
275+
276+ def handler (function : EventHandler ) -> EventHandler :
277+ for event in events :
278+ self .event_handlers [event ].append (function )
279+ return function
280+
281+ return handler
282+
283+ def add_event_handler (
284+ self ,
285+ event : TaskiqEvents ,
286+ handler : EventHandler ,
287+ ) -> None :
288+ """
289+ Adds event handler.
290+
291+ this function is the same as on_event.
292+
293+ >>> broker.add_event_handler(TaskiqEvents.WORKER_STARTUP, my_startup)
294+
295+ if similar to:
296+
297+ >>> @broker.on_event(TaskiqEvents.WORKER_STARTUP)
298+ >>> async def my_startup(context: Context) -> None:
299+ >>> ...
300+
301+ :param event: Event to react to.
302+ :param handler: handler to call when event is started.
303+ """
304+ self .event_handlers [event ].append (handler )
0 commit comments