Skip to content

Commit 67e7dfb

Browse files
Add support for receiving task context (#174)
1 parent bafb4a6 commit 67e7dfb

File tree

13 files changed

+174
-25
lines changed

13 files changed

+174
-25
lines changed

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,26 @@ modified_task = calculate_meaning_of_life.using(priority=10)
7878

7979
In addition to the above attributes, `run_after` can be passed to specify a specific time the task should run.
8080

81+
#### Task context
82+
83+
Sometimes the running task may need to know context about how it was enqueued. To receive the task context as an argument to your task function, pass `takes_context` to the decorator and ensure the task takes a `context` as the first argument.
84+
85+
```python
86+
from django_tasks import task, TaskContext
87+
88+
89+
@task(takes_context=True)
90+
def calculate_meaning_of_life(context: TaskContext) -> int:
91+
return 42
92+
```
93+
94+
The task context has the following attributes:
95+
96+
- `task_result`: The running task result
97+
- `attempt`: The current attempt number for the task
98+
99+
This API will be extended with additional features in future.
100+
81101
### Enqueueing tasks
82102

83103
To execute a task, call the `enqueue` method on it:

django_tasks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
DEFAULT_TASK_BACKEND_ALIAS,
1717
ResultStatus,
1818
Task,
19+
TaskContext,
1920
TaskResult,
2021
task,
2122
)
@@ -31,6 +32,7 @@
3132
"ResultStatus",
3233
"Task",
3334
"TaskResult",
35+
"TaskContext",
3436
]
3537

3638

django_tasks/backends/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from django.core.checks import messages
88
from django.db import connections
99
from django.utils import timezone
10+
from django.utils.inspect import get_func_args
1011
from typing_extensions import ParamSpec
1112

1213
from django_tasks.exceptions import InvalidTaskError
@@ -61,6 +62,15 @@ def validate_task(self, task: Task) -> None:
6162
if not self.supports_async_task and iscoroutinefunction(task.func):
6263
raise InvalidTaskError("Backend does not support async tasks")
6364

