Skip to content

Commit 34da301

Browse files
authored
Merge branch 'main' into andystaples/minimize-entity-state-exposure
2 parents 6fa6cfa + 34eaed6 commit 34da301

File tree

3 files changed

+160
-36
lines changed

3 files changed

+160
-36
lines changed

durabletask/worker.py

Lines changed: 126 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def __init__(
346346
else:
347347
self._interceptors = None
348348

349-
self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options)
349+
self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options, self._logger)
350350

351351
@property
352352
def concurrency_options(self) -> ConcurrencyOptions:
@@ -533,27 +533,31 @@ def stream_reader():
533533
if work_item.HasField("orchestratorRequest"):
534534
self._async_worker_manager.submit_orchestration(
535535
self._execute_orchestrator,
536+
self._cancel_orchestrator,
536537
work_item.orchestratorRequest,
537538
stub,
538539
work_item.completionToken,
539540
)
540541
elif work_item.HasField("activityRequest"):
541542
self._async_worker_manager.submit_activity(
542543
self._execute_activity,
544+
self._cancel_activity,
543545
work_item.activityRequest,
544546
stub,
545547
work_item.completionToken,
546548
)
547549
elif work_item.HasField("entityRequest"):
548550
self._async_worker_manager.submit_entity_batch(
549551
self._execute_entity_batch,
552+
self._cancel_entity_batch,
550553
work_item.entityRequest,
551554
stub,
552555
work_item.completionToken,
553556
)
554557
elif work_item.HasField("entityRequestV2"):
555558
self._async_worker_manager.submit_entity_batch(
556559
self._execute_entity_batch,
560+
self._cancel_entity_batch,
557561
work_item.entityRequestV2,
558562
stub,
559563
work_item.completionToken
@@ -670,6 +674,19 @@ def _execute_orchestrator(
670674
f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}"
671675
)
672676

677+
def _cancel_orchestrator(
678+
self,
679+
req: pb.OrchestratorRequest,
680+
stub: stubs.TaskHubSidecarServiceStub,
681+
completionToken,
682+
):
683+
stub.AbandonTaskOrchestratorWorkItem(
684+
pb.AbandonOrchestrationTaskRequest(
685+
completionToken=completionToken
686+
)
687+
)
688+
self._logger.info(f"Cancelled orchestration task for invocation ID: {req.instanceId}")
689+
673690
def _execute_activity(
674691
self,
675692
req: pb.ActivityRequest,
@@ -703,6 +720,19 @@ def _execute_activity(
703720
f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}"
704721
)
705722

