Skip to content

Commit

Permalink
refactor: use model diff to get the previous state of a Routine
Browse files Browse the repository at this point in the history
This change decreases ~2x/4x hits on database after updating/creating a Routine
  • Loading branch information
Hudson Medeiros authored and lucasgomide committed Jun 7, 2023
1 parent 22aabc3 commit 057dcfe
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 21 deletions.
4 changes: 2 additions & 2 deletions django_cloud_tasks/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from django.db import models, transaction
from django.utils import timezone

from drf_kit.models import ModelDiffMixin
from django_cloud_tasks import serializers
from django_cloud_tasks.field import TaskField

Expand Down Expand Up @@ -30,7 +30,7 @@ def add_routine(self, routine: dict) -> "Routine":
return self.routines.create(**routine)


class Routine(models.Model):
class Routine(models.Model, ModelDiffMixin):
class Statuses(models.TextChoices):
PENDING = ("pending", "Pending")
SCHEDULED = ("scheduled", "Scheduled")
Expand Down
15 changes: 6 additions & 9 deletions django_cloud_tasks/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@
from django_cloud_tasks import models


def _is_status_changing(instance: models.Routine) -> bool:
if not instance.pk:
return False

current_routine = models.Routine.objects.get(pk=instance.pk)
return current_routine.status != instance.status
def _is_status_changing(instance: Model) -> bool:
previous_status, current_status = instance._diff.get("status", (None, None))
return previous_status != current_status


def enqueue_next_routines(instance: models.Routine):
Expand Down Expand Up @@ -60,7 +57,7 @@ def ensure_status_machine(sender, instance: models.Routine, **kwargs):
if not _is_status_changing(instance=instance):
return

current_routine = models.Routine.objects.get(pk=instance.pk)
previous_status, _ = instance._diff.get("status", (None, None))

statuses = models.Routine.Statuses
machine_statuses = {
Expand All @@ -74,5 +71,5 @@ def ensure_status_machine(sender, instance: models.Routine, **kwargs):
}
available_statuses = machine_statuses[instance.status]

if current_routine.status not in available_statuses:
raise ValidationError(f"Status update from '{current_routine.status}' to '{instance.status}' is not allowed")
if previous_status not in available_statuses:
raise ValidationError(f"Status update from '{previous_status}' to '{instance.status}' is not allowed")
8 changes: 8 additions & 0 deletions django_cloud_tasks/tasks/routine_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def process_routine(self, routine: models.Routine):
routine.status = models.Routine.Statuses.RUNNING
routine.save(update_fields=("attempt_count", "status", "updated_at"))

# we are adding this to re-instantiate this object due to
# a bug that are happening with _diff field from ModelDiffMixin.
# the complete method called below is triggering the ensure_status_machine
# with wrong previous_status. when we call complete(), we had previous status
# scheduled, but we just changed it to running. this was raising an error:
# changing from scheduled to complete is not allowed.
routine = models.Routine(**routine._dict)

try:
logger.info(f"Routine #{routine.pk} is running")
task_response = routine.task_class(metadata=self._metadata).sync(**routine.body)
Expand Down
20 changes: 10 additions & 10 deletions sample_project/sample_app/tests/tests_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class RoutineModelTest(TestCase):
def test_fail(self):
routine = factories.RoutineWithoutSignalFactory(status="running", output=None, ends_at=None)
error = {"error": "something went wrong"}
with self.assertNumQueries(4):
with self.assertNumQueries(1):
routine.fail(output=error)
routine.refresh_from_db()
self.assertEqual("failed", routine.status)
Expand All @@ -26,7 +26,7 @@ def test_fail(self):
def test_complete(self):
routine = factories.RoutineWithoutSignalFactory(status="running", output=None, ends_at=None)
output = {"id": 42}
with self.assertNumQueries(5):
with self.assertNumQueries(2):
routine.complete(output=output)
routine.refresh_from_db()
self.assertEqual("completed", routine.status)
Expand All @@ -38,7 +38,7 @@ def test_complete(self):
def test_enqueue(self):
routine = factories.RoutineFactory()
with patch("django_cloud_tasks.tasks.RoutineExecutorTask.asap") as task:
with self.assertNumQueries(6):
with self.assertNumQueries(3):
routine.enqueue()
routine.refresh_from_db()
self.assertEqual("scheduled", routine.status)
Expand All @@ -48,7 +48,7 @@ def test_enqueue(self):
def test_revert_completed_routine(self):
routine = factories.RoutineWithoutSignalFactory(status="completed", output="{'id': 42}")
with patch("django_cloud_tasks.tasks.RoutineReverterTask.asap") as revert_task:
with self.assertNumQueries(6):
with self.assertNumQueries(3):
routine.revert()
routine.refresh_from_db()
self.assertEqual("reverting", routine.status)
Expand All @@ -73,7 +73,7 @@ def test_enqueue_next_routines_after_completed(self):
factories.RoutineVertexFactory(routine=first_routine, next_routine=third_routine)

with patch("django_cloud_tasks.tasks.RoutineExecutorTask.asap") as task:
with self.assertNumQueries(17):
with self.assertNumQueries(8):
first_routine.status = "completed"
first_routine.save()
calls = [call(routine_id=second_routine.pk), call(routine_id=third_routine.pk)]
Expand All @@ -92,7 +92,7 @@ def test_dont_enqueue_next_routines_after_completed_when_status_dont_change(self
factories.RoutineVertexFactory(routine=first_routine, next_routine=third_routine)

with patch("django_cloud_tasks.tasks.RoutineExecutorTask.asap") as task:
with self.assertNumQueries(3):
with self.assertNumQueries(1):
first_routine.status = "completed"
first_routine.save()
task.assert_not_called()
Expand All @@ -110,7 +110,7 @@ def test_enqueue_previously_routines_after_reverted(self):
factories.RoutineVertexFactory(routine=first_routine, next_routine=third_routine)

with patch("django_cloud_tasks.tasks.RoutineReverterTask.asap") as task:
with self.assertNumQueries(11):
with self.assertNumQueries(5):
third_routine.status = "reverted"
third_routine.save()

Expand All @@ -129,7 +129,7 @@ def test_dont_enqueue_previously_routines_after_reverted_completed_when_status_d
factories.RoutineVertexFactory(routine=first_routine, next_routine=third_routine)

with patch("django_cloud_tasks.tasks.RoutineReverterTask.asap") as task:
with self.assertNumQueries(3):
with self.assertNumQueries(1):
third_routine.status = "reverted"
third_routine.save()

Expand Down Expand Up @@ -174,7 +174,7 @@ def test_start_pipeline(self):
factories.RoutineVertexFactory(routine=first_routine, next_routine=second_routine)

with patch("django_cloud_tasks.tasks.RoutineExecutorTask.asap") as task:
with self.assertNumQueries(13):
with self.assertNumQueries(7):
pipeline.start()
calls = [call(routine_id=first_routine.pk), call(routine_id=another_first_routine.pk)]
task.assert_has_calls(calls, any_order=True)
Expand Down Expand Up @@ -206,7 +206,7 @@ def test_revert_pipeline(self):
factories.RoutineVertexFactory(routine=first_routine, next_routine=second_routine)

with patch("django_cloud_tasks.tasks.RoutineReverterTask.asap") as task:
with self.assertNumQueries(13):
with self.assertNumQueries(7):
pipeline.revert()
calls = [
call(routine_id=fourth_routine.pk),
Expand Down

0 comments on commit 057dcfe

Please sign in to comment.