65+
task_func_args = get_func_args(task.func)
66+
67+
if task.takes_context and (
68+
not task_func_args or task_func_args[0] != "context"
69+
):
70+
raise InvalidTaskError(
71+
"Task takes context but does not have a first argument of 'context'"
72+
)
73+
6474
if (
6575
task.priority < MIN_PRIORITY
6676
or task.priority > MAX_PRIORITY

django_tasks/backends/database/management/commands/db_worker.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from django_tasks.backends.database.utils import exclusive_transaction
2323
from django_tasks.exceptions import InvalidTaskBackendError
2424
from django_tasks.signals import task_finished, task_started
25-
from django_tasks.task import DEFAULT_QUEUE_NAME
25+
from django_tasks.task import DEFAULT_QUEUE_NAME, TaskContext
2626
from django_tasks.utils import get_random_id
2727

2828
package_logger = logging.getLogger("django_tasks")
@@ -160,8 +160,14 @@ def run_task(self, db_task_result: DBTaskResult) -> None:
160160
backend_type = task.get_backend()
161161

162162
task_started.send(sender=backend_type, task_result=task_result)
163-
164-
return_value = task.call(*task_result.args, **task_result.kwargs)
163+
if task.takes_context:
164+
return_value = task.call(
165+
TaskContext(task_result=task_result),
166+
*task_result.args,
167+
**task_result.kwargs,
168+
)
169+
else:
170+
return_value = task.call(*task_result.args, **task_result.kwargs)
165171

166172
# Setting the return and success value inside the error handling,
167173
# So errors setting it (eg JSON encode) can still be recorded

django_tasks/backends/immediate.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
import logging
22
from functools import partial
3-
from inspect import iscoroutinefunction
43
from typing import TypeVar
54

6-
from asgiref.sync import async_to_sync
75
from django.db import transaction
86
from django.utils import timezone
97
from typing_extensions import ParamSpec
108

119
from django_tasks.signals import task_enqueued, task_finished, task_started
12-
from django_tasks.task import ResultStatus, Task, TaskError, TaskResult
10+
from django_tasks.task import ResultStatus, Task, TaskContext, TaskError, TaskResult
1311
from django_tasks.utils import (
1412
get_exception_traceback,
1513
get_module_path,
@@ -43,23 +41,26 @@ def _execute_task(self, task_result: TaskResult) -> None:
4341

4442
task = task_result.task
4543

46-
calling_task_func = (
47-
async_to_sync(task.func) if iscoroutinefunction(task.func) else task.func
48-
)
49-
5044
object.__setattr__(task_result, "status", ResultStatus.RUNNING)
5145
object.__setattr__(task_result, "started_at", timezone.now())
5246
object.__setattr__(task_result, "last_attempted_at", timezone.now())
5347
task_result.worker_ids.append(self.worker_id)
5448
task_started.send(sender=type(self), task_result=task_result)
5549

5650
try:
51+
if task.takes_context:
52+
raw_return_value = task.call(
53+
TaskContext(task_result=task_result),
54+
*task_result.args,
55+
**task_result.kwargs,
56+
)
57+
else:
58+
raw_return_value = task.call(*task_result.args, **task_result.kwargs)
59+
5760
object.__setattr__(
5861
task_result,
5962
"_return_value",
60-
json_normalize(
61-
calling_task_func(*task_result.args, **task_result.kwargs)
62-
),
63+
json_normalize(raw_return_value),
6364
)
6465
except BaseException as e:
6566
# If the user tried to terminate, let them

django_tasks/backends/rq.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from django.core.checks import messages
99
from django.core.exceptions import SuspiciousOperation
1010
from django.db import transaction
11+
from django.utils.functional import cached_property
1112
from redis.client import Redis
1213
from rq.job import Callback, JobStatus
1314
from rq.job import Job as BaseJob
@@ -23,6 +24,7 @@
2324
MAX_PRIORITY,
2425
ResultStatus,
2526
Task,
27+
TaskContext,
2628
TaskError,
2729
)
2830
from django_tasks.task import TaskResult as BaseTaskResult
@@ -51,21 +53,30 @@ class TaskResult(BaseTaskResult[T]):
5153

5254
class Job(BaseJob):
5355
def perform(self) -> Any:
54-
task_result = self.into_task_result()
55-
5656
assert self.worker_name is not None
5757
self.meta.setdefault("_django_tasks_worker_ids", []).append(self.worker_name)
5858
self.save_meta() # type: ignore[no-untyped-call]
5959

60-
task_started.send(type(task_result.task.get_backend()), task_result=task_result)
60+
task_started.send(
61+
type(self.task_result.task.get_backend()), task_result=self.task_result
62+
)
6163

6264
return super().perform()
6365

6466
def _execute(self) -> Any:
6567
"""
6668
Shim RQ's `Job` to call the underlying `Task` function.
6769
"""
68-
return self.func.call(*self.args, **self.kwargs)
70+
try:
71+
if self.func.takes_context:
72+
return self.func.call(
73+
TaskContext(task_result=self.task_result), *self.args, **self.kwargs
74+
)
75+
return self.func.call(*self.args, **self.kwargs)
76+
finally:
77+
# Clear the task result cache, as it's changed now
78+
self.__dict__.pop("task_result", None)
79+
pass
6980

7081
@property
7182
def func(self) -> Task:
@@ -78,7 +89,8 @@ def func(self) -> Task:
7889

7990
return func
8091

81-
def into_task_result(self) -> TaskResult:
92+
@cached_property
93+
def task_result(self) -> TaskResult:
8294
task: Task = self.func
8395

8496
scheduled_job_registry = ScheduledJobRegistry( # type: ignore[no-untyped-call]
@@ -145,15 +157,15 @@ def failed_callback(
145157
)
146158
job.save_meta() # type: ignore[no-untyped-call]
147159

148-
task_result = job.into_task_result()
160+
task_result = job.task_result
149161

150162
object.__setattr__(task_result, "status", ResultStatus.FAILED)
151163

152164
task_finished.send(type(task_result.task.get_backend()), task_result=task_result)
153165

154166

155167
def success_callback(job: Job, connection: Optional[Redis], result: Any) -> None:
156-
task_result = job.into_task_result()
168+
task_result = job.task_result
157169

158170
object.__setattr__(task_result, "status", ResultStatus.SUCCEEDED)
159171

@@ -241,7 +253,7 @@ def get_result(self, result_id: str) -> TaskResult:
241253
if job is None:
242254
raise ResultDoesNotExist(result_id)
243255

244-
return job.into_task_result()
256+
return job.task_result
245257

246258
def check(self, **kwargs: Any) -> Iterable[messages.CheckMessage]:
247259
yield from super().check(**kwargs)

django_tasks/task.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Any,
77
Callable,
88
Generic,
9+
Literal,
910
Optional,
1011
TypeVar,
1112
Union,
@@ -17,7 +18,7 @@
1718
from django.db.models.enums import TextChoices
1819
from django.utils.module_loading import import_string
1920
from django.utils.translation import gettext_lazy as _
20-
from typing_extensions import ParamSpec, Self
21+
from typing_extensions import Concatenate, ParamSpec, Self
2122

2223
from .exceptions import ResultDoesNotExist
2324
from .utils import (
@@ -87,6 +88,11 @@ class Task(Generic[P, T]):
8788
immediately, or whatever the backend decides
8889
"""
8990

91+
takes_context: bool = False
92+
"""
93+
Whether the task receives the task context when executed.
94+
"""
95+
9096
def __post_init__(self) -> None:
9197
self.get_backend().validate_task(self)
9298

@@ -197,18 +203,37 @@ def task(
197203
queue_name: str = DEFAULT_QUEUE_NAME,
198204
backend: str = DEFAULT_TASK_BACKEND_ALIAS,
199205
enqueue_on_commit: Optional[bool] = None,
206+
takes_context: Literal[False] = False,
200207
) -> Callable[[Callable[P, T]], Task[P, T]]: ...
201208

202209

203-
# Implementation
210+
# Decorator with context and arguments
211+
# e.g. @task(takes_context=True, ...)
212+
@overload
204213
def task(
214+
*,
215+
priority: int = DEFAULT_PRIORITY,
216+
queue_name: str = DEFAULT_QUEUE_NAME,
217+
backend: str = DEFAULT_TASK_BACKEND_ALIAS,
218+
enqueue_on_commit: Optional[bool] = None,
219+
takes_context: Literal[True],
220+
) -> Callable[[Callable[Concatenate["TaskContext", P], T]], Task[P, T]]: ...
221+
222+
223+
# Implementation
224+
def task( # type: ignore[misc]
205225
function: Optional[Callable[P, T]] = None,
206226
*,
207227
priority: int = DEFAULT_PRIORITY,
208228
queue_name: str = DEFAULT_QUEUE_NAME,
209229
backend: str = DEFAULT_TASK_BACKEND_ALIAS,
210230
enqueue_on_commit: Optional[bool] = None,
211-
) -> Union[Task[P, T], Callable[[Callable[P, T]], Task[P, T]]]:
231+
takes_context: bool = False,
232+
) -> Union[
233+
Task[P, T],
234+
Callable[[Callable[P, T]], Task[P, T]],
235+
Callable[[Callable[Concatenate["TaskContext", P], T]], Task[P, T]],
236+
]:
212237
"""
213238
A decorator used to create a task.
214239
"""
@@ -221,6 +246,7 @@ def wrapper(f: Callable[P, T]) -> Task[P, T]:
221246
queue_name=queue_name,
222247
backend=backend,
223248
enqueue_on_commit=enqueue_on_commit,
249+
takes_context=takes_context,
224250
)
225251

226252
if function:
@@ -330,3 +356,12 @@ async def arefresh(self) -> None:
330356

331357
for attr in TASK_REFRESH_ATTRS:
332358
object.__setattr__(self, attr, getattr(refreshed_task, attr))
359+
360+
361+
@dataclass(frozen=True)
362+
class TaskContext:
363+
task_result: TaskResult
364+
365+
@property
366+
def attempt(self) -> int:
367+
return self.task_result.attempts

tests/tasks.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import time
22
from typing import Any
33

4-
from django_tasks import task
4+
from django_tasks import TaskContext, task
55

66

77
@task()
@@ -70,3 +70,14 @@ def hang() -> None:
7070
@task()
7171
def sleep_for(seconds: float) -> None:
7272
time.sleep(seconds)
73+
74+
75+
@task(takes_context=True)
76+
def get_task_id(context: TaskContext) -> str:
77+
return context.task_result.id
78+
79+
80+
@task(takes_context=True)
81+
def test_context(context: TaskContext, attempt: int) -> None:
82+
assert isinstance(context, TaskContext)
83+
assert context.attempt == attempt

tests/tests/test_database_backend.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,23 @@ def test_max_tasks(self) -> None:
862862
self.assertEqual(statuses[ResultStatus.SUCCEEDED], 2)
863863
self.assertEqual(statuses[ResultStatus.READY], 3)
864864

865+
def test_takes_context(self) -> None:
866+
result = test_tasks.get_task_id.enqueue()
867+
868+
self.run_worker()
869+
870+
result.refresh()
871+
872+
self.assertEqual(result.return_value, result.id)
873+
874+
def test_context(self) -> None:
875+
result = test_tasks.test_context.enqueue(1)
876+
877+
self.run_worker()
878+
result.refresh()
879+
880+
self.assertEqual(result.status, ResultStatus.SUCCEEDED)
881+
865882

866883
@override_settings(
867884
TASKS={

tests/tests/test_dummy_backend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,10 @@ def test_enqueue_on_commit_with_no_databases(self) -> None:
195195
self.assertEqual(len(errors), 1)
196196
self.assertIn("Set `ENQUEUE_ON_COMMIT` to False", errors[0].hint) # type:ignore[arg-type]
197197

198+
def test_takes_context(self) -> None:
199+
result = test_tasks.get_task_id.enqueue()
200+
self.assertEqual(result.status, ResultStatus.READY)
201+
198202

199203
class DummyBackendTransactionTestCase(TransactionTestCase):
200204
@override_settings(

0 commit comments

Comments
 (0)