723+
def _cancel_activity(
724+
self,
725+
req: pb.ActivityRequest,
726+
stub: stubs.TaskHubSidecarServiceStub,
727+
completionToken,
728+
):
729+
stub.AbandonTaskActivityWorkItem(
730+
pb.AbandonActivityTaskRequest(
731+
completionToken=completionToken
732+
)
733+
)
734+
self._logger.info(f"Cancelled activity task for task ID: {req.taskId} on orchestration ID: {req.orchestrationInstance.instanceId}")
735+
706736
def _execute_entity_batch(
707737
self,
708738
req: Union[pb.EntityBatchRequest, pb.EntityRequest],
@@ -771,6 +801,19 @@ def _execute_entity_batch(
771801

772802
return batch_result
773803

804+
def _cancel_entity_batch(
805+
self,
806+
req: Union[pb.EntityBatchRequest, pb.EntityRequest],
807+
stub: stubs.TaskHubSidecarServiceStub,
808+
completionToken,
809+
):
810+
stub.AbandonTaskEntityWorkItem(
811+
pb.AbandonEntityTaskRequest(
812+
completionToken=completionToken
813+
)
814+
)
815+
self._logger.info(f"Cancelled entity batch task for instance ID: {req.instanceId}")
816+
774817

775818
class _RuntimeOrchestrationContext(task.OrchestrationContext):
776819
_generator: Optional[Generator[task.Task, Any, Any]]
@@ -1931,8 +1974,10 @@ def _is_suspendable(event: pb.HistoryEvent) -> bool:
19311974

19321975

19331976
class _AsyncWorkerManager:
1934-
def __init__(self, concurrency_options: ConcurrencyOptions):
1977+
def __init__(self, concurrency_options: ConcurrencyOptions, logger: logging.Logger):
19351978
self.concurrency_options = concurrency_options
1979+
self._logger = logger
1980+
19361981
self.activity_semaphore = None
19371982
self.orchestration_semaphore = None
19381983
self.entity_semaphore = None
@@ -2042,17 +2087,51 @@ async def run(self):
20422087
)
20432088

20442089
# Start background consumers for each work type
2045-
if self.activity_queue is not None and self.orchestration_queue is not None \
2046-
and self.entity_batch_queue is not None:
2047-
await asyncio.gather(
2048-
self._consume_queue(self.activity_queue, self.activity_semaphore),
2049-
self._consume_queue(
2050-
self.orchestration_queue, self.orchestration_semaphore
2051-
),
2052-
self._consume_queue(
2053-
self.entity_batch_queue, self.entity_semaphore
2090+
try:
2091+
if self.activity_queue is not None and self.orchestration_queue is not None \
2092+
and self.entity_batch_queue is not None:
2093+
await asyncio.gather(
2094+
self._consume_queue(self.activity_queue, self.activity_semaphore),
2095+
self._consume_queue(
2096+
self.orchestration_queue, self.orchestration_semaphore
2097+
),
2098+
self._consume_queue(
2099+
self.entity_batch_queue, self.entity_semaphore
2100+
)
20542101
)
2055-
)
2102+
except Exception as queue_exception:
2103+
self._logger.error(f"Shutting down worker - Uncaught error in worker manager: {queue_exception}")
2104+
while self.activity_queue is not None and not self.activity_queue.empty():
2105+
try:
2106+
func, cancellation_func, args, kwargs = self.activity_queue.get_nowait()
2107+
await self._run_func(cancellation_func, *args, **kwargs)
2108+
self._logger.error(f"Activity work item args: {args}, kwargs: {kwargs}")
2109+
except asyncio.QueueEmpty:
2110+
# Queue was empty, no cancellation needed
2111+
pass
2112+
except Exception as cancellation_exception:
2113+
self._logger.error(f"Uncaught error while cancelling activity work item: {cancellation_exception}")
2114+
while self.orchestration_queue is not None and not self.orchestration_queue.empty():
2115+
try:
2116+
func, cancellation_func, args, kwargs = self.orchestration_queue.get_nowait()
2117+
await self._run_func(cancellation_func, *args, **kwargs)
2118+
self._logger.error(f"Orchestration work item args: {args}, kwargs: {kwargs}")
2119+
except asyncio.QueueEmpty:
2120+
# Queue was empty, no cancellation needed
2121+
pass
2122+
except Exception as cancellation_exception:
2123+
self._logger.error(f"Uncaught error while cancelling orchestration work item: {cancellation_exception}")
2124+
while self.entity_batch_queue is not None and not self.entity_batch_queue.empty():
2125+
try:
2126+
func, cancellation_func, args, kwargs = self.entity_batch_queue.get_nowait()
2127+
await self._run_func(cancellation_func, *args, **kwargs)
2128+
self._logger.error(f"Entity batch work item args: {args}, kwargs: {kwargs}")
2129+
except asyncio.QueueEmpty:
2130+
# Queue was empty, no cancellation needed
2131+
pass
2132+
except Exception as cancellation_exception:
2133+
self._logger.error(f"Uncaught error while cancelling entity batch work item: {cancellation_exception}")
2134+
self.shutdown()
20562135

20572136
async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore):
20582137
# List to track running tasks
@@ -2072,19 +2151,22 @@ async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphor
20722151
except asyncio.TimeoutError:
20732152
continue
20742153

2075-
func, args, kwargs = work
2154+
func, cancellation_func, args, kwargs = work
20762155
# Create a concurrent task for processing
20772156
task = asyncio.create_task(
2078-
self._process_work_item(semaphore, queue, func, args, kwargs)
2157+
self._process_work_item(semaphore, queue, func, cancellation_func, args, kwargs)
20792158
)
20802159
running_tasks.add(task)
20812160

20822161
async def _process_work_item(
2083-
self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, args, kwargs
2162+
self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, cancellation_func, args, kwargs
20842163
):
20852164
async with semaphore:
20862165
try:
20872166
await self._run_func(func, *args, **kwargs)
2167+
except Exception as work_exception:
2168+
self._logger.error(f"Uncaught error while processing work item, item will be abandoned: {work_exception}")
2169+
await self._run_func(cancellation_func, *args, **kwargs)
20882170
finally:
20892171
queue.task_done()
20902172

@@ -2103,26 +2185,32 @@ async def _run_func(self, func, *args, **kwargs):
21032185
self.thread_pool, lambda: func(*args, **kwargs)
21042186
)
21052187

