Skip to content

Commit

Permalink
feat: allow customizing TaskMetadata class
Browse files Browse the repository at this point in the history
  • Loading branch information
joaodaher committed Apr 4, 2024
1 parent 56e1f80 commit 9d6f403
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 9 deletions.
25 changes: 24 additions & 1 deletion django_cloud_tasks/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from django_cloud_tasks import exceptions


PREFIX = "DJANGO_CLOUD_TASKS_"
DEFAULT_PROPAGATION_HEADERS = ["traceparent"]
DEFAULT_PROPAGATION_HEADERS_KEY = "_http_headers"
Expand Down Expand Up @@ -47,6 +46,10 @@ def __init__(self, *args, **kwargs):
name="PROPAGATED_HEADERS_KEY", default=DEFAULT_PROPAGATION_HEADERS_KEY
)

@property
def task_metadata_class(self):
return self.get_task_metadata_class()

def get_tasks(self, only_subscriber: bool = False, only_periodic: bool = False, only_demand: bool = False):
all_tasks = {
"demand": list(self.on_demand_tasks.values()),
Expand Down Expand Up @@ -80,6 +83,26 @@ def get_backup_queue_name(self, original_name: str) -> str:
default=f"{original_name}{self.delimiter}temp",
)

def get_task_metadata_class(self):
from django_cloud_tasks.tasks import TaskMetadata

metadata_class_name = self._fetch_config(
name="TASK_METADATA_CLASS",
default="django_cloud_tasks.tasks.task.TaskMetadata",
)

try:
module_name, class_name = metadata_class_name.rsplit(".", 1)
module = __import__(module_name, fromlist=[class_name])
metadata_class = getattr(module, class_name)
except (AttributeError, ImportError, ValueError) as err:
raise ImportError(f"Unable to import {metadata_class_name}") from err

if not issubclass(metadata_class, TaskMetadata):
raise ImportError(f"Class {metadata_class_name} must be a subclass of TaskMetadata")

return metadata_class

def _fetch_config(self, name: str, default: Any, as_list: bool = False) -> Any:
config_name = f"{PREFIX}{name.upper()}"

Expand Down
5 changes: 3 additions & 2 deletions django_cloud_tasks/tasks/periodic_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from gcp_pilot.scheduler import CloudScheduler

from django_cloud_tasks.serializers import deserialize, serialize
from django_cloud_tasks.tasks.task import Task, get_config, TaskMetadata
from django_cloud_tasks.tasks.task import Task, get_config


class PeriodicTask(Task, abc.ABC):
Expand All @@ -15,7 +15,8 @@ def schedule(cls, **kwargs):
payload = serialize(kwargs)

if cls.eager():
eager_metadata = TaskMetadata.build_eager(task_class=cls)
task_metadata_class = get_config(name="task_metadata_class")
eager_metadata = task_metadata_class.build_eager(task_class=cls)
return cls(metadata=eager_metadata).run(**deserialize(value=payload))

return cls._get_scheduler_client().put(
Expand Down
8 changes: 6 additions & 2 deletions django_cloud_tasks/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class TaskMetadata:

def __post_init__(self):
self.custom_headers = get_current_headers()
self._max_attempts = None

@classmethod
def from_headers(cls, headers: dict) -> Self:
Expand Down Expand Up @@ -301,14 +302,17 @@ def push(
api_kwargs["queue_name"] = backup_queue_name
outcome = cls._get_tasks_client().push(**api_kwargs)

return TaskMetadata.from_task_obj(task_obj=outcome)
task_metadata_class = get_config(name="task_metadata_class")
return task_metadata_class.from_task_obj(task_obj=outcome)

@classmethod
def debug(cls, task_id: str):
client = cls._get_tasks_client()
task_obj = client.get_task(queue_name=cls.queue(), task_name=task_id)
task_kwargs = json.loads(task_obj.http_request.body)
metadata = TaskMetadata.from_task_obj(task_obj=task_obj)

task_metadata_class = get_config(name="task_metadata_class")
metadata = task_metadata_class.from_task_obj(task_obj=task_obj)
return cls(metadata=metadata).run(**task_kwargs)

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions django_cloud_tasks/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from django_cloud_tasks.exceptions import TaskNotFound
from django_cloud_tasks.serializers import deserialize
from django_cloud_tasks.tasks import Task, SubscriberTask
from django_cloud_tasks.tasks.task import TaskMetadata

from django_cloud_tasks.tasks.task import TaskMetadata, get_config

logger = logging.getLogger("django_cloud_tasks")

Expand Down Expand Up @@ -53,7 +52,8 @@ def parse_input(self, request, task_class: Type[Task]) -> dict:
return deserialize(value=request.body)

def parse_metadata(self, request) -> TaskMetadata:
return TaskMetadata.from_headers(headers=dict(request.headers))
task_metadata_class = get_config(name="task_metadata_class")
return task_metadata_class.from_headers(headers=dict(request.headers))


# More info: https://cloud.google.com/pubsub/docs/push#receiving_messages
Expand Down
8 changes: 7 additions & 1 deletion sample_project/sample_app/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from django.db.models import Model

from django_cloud_tasks.tasks import PeriodicTask, RoutineTask, SubscriberTask, Task, ModelPublisherTask
from django_cloud_tasks.tasks import PeriodicTask, RoutineTask, SubscriberTask, Task, ModelPublisherTask, TaskMetadata


class BaseAbstractTask(Task, abc.ABC):
Expand Down Expand Up @@ -104,3 +104,9 @@ def run(self, **kwargs): ...

@classmethod
def revert(cls, **kwargs): ...


class MyMetadata(TaskMetadata): ...


class MyUnsupportedMetadata: ...
23 changes: 23 additions & 0 deletions sample_project/sample_app/tests/tests_tasks/tests_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@

from django_cloud_tasks import exceptions
from django_cloud_tasks.tasks import Task, TaskMetadata, is_task_route
from django_cloud_tasks.tasks.task import get_config
from django_cloud_tasks.tests import tests_base
from django_cloud_tasks.tests.tests_base import eager_tasks
from sample_app import tasks
from django.http import HttpRequest

from sample_app.tasks import MyMetadata


class TasksTest(SimpleTestCase):
def setUp(self):
Expand Down Expand Up @@ -374,3 +377,23 @@ def test_comparable(self):

not_metadata = True
self.assertNotEqual(reference, not_metadata)

def test_custom_class(self):
with self.settings(DJANGO_CLOUD_TASKS_TASK_METADATA_CLASS="sample_app.tasks.MyMetadata"):
metadata_class = get_config("task_metadata_class")

self.assertTrue(issubclass(metadata_class, MyMetadata))

def test_custom_class_not_found(self):
with (
self.settings(DJANGO_CLOUD_TASKS_TASK_METADATA_CLASS="potato.tasks.MyMetadata"),
self.assertRaisesRegex(ImportError, "Unable to import"),
):
get_config("task_metadata_class")

def test_custom_class_unsupported(self):
with (
self.settings(DJANGO_CLOUD_TASKS_TASK_METADATA_CLASS="sample_app.tasks.MyUnsupportedMetadata"),
self.assertRaisesRegex(ImportError, "must be a subclass of TaskMetadata"),
):
get_config("task_metadata_class")

0 comments on commit 9d6f403

Please sign in to comment.