Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 77 additions & 54 deletions taskiq/receiver/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
broker: AsyncBroker,
executor: Optional[Executor] = None,
validate_params: bool = True,
max_async_tasks: int = 20,
max_async_tasks: "Optional[int]" = None,
) -> None:
self.broker = broker
self.executor = executor
Expand All @@ -53,7 +53,14 @@ def __init__(
self.task_signatures[task.task_name] = inspect.signature(task.original_func)
self.task_hints[task.task_name] = get_type_hints(task.original_func)
self.dependency_graphs[task.task_name] = DependencyGraph(task.original_func)
self.sem = asyncio.Semaphore(max_async_tasks)
self.sem: "Optional[asyncio.Semaphore]" = None
if max_async_tasks is not None and max_async_tasks > 0:
self.sem = asyncio.Semaphore(max_async_tasks)
else:
logger.warning(
"Setting unlimited number of async tasks "
+ "can result in undefined behavior",
)

async def callback( # noqa: C901, WPS213
self,
Expand All @@ -72,62 +79,61 @@ async def callback( # noqa: C901, WPS213
:param raise_err: raise an error if cannot save result in
result_backend.
"""
async with self.sem:
logger.debug(f"Received message: {message}")
if message.task_name not in self.broker.available_tasks:
logger.warning(
'task "%s" is not found. Maybe you forgot to import it?',
message.task_name,
)
return
logger.debug(
"Function for task %s is resolved. Executing...",
logger.debug(f"Received message: {message}")
if message.task_name not in self.broker.available_tasks:
logger.warning(
'task "%s" is not found. Maybe you forgot to import it?',
message.task_name,
)
try:
taskiq_msg = self.broker.formatter.loads(message=message)
except Exception as exc:
logger.warning(
"Cannot parse message: %s. Skipping execution.\n %s",
message,
exc,
exc_info=True,
return
logger.debug(
"Function for task %s is resolved. Executing...",
message.task_name,
)
try:
taskiq_msg = self.broker.formatter.loads(message=message)
except Exception as exc:
logger.warning(
"Cannot parse message: %s. Skipping execution.\n %s",
message,
exc,
exc_info=True,
)
return
for middleware in self.broker.middlewares:
if middleware.__class__.pre_execute != TaskiqMiddleware.pre_execute:
taskiq_msg = await maybe_awaitable(
middleware.pre_execute(
taskiq_msg,
),
)
return
for middleware in self.broker.middlewares:
if middleware.__class__.pre_execute != TaskiqMiddleware.pre_execute:
taskiq_msg = await maybe_awaitable(
middleware.pre_execute(
taskiq_msg,
),
)

logger.info(
"Executing task %s with ID: %s",
taskiq_msg.task_name,
taskiq_msg.task_id,
)
result = await self.run_task(
target=self.broker.available_tasks[message.task_name].original_func,
message=taskiq_msg,
logger.info(
"Executing task %s with ID: %s",
taskiq_msg.task_name,
taskiq_msg.task_id,
)
result = await self.run_task(
target=self.broker.available_tasks[message.task_name].original_func,
message=taskiq_msg,
)
for middleware in self.broker.middlewares:
if middleware.__class__.post_execute != TaskiqMiddleware.post_execute:
await maybe_awaitable(middleware.post_execute(taskiq_msg, result))
try:
await self.broker.result_backend.set_result(message.task_id, result)
except Exception as exc:
logger.exception(
"Can't set result in result backend. Cause: %s",
exc,
exc_info=True,
)
for middleware in self.broker.middlewares:
if middleware.__class__.post_execute != TaskiqMiddleware.post_execute:
await maybe_awaitable(middleware.post_execute(taskiq_msg, result))
try:
await self.broker.result_backend.set_result(message.task_id, result)
except Exception as exc:
logger.exception(
"Can't set result in result backend. Cause: %s",
exc,
exc_info=True,
)
if raise_err:
raise exc
if raise_err:
raise exc

for middleware in self.broker.middlewares:
if middleware.__class__.post_save != TaskiqMiddleware.post_save:
await maybe_awaitable(middleware.post_save(taskiq_msg, result))
for middleware in self.broker.middlewares:
if middleware.__class__.post_save != TaskiqMiddleware.post_save:
await maybe_awaitable(middleware.post_save(taskiq_msg, result))

async def run_task( # noqa: C901, WPS210
self,
Expand Down Expand Up @@ -232,11 +238,28 @@ async def listen(self) -> None: # pragma: no cover
It uses listen() method of an AsyncBroker
to get new messages from queues.
"""
logger.debug("Runing startup event.")
await self.broker.startup()
logger.info("Listening started.")
tasks = set()

def task_cb(task: "asyncio.Task[Any]") -> None:
"""
Callback for tasks.

This function used to remove task
from the list of active tasks and release
the semaphore, so other tasks can use it.

:param task: finished task
"""
tasks.discard(task)
if self.sem is not None:
self.sem.release()

async for message in self.broker.listen():
# Waits for semaphore to be released.
if self.sem is not None:
await self.sem.acquire()
task = asyncio.create_task(self.callback(message=message, raise_err=False))
tasks.add(task)

Expand All @@ -245,4 +268,4 @@ async def listen(self) -> None: # pragma: no cover
# Because python's GC can silently cancel task
# and it considered to be Hisenbug.
# https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/
task.add_done_callback(tasks.discard)
task.add_done_callback(task_cb)
48 changes: 37 additions & 11 deletions tests/cli/worker/test_receiver.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,42 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional
from typing import Any, AsyncGenerator, Callable, List, Optional, TypeVar

import pytest
from taskiq_dependencies import Depends

from taskiq.abc.broker import AsyncBroker
from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.abc.result_backend import AsyncResultBackend
from taskiq.brokers.inmemory_broker import InMemoryBroker
from taskiq.message import BrokerMessage, TaskiqMessage
from taskiq.receiver import Receiver
from taskiq.result import TaskiqResult

_T = TypeVar("_T")


class BrokerForTests(InMemoryBroker):
def __init__(
self,
result_backend: "Optional[AsyncResultBackend[_T]]" = None,
task_id_generator: Optional[Callable[[], str]] = None,
) -> None:
super().__init__(
result_backend=result_backend,
task_id_generator=task_id_generator,
)
self.to_send: "List[TaskiqMessage]" = []

async def listen(self) -> AsyncGenerator[BrokerMessage, None]:
for message in self.to_send:
yield self.formatter.dumps(message)


def get_receiver(
broker: Optional[AsyncBroker] = None,
no_parse: bool = False,
max_async_tasks: int = 10,
max_async_tasks: Optional[int] = None,
) -> Receiver:
"""
Returns receiver with custom broker and args.
Expand Down Expand Up @@ -247,7 +267,8 @@ def test_func(tes_val: MyTestClass = Depends()) -> int:
@pytest.mark.anyio
async def test_callback_semaphore() -> None:
"""Test that callback funcion semaphore works well."""
broker = InMemoryBroker()
max_async_tasks = 3
broker = BrokerForTests()
sem_num = 0

@broker.task
Expand All @@ -257,18 +278,23 @@ async def task_sem() -> int:
await asyncio.sleep(1)
return 1

receiver = get_receiver(broker, max_async_tasks=3)

broker_message = broker.formatter.dumps(
broker.to_send = [
TaskiqMessage(
task_id="test_sem",
task_name=task_sem.task_name,
labels={},
args=[],
kwargs=[],
),
)
tasks = [asyncio.create_task(receiver.callback(broker_message)) for _ in range(5)]
)
for _ in range(max_async_tasks + 2)
]

# broker_message = broker.formatter.dumps(
# )
receiver = get_receiver(broker, max_async_tasks=3)

listen_task = asyncio.create_task(receiver.listen())
await asyncio.sleep(0.3)
assert sem_num == 3
await asyncio.gather(*tasks)
assert sem_num == max_async_tasks
await listen_task
assert sem_num == max_async_tasks + 2