Skip to content

Commit 9f201ae

Browse files
committed
Add configurable transaction management
1 parent a24e53d commit 9f201ae

File tree

13 files changed

+361
-62
lines changed

13 files changed

+361
-62
lines changed

README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,23 @@ The returned `TaskResult` can be interrogated to query the current state of the
7979

8080
If the task takes arguments, these can be passed as-is to `enqueue`.
8181

82+
#### Transactions
83+
84+
By default, it's up to the backend to determine whether the task should be enqueued immediately (after calling `.enqueue`) or wait until the end of the current database transaction (if there is one).
85+
86+
This can be configured using the `ENQUEUE_ON_COMMIT` setting. `True` and `False` force the behaviour, and `None` is used to let the backend decide.
87+
88+
```python
89+
TASKS = {
90+
"default": {
91+
"BACKEND": "django_tasks.backends.immediate.ImmediateBackend",
92+
"ENQUEUE_ON_COMMIT": False
93+
}
94+
}
95+
```
96+
97+
All built-in backends default to waiting for the end of the transaction (`"ENQUEUE_ON_COMMIT": True`).
98+
8299
### Executing tasks with the database backend
83100

84101
First, you'll need to add `django_tasks.backends.database` to `INSTALLED_APPS`, and run `manage.py migrate`.

django_tasks/backends/base.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from abc import ABCMeta, abstractmethod
22
from inspect import iscoroutinefunction
3-
from typing import Any, List, TypeVar
3+
from typing import Any, Iterable, Optional, TypeVar
44

55
from asgiref.sync import sync_to_async
6-
from django.core.checks.messages import CheckMessage
6+
from django.core.checks import messages
77
from django.utils import timezone
88
from typing_extensions import ParamSpec
99

@@ -16,6 +16,9 @@
1616

1717

1818
class BaseTaskBackend(metaclass=ABCMeta):
19+
alias: str
20+
enqueue_on_commit: Optional[bool]
21+
1922
task_class = Task
2023

2124
supports_defer = False
@@ -29,6 +32,18 @@ class BaseTaskBackend(metaclass=ABCMeta):
2932

3033
def __init__(self, options: dict) -> None:
3134
self.alias = options["ALIAS"]
35+
self.enqueue_on_commit = options.get("ENQUEUE_ON_COMMIT", None)
36+
37+
def _get_enqueue_on_commit_for_task(self, task: Task) -> Optional[bool]:
38+
"""
39+
Determine the correct `enqueue_on_commit` setting to use for a given task.
40+
41+
If the task defines it, use that, otherwise, fall back to the backend.
42+
"""
43+
if isinstance(task.enqueue_on_commit, bool):
44+
return task.enqueue_on_commit
45+
46+
return self.enqueue_on_commit
3247

