Skip to content

Add cluster schedule source #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion taskiq_redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
)
from taskiq_redis.redis_broker import ListQueueBroker, PubSubBroker
from taskiq_redis.redis_cluster_broker import ListQueueClusterBroker
from taskiq_redis.schedule_source import RedisScheduleSource
from taskiq_redis.schedule_source import (
RedisClusterScheduleSource,
RedisScheduleSource,
)

__all__ = [
"RedisAsyncClusterResultBackend",
Expand All @@ -14,4 +17,5 @@
"PubSubBroker",
"ListQueueClusterBroker",
"RedisScheduleSource",
"RedisClusterScheduleSource",
]
81 changes: 80 additions & 1 deletion taskiq_redis/schedule_source.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, List, Optional

from redis.asyncio import ConnectionPool, Redis
from redis.asyncio import ConnectionPool, Redis, RedisCluster
from taskiq import ScheduleSource
from taskiq.abc.serializer import TaskiqSerializer
from taskiq.compat import model_dump, model_validate
Expand Down Expand Up @@ -95,3 +95,82 @@ async def post_send(self, task: ScheduledTask) -> None:
async def shutdown(self) -> None:
"""Shut down the schedule source."""
await self.connection_pool.disconnect()


class RedisClusterScheduleSource(ScheduleSource):
"""
Source of schedules for redis cluster.

This class allows you to store schedules in redis.
Also it supports dynamic schedules.

:param url: url to redis cluster.
:param prefix: prefix for redis schedule keys.
:param buffer_size: buffer size for redis scan.
This is how many keys will be fetched at once.
:param max_connection_pool_size: maximum number of connections in pool.
:param serializer: serializer for data.
:param connection_kwargs: additional arguments for RedisCluster.
"""

def __init__(
self,
url: str,
prefix: str = "schedule",
buffer_size: int = 50,
serializer: Optional[TaskiqSerializer] = None,
**connection_kwargs: Any,
) -> None:
self.prefix = prefix
self.redis: RedisCluster[bytes] = RedisCluster.from_url(
url,
**connection_kwargs,
)
self.buffer_size = buffer_size
if serializer is None:
serializer = PickleSerializer()
self.serializer = serializer

async def delete_schedule(self, schedule_id: str) -> None:
"""Remove schedule by id."""
await self.redis.delete(f"{self.prefix}:{schedule_id}") # type: ignore[attr-defined]

async def add_schedule(self, schedule: ScheduledTask) -> None:
"""
Add schedule to redis.

:param schedule: schedule to add.
:param schedule_id: schedule id.
"""
await self.redis.set( # type: ignore[attr-defined]
f"{self.prefix}:{schedule.schedule_id}",
self.serializer.dumpb(model_dump(schedule)),
)

async def get_schedules(self) -> List[ScheduledTask]:
"""
Get all schedules from redis.

This method is used by scheduler to get all schedules.

:return: list of schedules.
"""
schedules = []
buffer = []
async for key in self.redis.scan_iter(f"{self.prefix}:*"): # type: ignore[attr-defined]
buffer.append(key)
if len(buffer) >= self.buffer_size:
schedules.extend(await self.redis.mget(buffer)) # type: ignore[attr-defined]
buffer = []
if buffer:
schedules.extend(await self.redis.mget(buffer)) # type: ignore[attr-defined]
return [
model_validate(ScheduledTask, self.serializer.loadb(schedule))
for schedule in schedules
if schedule
]

async def post_send(self, task: ScheduledTask) -> None:
"""Delete a task after it's completed."""
if task.time is not None:
await self.delete_schedule(task.schedule_id)
150 changes: 149 additions & 1 deletion tests/test_schedule_source.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import datetime as dt
import uuid

import pytest
from taskiq import ScheduledTask

from taskiq_redis import RedisScheduleSource
from taskiq_redis import RedisClusterScheduleSource, RedisScheduleSource