2106-
def submit_activity(self, func, *args, **kwargs):
2107-
work_item = (func, args, kwargs)
2188+
def submit_activity(self, func, cancellation_func, *args, **kwargs):
2189+
if self._shutdown:
2190+
raise RuntimeError("Cannot submit new work items after shutdown has been initiated.")
2191+
work_item = (func, cancellation_func, args, kwargs)
21082192
self._ensure_queues_for_current_loop()
21092193
if self.activity_queue is not None:
21102194
self.activity_queue.put_nowait(work_item)
21112195
else:
21122196
# No event loop running, store in pending list
21132197
self._pending_activity_work.append(work_item)
21142198

2115-
def submit_orchestration(self, func, *args, **kwargs):
2116-
work_item = (func, args, kwargs)
2199+
def submit_orchestration(self, func, cancellation_func, *args, **kwargs):
2200+
if self._shutdown:
2201+
raise RuntimeError("Cannot submit new work items after shutdown has been initiated.")
2202+
work_item = (func, cancellation_func, args, kwargs)
21172203
self._ensure_queues_for_current_loop()
21182204
if self.orchestration_queue is not None:
21192205
self.orchestration_queue.put_nowait(work_item)
21202206
else:
21212207
# No event loop running, store in pending list
21222208
self._pending_orchestration_work.append(work_item)
21232209

2124-
def submit_entity_batch(self, func, *args, **kwargs):
2125-
work_item = (func, args, kwargs)
2210+
def submit_entity_batch(self, func, cancellation_func, *args, **kwargs):
2211+
if self._shutdown:
2212+
raise RuntimeError("Cannot submit new work items after shutdown has been initiated.")
2213+
work_item = (func, cancellation_func, args, kwargs)
21262214
self._ensure_queues_for_current_loop()
21272215
if self.entity_batch_queue is not None:
21282216
self.entity_batch_queue.put_nowait(work_item)
@@ -2134,7 +2222,7 @@ def shutdown(self):
21342222
self._shutdown = True
21352223
self.thread_pool.shutdown(wait=True)
21362224

2137-
def reset_for_new_run(self):
2225+
async def reset_for_new_run(self):
21382226
"""Reset the manager state for a new run."""
21392227
self._shutdown = False
21402228
# Clear any existing queues - they'll be recreated when needed
@@ -2143,18 +2231,28 @@ def reset_for_new_run(self):
21432231
# This ensures no items from previous runs remain
21442232
try:
21452233
while not self.activity_queue.empty():
2146-
self.activity_queue.get_nowait()
2147-
except Exception:
2148-
pass
2234+
func, cancellation_func, args, kwargs = self.activity_queue.get_nowait()
2235+
await self._run_func(cancellation_func, *args, **kwargs)
2236+
except Exception as reset_exception:
2237+
self._logger.warning(f"Error while clearing activity queue during reset: {reset_exception}")
21492238
if self.orchestration_queue is not None:
21502239
try:
21512240
while not self.orchestration_queue.empty():
2152-
self.orchestration_queue.get_nowait()
2153-
except Exception:
2154-
pass
2241+
func, cancellation_func, args, kwargs = self.orchestration_queue.get_nowait()
2242+
await self._run_func(cancellation_func, *args, **kwargs)
2243+
except Exception as reset_exception:
2244+
self._logger.warning(f"Error while clearing orchestration queue during reset: {reset_exception}")
2245+
if self.entity_batch_queue is not None:
2246+
try:
2247+
while not self.entity_batch_queue.empty():
2248+
func, cancellation_func, args, kwargs = self.entity_batch_queue.get_nowait()
2249+
await self._run_func(cancellation_func, *args, **kwargs)
2250+
except Exception as reset_exception:
2251+
self._logger.warning(f"Error while clearing entity queue during reset: {reset_exception}")
21552252
# Clear pending work lists
21562253
self._pending_activity_work.clear()
21572254
self._pending_orchestration_work.clear()
2255+
self._pending_entity_batch_work.clear()
21582256

21592257

21602258
# Export public API

tests/durabletask/test_worker_concurrency_loop.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,21 @@ def dummy_orchestrator(req, stub, completionToken):
5252
time.sleep(0.1)
5353
stub.CompleteOrchestratorTask('ok')
5454

55+
def cancel_dummy_orchestrator(req, stub, completionToken):
56+
pass
57+
5558
def dummy_activity(req, stub, completionToken):
5659
time.sleep(0.1)
5760
stub.CompleteActivityTask('ok')
5861

62+
def cancel_dummy_activity(req, stub, completionToken):
63+
pass
64+
5965
# Patch the worker's _execute_orchestrator and _execute_activity
6066
worker._execute_orchestrator = dummy_orchestrator
67+
worker._cancel_orchestrator = cancel_dummy_orchestrator
6168
worker._execute_activity = dummy_activity
69+
worker._cancel_activity = cancel_dummy_activity
6270

