Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions taskiq/cli/worker/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class WorkerArgs:
max_async_tasks: int = 100
receiver: str = "taskiq.receiver:Receiver"
receiver_arg: List[Tuple[str, str]] = field(default_factory=list)
max_prefetch: int = 0

@classmethod
def from_cli( # noqa: WPS213
Expand Down Expand Up @@ -168,6 +169,13 @@ def from_cli( # noqa: WPS213
default=100,
help="Maximum simultaneous async tasks per worker process. ",
)
parser.add_argument(
"--max-prefetch",
type=int,
dest="max_prefetch",
default=0,
help="Maximum prefetched tasks per worker process. ",
)

namespace = parser.parse_args(args)
return WorkerArgs(**namespace.__dict__)
1 change: 1 addition & 0 deletions taskiq/cli/worker/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
executor=pool,
validate_params=not args.no_parse,
max_async_tasks=args.max_async_tasks,
max_prefetch=args.max_prefetch,
**receiver_args,
)
loop.run_until_complete(receiver.listen())
Expand Down
53 changes: 48 additions & 5 deletions taskiq/receiver/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from concurrent.futures import Executor
from logging import getLogger
from time import time
from typing import Any, Callable, Dict, Optional, get_type_hints
from typing import Any, Callable, Dict, Optional, Set, get_type_hints

import anyio
from taskiq_dependencies import DependencyGraph

from taskiq.abc.broker import AsyncBroker
Expand All @@ -17,6 +18,7 @@
from taskiq.utils import maybe_awaitable

logger = getLogger(__name__)
QUEUE_DONE = b"-1"


def _run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any:
Expand All @@ -36,12 +38,13 @@ def _run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any:
class Receiver:
"""Class that uses as a callback handler."""

def __init__(
def __init__( # noqa: WPS211
self,
broker: AsyncBroker,
executor: Optional[Executor] = None,
validate_params: bool = True,
max_async_tasks: "Optional[int]" = None,
max_prefetch: int = 0,
) -> None:
self.broker = broker
self.executor = executor
Expand All @@ -61,6 +64,7 @@ def __init__(
"Setting unlimited number of async tasks "
+ "can result in undefined behavior",
)
self.sem_prefetch = asyncio.Semaphore(max_prefetch)

async def callback( # noqa: C901, WPS213
self,
Expand Down Expand Up @@ -239,7 +243,38 @@ async def listen(self) -> None: # pragma: no cover
"""
await self.broker.startup()
logger.info("Listening started.")
tasks = set()
queue: asyncio.Queue[bytes] = asyncio.Queue()

async with anyio.create_task_group() as gr:
gr.start_soon(self.prefetcher, queue)
gr.start_soon(self.runner, queue)

async def prefetcher(self, queue: "asyncio.Queue[Any]") -> None:
"""
Prefetch tasks data.

:param queue: queue for prefetched data.
"""
iterator = self.broker.listen()

while True:
try:
await self.sem_prefetch.acquire()
message = await iterator.__anext__() # noqa: WPS609
await queue.put(message)

except StopAsyncIteration:
break

await queue.put(QUEUE_DONE)

async def runner(self, queue: "asyncio.Queue[bytes]") -> None:
"""
Run tasks.

:param queue: queue with prefetched data.
"""
tasks: Set[asyncio.Task[Any]] = set()

def task_cb(task: "asyncio.Task[Any]") -> None:
"""
Expand All @@ -255,11 +290,19 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
if self.sem is not None:
self.sem.release()

async for message in self.broker.listen():
while True:
# Waits for semaphore to be released.
if self.sem is not None:
await self.sem.acquire()
task = asyncio.create_task(self.callback(message=message, raise_err=False))

self.sem_prefetch.release()
message = await queue.get()
if message is QUEUE_DONE:
break

task = asyncio.create_task(
self.callback(message=message, raise_err=False),
)
tasks.add(task)

# We want the task to remove itself from the set when it's done.
Expand Down