@pytest.mark.anyio
Expand Down Expand Up @@ -56,6 +57,153 @@ async def test_post_run_cron(redis_url: str) -> None:
cron="* * * * *",
)
await source.add_schedule(schedule)
assert await source.get_schedules() == [schedule]
await source.post_send(schedule)
assert await source.get_schedules() == [schedule]
await source.shutdown()


@pytest.mark.anyio
async def test_post_run_time(redis_url: str) -> None:
prefix = uuid.uuid4().hex
source = RedisScheduleSource(redis_url, prefix=prefix)
schedule = ScheduledTask(
task_name="test_task",
labels={},
args=[],
kwargs={},
time=dt.datetime(2000, 1, 1),
)
await source.add_schedule(schedule)
assert await source.get_schedules() == [schedule]
await source.post_send(schedule)
assert await source.get_schedules() == []
await source.shutdown()


@pytest.mark.anyio
async def test_buffer(redis_url: str) -> None:
prefix = uuid.uuid4().hex
source = RedisScheduleSource(redis_url, prefix=prefix, buffer_size=1)
schedule1 = ScheduledTask(
task_name="test_task1",
labels={},
args=[],
kwargs={},
cron="* * * * *",
)
schedule2 = ScheduledTask(
task_name="test_task2",
labels={},
args=[],
kwargs={},
cron="* * * * *",
)
await source.add_schedule(schedule1)
await source.add_schedule(schedule2)
schedules = await source.get_schedules()
assert len(schedules) == 2
assert schedule1 in schedules
assert schedule2 in schedules
await source.shutdown()


@pytest.mark.anyio
async def test_cluster_set_schedule(redis_cluster_url: str) -> None:
prefix = uuid.uuid4().hex
source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix)
schedule = ScheduledTask(
task_name="test_task",
labels={},
args=[],
kwargs={},
cron="* * * * *",
)
await source.add_schedule(schedule)
schedules = await source.get_schedules()
assert schedules == [schedule]
await source.shutdown()


@pytest.mark.anyio
async def test_cluster_delete_schedule(redis_cluster_url: str) -> None:
prefix = uuid.uuid4().hex
source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix)
schedule = ScheduledTask(
task_name="test_task",
labels={},
args=[],
kwargs={},
cron="* * * * *",
)
await source.add_schedule(schedule)
schedules = await source.get_schedules()
assert schedules == [schedule]
await source.delete_schedule(schedule.schedule_id)
schedules = await source.get_schedules()
# Schedules are empty.
assert not schedules
await source.shutdown()


@pytest.mark.anyio
async def test_cluster_post_run_cron(redis_cluster_url: str) -> None:
prefix = uuid.uuid4().hex
source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix)
schedule = ScheduledTask(
task_name="test_task",
labels={},
args=[],
kwargs={},
cron="* * * * *",
)
await source.add_schedule(schedule)
assert await source.get_schedules() == [schedule]
await source.post_send(schedule)
assert await source.get_schedules() == [schedule]
await source.shutdown()


@pytest.mark.anyio
async def test_cluster_post_run_time(redis_cluster_url: str) -> None:
prefix = uuid.uuid4().hex
source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix)
schedule = ScheduledTask(
task_name="test_task",
labels={},
args=[],
kwargs={},
time=dt.datetime(2000, 1, 1),
)
await source.add_schedule(schedule)
assert await source.get_schedules() == [schedule]
await source.post_send(schedule)
assert await source.get_schedules() == []
await source.shutdown()


@pytest.mark.anyio
async def test_cluster_buffer(redis_cluster_url: str) -> None:
prefix = uuid.uuid4().hex
source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix, buffer_size=1)
schedule1 = ScheduledTask(
task_name="test_task1",
labels={},
args=[],
kwargs={},
cron="* * * * *",
)
schedule2 = ScheduledTask(
task_name="test_task2",
labels={},
args=[],
kwargs={},
cron="* * * * *",
)
await source.add_schedule(schedule1)
await source.add_schedule(schedule2)
schedules = await source.get_schedules()
assert len(schedules) == 2
assert schedule1 in schedules
assert schedule2 in schedules
await source.shutdown()