@@ -346,7 +346,7 @@ def __init__(
346346 else :
347347 self ._interceptors = None
348348
349- self ._async_worker_manager = _AsyncWorkerManager (self ._concurrency_options )
349+ self ._async_worker_manager = _AsyncWorkerManager (self ._concurrency_options , self . _logger )
350350
351351 @property
352352 def concurrency_options (self ) -> ConcurrencyOptions :
@@ -533,27 +533,31 @@ def stream_reader():
533533 if work_item .HasField ("orchestratorRequest" ):
534534 self ._async_worker_manager .submit_orchestration (
535535 self ._execute_orchestrator ,
536+ self ._cancel_orchestrator ,
536537 work_item .orchestratorRequest ,
537538 stub ,
538539 work_item .completionToken ,
539540 )
540541 elif work_item .HasField ("activityRequest" ):
541542 self ._async_worker_manager .submit_activity (
542543 self ._execute_activity ,
544+ self ._cancel_activity ,
543545 work_item .activityRequest ,
544546 stub ,
545547 work_item .completionToken ,
546548 )
547549 elif work_item .HasField ("entityRequest" ):
548550 self ._async_worker_manager .submit_entity_batch (
549551 self ._execute_entity_batch ,
552+ self ._cancel_entity_batch ,
550553 work_item .entityRequest ,
551554 stub ,
552555 work_item .completionToken ,
553556 )
554557 elif work_item .HasField ("entityRequestV2" ):
555558 self ._async_worker_manager .submit_entity_batch (
556559 self ._execute_entity_batch ,
560+ self ._cancel_entity_batch ,
557561 work_item .entityRequestV2 ,
558562 stub ,
559563 work_item .completionToken
@@ -670,6 +674,19 @@ def _execute_orchestrator(
670674 f"Failed to deliver orchestrator response for '{ req .instanceId } ' to sidecar: { ex } "
671675 )
672676
677+ def _cancel_orchestrator (
678+ self ,
679+ req : pb .OrchestratorRequest ,
680+ stub : stubs .TaskHubSidecarServiceStub ,
681+ completionToken ,
682+ ):
683+ stub .AbandonTaskOrchestratorWorkItem (
684+ pb .AbandonOrchestrationTaskRequest (
685+ completionToken = completionToken
686+ )
687+ )
688+ self ._logger .info (f"Cancelled orchestration task for invocation ID: { req .instanceId } " )
689+
673690 def _execute_activity (
674691 self ,
675692 req : pb .ActivityRequest ,
@@ -703,6 +720,19 @@ def _execute_activity(
703720 f"Failed to deliver activity response for '{ req .name } #{ req .taskId } ' of orchestration ID '{ instance_id } ' to sidecar: { ex } "
704721 )
705722
723+ def _cancel_activity (
724+ self ,
725+ req : pb .ActivityRequest ,
726+ stub : stubs .TaskHubSidecarServiceStub ,
727+ completionToken ,
728+ ):
729+ stub .AbandonTaskActivityWorkItem (
730+ pb .AbandonActivityTaskRequest (
731+ completionToken = completionToken
732+ )
733+ )
734+ self ._logger .info (f"Cancelled activity task for task ID: { req .taskId } on orchestration ID: { req .orchestrationInstance .instanceId } " )
735+
706736 def _execute_entity_batch (
707737 self ,
708738 req : Union [pb .EntityBatchRequest , pb .EntityRequest ],
@@ -771,6 +801,19 @@ def _execute_entity_batch(
771801
772802 return batch_result
773803
804+ def _cancel_entity_batch (
805+ self ,
806+ req : Union [pb .EntityBatchRequest , pb .EntityRequest ],
807+ stub : stubs .TaskHubSidecarServiceStub ,
808+ completionToken ,
809+ ):
810+ stub .AbandonTaskEntityWorkItem (
811+ pb .AbandonEntityTaskRequest (
812+ completionToken = completionToken
813+ )
814+ )
815+ self ._logger .info (f"Cancelled entity batch task for instance ID: { req .instanceId } " )
816+
774817
775818class _RuntimeOrchestrationContext (task .OrchestrationContext ):
776819 _generator : Optional [Generator [task .Task , Any , Any ]]
@@ -1931,8 +1974,10 @@ def _is_suspendable(event: pb.HistoryEvent) -> bool:
19311974
19321975
19331976class _AsyncWorkerManager :
1934- def __init__ (self , concurrency_options : ConcurrencyOptions ):
1977+ def __init__ (self , concurrency_options : ConcurrencyOptions , logger : logging . Logger ):
19351978 self .concurrency_options = concurrency_options
1979+ self ._logger = logger
1980+
19361981 self .activity_semaphore = None
19371982 self .orchestration_semaphore = None
19381983 self .entity_semaphore = None
@@ -2042,17 +2087,51 @@ async def run(self):
20422087 )
20432088
20442089 # Start background consumers for each work type
2045- if self .activity_queue is not None and self .orchestration_queue is not None \
2046- and self .entity_batch_queue is not None :
2047- await asyncio .gather (
2048- self ._consume_queue (self .activity_queue , self .activity_semaphore ),
2049- self ._consume_queue (
2050- self .orchestration_queue , self .orchestration_semaphore
2051- ),
2052- self ._consume_queue (
2053- self .entity_batch_queue , self .entity_semaphore
2090+ try :
2091+ if self .activity_queue is not None and self .orchestration_queue is not None \
2092+ and self .entity_batch_queue is not None :
2093+ await asyncio .gather (
2094+ self ._consume_queue (self .activity_queue , self .activity_semaphore ),
2095+ self ._consume_queue (
2096+ self .orchestration_queue , self .orchestration_semaphore
2097+ ),
2098+ self ._consume_queue (
2099+ self .entity_batch_queue , self .entity_semaphore
2100+ )
20542101 )
2055- )
2102+ except Exception as queue_exception :
2103+ self ._logger .error (f"Shutting down worker - Uncaught error in worker manager: { queue_exception } " )
2104+ while self .activity_queue is not None and not self .activity_queue .empty ():
2105+ try :
2106+ func , cancellation_func , args , kwargs = self .activity_queue .get_nowait ()
2107+ await self ._run_func (cancellation_func , * args , ** kwargs )
2108+ self ._logger .error (f"Activity work item args: { args } , kwargs: { kwargs } " )
2109+ except asyncio .QueueEmpty :
2110+ # Queue was empty, no cancellation needed
2111+ pass
2112+ except Exception as cancellation_exception :
2113+ self ._logger .error (f"Uncaught error while cancelling activity work item: { cancellation_exception } " )
2114+ while self .orchestration_queue is not None and not self .orchestration_queue .empty ():
2115+ try :
2116+ func , cancellation_func , args , kwargs = self .orchestration_queue .get_nowait ()
2117+ await self ._run_func (cancellation_func , * args , ** kwargs )
2118+ self ._logger .error (f"Orchestration work item args: { args } , kwargs: { kwargs } " )
2119+ except asyncio .QueueEmpty :
2120+ # Queue was empty, no cancellation needed
2121+ pass
2122+ except Exception as cancellation_exception :
2123+ self ._logger .error (f"Uncaught error while cancelling orchestration work item: { cancellation_exception } " )
2124+ while self .entity_batch_queue is not None and not self .entity_batch_queue .empty ():
2125+ try :
2126+ func , cancellation_func , args , kwargs = self .entity_batch_queue .get_nowait ()
2127+ await self ._run_func (cancellation_func , * args , ** kwargs )
2128+ self ._logger .error (f"Entity batch work item args: { args } , kwargs: { kwargs } " )
2129+ except asyncio .QueueEmpty :
2130+ # Queue was empty, no cancellation needed
2131+ pass
2132+ except Exception as cancellation_exception :
2133+ self ._logger .error (f"Uncaught error while cancelling entity batch work item: { cancellation_exception } " )
2134+ self .shutdown ()
20562135
20572136 async def _consume_queue (self , queue : asyncio .Queue , semaphore : asyncio .Semaphore ):
20582137 # List to track running tasks
@@ -2072,19 +2151,22 @@ async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphor
20722151 except asyncio .TimeoutError :
20732152 continue
20742153
2075- func , args , kwargs = work
2154+ func , cancellation_func , args , kwargs = work
20762155 # Create a concurrent task for processing
20772156 task = asyncio .create_task (
2078- self ._process_work_item (semaphore , queue , func , args , kwargs )
2157+ self ._process_work_item (semaphore , queue , func , cancellation_func , args , kwargs )
20792158 )
20802159 running_tasks .add (task )
20812160
20822161 async def _process_work_item (
2083- self , semaphore : asyncio .Semaphore , queue : asyncio .Queue , func , args , kwargs
2162+ self , semaphore : asyncio .Semaphore , queue : asyncio .Queue , func , cancellation_func , args , kwargs
20842163 ):
20852164 async with semaphore :
20862165 try :
20872166 await self ._run_func (func , * args , ** kwargs )
2167+ except Exception as work_exception :
2168+ self ._logger .error (f"Uncaught error while processing work item, item will be abandoned: { work_exception } " )
2169+ await self ._run_func (cancellation_func , * args , ** kwargs )
20882170 finally :
20892171 queue .task_done ()
20902172
@@ -2103,26 +2185,32 @@ async def _run_func(self, func, *args, **kwargs):
21032185 self .thread_pool , lambda : func (* args , ** kwargs )
21042186 )
21052187
2106- def submit_activity (self , func , * args , ** kwargs ):
2107- work_item = (func , args , kwargs )
2188+ def submit_activity (self , func , cancellation_func , * args , ** kwargs ):
2189+ if self ._shutdown :
2190+ raise RuntimeError ("Cannot submit new work items after shutdown has been initiated." )
2191+ work_item = (func , cancellation_func , args , kwargs )
21082192 self ._ensure_queues_for_current_loop ()
21092193 if self .activity_queue is not None :
21102194 self .activity_queue .put_nowait (work_item )
21112195 else :
21122196 # No event loop running, store in pending list
21132197 self ._pending_activity_work .append (work_item )
21142198
2115- def submit_orchestration (self , func , * args , ** kwargs ):
2116- work_item = (func , args , kwargs )
2199+ def submit_orchestration (self , func , cancellation_func , * args , ** kwargs ):
2200+ if self ._shutdown :
2201+ raise RuntimeError ("Cannot submit new work items after shutdown has been initiated." )
2202+ work_item = (func , cancellation_func , args , kwargs )
21172203 self ._ensure_queues_for_current_loop ()
21182204 if self .orchestration_queue is not None :
21192205 self .orchestration_queue .put_nowait (work_item )
21202206 else :
21212207 # No event loop running, store in pending list
21222208 self ._pending_orchestration_work .append (work_item )
21232209
2124- def submit_entity_batch (self , func , * args , ** kwargs ):
2125- work_item = (func , args , kwargs )
2210+ def submit_entity_batch (self , func , cancellation_func , * args , ** kwargs ):
2211+ if self ._shutdown :
2212+ raise RuntimeError ("Cannot submit new work items after shutdown has been initiated." )
2213+ work_item = (func , cancellation_func , args , kwargs )
21262214 self ._ensure_queues_for_current_loop ()
21272215 if self .entity_batch_queue is not None :
21282216 self .entity_batch_queue .put_nowait (work_item )
@@ -2134,7 +2222,7 @@ def shutdown(self):
21342222 self ._shutdown = True
21352223 self .thread_pool .shutdown (wait = True )
21362224
2137- def reset_for_new_run (self ):
2225+ async def reset_for_new_run (self ):
21382226 """Reset the manager state for a new run."""
21392227 self ._shutdown = False
21402228 # Clear any existing queues - they'll be recreated when needed
@@ -2143,18 +2231,28 @@ def reset_for_new_run(self):
21432231 # This ensures no items from previous runs remain
21442232 try :
21452233 while not self .activity_queue .empty ():
2146- self .activity_queue .get_nowait ()
2147- except Exception :
2148- pass
2234+ func , cancellation_func , args , kwargs = self .activity_queue .get_nowait ()
2235+ await self ._run_func (cancellation_func , * args , ** kwargs )
2236+ except Exception as reset_exception :
2237+ self ._logger .warning (f"Error while clearing activity queue during reset: { reset_exception } " )
21492238 if self .orchestration_queue is not None :
21502239 try :
21512240 while not self .orchestration_queue .empty ():
2152- self .orchestration_queue .get_nowait ()
2153- except Exception :
2154- pass
2241+ func , cancellation_func , args , kwargs = self .orchestration_queue .get_nowait ()
2242+ await self ._run_func (cancellation_func , * args , ** kwargs )
2243+ except Exception as reset_exception :
2244+ self ._logger .warning (f"Error while clearing orchestration queue during reset: { reset_exception } " )
2245+ if self .entity_batch_queue is not None :
2246+ try :
2247+ while not self .entity_batch_queue .empty ():
2248+ func , cancellation_func , args , kwargs = self .entity_batch_queue .get_nowait ()
2249+ await self ._run_func (cancellation_func , * args , ** kwargs )
2250+ except Exception as reset_exception :
2251+ self ._logger .warning (f"Error while clearing entity queue during reset: { reset_exception } " )
21552252 # Clear pending work lists
21562253 self ._pending_activity_work .clear ()
21572254 self ._pending_orchestration_work .clear ()
2255+ self ._pending_entity_batch_work .clear ()
21582256
21592257
21602258# Export public API
0 commit comments