Skip to content

Commit f66a2c2

Browse files
authored
Merge pull request #57 from taskiq-python/feature/backend-serializer
2 parents 386b5bb + f536b4f commit f66a2c2

File tree

7 files changed

+456
-437
lines changed

7 files changed

+456
-437
lines changed

poetry.lock

Lines changed: 389 additions & 383 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ keywords = [
2626

2727
[tool.poetry.dependencies]
2828
python = "^3.8.1"
29-
taskiq = ">=0.10.3,<1"
29+
taskiq = ">=0.11.1,<1"
3030
redis = "^5"
3131

3232
[tool.poetry.group.dev.dependencies]
@@ -40,7 +40,7 @@ fakeredis = "^2"
4040
pre-commit = "^2.20.0"
4141
pytest-xdist = { version = "^2.5.0", extras = ["psutil"] }
4242
ruff = "^0.1.0"
43-
types-redis = "^4.6.0.7"
43+
types-redis = "^4.6.0.20240425"
4444

4545
[tool.mypy]
4646
strict = true

taskiq_redis/redis_backend.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pickle
21
import sys
32
from contextlib import asynccontextmanager
43
from typing import (
@@ -15,16 +14,18 @@
1514

1615
from redis.asyncio import BlockingConnectionPool, Redis, Sentinel
1716
from redis.asyncio.cluster import RedisCluster
17+
from redis.asyncio.connection import Connection
1818
from taskiq import AsyncResultBackend
1919
from taskiq.abc.result_backend import TaskiqResult
2020
from taskiq.abc.serializer import TaskiqSerializer
21+
from taskiq.compat import model_dump, model_validate
22+
from taskiq.serializers import PickleSerializer
2123

2224
from taskiq_redis.exceptions import (
2325
DuplicateExpireTimeSelectedError,
2426
ExpireTimeMustBeMoreThanZeroError,
2527
ResultIsMissingError,
2628
)
27-
from taskiq_redis.serializer import PickleSerializer
2829

2930
if sys.version_info >= (3, 10):
3031
from typing import TypeAlias
@@ -33,8 +34,10 @@
3334

3435
if TYPE_CHECKING:
3536
_Redis: TypeAlias = Redis[bytes]
37+
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection]
3638
else:
3739
_Redis: TypeAlias = Redis
40+
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool
3841

3942
_ReturnType = TypeVar("_ReturnType")
4043

@@ -49,6 +52,7 @@ def __init__(
4952
result_ex_time: Optional[int] = None,
5053
result_px_time: Optional[int] = None,
5154
max_connection_pool_size: Optional[int] = None,
55+
serializer: Optional[TaskiqSerializer] = None,
5256
**connection_kwargs: Any,
5357
) -> None:
5458
"""
@@ -66,11 +70,12 @@ def __init__(
6670
:raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
6771
and result_px_time are equal zero.
6872
"""
69-
self.redis_pool = BlockingConnectionPool.from_url(
73+
self.redis_pool: _BlockingConnectionPool = BlockingConnectionPool.from_url(
7074
url=redis_url,
7175
max_connections=max_connection_pool_size,
7276
**connection_kwargs,
7377
)
78+
self.serializer = serializer or PickleSerializer()
7479
self.keep_results = keep_results
7580
self.result_ex_time = result_ex_time
7681
self.result_px_time = result_px_time
@@ -110,9 +115,9 @@ async def set_result(
110115
:param task_id: ID of the task.
111116
:param result: TaskiqResult instance.
112117
"""
113-
redis_set_params: Dict[str, Union[str, bytes, int]] = {
118+
redis_set_params: Dict[str, Union[str, int, bytes]] = {
114119
"name": task_id,
115-
"value": pickle.dumps(result),
120+
"value": self.serializer.dumpb(model_dump(result)),
116121
}
117122
if self.result_ex_time:
118123
redis_set_params["ex"] = self.result_ex_time
@@ -159,8 +164,9 @@ async def get_result(
159164
if result_value is None:
160165
raise ResultIsMissingError
161166

162-
taskiq_result: TaskiqResult[_ReturnType] = pickle.loads( # noqa: S301
163-
result_value,
167+
taskiq_result = model_validate(
168+
TaskiqResult[_ReturnType],
169+
self.serializer.loadb(result_value),
164170
)
165171

166172
if not with_logs:
@@ -178,6 +184,7 @@ def __init__(
178184
keep_results: bool = True,
179185
result_ex_time: Optional[int] = None,
180186
result_px_time: Optional[int] = None,
187+
serializer: Optional[TaskiqSerializer] = None,
181188
**connection_kwargs: Any,
182189
) -> None:
183190
"""
@@ -198,6 +205,7 @@ def __init__(
198205
redis_url,
199206
**connection_kwargs,
200207
)
208+
self.serializer = serializer or PickleSerializer()
201209
self.keep_results = keep_results
202210
self.result_ex_time = result_ex_time
203211
self.result_px_time = result_px_time
@@ -239,7 +247,7 @@ async def set_result(
239247
"""
240248
redis_set_params: Dict[str, Union[str, bytes, int]] = {
241249
"name": task_id,
242-
"value": pickle.dumps(result),
250+
"value": self.serializer.dumpb(model_dump(result)),
243251
}
244252
if self.result_ex_time:
245253
redis_set_params["ex"] = self.result_ex_time
@@ -283,8 +291,9 @@ async def get_result(
283291
if result_value is None:
284292
raise ResultIsMissingError
285293

286-
taskiq_result: TaskiqResult[_ReturnType] = pickle.loads( # noqa: S301
287-
result_value,
294+
taskiq_result: TaskiqResult[_ReturnType] = model_validate(
295+
TaskiqResult[_ReturnType],
296+
self.serializer.loadb(result_value),
288297
)
289298

290299
if not with_logs:
@@ -331,9 +340,7 @@ def __init__(
331340
**connection_kwargs,
332341
)
333342
self.master_name = master_name
334-
if serializer is None:
335-
serializer = PickleSerializer()
336-
self.serializer = serializer
343+
self.serializer = serializer or PickleSerializer()
337344
self.keep_results = keep_results
338345
self.result_ex_time = result_ex_time
339346
self.result_px_time = result_px_time
@@ -375,7 +382,7 @@ async def set_result(
375382
"""
376383
redis_set_params: Dict[str, Union[str, bytes, int]] = {
377384
"name": task_id,
378-
"value": self.serializer.dumpb(result),
385+
"value": self.serializer.dumpb(model_dump(result)),
379386
}
380387
if self.result_ex_time:
381388
redis_set_params["ex"] = self.result_ex_time
@@ -422,11 +429,17 @@ async def get_result(
422429
if result_value is None:
423430
raise ResultIsMissingError
424431

425-
taskiq_result: TaskiqResult[_ReturnType] = pickle.loads( # noqa: S301
426-
result_value,
432+
taskiq_result = model_validate(
433+
TaskiqResult[_ReturnType],
434+
self.serializer.loadb(result_value),
427435
)
428436

429437
if not with_logs:
430438
taskiq_result.log = None
431439

432440
return taskiq_result
441+
442+
async def shutdown(self) -> None:
443+
"""Shutdown sentinel connections."""
444+
for sentinel in self.sentinel.sentinels:
445+
await sentinel.aclose() # type: ignore[attr-defined]

taskiq_redis/redis_broker.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import sys
12
from logging import getLogger
2-
from typing import Any, AsyncGenerator, Callable, Optional, TypeVar
3+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Optional, TypeVar
34

4-
from redis.asyncio import BlockingConnectionPool, ConnectionPool, Redis
5+
from redis.asyncio import BlockingConnectionPool, Connection, Redis
56
from taskiq.abc.broker import AsyncBroker
67
from taskiq.abc.result_backend import AsyncResultBackend
78
from taskiq.message import BrokerMessage
@@ -10,6 +11,16 @@
1011

1112
logger = getLogger("taskiq.redis_broker")
1213

14+
if sys.version_info >= (3, 10):
15+
from typing import TypeAlias
16+
else:
17+
from typing_extensions import TypeAlias
18+
19+
if TYPE_CHECKING:
20+
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection]
21+
else:
22+
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool
23+
1324

1425
class BaseRedisBroker(AsyncBroker):
1526
"""Base broker that works with Redis."""
@@ -40,7 +51,7 @@ def __init__(
4051
task_id_generator=task_id_generator,
4152
)
4253

43-
self.connection_pool: ConnectionPool = BlockingConnectionPool.from_url(
54+
self.connection_pool: _BlockingConnectionPool = BlockingConnectionPool.from_url(
4455
url=url,
4556
max_connections=max_connection_pool_size,
4657
**connection_kwargs,

taskiq_redis/schedule_source.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from redis.asyncio import (
66
BlockingConnectionPool,
7-
ConnectionPool,
7+
Connection,
88
Redis,
99
RedisCluster,
1010
Sentinel,
@@ -13,8 +13,7 @@
1313
from taskiq.abc.serializer import TaskiqSerializer
1414
from taskiq.compat import model_dump, model_validate
1515
from taskiq.scheduler.scheduled_task import ScheduledTask
16-
17-
from taskiq_redis.serializer import PickleSerializer
16+
from taskiq.serializers import PickleSerializer
1817

1918
if sys.version_info >= (3, 10):
2019
from typing import TypeAlias
@@ -23,8 +22,10 @@
2322

2423
if TYPE_CHECKING:
2524
_Redis: TypeAlias = Redis[bytes]
25+
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection]
2626
else:
2727
_Redis: TypeAlias = Redis
28+
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool
2829

2930

3031
class RedisScheduleSource(ScheduleSource):
@@ -53,7 +54,7 @@ def __init__(
5354
**connection_kwargs: Any,
5455
) -> None:
5556
self.prefix = prefix
56-
self.connection_pool: ConnectionPool = BlockingConnectionPool.from_url(
57+
self.connection_pool: _BlockingConnectionPool = BlockingConnectionPool.from_url(
5758
url=url,
5859
max_connections=max_connection_pool_size,
5960
**connection_kwargs,
@@ -186,6 +187,10 @@ async def post_send(self, task: ScheduledTask) -> None:
186187
if task.time is not None:
187188
await self.delete_schedule(task.schedule_id)
188189

190+
async def shutdown(self) -> None:
191+
"""Shut down the schedule source."""
192+
await self.redis.aclose() # type: ignore[attr-defined]
193+
189194

190195
class RedisSentinelScheduleSource(ScheduleSource):
191196
"""
@@ -279,3 +284,8 @@ async def post_send(self, task: ScheduledTask) -> None:
279284
"""Delete a task after it's completed."""
280285
if task.time is not None:
281286
await self.delete_schedule(task.schedule_id)
287+
288+
async def shutdown(self) -> None:
289+
"""Shut down the schedule source."""
290+
for sentinel in self.sentinel.sentinels:
291+
await sentinel.aclose() # type: ignore[attr-defined]

taskiq_redis/serializer.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

tests/test_backend.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -90,34 +90,29 @@ async def test_success_backend_default_result(
9090

9191

9292
@pytest.mark.anyio
93-
async def test_success_backend_custom_result(
93+
async def test_error_backend_custom_result(
9494
custom_taskiq_result: TaskiqResult[_ReturnType],
9595
task_id: str,
9696
redis_url: str,
9797
) -> None:
9898
"""
9999
Tests normal behavior with custom result in TaskiqResult.
100100
101+
Setting custom class as a result should raise an error.
102+
101103
:param custom_taskiq_result: TaskiqResult with custom result.
102104
:param task_id: ID for task.
103105
:param redis_url: url to redis.
104106
"""
105107
backend: RedisAsyncResultBackend[_ReturnType] = RedisAsyncResultBackend(
106108
redis_url,
107109
)
108-
await backend.set_result(
109-
task_id=task_id,
110-
result=custom_taskiq_result,
111-
)
112-
result = await backend.get_result(task_id=task_id)
110+
with pytest.raises(ValueError):
111+
await backend.set_result(
112+
task_id=task_id,
113+
result=custom_taskiq_result,
114+
)
113115

114-
assert (
115-
result.return_value.test_arg # type: ignore
116-
== custom_taskiq_result.return_value.test_arg # type: ignore
117-
)
118-
assert result.is_err == custom_taskiq_result.is_err
119-
assert result.execution_time == custom_taskiq_result.execution_time
120-
assert result.log == custom_taskiq_result.log
121116
await backend.shutdown()
122117

123118

0 commit comments

Comments
 (0)