Skip to content

Commit 12f3942

Browse files
Track which worker executed a task (#136)
1 parent 98493f0 commit 12f3942

File tree

10 files changed

+155
-23
lines changed

10 files changed

+155
-23
lines changed

django_tasks/backends/database/management/commands/db_worker.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from django_tasks.exceptions import InvalidTaskBackendError
2424
from django_tasks.signals import task_finished, task_started
2525
from django_tasks.task import DEFAULT_QUEUE_NAME
26+
from django_tasks.utils import get_random_id
2627

2728
package_logger = logging.getLogger("django_tasks")
2829
logger = logging.getLogger("django_tasks.backends.database.db_worker")
@@ -38,6 +39,7 @@ def __init__(
3839
backend_name: str,
3940
startup_delay: bool,
4041
max_tasks: Optional[int],
42+
worker_id: str,
4143
):
4244
self.queue_names = queue_names
4345
self.process_all_queues = "*" in queue_names
@@ -51,6 +53,8 @@ def __init__(
5153
self.running_task = False
5254
self._run_tasks = 0
5355

56+
self.worker_id = worker_id
57+
5458
def shutdown(self, signum: int, frame: Optional[FrameType]) -> None:
5559
if not self.running:
5660
logger.warning(
@@ -82,11 +86,17 @@ def reset_signals(self) -> None:
8286
if hasattr(signal, "SIGQUIT"):
8387
signal.signal(signal.SIGQUIT, signal.SIG_DFL)
8488

85-
def start(self) -> None:
86-
logger.info("Starting worker for queues=%s", ",".join(self.queue_names))
89+
def run(self) -> None:
90+
self.configure_signals()
91+
92+
logger.info(
93+
"Starting worker worker_id=%s queues=%s",
94+
self.worker_id,
95+
",".join(self.queue_names),
96+
)
8797

8898
if self.startup_delay and self.interval:
89-
# Add a random small delay before starting the loop to avoid a thundering herd
99+
# Add a random small delay before starting to avoid a thundering herd
90100
time.sleep(random.random())
91101

92102
while self.running:
@@ -109,19 +119,24 @@ def start(self) -> None:
109119

110120
if task_result is not None:
111121
# "claim" the task, so it isn't run by another worker process
112-
task_result.claim()
122+
task_result.claim(self.worker_id)
113123

114124
if task_result is not None:
115125
self.run_task(task_result)
116126

117127
if self.batch and task_result is None:
118128
# If we're running in "batch" mode, terminate the loop (and thus the worker)
119-
logger.info("No more tasks to run - exiting gracefully.")
129+
logger.info(
130+
"No more tasks to run for worker_id=%s - exiting gracefully.",
131+
self.worker_id,
132+
)
120133
return None
121134

122135
if self.max_tasks is not None and self._run_tasks >= self.max_tasks:
123136
logger.info(
124-
"Run maximum tasks (%d) - exiting gracefully.", self._run_tasks
137+
"Run maximum tasks (%d) on worker=%s - exiting gracefully.",
138+
self._run_tasks,
139+
self.worker_id,
125140
)
126141
return None
127142

@@ -199,6 +214,14 @@ def valid_max_tasks(val: str) -> int:
199214
return num
200215

201216

217+
def validate_worker_id(val: str) -> str:
218+
if not val:
219+
raise ArgumentTypeError("Worker id must not be empty")
220+
if len(val) > 64:
221+
raise ArgumentTypeError("Worker ids must be shorter than 64 characters")
222+
return val
223+
224+
202225
class Command(BaseCommand):
203226
help = "Run a database background worker"
204227

@@ -249,6 +272,13 @@ def add_arguments(self, parser: ArgumentParser) -> None:
249272
type=valid_max_tasks,
250273
help="If provided, the maximum number of tasks the worker will execute before exiting.",
251274
)
275+
parser.add_argument(
276+
"--worker-id",
277+
nargs="?",
278+
type=validate_worker_id,
279+
help="Worker id. MUST be unique across worker pool (default: auto-generate)",
280+
default=get_random_id(),
281+
)
252282

253283
def configure_logging(self, verbosity: int) -> None:
254284
if verbosity == 0:
@@ -274,6 +304,7 @@ def handle(
274304
startup_delay: bool,
275305
reload: bool,
276306
max_tasks: Optional[int],
307+
worker_id: str,
277308
**options: dict,
278309
) -> None:
279310
self.configure_logging(verbosity)
@@ -291,14 +322,15 @@ def handle(
291322
backend_name=backend_name,
292323
startup_delay=startup_delay,
293324
max_tasks=max_tasks,
325+
worker_id=worker_id,
294326
)
295327

296328
if reload:
297329
if os.environ.get(DJANGO_AUTORELOAD_ENV) == "true":
298330
# Only the child process should configure its signals
299331
worker.configure_signals()
300332

301-
run_with_reloader(worker.start)
333+
run_with_reloader(worker.run)
302334
else:
303335
worker.configure_signals()
304-
worker.start()
336+
worker.run()
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Generated by Django 5.1.6 on 2025-02-16 15:55
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
dependencies = [
8+
(
9+
"django_tasks_database",
10+
"0017_remove_dbtaskresult_django_task_new_ordering_idx_and_more",
11+
),
12+
]
13+
14+
operations = [
15+
migrations.AddField(
16+
model_name="dbtaskresult",
17+
name="worker_id",
18+
field=models.CharField(default="", max_length=64, verbose_name="worker id"),
19+
),
20+
]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Generated by Django 4.2 on 2025-06-20 12:03
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
dependencies = [
8+
("django_tasks_database", "0018_dbtaskresult_worker_id"),
9+
]
10+
11+
operations = [
12+
migrations.RemoveField(
13+
model_name="dbtaskresult",
14+
name="worker_id",
15+
),
16+
migrations.AddField(
17+
model_name="dbtaskresult",
18+
name="worker_ids",
19+
field=models.JSONField(default=list, verbose_name="worker id"),
20+
),
21+
]

django_tasks/backends/database/models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class DBTaskResult(GenericBase[P, T], models.Model):
9999
priority = models.IntegerField(_("priority"), default=DEFAULT_PRIORITY)
100100

101101
task_path = models.TextField(_("task path"))
102+
worker_ids = models.JSONField(_("worker id"), default=list)
102103

103104
queue_name = models.CharField(
104105
_("queue name"), default=DEFAULT_QUEUE_NAME, max_length=32
@@ -177,6 +178,7 @@ def task_result(self) -> "TaskResult[T]":
177178
kwargs=self.args_kwargs["kwargs"],
178179
backend=self.backend_name,
179180
errors=[],
181+
worker_ids=self.worker_ids,
180182
)
181183

182184
if self.status == ResultStatus.FAILED:
@@ -206,13 +208,14 @@ def task_name(self) -> str:
206208
return self.task_path
207209

208210
@retry(backoff_delay=0)
209-
def claim(self) -> None:
211+
def claim(self, worker_id: str) -> None:
210212
"""
211213
Mark as job as being run
212214
"""
213215
self.status = ResultStatus.RUNNING
214216
self.started_at = timezone.now()
215-
self.save(update_fields=["status", "started_at"])
217+
self.worker_ids = [*self.worker_ids, worker_id]
218+
self.save(update_fields=["status", "started_at", "worker_ids"])
216219

217220
@retry()
218221
def set_succeeded(self, return_value: Any) -> None:

django_tasks/backends/dummy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def enqueue(
5252
kwargs=kwargs,
5353
backend=self.alias,
5454
errors=[],
55+
worker_ids=[],
5556
)
5657

5758
if self._get_enqueue_on_commit_for_task(task) is not False:

django_tasks/backends/immediate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
class ImmediateBackend(BaseTaskBackend):
3030
supports_async_task = True
3131

32+
def __init__(self, alias: str, params: dict):
33+
super().__init__(alias, params)
34+
35+
self.worker_id = get_random_id()
36+
3237
def _execute_task(self, task_result: TaskResult) -> None:
3338
"""
3439
Execute the task for the given `TaskResult`, mutating it with the outcome
@@ -45,6 +50,7 @@ def _execute_task(self, task_result: TaskResult) -> None:
4550
object.__setattr__(task_result, "status", ResultStatus.RUNNING)
4651
object.__setattr__(task_result, "started_at", timezone.now())
4752
object.__setattr__(task_result, "last_attempted_at", timezone.now())
53+
task_result.worker_ids.append(self.worker_id)
4854
task_started.send(sender=type(self), task_result=task_result)
4955

5056
try:
@@ -98,6 +104,7 @@ def enqueue(
98104
kwargs=kwargs,
99105
backend=self.alias,
100106
errors=[],
107+
worker_ids=[],
101108
)
102109

103110
if self._get_enqueue_on_commit_for_task(task) is not False:

django_tasks/backends/rq.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ class Job(BaseJob):
5353
def perform(self) -> Any:
5454
task_result = self.into_task_result()
5555

56+
assert self.worker_name is not None
57+
self.meta.setdefault("_django_tasks_worker_ids", []).append(self.worker_name)
58+
self.save_meta() # type: ignore[no-untyped-call]
59+
5660
task_started.send(type(task_result.task.get_backend()), task_result=task_result)
5761

5862
return super().perform()
@@ -103,6 +107,7 @@ def into_task_result(self) -> TaskResult:
103107
kwargs=self.kwargs,
104108
backend=self.meta["backend_name"],
105109
errors=[],
110+
worker_ids=self.meta.get("_django_tasks_worker_ids", []),
106111
)
107112

108113
exception_classes = self.meta.get("_django_tasks_exceptions", []).copy()
@@ -188,6 +193,7 @@ def enqueue(
188193
kwargs=kwargs,
189194
backend=self.alias,
190195
errors=[],
196+
worker_ids=[],
191197
)
192198

193199
job = queue.create_job(
@@ -209,7 +215,6 @@ def save_result() -> None:
209215
job = queue.schedule_job(job, task.run_after)
210216

211217
object.__setattr__(task_result, "enqueued_at", job.enqueued_at)
212-
213218
task_enqueued.send(type(self), task_result=task_result)
214219

215220
if self._get_enqueue_on_commit_for_task(task):

django_tasks/task.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
"last_attempted_at",
4343
"status",
4444
"enqueued_at",
45+
"worker_ids",
4546
}
4647

4748

@@ -283,6 +284,9 @@ class TaskResult(Generic[T]):
283284
errors: list[TaskError]
284285
"""The errors raised when running the task"""
285286

287+
worker_ids: list[str]
288+
"""The workers which have processed the task"""
289+
286290
_return_value: Optional[T] = field(init=False, default=None)
287291

288292
@property
@@ -307,12 +311,7 @@ def is_finished(self) -> bool:
307311

308312
@property
309313
def attempts(self) -> int:
310-
attempts = len(self.errors)
311-
312-
if self.status == ResultStatus.SUCCEEDED:
313-
attempts += 1
314-
315-
return attempts
314+
return len(self.worker_ids)
316315

317316
def refresh(self) -> None:
318317
"""

django_tasks/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def get_exception_traceback(exc: BaseException) -> str:
6666

6767
def get_random_id() -> str:
6868
"""
69-
Return a random string for use as a task id.
69+
Return a random string for use as a task or worker id.
7070
7171
Whilst 64 characters is the max, just use 32 as a sensible middle-ground.
7272

0 commit comments

Comments
 (0)