Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions taskiq/abc/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
Any,
AsyncGenerator,
Callable,
Coroutine,
Dict,
List,
NoReturn,
Optional,
TypeVar,
Union,
Expand Down Expand Up @@ -120,14 +122,17 @@ async def kick(
"""

@abstractmethod
def listen(self) -> AsyncGenerator[BrokerMessage, None]:
async def listen(
self,
callback: Callable[[BrokerMessage], Coroutine[Any, Any, None]],
) -> None:
"""
This function listens to new messages and yields them.

This it the main point for workers.
This function is used to get new tasks from the network.

:yields: taskiq messages.
:param callback: function to call when message received.
:return: nothing.
"""

Expand Down
8 changes: 6 additions & 2 deletions taskiq/brokers/inmemory_broker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncGenerator, Callable, Optional, TypeVar
from typing import Any, Callable, Coroutine, Optional, TypeVar

from taskiq.abc.broker import AsyncBroker
from taskiq.abc.result_backend import AsyncResultBackend, TaskiqResult
Expand Down Expand Up @@ -139,13 +139,17 @@ async def kick(self, message: BrokerMessage) -> None:
except Exception as exc:
raise ResultSetError("Cannot set result.") from exc

async def listen(self) -> AsyncGenerator[BrokerMessage, None]: # type: ignore
async def listen(
self,
callback: Callable[[BrokerMessage], Coroutine[Any, Any, None]],
) -> None:
"""
Inmemory broker cannot listen.

This method throws RuntimeError if you call it.
Because inmemory broker cannot really listen to any of tasks.

:param callback: message callback.
:raises RuntimeError: if this method is called.
"""
raise RuntimeError("Inmemory brokers cannot listen.")
8 changes: 6 additions & 2 deletions taskiq/brokers/shared_broker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import AsyncGenerator, Optional, TypeVar
from typing import Any, Callable, Coroutine, Optional, TypeVar

from taskiq.abc.broker import AsyncBroker
from taskiq.decor import AsyncTaskiqDecoratedTask
Expand Down Expand Up @@ -59,12 +59,16 @@ async def kick(self, message: BrokerMessage) -> None:
"without setting the default_broker.",
)

async def listen(self) -> AsyncGenerator[BrokerMessage, None]: # type: ignore
async def listen(
self,
callback: Callable[[BrokerMessage], Coroutine[Any, Any, None]],
) -> None: # type: ignore
"""
Shared broker cannot listen to tasks.

This method will throw an exception.

:param callback: message callback.
:raises TaskiqError: if called.
"""
raise TaskiqError("Shared broker cannot listen")
Expand Down
11 changes: 7 additions & 4 deletions taskiq/brokers/zmq_broker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import AsyncGenerator, Callable, Optional, TypeVar
from typing import Any, Callable, Coroutine, Optional, TypeVar

from taskiq.abc.broker import AsyncBroker
from taskiq.abc.result_backend import AsyncResultBackend
Expand Down Expand Up @@ -58,12 +58,15 @@ async def kick(self, message: BrokerMessage) -> None:
with self.socket.connect(self.sub_host) as sock:
await sock.send_string(message.json())

async def listen(self) -> AsyncGenerator[BrokerMessage, None]:
async def listen(
self,
callback: Callable[[BrokerMessage], Coroutine[Any, Any, None]],
) -> None:
"""
Start accepting new messages.

:yield: received broker message
:param callback: function to call when message received.
"""
while True: # noqa: WPS457
with self.socket.connect(self.sub_host) as sock:
yield BrokerMessage.parse_raw(await sock.recv_string())
await callback(BrokerMessage.parse_raw(await sock.recv_string()))
94 changes: 58 additions & 36 deletions taskiq/cli/async_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from taskiq.cli.args import TaskiqArgs
from taskiq.cli.log_collector import log_collector
from taskiq.context import Context, context_updater
from taskiq.message import TaskiqMessage
from taskiq.message import BrokerMessage, TaskiqMessage
from taskiq.result import TaskiqResult
from taskiq.utils import maybe_awaitable

Expand Down Expand Up @@ -180,52 +180,53 @@ async def run_task( # noqa: C901, WPS210, WPS211
return result


async def async_listen_messages( # noqa: C901, WPS210, WPS213
broker: AsyncBroker,
cli_args: TaskiqArgs,
) -> None:
"""
This function iterates over tasks asynchronously.
class Receiver:
"""Class that uses as a callback handler."""

It uses listen() method of an AsyncBroker
to get new messages from queues.

:param broker: broker to listen to.
:param cli_args: CLI arguments for worker.
"""
logger.info("Runing startup event.")
await broker.startup()
executor = ThreadPoolExecutor(
max_workers=cli_args.max_threadpool_threads,
)
logger.info("Listening started.")
task_signatures: Dict[str, inspect.Signature] = {}
for task in broker.available_tasks.values():
def __init__(self, broker: AsyncBroker, cli_args: TaskiqArgs) -> None:
self.broker = broker
self.cli_args = cli_args
self.task_signatures: Dict[str, inspect.Signature] = {}
if not cli_args.no_parse:
task_signatures[task.task_name] = inspect.signature(task.original_func)
async for message in broker.listen():
for task in self.broker.available_tasks.values():
self.task_signatures[task.task_name] = inspect.signature(
task.original_func,
)
self.executor = ThreadPoolExecutor(
max_workers=cli_args.max_threadpool_threads,
)

async def callback(self, message: BrokerMessage) -> None: # noqa: C901
"""
Receive new message and execute tasks.

This method is used to process message,
that came from brokers.

:param message: received message.
"""
logger.debug(f"Received message: {message}")
if message.task_name not in broker.available_tasks:
if message.task_name not in self.broker.available_tasks:
logger.warning(
'task "%s" is not found. Maybe you forgot to import it?',
message.task_name,
)
continue
return
logger.debug(
"Function for task %s is resolved. Executing...",
message.task_name,
)
try:
taskiq_msg = broker.formatter.loads(message=message)
taskiq_msg = self.broker.formatter.loads(message=message)
except Exception as exc:
logger.warning(
"Cannot parse message: %s. Skipping execution.\n %s",
message,
exc,
exc_info=True,
)
continue
for middleware in broker.middlewares:
return
for middleware in self.broker.middlewares:
if middleware.__class__.pre_execute != TaskiqMiddleware.pre_execute:
taskiq_msg = await maybe_awaitable(
middleware.pre_execute(
Expand All @@ -238,23 +239,44 @@ async def async_listen_messages( # noqa: C901, WPS210, WPS213
taskiq_msg.task_name,
taskiq_msg.task_id,
)
with context_updater(Context(taskiq_msg, broker)):
with context_updater(Context(taskiq_msg, self.broker)):
result = await run_task(
target=broker.available_tasks[message.task_name].original_func,
signature=task_signatures.get(message.task_name),
target=self.broker.available_tasks[message.task_name].original_func,
signature=self.task_signatures.get(message.task_name),
message=taskiq_msg,
log_collector_format=cli_args.log_collector_format,
executor=executor,
middlewares=broker.middlewares,
log_collector_format=self.cli_args.log_collector_format,
executor=self.executor,
middlewares=self.broker.middlewares,
)
for middleware in broker.middlewares:
for middleware in self.broker.middlewares:
if middleware.__class__.post_execute != TaskiqMiddleware.post_execute:
await maybe_awaitable(middleware.post_execute(taskiq_msg, result))
try:
await broker.result_backend.set_result(message.task_id, result)
await self.broker.result_backend.set_result(message.task_id, result)
except Exception as exc:
logger.exception(
"Can't set result in result backend. Cause: %s",
exc,
exc_info=True,
)


async def async_listen_messages(
broker: AsyncBroker,
cli_args: TaskiqArgs,
) -> None:
"""
This function iterates over tasks asynchronously.

It uses listen() method of an AsyncBroker
to get new messages from queues.

:param broker: broker to listen to.
:param cli_args: CLI arguments for worker.
"""
logger.info("Runing startup event.")
await broker.startup()
logger.info("Inicialized receiver.")
receiver = Receiver(broker, cli_args)
logger.info("Listening started.")
await broker.listen(receiver.callback)