diff --git a/api/task_processor/models.py b/api/task_processor/models.py index 610149ca6391..e2f6ccb10ccd 100644 --- a/api/task_processor/models.py +++ b/api/task_processor/models.py @@ -44,9 +44,12 @@ def deserialize_data(data: typing.Any): return json.loads(data) def mark_failure(self): - self.is_locked = False + self.unlock() def mark_success(self): + self.unlock() + + def unlock(self): self.is_locked = False def run(self): @@ -119,7 +122,7 @@ def mark_failure(self): self.num_failures += 1 def mark_success(self): - super().mark_failure() + super().mark_success() self.completed = True diff --git a/api/task_processor/processor.py b/api/task_processor/processor.py index 7896128d83ff..a5f6e73afd7b 100644 --- a/api/task_processor/processor.py +++ b/api/task_processor/processor.py @@ -45,7 +45,7 @@ def run_tasks(num_tasks: int = 1) -> typing.List[TaskRun]: return [] -def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTask]: +def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTaskRun]: if num_tasks < 1: raise ValueError("Number of tasks to process must be at least one") @@ -55,7 +55,6 @@ def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTask]: tasks = RecurringTask.objects.get_tasks_to_process(num_tasks) if tasks: task_runs = [] - executed_tasks = [] for task in tasks: # Remove the task if it's not registered anymore @@ -65,11 +64,13 @@ def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTask]: if task.should_execute: task, task_run = _run_task(task) - executed_tasks.append(task) task_runs.append(task_run) + else: + task.unlock() - if executed_tasks: - RecurringTask.objects.bulk_update(executed_tasks, fields=["is_locked"]) + # update all tasks that were not deleted + to_update = [task for task in tasks if task.id] + RecurringTask.objects.bulk_update(to_update, fields=["is_locked"]) if task_runs: RecurringTaskRun.objects.bulk_create(task_runs) @@ -80,7 +81,7 @@ def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTask]: return [] -def _run_task(task: Task) -> typing.Optional[typing.Tuple[Task, TaskRun]]: +def _run_task(task: typing.Union[Task, RecurringTask]) -> typing.Tuple[Task, TaskRun]: task_run = task.task_runs.model(started_at=timezone.now(), task=task) try: diff --git a/api/tests/unit/task_processor/test_unit_task_processor_processor.py b/api/tests/unit/task_processor/test_unit_task_processor_processor.py index bb65ec64e5d1..154daedc4bbc 100644 --- a/api/tests/unit/task_processor/test_unit_task_processor_processor.py +++ b/api/tests/unit/task_processor/test_unit_task_processor_processor.py @@ -4,6 +4,7 @@ from threading import Thread import pytest +from django.utils import timezone from organisations.models import Organisation from task_processor.decorators import ( @@ -80,7 +81,7 @@ def test_run_recurring_tasks_multiple_runs(db, run_by_processor): task_identifier = "test_unit_task_processor_processor._create_organisation" @register_recurring_task( - run_every=timedelta(milliseconds=100), args=(organisation_name,) + run_every=timedelta(milliseconds=200), args=(organisation_name,) ) def _create_organisation(organisation_name): Organisation.objects.create(name=organisation_name) @@ -89,17 +90,28 @@ def _create_organisation(organisation_name): # When first_task_runs = run_recurring_tasks() - time.sleep(0.2) - second_task_runs = run_recurring_tasks() + # run the process again before the task is scheduled to run again to ensure + # that tasks are unlocked when they are picked up by the task processor but + # not executed. + no_task_runs = run_recurring_tasks() - task_runs = first_task_runs + second_task_runs + time.sleep(0.3) + + second_task_runs = run_recurring_tasks() # Then + assert len(first_task_runs) == 1 + assert len(no_task_runs) == 0 + assert len(second_task_runs) == 1 + + # we should still only have 2 organisations, despite executing the + # `run_recurring_tasks` function 3 times. assert Organisation.objects.filter(name=organisation_name).count() == 2 - assert len(task_runs) == RecurringTaskRun.objects.filter(task=task).count() == 2 - for task_run in task_runs: + all_task_runs = first_task_runs + second_task_runs + assert len(all_task_runs) == RecurringTaskRun.objects.filter(task=task).count() == 2 + for task_run in all_task_runs: assert task_run.result == TaskResult.SUCCESS assert task_run.started_at assert task_run.finished_at @@ -322,6 +334,37 @@ def test_run_more_than_one_task(db): assert task.completed +def test_recurring_tasks_are_unlocked_if_picked_up_but_not_executed( + db, run_by_processor +): + # Given + @register_recurring_task(run_every=timedelta(days=1)) + def my_task(): + pass + + recurring_task = RecurringTask.objects.get( + task_identifier="test_unit_task_processor_processor.my_task" + ) + + # mimic the task having already been run so that it is next picked up, + # but not executed + now = timezone.now() + one_minute_ago = now - timedelta(minutes=1) + RecurringTaskRun.objects.create( + task=recurring_task, + started_at=one_minute_ago, + finished_at=now, + result=TaskResult.SUCCESS.name, + ) + + # When + run_recurring_tasks() + + # Then + recurring_task.refresh_from_db() + assert recurring_task.is_locked is False + + @register_task_handler() def _create_organisation(name: str): """function used to test that task is being run successfully"""