66from dataclasses import dataclass
77from datetime import datetime , timezone
88from enum import Enum
9- from typing import Any , Optional , Sequence , TypeVar , Union
9+ from typing import Any , List , Optional , Sequence , TypeVar , Union
1010
1111import grpc
12- from google .protobuf import wrappers_pb2
1312
1413from durabletask .entities import EntityInstanceId
1514from durabletask .entities .entity_metadata import EntityMetadata
@@ -57,6 +56,39 @@ def raise_if_failed(self):
5756 self .failure_details )
5857
5958
59+ @dataclass
60+ class OrchestrationQuery :
61+ created_time_from : Optional [datetime ] = None
62+ created_time_to : Optional [datetime ] = None
63+ runtime_status : Optional [List [OrchestrationStatus ]] = None
64+ # Some backends don't respond well with max_instance_count = None, so we use the integer limit for non-paginated
65+ # results instead.
66+ max_instance_count : Optional [int ] = (1 << 31 ) - 1
67+ fetch_inputs_and_outputs : bool = False
68+
69+
70+ @dataclass
71+ class EntityQuery :
72+ instance_id_starts_with : Optional [str ] = None
73+ last_modified_from : Optional [datetime ] = None
74+ last_modified_to : Optional [datetime ] = None
75+ include_state : bool = True
76+ include_transient : bool = False
77+ page_size : Optional [int ] = None
78+
79+
80+ @dataclass
81+ class PurgeInstancesResult :
82+ deleted_instance_count : int
83+ is_complete : bool
84+
85+
86+ @dataclass
87+ class CleanEntityStorageResult :
88+ empty_entities_removed : int
89+ orphaned_locks_released : int
90+
91+
6092class OrchestrationFailedError (Exception ):
6193 def __init__ (self , message : str , failure_details : task .FailureDetails ):
6294 super ().__init__ (message )
@@ -73,6 +105,12 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
73105
74106 state = res .orchestrationState
75107
108+ new_state = parse_orchestration_state (state )
109+ new_state .instance_id = instance_id # Override instance_id with the one from the request, to match old behavior
110+ return new_state
111+
112+
113+ def parse_orchestration_state (state : pb .OrchestrationState ) -> OrchestrationState :
76114 failure_details = None
77115 if state .failureDetails .errorMessage != '' or state .failureDetails .errorType != '' :
78116 failure_details = task .FailureDetails (
@@ -81,7 +119,7 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
81119 state .failureDetails .stackTrace .value if not helpers .is_empty (state .failureDetails .stackTrace ) else None )
82120
83121 return OrchestrationState (
84- instance_id ,
122+ state . instanceId ,
85123 state .name ,
86124 OrchestrationStatus (state .orchestrationStatus ),
87125 state .createdTimestamp .ToDatetime (),
@@ -93,7 +131,6 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
93131
94132
95133class TaskHubGrpcClient :
96-
97134 def __init__ (self , * ,
98135 host_address : Optional [str ] = None ,
99136 metadata : Optional [list [tuple [str , str ]]] = None ,
@@ -136,7 +173,7 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu
136173 req = pb .CreateInstanceRequest (
137174 name = name ,
138175 instanceId = instance_id if instance_id else uuid .uuid4 ().hex ,
139- input = wrappers_pb2 . StringValue ( value = shared .to_json (input )) if input is not None else None ,
176+ input = helpers . get_string_value ( shared .to_json (input ) if input is not None else None ) ,
140177 scheduledStartTimestamp = helpers .new_timestamp (start_at ) if start_at else None ,
141178 version = helpers .get_string_value (version if version else self .default_version ),
142179 orchestrationIdReusePolicy = reuse_id_policy ,
@@ -152,6 +189,42 @@ def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = Tr
152189 res : pb .GetInstanceResponse = self ._stub .GetInstance (req )
153190 return new_orchestration_state (req .instanceId , res )
154191
192+ def get_all_orchestration_states (self ,
193+ orchestration_query : Optional [OrchestrationQuery ] = None
194+ ) -> List [OrchestrationState ]:
195+ if orchestration_query is None :
196+ orchestration_query = OrchestrationQuery ()
197+ _continuation_token = None
198+
199+ self ._logger .info (f"Querying orchestration instances with query: { orchestration_query } " )
200+
201+ states = []
202+
203+ while True :
204+ req = pb .QueryInstancesRequest (
205+ query = pb .InstanceQuery (
206+ runtimeStatus = [status .value for status in orchestration_query .runtime_status ] if orchestration_query .runtime_status else None ,
207+ createdTimeFrom = helpers .new_timestamp (orchestration_query .created_time_from ) if orchestration_query .created_time_from else None ,
208+ createdTimeTo = helpers .new_timestamp (orchestration_query .created_time_to ) if orchestration_query .created_time_to else None ,
209+ maxInstanceCount = orchestration_query .max_instance_count ,
210+ fetchInputsAndOutputs = orchestration_query .fetch_inputs_and_outputs ,
211+ continuationToken = _continuation_token
212+ )
213+ )
214+ resp : pb .QueryInstancesResponse = self ._stub .QueryInstances (req )
215+ states += [parse_orchestration_state (res ) for res in resp .orchestrationState ]
216+ # Check the value for continuationToken - none or "0" indicates that there are no more results.
217+ if resp .continuationToken and resp .continuationToken .value and resp .continuationToken .value != "0" :
218+ self ._logger .info (f"Received continuation token with value { resp .continuationToken .value } , fetching next list of instances..." )
219+ if _continuation_token and _continuation_token .value and _continuation_token .value == resp .continuationToken .value :
220+ self ._logger .warning (f"Received the same continuation token value { resp .continuationToken .value } again, stopping to avoid infinite loop." )
221+ break
222+ _continuation_token = resp .continuationToken
223+ else :
224+ break
225+
226+ return states
227+
155228 def wait_for_orchestration_start (self , instance_id : str , * ,
156229 fetch_payloads : bool = False ,
157230 timeout : int = 60 ) -> Optional [OrchestrationState ]:
@@ -199,7 +272,8 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *,
199272 req = pb .RaiseEventRequest (
200273 instanceId = instance_id ,
201274 name = event_name ,
202- input = wrappers_pb2 .StringValue (value = shared .to_json (data )) if data else None )
275+ input = helpers .get_string_value (shared .to_json (data ) if data is not None else None )
276+ )
203277
204278 self ._logger .info (f"Raising event '{ event_name } ' for instance '{ instance_id } '." )
205279 self ._stub .RaiseEvent (req )
@@ -209,7 +283,7 @@ def terminate_orchestration(self, instance_id: str, *,
209283 recursive : bool = True ):
210284 req = pb .TerminateRequest (
211285 instanceId = instance_id ,
212- output = wrappers_pb2 . StringValue ( value = shared .to_json (output )) if output else None ,
286+ output = helpers . get_string_value ( shared .to_json (output ) if output is not None else None ) ,
213287 recursive = recursive )
214288
215289 self ._logger .info (f"Terminating instance '{ instance_id } '." )
@@ -245,10 +319,31 @@ def restart_orchestration(self, instance_id: str, *,
245319 res : pb .RestartInstanceResponse = self ._stub .RestartInstance (req )
246320 return res .instanceId
247321
248- def purge_orchestration (self , instance_id : str , recursive : bool = True ):
322+ def purge_orchestration (self , instance_id : str , recursive : bool = True ) -> PurgeInstancesResult :
249323 req = pb .PurgeInstancesRequest (instanceId = instance_id , recursive = recursive )
250324 self ._logger .info (f"Purging instance '{ instance_id } '." )
251- self ._stub .PurgeInstances (req )
325+ resp : pb .PurgeInstancesResponse = self ._stub .PurgeInstances (req )
326+ return PurgeInstancesResult (resp .deletedInstanceCount , resp .isComplete .value )
327+
328+ def purge_orchestrations_by (self ,
329+ created_time_from : Optional [datetime ] = None ,
330+ created_time_to : Optional [datetime ] = None ,
331+ runtime_status : Optional [List [OrchestrationStatus ]] = None ,
332+ recursive : bool = False ) -> PurgeInstancesResult :
333+ self ._logger .info ("Purging orchestrations by filter: "
334+ f"created_time_from={ created_time_from } , "
335+ f"created_time_to={ created_time_to } , "
336+ f"runtime_status={ [str (status ) for status in runtime_status ] if runtime_status else None } , "
337+ f"recursive={ recursive } " )
338+ resp : pb .PurgeInstancesResponse = self ._stub .PurgeInstances (pb .PurgeInstancesRequest (
339+ purgeInstanceFilter = pb .PurgeInstanceFilter (
340+ createdTimeFrom = helpers .new_timestamp (created_time_from ) if created_time_from else None ,
341+ createdTimeTo = helpers .new_timestamp (created_time_to ) if created_time_to else None ,
342+ runtimeStatus = [status .value for status in runtime_status ] if runtime_status else None
343+ ),
344+ recursive = recursive
345+ ))
346+ return PurgeInstancesResult (resp .deletedInstanceCount , resp .isComplete .value )
252347
253348 def signal_entity (self ,
254349 entity_instance_id : EntityInstanceId ,
@@ -257,7 +352,7 @@ def signal_entity(self,
257352 req = pb .SignalEntityRequest (
258353 instanceId = str (entity_instance_id ),
259354 name = operation_name ,
260- input = wrappers_pb2 . StringValue ( value = shared .to_json (input )) if input else None ,
355+ input = helpers . get_string_value ( shared .to_json (input ) if input is not None else None ) ,
261356 requestId = str (uuid .uuid4 ()),
262357 scheduledTime = None ,
263358 parentTraceContext = None ,
@@ -276,4 +371,69 @@ def get_entity(self,
276371 if not res .exists :
277372 return None
278373
279- return EntityMetadata .from_entity_response (res , include_state )
374+ return EntityMetadata .from_entity_metadata (res .entity , include_state )
375+
376+ def get_all_entities (self ,
377+ entity_query : Optional [EntityQuery ] = None ) -> List [EntityMetadata ]:
378+ if entity_query is None :
379+ entity_query = EntityQuery ()
380+ _continuation_token = None
381+
382+ self ._logger .info (f"Retrieving entities by filter: { entity_query } " )
383+
384+ entities = []
385+
386+ while True :
387+ query_request = pb .QueryEntitiesRequest (
388+ query = pb .EntityQuery (
389+ instanceIdStartsWith = helpers .get_string_value (entity_query .instance_id_starts_with ),
390+ lastModifiedFrom = helpers .new_timestamp (entity_query .last_modified_from ) if entity_query .last_modified_from else None ,
391+ lastModifiedTo = helpers .new_timestamp (entity_query .last_modified_to ) if entity_query .last_modified_to else None ,
392+ includeState = entity_query .include_state ,
393+ includeTransient = entity_query .include_transient ,
394+ pageSize = helpers .get_int_value (entity_query .page_size ),
395+ continuationToken = _continuation_token
396+ )
397+ )
398+ resp : pb .QueryEntitiesResponse = self ._stub .QueryEntities (query_request )
399+ entities += [EntityMetadata .from_entity_metadata (entity , query_request .query .includeState ) for entity in resp .entities ]
400+ if resp .continuationToken and resp .continuationToken .value and resp .continuationToken .value != "0" :
401+ self ._logger .info (f"Received continuation token with value { resp .continuationToken .value } , fetching next page of entities..." )
402+ if _continuation_token and _continuation_token .value and _continuation_token .value == resp .continuationToken .value :
403+ self ._logger .warning (f"Received the same continuation token value { resp .continuationToken .value } again, stopping to avoid infinite loop." )
404+ break
405+ _continuation_token = resp .continuationToken
406+ else :
407+ break
408+ return entities
409+
410+ def clean_entity_storage (self ,
411+ remove_empty_entities : bool = True ,
412+ release_orphaned_locks : bool = True
413+ ) -> CleanEntityStorageResult :
414+ self ._logger .info ("Cleaning entity storage" )
415+
416+ empty_entities_removed = 0
417+ orphaned_locks_released = 0
418+ _continuation_token = None
419+
420+ while True :
421+ req = pb .CleanEntityStorageRequest (
422+ removeEmptyEntities = remove_empty_entities ,
423+ releaseOrphanedLocks = release_orphaned_locks ,
424+ continuationToken = _continuation_token
425+ )
426+ resp : pb .CleanEntityStorageResponse = self ._stub .CleanEntityStorage (req )
427+ empty_entities_removed += resp .emptyEntitiesRemoved
428+ orphaned_locks_released += resp .orphanedLocksReleased
429+
430+ if resp .continuationToken and resp .continuationToken .value and resp .continuationToken .value != "0" :
431+ self ._logger .info (f"Received continuation token with value { resp .continuationToken .value } , cleaning next page..." )
432+ if _continuation_token and _continuation_token .value and _continuation_token .value == resp .continuationToken .value :
433+ self ._logger .warning (f"Received the same continuation token value { resp .continuationToken .value } again, stopping to avoid infinite loop." )
434+ break
435+ _continuation_token = resp .continuationToken
436+ else :
437+ break
438+
439+ return CleanEntityStorageResult (empty_entities_removed , orphaned_locks_released )
0 commit comments