Skip to content

Commit

Permalink
feat: 2-step publication of models to Pub/Sub
Browse files Browse the repository at this point in the history
**Context**

ModelPublisherTask takes a django instance as one of its inputs. However, django models are mutable, so the task payload must be constructed immediately, even if we want to defer publishing. This is actually already the case: the sync() method does this (even though it is actually not affected by this) and the asap() method does this too.

**Issues**

The payload construction code is not isolated (sync() and asap() duplicate code as of now) and not uniform (asap() and push() have different APIs in a ModelPublisherTask, which violates a principle that applies to most, if not all, other Tasks). Also, in case we want to defer the sync() method to the end of the current django transaction, we have no way of doing so as there is no way to capture the payload and later send it using `transaction.on_commit`

**Solution**

- We unified the `push` and `asap` method signatures in ModelPublisherTask. That means you can now do `PublisherTask.asap(obj=model)` and also `PublisherTask.push(task_kwargs={'obj': model}, queue='queue-override'}`, for example. This task no longer differs from regular tasks on this regard.
- Two-step publication (which is required for Pub/Sub to work inside transactions, by deferring actual calls until the transaction commits) is now possible using the `ModelPublisherTask.prepare()` API. It captures the payload that needs to be sent at the call site, and the process can be finalized by calling either `asap()`, `sync()` or `push()` on the result of this API call. For example:

```python
def post_save(sender, instance):
   publication = ModelPublisherTask.prepare(obj=instance, event="created")

   transaction.on_commit(lambda: publication.sync())
```
  • Loading branch information
rodrigoalmeidaee authored and joaodaher committed Apr 22, 2024
1 parent 1798fed commit 5b45488
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 19 deletions.
64 changes: 51 additions & 13 deletions django_cloud_tasks/tasks/publisher_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import abc
from dataclasses import dataclass
from typing import Type

from cachetools.func import lru_cache
from django.apps import apps
Expand Down Expand Up @@ -66,27 +68,63 @@ def _app(self) -> DjangoCloudTasksAppConfig:
return apps.get_app_config("django_cloud_tasks")


@dataclass
class PreparedModelPublication:
"""Stores the information needed to publish a model to PubSub.
Because models are mutable objects, in case we don't want to publish the event right away,
we need to store the information needed to publish right away.
"""

task_klass: Type["ModelPublisherTask"]
message: dict
attributes: dict[str, str]
topic_name: str

def get_task_kwargs(self):
return {
"message": self.message,
"attributes": self.attributes,
"topic_name": self.topic_name,
}

def sync(self):
return self.task_klass().run(**self.get_task_kwargs())

def asap(self):
return self.push()

def push(self, **kwargs):
return self.task_klass._push_prepared(prepared=self, **kwargs)


class ModelPublisherTask(PublisherTask, abc.ABC):
# Just a specialized Task that publishes a Django model to PubSub
# Since it cannot accept any random parameters, all its signatures have fixed arguments
@classmethod
def sync(cls, obj: Model, **kwargs):
message = cls.build_message_content(obj=obj, **kwargs)
attributes = cls.build_message_attributes(obj=obj, **kwargs)
topic_name = cls.topic_name(obj=obj, **kwargs)
return cls().run(message=message, attributes=attributes, topic_name=topic_name)
return cls.prepare(obj=obj, **kwargs).sync()

@classmethod
def asap(cls, obj: Model, **kwargs):
message = cls.build_message_content(obj=obj, **kwargs)
attributes = cls.build_message_attributes(obj=obj, **kwargs)
topic_name = cls.topic_name(obj=obj, **kwargs)
task_kwargs = {
"message": message,
"attributes": attributes,
"topic_name": topic_name,
}
return cls.push(task_kwargs=task_kwargs)
return cls.prepare(obj=obj, **kwargs).asap()

@classmethod
def push(cls, task_kwargs: dict, **kwargs):
return cls.prepare(**task_kwargs).push(**kwargs)

@classmethod
def _push_prepared(cls, prepared: PreparedModelPublication, **kwargs):
return super().push(task_kwargs=prepared.get_task_kwargs(), **kwargs)

@classmethod
def prepare(cls, obj: Model, **kwargs):
return PreparedModelPublication(
task_klass=cls,
message=cls.build_message_content(obj=obj, **kwargs),
attributes=cls.build_message_attributes(obj=obj, **kwargs),
topic_name=cls.topic_name(obj=obj, **kwargs),
)

