Skip to content

Commit 0737287

Browse files
committed
Add max_attempts_at_message
1 parent fec9633 commit 0737287

File tree

8 files changed

+178
-9
lines changed

8 files changed

+178
-9
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ lint.ignore = [
159159
"ANN401", # typing.Any are disallowed in `**kwargs
160160
"PLR0913", # Too many arguments for function call
161161
"D106", # Missing docstring in public nested class
162+
"D205", # 1 blank line required between summary line and description
162163
]
163164
exclude = [".venv/"]
164165
lint.mccabe = { max-complexity = 10 }

taskiq/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from taskiq.abc.middleware import TaskiqMiddleware
99
from taskiq.abc.result_backend import AsyncResultBackend
1010
from taskiq.abc.schedule_source import ScheduleSource
11-
from taskiq.acks import AckableMessage
11+
from taskiq.acks import AckableMessage, AckableMessageWithDeliveryCount
1212
from taskiq.brokers.inmemory_broker import InMemoryBroker
1313
from taskiq.brokers.shared_broker import async_shared_broker
1414
from taskiq.brokers.zmq_broker import ZeroMQBroker
@@ -24,7 +24,7 @@
2424
TaskiqResultTimeoutError,
2525
)
2626
from taskiq.funcs import gather
27-
from taskiq.message import BrokerMessage, TaskiqMessage
27+
from taskiq.message import BrokerMessage, DeliveryCountMessage, TaskiqMessage
2828
from taskiq.middlewares.prometheus_middleware import PrometheusMiddleware
2929
from taskiq.middlewares.retry_middleware import SimpleRetryMiddleware
3030
from taskiq.result import TaskiqResult
@@ -53,6 +53,8 @@
5353
"NoResultError",
5454
"SendTaskError",
5555
"AckableMessage",
56+
"DeliveryCountMessage",
57+
"AckableMessageWithDeliveryCount",
5658
"InMemoryBroker",
5759
"ScheduleSource",
5860
"TaskiqScheduler",

taskiq/abc/broker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
self,
7878
result_backend: "Optional[AsyncResultBackend[_T]]" = None,
7979
task_id_generator: Optional[Callable[[], str]] = None,
80+
max_attempts_at_message: Optional[int] = None,
8081
) -> None:
8182
if result_backend is None:
8283
result_backend = DummyResultBackend()
@@ -113,6 +114,7 @@ def __init__(
113114
self.state = TaskiqState()
114115
self.custom_dependency_context: Dict[Any, Any] = {}
115116
self.dependency_overrides: Dict[Any, Any] = {}
117+
self.max_attempts_at_message = max_attempts_at_message
116118
# True only if broker runs in worker process.
117119
self.is_worker_process: bool = False
118120
# True only if broker runs in scheduler process.

taskiq/acks.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import enum
22
from typing import Awaitable, Callable, Union
33

4-
from pydantic import BaseModel
4+
from taskiq.message import DeliveryCountMessage, WrappedMessage
55

66

77
@enum.unique
@@ -20,7 +20,7 @@ class AcknowledgeType(str, enum.Enum):
2020
WHEN_SAVED = "when_saved"
2121

2222

23-
class AckableMessage(BaseModel):
23+
class AckableMessage(WrappedMessage):
2424
"""
2525
Message that can be acknowledged.
2626
@@ -33,5 +33,8 @@ class AckableMessage(BaseModel):
3333
as a whole.
3434
"""
3535

36-
data: bytes
3736
ack: Callable[[], Union[None, Awaitable[None]]]
37+
38+
39+
class AckableMessageWithDeliveryCount(AckableMessage, DeliveryCountMessage):
40+
"""Message that can be acknowledged and has a delivery count."""

taskiq/cli/worker/run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
148148
ack_type=args.ack_type,
149149
max_tasks_to_execute=args.max_tasks_per_child,
150150
wait_tasks_timeout=args.wait_tasks_timeout,
151+
max_attempts_at_message=broker.max_attempts_at_message,
151152
**receiver_kwargs, # type: ignore
152153
)
153154
loop.run_until_complete(receiver.listen(shutdown_event))

