Skip to content

Feat(Stream): Use redis stream #504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions arq/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,16 @@
from redis.asyncio.sentinel import Sentinel
from redis.exceptions import RedisError, WatchError

from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix
from .constants import (
default_queue_name,
expires_extra_ms,
job_key_prefix,
job_message_id_prefix,
result_key_prefix,
stream_key_suffix,
)
from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job
from .lua_script import publish_job_lua
from .utils import timestamp_ms, to_ms, to_unix_ms

logger = logging.getLogger('arq.connections')
Expand Down Expand Up @@ -114,6 +122,7 @@ def __init__(
if pool_or_conn:
kwargs['connection_pool'] = pool_or_conn
self.expires_extra_ms = expires_extra_ms
self.publish_job_sha = None
super().__init__(**kwargs)

async def enqueue_job(
Expand All @@ -126,6 +135,7 @@ async def enqueue_job(
_defer_by: Union[None, int, float, timedelta] = None,
_expires: Union[None, int, float, timedelta] = None,
_job_try: Optional[int] = None,
_use_stream: bool = False,
**kwargs: Any,
) -> Optional[Job]:
"""
Expand All @@ -145,6 +155,7 @@ async def enqueue_job(
"""
if _queue_name is None:
_queue_name = self.default_queue_name

job_id = _job_id or uuid4().hex
job_key = job_key_prefix + job_id
if _defer_until and _defer_by:
Expand All @@ -153,6 +164,9 @@ async def enqueue_job(
defer_by_ms = to_ms(_defer_by)
expires_ms = to_ms(_expires)

if _use_stream is True and self.publish_job_sha is None:
self.publish_job_sha = await self.script_load(publish_job_lua) # type: ignore[no-untyped-call]

async with self.pipeline(transaction=True) as pipe:
await pipe.watch(job_key)
if await pipe.exists(job_key, result_key_prefix + job_id):
Expand All @@ -172,14 +186,37 @@ async def enqueue_job(
job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer)
pipe.multi()
pipe.psetex(job_key, expires_ms, job)
pipe.zadd(_queue_name, {job_id: score})

if _use_stream is False:
pipe.zadd(_queue_name, {job_id: score})
else:
stream_key = _queue_name + stream_key_suffix
job_message_id_key = job_message_id_prefix + job_id

pipe.evalsha(
self.publish_job_sha,
2,
# keys
stream_key,
job_message_id_key,
# args
job_id,
str(enqueue_time_ms),
str(expires_ms),
)
try:
await pipe.execute()
except WatchError:
# job got enqueued since we checked 'job_exists'
return None
return Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer)

async def get_stream_size(self, queue_name: str | None = None, include_delayed_tasks: bool = True) -> int:
if queue_name is None:
queue_name = self.default_queue_name

return await self.xlen(queue_name + stream_key_suffix)

async def _get_job_result(self, key: bytes) -> JobResult:
job_id = key[len(result_key_prefix) :].decode()
job = Job(job_id, self, _deserializer=self.job_deserializer)
Expand Down
3 changes: 3 additions & 0 deletions arq/constants.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
default_queue_name = 'arq:queue'
job_key_prefix = 'arq:job:'
in_progress_key_prefix = 'arq:in-progress:'
job_message_id_prefix = 'arq:message-id:'
result_key_prefix = 'arq:result:'
retry_key_prefix = 'arq:retry:'
abort_jobs_ss = 'arq:abort'
stream_key_suffix = ':stream'
default_consumer_group = 'arq:consumers'
# age of items in the abort_key sorted set after which they're deleted
abort_job_max_age = 60
health_check_key_suffix = ':health-check'
Expand Down
42 changes: 35 additions & 7 deletions arq/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,17 @@

from redis.asyncio import Redis

from .constants import abort_jobs_ss, default_queue_name, in_progress_key_prefix, job_key_prefix, result_key_prefix
from .utils import ms_to_datetime, poll, timestamp_ms
from .constants import (
abort_jobs_ss,
default_queue_name,
in_progress_key_prefix,
job_key_prefix,
job_message_id_prefix,
result_key_prefix,
stream_key_suffix,
)
from .lua_script import get_job_from_stream_lua
from .utils import _list_to_dict, ms_to_datetime, poll, timestamp_ms

logger = logging.getLogger('arq.jobs')

Expand Down Expand Up @@ -105,7 +114,8 @@ async def result(
async with self._redis.pipeline(transaction=True) as tr:
tr.get(result_key_prefix + self.job_id)
tr.zscore(self._queue_name, self.job_id)
v, s = await tr.execute()
tr.get(job_message_id_prefix + self.job_id)
v, s, m = await tr.execute()

if v:
info = deserialize_result(v, deserializer=self._deserializer)
Expand All @@ -115,7 +125,7 @@ async def result(
raise info.result
else:
raise SerializationError(info.result)
elif s is None:
elif s is None and m is None:
raise ResultNotFound(
'Not waiting for job result because the job is not in queue. '
'Is the worker function configured to keep result?'
Expand All @@ -134,8 +144,23 @@ async def info(self) -> Optional[JobDef]:
if v:
info = deserialize_job(v, deserializer=self._deserializer)
if info:
s = await self._redis.zscore(self._queue_name, self.job_id)
info.score = None if s is None else int(s)
async with self._redis.pipeline(transaction=True) as tr:
tr.zscore(self._queue_name, self.job_id)
tr.eval(
get_job_from_stream_lua,
2,
self._queue_name + stream_key_suffix,
job_message_id_prefix + self.job_id,
)
delayed_score, job_info = await tr.execute()

if delayed_score:
info.score = int(delayed_score)
elif job_info:
_, job_info_payload = job_info
info.score = int(_list_to_dict(job_info_payload)[b'score'])
else:
info.score = None
return info

async def result_info(self) -> Optional[JobResult]:
Expand All @@ -157,12 +182,15 @@ async def status(self) -> JobStatus:
tr.exists(result_key_prefix + self.job_id)
tr.exists(in_progress_key_prefix + self.job_id)
tr.zscore(self._queue_name, self.job_id)
is_complete, is_in_progress, score = await tr.execute()
tr.exists(job_message_id_prefix + self.job_id)
is_complete, is_in_progress, score, queued = await tr.execute()

if is_complete:
return JobStatus.complete
elif is_in_progress:
return JobStatus.in_progress
elif queued:
return JobStatus.queued
elif score:
return JobStatus.deferred if score > timestamp_ms() else JobStatus.queued
else:
Expand Down
24 changes: 24 additions & 0 deletions arq/lua_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
publish_job_lua = """
local stream_key = KEYS[1]
local job_message_id_key = KEYS[2]
local job_id = ARGV[1]
local score = ARGV[2]
local job_message_id_expire_ms = ARGV[3]
local message_id = redis.call('xadd', stream_key, '*', 'job_id', job_id, 'score', score)
redis.call('set', job_message_id_key, message_id, 'px', job_message_id_expire_ms)
return message_id
"""

get_job_from_stream_lua = """
local stream_key = KEYS[1]
local job_message_id_key = KEYS[2]
local message_id = redis.call('get', job_message_id_key)
if message_id == false then
return nil
end
local job = redis.call('xrange', stream_key, message_id, message_id)
if job == nil then
return nil
end
return job[1]
"""
4 changes: 4 additions & 0 deletions arq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,7 @@ def import_string(dotted_path: str) -> Any:
return getattr(module, class_name)
except AttributeError as e:
raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e


def _list_to_dict(input_list: list[Any]) -> dict[Any, Any]:
return dict(zip(input_list[::2], input_list[1::2], strict=True))
Loading
Loading