3348
def validate_task(self, task: Task) -> None:
3449
"""
@@ -94,8 +109,8 @@ def close(self) -> None:
94109
# HACK: `close` isn't abstract, but should do nothing by default
95110
return None
96111

97-
def check(self, **kwargs: Any) -> List[CheckMessage]:
98-
raise NotImplementedError(
99-
"subclasses may provide a check() method to verify that task "
100-
"backend is configured correctly."
101-
)
112+
def check(self, **kwargs: Any) -> Iterable[messages.CheckMessage]:
113+
if self.enqueue_on_commit not in {True, False, None}:
114+
yield messages.CheckMessage(
115+
messages.ERROR, "`ENQUEUE_ON_COMMIT` must be a bool or None"
116+
)

django_tasks/backends/database/backend.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from dataclasses import asdict, dataclass
2-
from typing import TYPE_CHECKING, Any, List, TypeVar
2+
from typing import TYPE_CHECKING, Any, Iterable, TypeVar
33

44
from django.apps import apps
5-
from django.core.checks import ERROR, CheckMessage
5+
from django.core.checks import messages
66
from django.core.exceptions import ValidationError
77
from typing_extensions import ParamSpec
88

99
from django_tasks.backends.base import BaseTaskBackend
10-
from django_tasks.exceptions import ResultDoesNotExist
10+
from django_tasks.exceptions import InvalidTaskError, ResultDoesNotExist
1111
from django_tasks.task import Task
1212
from django_tasks.task import TaskResult as BaseTaskResult
1313
from django_tasks.utils import json_normalize
@@ -39,6 +39,14 @@ class DatabaseBackend(BaseTaskBackend):
3939
supports_get_result = True
4040
supports_defer = True
4141

42+
def validate_task(self, task: Task[P, T]) -> None:
43+
super().validate_task(task)
44+
45+
if task.enqueue_on_commit is False:
46+
raise InvalidTaskError(
47+
"enqueue_on_commit must be True or None when using database backend"
48+
)
49+
4250
def _task_to_db_task(
4351
self, task: Task[P, T], args: P.args, kwargs: P.kwargs
4452
) -> "DBTaskResult":
@@ -91,16 +99,21 @@ async def aget_result(self, result_id: str) -> TaskResult:
9199
except (DBTaskResult.DoesNotExist, ValidationError) as e:
92100
raise ResultDoesNotExist(result_id) from e
93101

94-
def check(self, **kwargs: Any) -> List[CheckMessage]:
95-
if not apps.is_installed("django_tasks.backends.database"):
96-
backend_name = self.__class__.__name__
102+
def check(self, **kwargs: Any) -> Iterable[messages.CheckMessage]:
103+
yield from super().check(**kwargs)
97104

98-
return [
99-
CheckMessage(
100-
ERROR,
101-
f"{backend_name} configured as django_tasks backend, but database app not installed",
102-
"Insert 'django_tasks.backends.database' in INSTALLED_APPS",
103-
)
104-
]
105+
backend_name = self.__class__.__name__
105106

106-
return []
107+
if not apps.is_installed("django_tasks.backends.database"):
108+
yield messages.CheckMessage(
109+
messages.ERROR,
110+
f"{backend_name} configured as django_tasks backend, but database app not installed",
111+
"Insert 'django_tasks.backends.database' in INSTALLED_APPS",
112+
)
113+
114+
if self.enqueue_on_commit is False:
115+
yield messages.CheckMessage(
116+
messages.WARNING,
117+
f"{backend_name} must enqueue tasks at the end of a transaction",
118+
"Ensure ENQUEUE_ON_COMMIT is True or None",
119+
)

django_tasks/backends/dummy.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from copy import deepcopy
2+
from functools import partial
23
from typing import List, TypeVar
34
from uuid import uuid4
45

6+
from django.db import transaction
57
from django.utils import timezone
68
from typing_extensions import ParamSpec
79

@@ -41,8 +43,11 @@ def enqueue(
4143
backend=self.alias,
4244
)
4345

44-
# Copy the task to prevent mutation issues
45-
self.results.append(deepcopy(result))
46+
if self._get_enqueue_on_commit_for_task(task) is not False:
47+
# Copy the task to prevent mutation issues
48+
transaction.on_commit(partial(self.results.append, deepcopy(result)))
49+
else:
50+
self.results.append(deepcopy(result))
4651

4752
return result
4853

django_tasks/backends/immediate.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from functools import partial
12
from inspect import iscoroutinefunction
23
from typing import TypeVar
34
from uuid import uuid4
45

56
from asgiref.sync import async_to_sync
7+
from django.db import transaction
68
from django.utils import timezone
79
from typing_extensions import ParamSpec
810

@@ -18,34 +20,46 @@
1820
class ImmediateBackend(BaseTaskBackend):
1921
supports_async_task = True
2022

21-
def enqueue(
22-
self, task: Task[P, T], args: P.args, kwargs: P.kwargs
23-
) -> TaskResult[T]:
24-
self.validate_task(task)
25-
23+
def _execute_task(self, task_result: TaskResult) -> None:
24+
"""
25+
Execute the task for the given `TaskResult`, mutating it with the outcome
26+
"""
2627
calling_task_func = (
27-
async_to_sync(task.func) if iscoroutinefunction(task.func) else task.func
28+
async_to_sync(task_result.task.func)
29+
if iscoroutinefunction(task_result.task.func)
30+
else task_result.task.func
2831
)
2932

30-
enqueued_at = timezone.now()
3133
try:
32-
result = json_normalize(calling_task_func(*args, **kwargs))
33-
status = ResultStatus.COMPLETE
34+
task_result._result = json_normalize(
35+
calling_task_func(*task_result.args, **task_result.kwargs)
36+
)
37+
task_result.status = ResultStatus.COMPLETE
3438
except Exception:
35-
result = None
36-
status = ResultStatus.FAILED
39+
task_result._result = None
40+
task_result.status = ResultStatus.FAILED
41+
42+
task_result.finished_at = timezone.now()
43+
44+
def enqueue(
45+
self, task: Task[P, T], args: P.args, kwargs: P.kwargs
46+
) -> TaskResult[T]:
47+
self.validate_task(task)
3748

3849
task_result = TaskResult[T](
3950
task=task,
4051
id=str(uuid4()),
41-
status=status,
42-
enqueued_at=enqueued_at,
43-
finished_at=timezone.now(),
52+
status=ResultStatus.NEW,
53+
enqueued_at=timezone.now(),
54+
finished_at=None,
4455
args=json_normalize(args),
4556
kwargs=json_normalize(kwargs),
4657
backend=self.alias,
4758
)
4859

49-
task_result._result = result
60+
if self._get_enqueue_on_commit_for_task(task) is not False:
61+
transaction.on_commit(partial(self._execute_task, task_result))
62+
else:
63+
self._execute_task(task_result)
5064

5165
return task_result

django_tasks/checks.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, List, Sequence
1+
from typing import Any, Iterable, Sequence
22

33
from django.apps.config import AppConfig
44
from django.core.checks.messages import CheckMessage
@@ -8,16 +8,11 @@
88

99
def check_tasks(
1010
app_configs: Sequence[AppConfig] = None, **kwargs: Any
11-
) -> List[CheckMessage]:
11+
) -> Iterable[CheckMessage]:
1212
"""Checks all registered task backends."""
1313

14-
errors = []
1514
for backend in tasks.all():
1615
try:
17-
backend_errors = backend.check()
16+
yield from backend.check()
1817
except NotImplementedError:
1918
pass
20-
else:
21-
errors.extend(backend_errors)
22-
23-
return errors

django_tasks/task.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ class Task(Generic[P, T]):
5656
run_after: Optional[datetime] = None
5757
"""The earliest this task will run"""
5858

59+
enqueue_on_commit: Optional[bool] = None
60+
5961
def __post_init__(self) -> None:
6062
self.get_backend().validate_task(self)
6163

@@ -164,6 +166,7 @@ def task(
164166
priority: int = 0,
165167
queue_name: str = DEFAULT_QUEUE_NAME,
166168
backend: str = DEFAULT_TASK_BACKEND_ALIAS,
169+
enqueue_on_commit: Optional[bool] = None,
167170
) -> Callable[[Callable[P, T]], Task[P, T]]: ...
168171

169172

@@ -174,6 +177,7 @@ def task(
174177
priority: int = 0,
175178
queue_name: str = DEFAULT_QUEUE_NAME,
176179
backend: str = DEFAULT_TASK_BACKEND_ALIAS,
180+
enqueue_on_commit: Optional[bool] = None,
177181
) -> Union[Task[P, T], Callable[[Callable[P, T]], Task[P, T]]]:
178182
"""
179183
A decorator used to create a task.
@@ -182,7 +186,11 @@ def task(
182186

183187
def wrapper(f: Callable[P, T]) -> Task[P, T]:
184188
return tasks[backend].task_class(
185-
priority=priority, func=f, queue_name=queue_name, backend=backend
189+
priority=priority,
190+
func=f,
191+
queue_name=queue_name,
192+
backend=backend,
193+
enqueue_on_commit=enqueue_on_commit,
186194
)
187195

188196
if function:

tests/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,6 @@
5858

5959
USE_TZ = True
6060

61-
if sys.argv[0] != "test":
61+
if sys.argv[1] == "runserver":
6262
DEBUG = True
6363
TASKS = {"default": {"BACKEND": "django_tasks.backends.database.DatabaseBackend"}}

tests/tasks.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,13 @@ def failing_task() -> None:
2929
@task()
3030
def exit_task() -> None:
3131
exit(1)
32+
33+
34+
@task(enqueue_on_commit=True)
35+
def enqueue_on_commit_task() -> None:
36+
pass
37+
38+
39+
@task(enqueue_on_commit=False)
40+
def never_enqueue_on_commit_task() -> None:
41+
pass

0 commit comments

Comments
 (0)