taskiq/message.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,15 @@ class BrokerMessage(BaseModel):
4242
task_name: str
4343
message: bytes
4444
labels: Dict[str, Any]
45+
46+
47+
class WrappedMessage(BaseModel):
48+
"""Abstraction for an incoming message in a wrapper."""
49+
50+
data: bytes
51+
52+
53+
class DeliveryCountMessage(WrappedMessage):
54+
"""Message with a present delivery count."""
55+
56+
delivery_count: Optional[int] = None

taskiq/receiver/receiver.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from taskiq.acks import AcknowledgeType
1414
from taskiq.context import Context
1515
from taskiq.exceptions import NoResultError
16-
from taskiq.message import TaskiqMessage
16+
from taskiq.message import DeliveryCountMessage, TaskiqMessage, WrappedMessage
1717
from taskiq.receiver.params_parser import parse_params
1818
from taskiq.result import TaskiqResult
1919
from taskiq.state import TaskiqState
@@ -58,6 +58,7 @@ def __init__(
5858
on_exit: Optional[Callable[["Receiver"], None]] = None,
5959
max_tasks_to_execute: Optional[int] = None,
6060
wait_tasks_timeout: Optional[float] = None,
61+
max_attempts_at_message: Optional[int] = None,
6162
) -> None:
6263
self.broker = broker
6364
self.executor = executor
@@ -72,6 +73,7 @@ def __init__(
7273
self.known_tasks: Set[str] = set()
7374
self.max_tasks_to_execute = max_tasks_to_execute
7475
self.wait_tasks_timeout = wait_tasks_timeout
76+
self.max_attempts_at_message = max_attempts_at_message
7577
for task in self.broker.get_all_tasks().values():
7678
self._prepare_task(task.task_name, task.original_func)
7779
self.sem: "Optional[asyncio.Semaphore]" = None
@@ -86,7 +88,7 @@ def __init__(
8688

8789
async def callback( # noqa: C901, PLR0912
8890
self,
89-
message: Union[bytes, AckableMessage],
91+
message: Union[bytes, WrappedMessage],
9092
raise_err: bool = False,
9193
) -> None:
9294
"""
@@ -101,7 +103,31 @@ async def callback( # noqa: C901, PLR0912
101103
:param raise_err: raise an error if cannot save result in
102104
result_backend.
103105
"""
104-
message_data = message.data if isinstance(message, AckableMessage) else message
106+
message_data = message.data if isinstance(message, WrappedMessage) else message
107+
108+
delivery_count = (
109+
message.delivery_count
110+
if isinstance(message, DeliveryCountMessage)
111+
else None
112+
)
113+
if (
114+
delivery_count
115+
and self.max_attempts_at_message
116+
and delivery_count >= self.max_attempts_at_message
117+
):
118+
logger.error(
119+
"Permitted number of attempts at processing message %s "
120+
"has been exhausted after %s attempts.",
121+
message_data,
122+
self.max_attempts_at_message,
123+
)
124+
if isinstance(
125+
message,
126+
AckableMessage,
127+
):
128+
await maybe_awaitable(message.ack())
129+
return
130+
105131
try:
106132
taskiq_msg = self.broker.formatter.loads(message=message_data)
107133
taskiq_msg.parse_labels()

tests/receiver/test_receiver.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99

1010
from taskiq.abc.broker import AckableMessage, AsyncBroker
1111
from taskiq.abc.middleware import TaskiqMiddleware
12+
from taskiq.acks import AckableMessageWithDeliveryCount
1213
from taskiq.brokers.inmemory_broker import InMemoryBroker
1314
from taskiq.exceptions import NoResultError, TaskiqResultTimeoutError
14-
from taskiq.message import TaskiqMessage
15+
from taskiq.message import DeliveryCountMessage, TaskiqMessage
1516
from taskiq.receiver import Receiver
1617
from taskiq.result import TaskiqResult
1718
from tests.utils import AsyncQueueBroker
@@ -359,6 +360,127 @@ async def test_callback_unknown_task() -> None:
359360
await receiver.callback(broker_message.message)
360361

361362

363+
@pytest.mark.anyio
364+
@pytest.mark.parametrize("delivery_count", [2, None])
365+
async def test_callback_max_attempts_at_message_not_exceeded(
366+
delivery_count: Optional[int],
367+
) -> None:
368+
"""
369+
Test that callback function calls the task if `max_attempts_at_message`
370+
is not exceeded.
371+
"""
372+
broker = InMemoryBroker()
373+
called_times = 0
374+
375+
@broker.task
376+
async def my_task() -> int:
377+
nonlocal called_times
378+
called_times += 1
379+
return 1
380+
381+
receiver = get_receiver(broker)
382+
receiver.max_attempts_at_message = 3
383+
384+
broker_message = broker.formatter.dumps(
385+
TaskiqMessage(
386+
task_id="task_id",
387+
task_name=my_task.task_name,
388+
labels={},
389+
args=[],
390+
kwargs={},
391+
),
392+
)
393+
394+
await receiver.callback(
395+
DeliveryCountMessage(
396+
data=broker_message.message,
397+
delivery_count=delivery_count,
398+
),
399+
)
400+
assert called_times == 1
401+
402+
403+
@pytest.mark.anyio
404+
async def test_callback_max_attempts_at_message_exceeded() -> None:
405+
"""
406+
Test that callback function does not call the task if `max_attempts_at_message`
407+
is exceeded.
408+
"""
409+
broker = InMemoryBroker()
410+
called_times = 0
411+
412+
@broker.task
413+
async def my_task() -> int:
414+
nonlocal called_times
415+
called_times += 1
416+
return 1
417+
418+
receiver = get_receiver(broker)
419+
receiver.max_attempts_at_message = 3
420+
421+
broker_message = broker.formatter.dumps(
422+
TaskiqMessage(
423+
task_id="task_id",
424+
task_name=my_task.task_name,
425+
labels={},
426+
args=[],
427+
kwargs={},
428+
),
429+
)
430+
431+
await receiver.callback(
432+
DeliveryCountMessage(
433+
data=broker_message.message,
434+
delivery_count=3,
435+
),
436+
)
437+
assert called_times == 0
438+
439+
440+
@pytest.mark.anyio
441+
async def test_callback_max_attempts_at_message_exceeded_ackable() -> None:
442+
"""
443+
Test that callback function does not call the task if `max_attempts_at_message`
444+
is exceeded and acks the message.
445+
"""
446+
broker = InMemoryBroker()
447+
called_times = 0
448+
acked = False
449+
450+
@broker.task
451+
async def my_task() -> int:
452+
nonlocal called_times
453+
called_times += 1
454+
return 1
455+
456+
async def ack_callback() -> None:
457+
nonlocal acked
458+
acked = True
459+
460+
receiver = get_receiver(broker)
461+
receiver.max_attempts_at_message = 3
462+
463+
broker_message = broker.formatter.dumps(
464+
TaskiqMessage(
465+
task_id="task_id",
466+
task_name=my_task.task_name,
467+
labels={},
468+
args=[],
469+
kwargs={},
470+
),
471+
)
472+
473+
await receiver.callback(
474+
AckableMessageWithDeliveryCount(
475+
data=broker_message.message,
476+
delivery_count=3,
477+
ack=ack_callback,
478+
),
479+
)
480+
assert called_times == 0
481+
assert acked
482+
483+
362484
@pytest.mark.anyio
363485
async def test_custom_ctx() -> None:
364486
"""Tests that run_task can run sync tasks."""

0 commit comments

Comments
 (0)