Skip to content

Commit 71b3ca3

Browse files
committed
Run event loop after job application
1 parent 3901cb7 commit 71b3ca3

File tree

1 file changed

+31
-36
lines changed

1 file changed

+31
-36
lines changed

temporalio/worker/_workflow_instance.py

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
208208
self._worker_level_failure_exception_types = (
209209
det.worker_level_failure_exception_types
210210
)
211+
self._primary_task_initter: Optional[Callable[[], asyncio.Task[None]]] = None
211212
self._primary_task: Optional[asyncio.Task[None]] = None
212213
self._time_ns = 0
213214
self._cancel_requested = False
@@ -356,39 +357,22 @@ def activate(
356357
self._current_thread_id = threading.get_ident()
357358
activation_err: Optional[Exception] = None
358359
try:
359-
# Split into job sets with patches, then signals + updates, then
360-
# non-queries, then queries
361-
start_job = None
362-
job_sets: List[
363-
List[temporalio.bridge.proto.workflow_activation.WorkflowActivationJob]
364-
] = [[], [], [], []]
360+
# Apply every job, running the loop afterward
361+
is_query = False
365362
for job in act.jobs:
366-
if job.HasField("notify_has_patch"):
367-
job_sets[0].append(job)
368-
elif job.HasField("signal_workflow") or job.HasField("do_update"):
369-
job_sets[1].append(job)
370-
elif not job.HasField("query_workflow"):
371-
if job.HasField("initialize_workflow"):
372-
start_job = job.initialize_workflow
373-
job_sets[2].append(job)
374-
else:
375-
job_sets[3].append(job)
376-
377-
if start_job:
378-
self._workflow_input = self._make_workflow_input(start_job)
379-
380-
# Apply every job set, running after each set
381-
for index, job_set in enumerate(job_sets):
382-
if not job_set:
383-
continue
384-
for job in job_set:
385-
# Let errors bubble out of these to the caller to fail the task
386-
self._apply(job)
387-
388-
# Run one iteration of the loop. We do not allow conditions to
389-
# be checked in patch jobs (first index) or query jobs (last
390-
# index).
391-
self._run_once(check_conditions=index == 1 or index == 2)
363+
if job.HasField("initialize_workflow"):
364+
self._workflow_input = self._make_workflow_input(job.initialize_workflow)
365+
# Let errors bubble out of these to the caller to fail the task
366+
self._apply(job)
367+
if job.HasField("query_workflow"):
368+
is_query = True
369+
370+
# Ensure the main loop is called, and called last, if needed
371+
if self._primary_task_initter is not None and self._primary_task is None:
372+
self._primary_task_initter()
373+
# Conditions are not checked on query activations. Query activations always come without
374+
# any other jobs.
375+
self._run_once(check_conditions=not is_query)
392376
except Exception as err:
393377
# We want some errors during activation, like those that can happen
394378
# during payload conversion, to be able to fail the workflow not the
@@ -508,6 +492,15 @@ def _apply_cancel_workflow(
508492
# workflow the ability to receive the cancellation, so we must defer
509493
# this cancellation to the next iteration of the event loop.
510494
self.call_soon(self._primary_task.cancel)
495+
elif self._primary_task_initter:
496+
# If we're being cancelled before ever being started, we need to run the cancel
497+
# after initialization
498+
old_initter = self._primary_task_initter
499+
def init_then_cancel():
500+
old_initter()
501+
self.call_soon(self._primary_task.cancel)
502+
self._primary_task_initter = init_then_cancel
503+
511504

512505
def _apply_do_update(
513506
self, job: temporalio.bridge.proto.workflow_activation.DoUpdate
@@ -889,10 +882,12 @@ async def run_workflow(input: ExecuteWorkflowInput) -> None:
889882
raise RuntimeError(
890883
"Expected workflow input to be set. This is an SDK Python bug."
891884
)
892-
self._primary_task = self.create_task(
893-
self._run_top_level_workflow_function(run_workflow(self._workflow_input)),
894-
name="run",
895-
)
885+
def primary_initter():
886+
self._primary_task = self.create_task(
887+
self._run_top_level_workflow_function(run_workflow(self._workflow_input)),
888+
name="run",
889+
)
890+
self._primary_task_initter = primary_initter
896891

897892
def _apply_update_random_seed(
898893
self, job: temporalio.bridge.proto.workflow_activation.UpdateRandomSeed

0 commit comments

Comments
 (0)