Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions contentcuration/contentcuration/tests/test_asynctask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import absolute_import

import json
import threading
import uuid

Expand Down Expand Up @@ -191,20 +190,46 @@ def test_only_create_async_task_creates_task_entry(self):
self.assertEquals(result, 42)
self.assertEquals(TaskResult.objects.filter(task_id=async_result.task_id).count(), 0)

def test_fetch_or_enqueue_task(self):
expected_task = TaskResult.objects.create(
task_id=uuid.uuid4().hex,
task_name=test_task.name,
status=states.PENDING,
user=self.user,
task_kwargs=json.dumps({
"is_test": True
}),
)
def test_enqueue_task_adds_result_with_necessary_info(self):
async_result = test_task.enqueue(self.user, is_test=True)
try:
task_result = TaskResult.objects.get(task_id=async_result.task_id)
except TaskResult.DoesNotExist:
self.fail('Missing task result')

self.assertEqual(task_result.task_name, test_task.name)
_, _, encoded_kwargs = test_task.backend.encode_content(dict(is_test=True))
self.assertEqual(task_result.task_kwargs, encoded_kwargs)

def test_fetch_or_enqueue_task(self):
expected_task = test_task.enqueue(self.user, is_test=True)
async_result = test_task.fetch_or_enqueue(self.user, is_test=True)
self.assertEqual(expected_task.task_id, async_result.task_id)

def test_fetch_or_enqueue_task__channel_id(self):
channel_id = uuid.uuid4()
expected_task = test_task.enqueue(self.user, channel_id=channel_id)
async_result = test_task.fetch_or_enqueue(self.user, channel_id=channel_id)
self.assertEqual(expected_task.task_id, async_result.task_id)

def test_fetch_or_enqueue_task__channel_id__hex(self):
channel_id = uuid.uuid4()
expected_task = test_task.enqueue(self.user, channel_id=channel_id.hex)
async_result = test_task.fetch_or_enqueue(self.user, channel_id=channel_id.hex)
self.assertEqual(expected_task.task_id, async_result.task_id)

def test_fetch_or_enqueue_task__channel_id__hex_then_uuid(self):
channel_id = uuid.uuid4()
expected_task = test_task.enqueue(self.user, channel_id=channel_id.hex)
async_result = test_task.fetch_or_enqueue(self.user, channel_id=channel_id)
self.assertEqual(expected_task.task_id, async_result.task_id)

def test_fetch_or_enqueue_task__channel_id__uuid_then_hex(self):
channel_id = uuid.uuid4()
expected_task = test_task.enqueue(self.user, channel_id=channel_id)
async_result = test_task.fetch_or_enqueue(self.user, channel_id=channel_id.hex)
self.assertEqual(expected_task.task_id, async_result.task_id)

def test_requeue_task(self):
existing_task_ids = requeue_test_task.find_ids()
self.assertEqual(len(existing_task_ids), 0)
Expand Down
28 changes: 21 additions & 7 deletions contentcuration/contentcuration/utils/celery/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,10 @@ def fetch(self, task_id):
"""
return self.AsyncResult(task_id)

def fetch_match(self, task_id, **kwargs):
def _fetch_match(self, task_id, **kwargs):
"""
Gets the result object for a task, assuming it was called async, and ensures it was called with kwargs
Gets the result object for a task, assuming it was called async, and ensures it was called with kwargs and
assumes that kwargs is has been decoded from an prepared form
:param task_id: The hex task ID
:param kwargs: The kwargs the task was called with, which must match when fetching
:return: A CeleryAsyncResult
Expand All @@ -160,6 +161,12 @@ def fetch_match(self, task_id, **kwargs):
return async_result
return None

def _prepare_kwargs(self, kwargs):
return self.backend.encode({
key: value.hex if isinstance(value, uuid.UUID) else value
for key, value in kwargs.items()
})

def enqueue(self, user, **kwargs):
"""
Enqueues the task called with `kwargs`, and requires the user who wants to enqueue it. If `channel_id` is
Expand All @@ -176,19 +183,25 @@ def enqueue(self, user, **kwargs):
raise TypeError("All tasks must be assigned to a user.")

task_id = uuid.uuid4().hex
channel_id = kwargs.get("channel_id")
prepared_kwargs = self._prepare_kwargs(kwargs)
transcoded_kwargs = self.backend.decode(prepared_kwargs)
channel_id = transcoded_kwargs.get("channel_id")

logging.info(f"Enqueuing task:id {self.name}:{task_id} for user:channel {user.pk}:{channel_id} | {kwargs}")
logging.info(f"Enqueuing task:id {self.name}:{task_id} for user:channel {user.pk}:{channel_id} | {prepared_kwargs}")

# returns a CeleryAsyncResult
async_result = self.apply_async(
task_id=task_id,
kwargs=kwargs,
kwargs=transcoded_kwargs,
)

# ensure the result is saved to the backend (database)
self.backend.add_pending_result(async_result)

# after calling apply, we should have task result model, so get it and set our custom fields
task_result = get_task_model(self, task_id)
task_result.task_name = self.name
task_result.task_kwargs = prepared_kwargs
task_result.user = user
task_result.channel_id = channel_id
task_result.save()
Expand All @@ -207,9 +220,10 @@ def fetch_or_enqueue(self, user, **kwargs):
# if we're eagerly executing the task (synchronously), then we shouldn't check for an existing task because
# implementations probably aren't prepared to rely on an existing asynchronous task
if not self.app.conf.task_always_eager:
task_ids = self.find_incomplete_ids(**kwargs).order_by("date_created")[:1]
transcoded_kwargs = self.backend.decode(self._prepare_kwargs(kwargs))
task_ids = self.find_incomplete_ids(**transcoded_kwargs).order_by("date_created")[:1]
if task_ids:
async_result = self.fetch_match(task_ids[0], **kwargs)
async_result = self._fetch_match(task_ids[0], **transcoded_kwargs)
if async_result:
logging.info(f"Fetched matching task {self.name} for user {user.pk} with id {async_result.id} | {kwargs}")
return async_result
Expand Down