6371
orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)]
6472
activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)]
@@ -67,9 +75,9 @@ async def run_test():
6775
# Start the worker manager's run loop in the background
6876
worker_task = asyncio.create_task(worker._async_worker_manager.run())
6977
for req in orchestrator_requests:
70-
worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken())
78+
worker._async_worker_manager.submit_orchestration(dummy_orchestrator, cancel_dummy_orchestrator, req, stub, DummyCompletionToken())
7179
for req in activity_requests:
72-
worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken())
80+
worker._async_worker_manager.submit_activity(dummy_activity, cancel_dummy_activity, req, stub, DummyCompletionToken())
7381
await asyncio.sleep(1.0)
7482
orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator')
7583
activity_count = sum(1 for t, _ in stub.completed if t == 'activity')
@@ -120,8 +128,8 @@ def fn(*args, **kwargs):
120128

121129
# Submit more work than concurrency allows
122130
for i in range(5):
123-
manager.submit_orchestration(make_work("orch", i))
124-
manager.submit_activity(make_work("act", i))
131+
manager.submit_orchestration(make_work("orch", i), lambda *a, **k: None)
132+
manager.submit_activity(make_work("act", i), lambda *a, **k: None)
125133

126134
# Run the manager loop in a thread (sync context)
127135
def run_manager():
@@ -131,6 +139,11 @@ def run_manager():
131139
t.start()
132140
time.sleep(1.5) # Let work process
133141
manager.shutdown()
142+
143+
# Ensure the queues have been started
144+
if (manager.activity_queue is None or manager.orchestration_queue is None):
145+
raise RuntimeError("Worker manager queues not initialized")
146+
134147
# Unblock the consumers by putting dummy items in the queues
135148
manager.activity_queue.put_nowait((lambda: None, (), {}))
136149
manager.orchestration_queue.put_nowait((lambda: None, (), {}))

tests/durabletask/test_worker_concurrency_loop_async.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,21 @@ async def dummy_orchestrator(req, stub, completionToken):
5050
await asyncio.sleep(0.1)
5151
stub.CompleteOrchestratorTask('ok')
5252

53+
async def cancel_dummy_orchestrator(req, stub, completionToken):
54+
pass
55+
5356
async def dummy_activity(req, stub, completionToken):
5457
await asyncio.sleep(0.1)
5558
stub.CompleteActivityTask('ok')
5659

60+
async def cancel_dummy_activity(req, stub, completionToken):
61+
pass
62+
5763
# Patch the worker's _execute_orchestrator and _execute_activity
58-
grpc_worker._execute_orchestrator = dummy_orchestrator
59-
grpc_worker._execute_activity = dummy_activity
64+
grpc_worker._execute_orchestrator = dummy_orchestrator.__get__(grpc_worker, TaskHubGrpcWorker)
65+
grpc_worker._cancel_orchestrator = cancel_dummy_orchestrator.__get__(grpc_worker, TaskHubGrpcWorker)
66+
grpc_worker._execute_activity = dummy_activity.__get__(grpc_worker, TaskHubGrpcWorker)
67+
grpc_worker._cancel_activity = cancel_dummy_activity.__get__(grpc_worker, TaskHubGrpcWorker)
6068

6169
orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)]
6270
activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)]
@@ -65,10 +73,15 @@ async def run_test():
6573
# Clear stub state before each run
6674
stub.completed.clear()
6775
worker_task = asyncio.create_task(grpc_worker._async_worker_manager.run())
76+
# Need to yield to that thread in order to let it start up on the second run
77+
startup_attempts = 0
78+
while grpc_worker._async_worker_manager._shutdown and startup_attempts < 10:
79+
await asyncio.sleep(0.1)
80+
startup_attempts += 1
6881
for req in orchestrator_requests:
69-
grpc_worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken())
82+
grpc_worker._async_worker_manager.submit_orchestration(dummy_orchestrator, cancel_dummy_orchestrator, req, stub, DummyCompletionToken())
7083
for req in activity_requests:
71-
grpc_worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken())
84+
grpc_worker._async_worker_manager.submit_activity(dummy_activity, cancel_dummy_activity, req, stub, DummyCompletionToken())
7285
await asyncio.sleep(1.0)
7386
orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator')
7487
activity_count = sum(1 for t, _ in stub.completed if t == 'activity')

0 commit comments

Comments
 (0)