Skip to content

Commit

Permalink
Add pre and post task hooks.
Browse files Browse the repository at this point in the history
  • Loading branch information
james-certn committed Jun 20, 2024
1 parent 9a8facd commit 71f53be
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 1 deletion.
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,29 @@ JOBS = {
}
```

#### Pre & Post Task Hooks
You can also run pre task or post task hooks, which happen in the normal processing of your `Job` instances and are executed in the worker process.

Both pre and post task hooks receive your `Job` instance as their only argument. Here's an example:

```python
def my_pre_task_hook(job):
... # configure something before running your task
```

To ensure these hooks gets run, simply add a `pre_task_hook` or `post_task_hook` key (or both, if needed) to your job config like so:

```python
JOBS = {
"my_job": {
"tasks": ["project.common.jobs.my_task"],
"pre_task_hook": "project.common.jobs.my_pre_task_hook",
"post_task_hook": "project.common.jobs.my_post_task_hook",
},
}
```


### Start the worker

In another terminal:
Expand Down
2 changes: 1 addition & 1 deletion django_dbq/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.1.0"
__version__ = "3.2.0"
4 changes: 4 additions & 0 deletions django_dbq/management/commands/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def _process_job(self):
if not job:
return

job.run_pre_task_hook()

logger.info(
'Processing job: name="%s" queue="%s" id=%s state=%s next_task=%s',
job.name,
Expand Down Expand Up @@ -109,6 +111,8 @@ def _process_job(self):
logger.exception("Failed to save job: id=%s", job.pk)
raise

job.run_post_task_hook()

self.current_job = None


Expand Down
22 changes: 22 additions & 0 deletions django_dbq/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from django.utils.module_loading import import_string
from django_dbq.tasks import (
get_next_task_name,
get_pre_task_hook_name,
get_post_task_hook_name,
get_failure_hook_name,
get_creation_hook_name,
)
Expand Down Expand Up @@ -126,12 +128,32 @@ def save(self, *args, **kwargs):
def update_next_task(self):
self.next_task = get_next_task_name(self.name, self.next_task) or ""

def get_pre_task_hook_name(self):
return get_pre_task_hook_name(self.name)

def get_post_task_hook_name(self):
return get_post_task_hook_name(self.name)

def get_failure_hook_name(self):
return get_failure_hook_name(self.name)

def get_creation_hook_name(self):
return get_creation_hook_name(self.name)

def run_pre_task_hook(self):
pre_task_hook_name = self.get_pre_task_hook_name()
if pre_task_hook_name:
logger.info("Running pre_task hook %s for new job", pre_task_hook_name)
pre_task_hook_function = import_string(pre_task_hook_name)
pre_task_hook_function(self)

def run_post_task_hook(self):
post_task_hook_name = self.get_post_task_hook_name()
if post_task_hook_name:
logger.info("Running post_task hook %s for new job", post_task_hook_name)
post_task_hook_function = import_string(post_task_hook_name)
post_task_hook_function(self)

def run_creation_hook(self):
creation_hook_name = self.get_creation_hook_name()
if creation_hook_name:
Expand Down
12 changes: 12 additions & 0 deletions django_dbq/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@


TASK_LIST_KEY = "tasks"
PRE_TASK_HOOK_KEY = "pre_task_hook"
POST_TASK_HOOK_KEY = "post_task_hook"
FAILURE_HOOK_KEY = "failure_hook"
CREATION_HOOK_KEY = "creation_hook"

Expand All @@ -24,6 +26,16 @@ def get_next_task_name(job_name, current_task=None):
return None


def get_pre_task_hook_name(job_name):
"""Return the name of the pre task hook for the given job (as a string) or None"""
return settings.JOBS[job_name].get(PRE_TASK_HOOK_KEY)


def get_post_task_hook_name(job_name):
"""Return the name of the post_task hook for the given job (as a string) or None"""
return settings.JOBS[job_name].get(POST_TASK_HOOK_KEY)


def get_failure_hook_name(job_name):
"""Return the name of the failure hook for the given job (as a string) or None"""
return settings.JOBS[job_name].get(FAILURE_HOOK_KEY)
Expand Down
52 changes: 52 additions & 0 deletions django_dbq/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,25 @@ def failing_task(job):
raise Exception("uh oh")


def pre_task_hook(job):
job.workspace["output"] = "pre task hook ran"
job.workspace["job_id"] = str(job.id)


def post_task_hook(job):
job.workspace["output"] = "post task hook ran"
job.workspace["job_id"] = str(job.id)


def failure_hook(job, exception):
job.workspace["output"] = "failure hook ran"
job.workspace["exception"] = str(exception)
job.workspace["job_id"] = str(job.id)


def creation_hook(job):
job.workspace["output"] = "creation hook ran"
job.workspace["job_id"] = str(job.id)


@override_settings(JOBS={"testjob": {"tasks": ["a"]}})
Expand Down Expand Up @@ -316,6 +329,7 @@ def test_creation_hook(self):
job = Job.objects.create(name="testjob")
job = Job.objects.get()
self.assertEqual(job.workspace["output"], "creation hook ran")
self.assertEqual(job.workspace["job_id"], str(job.id))

def test_creation_hook_only_runs_on_create(self):
job = Job.objects.create(name="testjob")
Expand All @@ -326,6 +340,42 @@ def test_creation_hook_only_runs_on_create(self):
self.assertEqual(job.workspace["output"], "creation hook output removed")


@override_settings(
JOBS={
"testjob": {
"tasks": ["django_dbq.tests.failing_task"],
"pre_task_hook": "django_dbq.tests.pre_task_hook",
}
}
)
class JobPreTaskHookTestCase(TestCase):
def test_pre_task_hook(self):
job = Job.objects.create(name="testjob")
Worker("default", 1)._process_job()
job = Job.objects.get()
self.assertEqual(job.state, Job.STATES.FAILED)
self.assertEqual(job.workspace["output"], "failure hook ran")
self.assertEqual(job.workspace["job_id"], str(job.id))


@override_settings(
JOBS={
"testjob": {
"tasks": ["django_dbq.tests.failing_task"],
"post_task_hook": "django_dbq.tests.post_task_hook",
}
}
)
class JobPostTaskHookTestCase(TestCase):
def test_post_task_hook(self):
job = Job.objects.create(name="testjob")
Worker("default", 1)._process_job()
job = Job.objects.get()
self.assertEqual(job.state, Job.STATES.FAILED)
self.assertEqual(job.workspace["output"], "post task hook ran")
self.assertEqual(job.workspace["job_id"], str(job.id))


@override_settings(
JOBS={
"testjob": {
Expand All @@ -341,6 +391,8 @@ def test_failure_hook(self):
job = Job.objects.get()
self.assertEqual(job.state, Job.STATES.FAILED)
self.assertEqual(job.workspace["output"], "failure hook ran")
self.assertIn("uh oh", job.workspace["exception"])
self.assertEqual(job.workspace["job_id"], str(job.id))


@override_settings(JOBS={"testjob": {"tasks": ["a"]}})
Expand Down

0 comments on commit 71f53be

Please sign in to comment.