Skip to content

Commit 8c5ec78

Browse files
committed
Merge branch 'release/0.5.6'
2 parents 8601636 + c5bb440 commit 8c5ec78

8 files changed

+128
-13
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ Brokers parameters:
7171
* `result_backend` - custom result backend.
7272
* `queue_name` - name of the pub/sub channel in redis.
7373
* `max_connection_pool_size` - maximum number of connections in pool.
74+
* Any other keyword arguments are passed to `redis.asyncio.BlockingConnectionPool`.
75+
Notably, you can use `timeout` to set custom timeout in seconds for reconnects
76+
(or set it to `None` to try reconnects indefinitely).
7477

7578
## RedisAsyncResultBackend configuration
7679

@@ -79,6 +82,9 @@ RedisAsyncResultBackend parameters:
7982
* `keep_results` - flag to not remove results from Redis after reading.
8083
* `result_ex_time` - expire time in seconds (by default - not specified)
8184
* `result_px_time` - expire time in milliseconds (by default - not specified)
85+
* Any other keyword arguments are passed to `redis.asyncio.BlockingConnectionPool`.
86+
Notably, you can use `timeout` to set custom timeout in seconds for reconnects
87+
(or set it to `None` to try reconnects indefinitely).
8288
> IMPORTANT: **It is highly recommended to use expire time ​​in RedisAsyncResultBackend**
8389
> If you want to add expiration, either `result_ex_time` or `result_px_time` must be set.
8490
>```python

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "taskiq-redis"
3-
version = "0.5.5"
3+
version = "0.5.6"
44
description = "Redis integration for taskiq"
55
authors = ["taskiq-team <taskiq@norely.com>"]
66
readme = "README.md"

taskiq_redis/redis_backend.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pickle
2-
from typing import Dict, Optional, TypeVar, Union
2+
from typing import Any, Dict, Optional, TypeVar, Union
33

4-
from redis.asyncio import ConnectionPool, Redis
4+
from redis.asyncio import BlockingConnectionPool, Redis
55
from redis.asyncio.cluster import RedisCluster
66
from taskiq import AsyncResultBackend
77
from taskiq.abc.result_backend import TaskiqResult
@@ -24,6 +24,8 @@ def __init__(
2424
keep_results: bool = True,
2525
result_ex_time: Optional[int] = None,
2626
result_px_time: Optional[int] = None,
27+
max_connection_pool_size: Optional[int] = None,
28+
**connection_kwargs: Any,
2729
) -> None:
2830
"""
2931
Constructs a new result backend.
@@ -32,13 +34,19 @@ def __init__(
3234
:param keep_results: flag to not remove results from Redis after reading.
3335
:param result_ex_time: expire time in seconds for result.
3436
:param result_px_time: expire time in milliseconds for result.
37+
:param max_connection_pool_size: maximum number of connections in pool.
38+
:param connection_kwargs: additional arguments for redis BlockingConnectionPool.
3539
3640
:raises DuplicateExpireTimeSelectedError: if result_ex_time
3741
and result_px_time are selected.
3842
:raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
3943
and result_px_time are equal zero.
4044
"""
41-
self.redis_pool = ConnectionPool.from_url(redis_url)
45+
self.redis_pool = BlockingConnectionPool.from_url(
46+
url=redis_url,
47+
max_connections=max_connection_pool_size,
48+
**connection_kwargs,
49+
)
4250
self.keep_results = keep_results
4351
self.result_ex_time = result_ex_time
4452
self.result_px_time = result_px_time
@@ -146,6 +154,7 @@ def __init__(
146154
keep_results: bool = True,
147155
result_ex_time: Optional[int] = None,
148156
result_px_time: Optional[int] = None,
157+
**connection_kwargs: Any,
149158
) -> None:
150159
"""
151160
Constructs a new result backend.
@@ -154,13 +163,17 @@ def __init__(
154163
:param keep_results: flag to not remove results from Redis after reading.
155164
:param result_ex_time: expire time in seconds for result.
156165
:param result_px_time: expire time in milliseconds for result.
166+
:param connection_kwargs: additional arguments for RedisCluster.
157167
158168
:raises DuplicateExpireTimeSelectedError: if result_ex_time
159169
and result_px_time are selected.
160170
:raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
161171
and result_px_time are equal zero.
162172
"""
163-
self.redis: RedisCluster[bytes] = RedisCluster.from_url(redis_url)
173+
self.redis: RedisCluster[bytes] = RedisCluster.from_url(
174+
redis_url,
175+
**connection_kwargs,
176+
)
164177
self.keep_results = keep_results
165178
self.result_ex_time = result_ex_time
166179
self.result_px_time = result_px_time

taskiq_redis/redis_broker.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from logging import getLogger
22
from typing import Any, AsyncGenerator, Callable, Optional, TypeVar
33

4-
from redis.asyncio import ConnectionPool, Redis
4+
from redis.asyncio import BlockingConnectionPool, ConnectionPool, Redis
55
from taskiq.abc.broker import AsyncBroker
66
from taskiq.abc.result_backend import AsyncResultBackend
77
from taskiq.message import BrokerMessage
@@ -31,14 +31,16 @@ def __init__(
3131
:param result_backend: custom result backend.
3232
:param queue_name: name for a list in redis.
3333
:param max_connection_pool_size: maximum number of connections in pool.
34-
:param connection_kwargs: additional arguments for aio-redis ConnectionPool.
34+
Each worker opens its own connection. Therefore this value has to be
35+
at least number of workers + 1.
36+
:param connection_kwargs: additional arguments for redis BlockingConnectionPool.
3537
"""
3638
super().__init__(
3739
result_backend=result_backend,
3840
task_id_generator=task_id_generator,
3941
)
4042

41-
self.connection_pool: ConnectionPool = ConnectionPool.from_url(
43+
self.connection_pool: ConnectionPool = BlockingConnectionPool.from_url(
4244
url=url,
4345
max_connections=max_connection_pool_size,
4446
**connection_kwargs,
@@ -60,8 +62,9 @@ async def kick(self, message: BrokerMessage) -> None:
6062
6163
:param message: message to send.
6264
"""
65+
queue_name = message.labels.get("queue_name") or self.queue_name
6366
async with Redis(connection_pool=self.connection_pool) as redis_conn:
64-
await redis_conn.publish(self.queue_name, message.message)
67+
await redis_conn.publish(queue_name, message.message)
6568

6669
async def listen(self) -> AsyncGenerator[bytes, None]:
6770
"""
@@ -95,8 +98,9 @@ async def kick(self, message: BrokerMessage) -> None:
9598
9699
:param message: message to append.
97100
"""
101+
queue_name = message.labels.get("queue_name") or self.queue_name
98102
async with Redis(connection_pool=self.connection_pool) as redis_conn:
99-
await redis_conn.lpush(self.queue_name, message.message)
103+
await redis_conn.lpush(queue_name, message.message)
100104

101105
async def listen(self) -> AsyncGenerator[bytes, None]:
102106
"""

taskiq_redis/schedule_source.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, List, Optional
22

3-
from redis.asyncio import ConnectionPool, Redis, RedisCluster
3+
from redis.asyncio import BlockingConnectionPool, ConnectionPool, Redis, RedisCluster
44
from taskiq import ScheduleSource
55
from taskiq.abc.serializer import TaskiqSerializer
66
from taskiq.compat import model_dump, model_validate
@@ -22,7 +22,7 @@ class RedisScheduleSource(ScheduleSource):
2222
This is how many keys will be fetched at once.
2323
:param max_connection_pool_size: maximum number of connections in pool.
2424
:param serializer: serializer for data.
25-
:param connection_kwargs: additional arguments for aio-redis ConnectionPool.
25+
:param connection_kwargs: additional arguments for redis BlockingConnectionPool.
2626
"""
2727

2828
def __init__(
@@ -35,7 +35,7 @@ def __init__(
3535
**connection_kwargs: Any,
3636
) -> None:
3737
self.prefix = prefix
38-
self.connection_pool: ConnectionPool = ConnectionPool.from_url(
38+
self.connection_pool: ConnectionPool = BlockingConnectionPool.from_url(
3939
url=url,
4040
max_connections=max_connection_pool_size,
4141
**connection_kwargs,

tests/test_broker.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,29 @@ async def test_pub_sub_broker(
7171
await broker.shutdown()
7272

7373

74+
@pytest.mark.anyio
75+
async def test_pub_sub_broker_max_connections(
76+
valid_broker_message: BrokerMessage,
77+
redis_url: str,
78+
) -> None:
79+
"""Test PubSubBroker with connection limit set."""
80+
broker = PubSubBroker(
81+
url=redis_url,
82+
queue_name=uuid.uuid4().hex,
83+
max_connection_pool_size=4,
84+
timeout=1,
85+
)
86+
worker_tasks = [asyncio.create_task(get_message(broker)) for _ in range(3)]
87+
await asyncio.sleep(0.3)
88+
89+
await asyncio.gather(*[broker.kick(valid_broker_message) for _ in range(50)])
90+
await asyncio.sleep(0.3)
91+
92+
for worker in worker_tasks:
93+
worker.cancel()
94+
await broker.shutdown()
95+
96+
7497
@pytest.mark.anyio
7598
async def test_list_queue_broker(
7699
valid_broker_message: BrokerMessage,
@@ -98,6 +121,29 @@ async def test_list_queue_broker(
98121
await broker.shutdown()
99122

100123

124+
@pytest.mark.anyio
125+
async def test_list_queue_broker_max_connections(
126+
valid_broker_message: BrokerMessage,
127+
redis_url: str,
128+
) -> None:
129+
"""Test ListQueueBroker with connection limit set."""
130+
broker = ListQueueBroker(
131+
url=redis_url,
132+
queue_name=uuid.uuid4().hex,
133+
max_connection_pool_size=4,
134+
timeout=1,
135+
)
136+
worker_tasks = [asyncio.create_task(get_message(broker)) for _ in range(3)]
137+
await asyncio.sleep(0.3)
138+
139+
await asyncio.gather(*[broker.kick(valid_broker_message) for _ in range(50)])
140+
await asyncio.sleep(0.3)
141+
142+
for worker in worker_tasks:
143+
worker.cancel()
144+
await broker.shutdown()
145+
146+
101147
@pytest.mark.anyio
102148
async def test_list_queue_cluster_broker(
103149
valid_broker_message: BrokerMessage,

tests/test_result_backend.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import uuid
23

34
import pytest
@@ -132,6 +133,38 @@ async def test_keep_results_after_reading(redis_url: str) -> None:
132133
await result_backend.shutdown()
133134

134135

136+
@pytest.mark.anyio
137+
async def test_set_result_max_connections(redis_url: str) -> None:
138+
"""
139+
Tests that asynchronous backend works with connection limit.
140+
141+
:param redis_url: redis URL.
142+
"""
143+
result_backend = RedisAsyncResultBackend( # type: ignore
144+
redis_url=redis_url,
145+
max_connection_pool_size=1,
146+
timeout=3,
147+
)
148+
149+
task_id = uuid.uuid4().hex
150+
result: "TaskiqResult[int]" = TaskiqResult(
151+
is_err=True,
152+
log="My Log",
153+
return_value=11,
154+
execution_time=112.2,
155+
)
156+
await result_backend.set_result(
157+
task_id=task_id,
158+
result=result,
159+
)
160+
161+
async def get_result() -> None:
162+
await result_backend.get_result(task_id=task_id, with_logs=True)
163+
164+
await asyncio.gather(*[get_result() for _ in range(10)])
165+
await result_backend.shutdown()
166+
167+
135168
@pytest.mark.anyio
136169
async def test_set_result_success_cluster(redis_cluster_url: str) -> None:
137170
"""

tests/test_schedule_source.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import datetime as dt
23
import uuid
34

@@ -108,6 +109,18 @@ async def test_buffer(redis_url: str) -> None:
108109
await source.shutdown()
109110

110111

112+
@pytest.mark.anyio
113+
async def test_max_connections(redis_url: str) -> None:
114+
prefix = uuid.uuid4().hex
115+
source = RedisScheduleSource(
116+
redis_url,
117+
prefix=prefix,
118+
max_connection_pool_size=1,
119+
timeout=3,
120+
)
121+
await asyncio.gather(*[source.get_schedules() for _ in range(10)])
122+
123+
111124
@pytest.mark.anyio
112125
async def test_cluster_set_schedule(redis_cluster_url: str) -> None:
113126
prefix = uuid.uuid4().hex

0 commit comments

Comments
 (0)