Skip to content

Commit 1e3fcaa

Browse files
committed
[Canvas] Support special case of group(A.s() | group(B.s() | C.S()))
1 parent 759842a commit 1e3fcaa

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

celery/canvas.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from kombu.utils import cached_property, fxrange, reprcall, uuid
2222

2323
from celery._state import current_app, get_current_worker_task
24+
from celery.result import GroupResult
2425
from celery.utils.functional import (
2526
maybe_list, is_list, regen,
2627
chunks as _chunks,
@@ -368,6 +369,7 @@ def __init__(self, *tasks, **options):
368369
self, 'celery.chain', (), {'tasks': tasks}, **options
369370
)
370371
self.subtask_type = 'chain'
372+
self._frozen = None
371373

372374
def __call__(self, *args, **kwargs):
373375
if self.tasks:
@@ -387,17 +389,27 @@ def run(self, args=(), kwargs={}, group_id=None, chord=None,
387389
app = app or self.app
388390
args = (tuple(args) + tuple(self.args)
389391
if args and not self.immutable else self.args)
390-
tasks, results = self.prepare_steps(
391-
args, self.tasks, root_id, link_error, app,
392-
task_id, group_id, chord,
393-
)
392+
393+
try:
394+
tasks, results = self._frozen
395+
except (AttributeError, ValueError):
396+
tasks, results = self.prepare_steps(
397+
args, self.tasks, root_id, link_error, app,
398+
task_id, group_id, chord,
399+
)
394400
if results:
395401
# make sure we can do a link() and link_error() on a chain object.
396402
if link:
397403
tasks[-1].set(link=link)
398404
tasks[0].apply_async(**options)
399405
return results[-1]
400406

407+
def freeze(self, _id=None, group_id=None, chord=None, root_id=None):
408+
_, results = self._frozen = self.prepare_steps(
409+
(), self.tasks, root_id, None, self.app, _id, group_id, chord,
410+
)
411+
return results[-1]
412+
401413
def prepare_steps(self, args, tasks,
402414
root_id=None, link_error=None, app=None,
403415
last_task_id=None, group_id=None, chord_body=None,
@@ -665,6 +677,16 @@ def apply_async(self, args=(), kwargs=None, add_to_parent=True,
665677
result = self.app.GroupResult(
666678
group_id, list(self._apply_tasks(tasks, producer, app, **options)),
667679
)
680+
681+
# - Special case of group(A.s() | group(B.s(), C.s()))
682+
# That is, group with single item that is a chain but the
683+
# last task in that chain is a group.
684+
#
685+
# We cannot actually support arbitrary GroupResults in chains,
686+
# but this special case we can.
687+
if len(result) == 1 and isinstance(result[0], GroupResult):
688+
result = result[0]
689+
668690
parent_task = get_current_worker_task()
669691
if add_to_parent and parent_task:
670692
parent_task.add_trail(result)

0 commit comments

Comments
 (0)