def run(
self, message: dict, topic_name: str, attributes: dict[str, str] | None, headers: dict[str, str] | None = None
Expand Down
104 changes: 98 additions & 6 deletions sample_project/sample_app/tests/tests_tasks/tests_tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from datetime import timedelta, datetime, UTC
from unittest.mock import patch
from unittest.mock import patch, ANY

from django.apps import apps
from django.test import SimpleTestCase, TestCase
Expand All @@ -14,17 +14,14 @@
from django_cloud_tasks.tasks.task import get_config
from django_cloud_tasks.tests import tests_base
from django_cloud_tasks.tests.tests_base import eager_tasks
from sample_app import tasks
from sample_app import models, tasks
from django.http import HttpRequest

from sample_app.tasks import MyMetadata


class TasksTest(SimpleTestCase):
class PatchOutputAndAuthMixin:
def setUp(self):
super().setUp()
Task._get_tasks_client.cache_clear()

patch_output = patch("django_cloud_tasks.tasks.TaskMetadata.from_task_obj")
patch_output.start()
self.addCleanup(patch_output.stop)
Expand All @@ -33,6 +30,12 @@ def setUp(self):
auth.start()
self.addCleanup(auth.stop)


class TasksTest(PatchOutputAndAuthMixin, SimpleTestCase):
def setUp(self):
super().setUp()
Task._get_tasks_client.cache_clear()

def tearDown(self):
super().tearDown()
Task._get_tasks_client.cache_clear()
Expand Down Expand Up @@ -397,3 +400,92 @@ def test_custom_class_unsupported(self):
self.assertRaisesRegex(ImportError, "must be a subclass of TaskMetadata"),
):
get_config("task_metadata_class")


class TestModelPublisherTask(PatchOutputAndAuthMixin, TestCase):
def setUp(self):
super().setUp()
patch_run = patch("django_cloud_tasks.tasks.ModelPublisherTask.run")
self.patched_run = patch_run.start()
self.addCleanup(patch_run.stop)

patch_push = patch("gcp_pilot.tasks.CloudTasks.push")
self.patched_push = patch_push.start()
self.addCleanup(patch_push.stop)

self.person = models.Person(name="Harry Potter", pk=1)
self.expected_task_kwargs = dict(
message={"id": 1, "name": "Harry Potter"},
topic_name="sample_app-person",
attributes={"any-custom-attribute": "yay!", "event": "saved"},
)

def test_sync_forward_correct_parameters(self):
tasks.PublishPersonTask.sync(obj=self.person, event="saved")
self.patched_run.assert_called_once_with(**self.expected_task_kwargs)

def test_asap_forward_correct_parameters(self):
tasks.PublishPersonTask.asap(obj=self.person, event="saved")
self.patched_push.assert_called_once_with(**self._build_expected_push(payload=self.expected_task_kwargs))

def test_push_forward_correct_parameters(self):
tasks.PublishPersonTask.push(
{
"obj": self.person,
"event": "saved",
},
queue="tasks--low",
)
self.patched_push.assert_called_once_with(
**self._build_expected_push(
payload=self.expected_task_kwargs,
queue_name="tasks--low",
)
)

def test_delayed_sync_forward_correct_parameters(self):
prepared = tasks.PublishPersonTask.prepare(obj=self.person, event="saved")
self._mutate_person()

prepared.sync()
self.patched_run.assert_called_once_with(**self.expected_task_kwargs)

def test_delayed_asap_forward_correct_parameters(self):
prepared = tasks.PublishPersonTask.prepare(obj=self.person, event="saved")
self._mutate_person()

prepared.asap()
self.patched_push.assert_called_once_with(**self._build_expected_push(payload=self.expected_task_kwargs))

def test_delayed_push_forward_correct_parameters(self):
prepared = tasks.PublishPersonTask.prepare(obj=self.person, event="saved")
self._mutate_person()

prepared.push(queue="tasks--low")
self.patched_push.assert_called_once_with(
**self._build_expected_push(
payload=self.expected_task_kwargs,
queue_name="tasks--low",
)
)

def _mutate_person(self):
self.person.pk = None
self.person.name = "Mutated Name"

def _build_expected_push(self, payload: dict, **kwargs) -> dict:
return (
dict(
queue_name="tasks",
url=ANY,
headers=ANY,
payload=json.dumps(
dict(
message=payload["message"],
attributes=payload["attributes"],
topic_name=payload["topic_name"],
)
),
)
| kwargs
)

0 comments on commit 5b45488

Please sign in to comment.