44import logging
55import uuid
66from dataclasses import dataclass
7- from datetime import datetime , timezone
7+ from datetime import datetime , timedelta , timezone
88from enum import Enum
9- from typing import Any , Optional , Sequence , TypeVar , Union
9+ from typing import Any , List , Optional , Sequence , TypeVar , Union
1010
1111import grpc
12- from google .protobuf import wrappers_pb2
12+ from google .protobuf import wrappers_pb2 as pb2
1313
1414from durabletask .entities import EntityInstanceId
1515from durabletask .entities .entity_metadata import EntityMetadata
@@ -57,6 +57,12 @@ def raise_if_failed(self):
5757 self .failure_details )
5858
5959
60+ class PurgeInstancesResult :
61+ def __init__ (self , deleted_instance_count : int , is_complete : bool ):
62+ self .deleted_instance_count = deleted_instance_count
63+ self .is_complete = is_complete
64+
65+
6066class OrchestrationFailedError (Exception ):
6167 def __init__ (self , message : str , failure_details : task .FailureDetails ):
6268 super ().__init__ (message )
@@ -73,6 +79,12 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
7379
7480 state = res .orchestrationState
7581
82+ new_state = parse_orchestration_state (state )
83+ new_state .instance_id = instance_id # Override instance_id with the one from the request, to match old behavior
84+ return new_state
85+
86+
87+ def parse_orchestration_state (state : pb .OrchestrationState ) -> OrchestrationState :
7688 failure_details = None
7789 if state .failureDetails .errorMessage != '' or state .failureDetails .errorType != '' :
7890 failure_details = task .FailureDetails (
@@ -81,7 +93,7 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
8193 state .failureDetails .stackTrace .value if not helpers .is_empty (state .failureDetails .stackTrace ) else None )
8294
8395 return OrchestrationState (
84- instance_id ,
96+ state . instanceId ,
8597 state .name ,
8698 OrchestrationStatus (state .orchestrationStatus ),
8799 state .createdTimestamp .ToDatetime (),
@@ -93,7 +105,6 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
93105
94106
95107class TaskHubGrpcClient :
96-
97108 def __init__ (self , * ,
98109 host_address : Optional [str ] = None ,
99110 metadata : Optional [list [tuple [str , str ]]] = None ,
@@ -136,7 +147,7 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu
136147 req = pb .CreateInstanceRequest (
137148 name = name ,
138149 instanceId = instance_id if instance_id else uuid .uuid4 ().hex ,
139- input = wrappers_pb2 . StringValue ( value = shared .to_json (input )) if input is not None else None ,
150+ input = helpers . get_string_value ( shared .to_json (input )),
140151 scheduledStartTimestamp = helpers .new_timestamp (start_at ) if start_at else None ,
141152 version = helpers .get_string_value (version if version else self .default_version ),
142153 orchestrationIdReusePolicy = reuse_id_policy ,
@@ -152,6 +163,54 @@ def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = Tr
152163 res : pb .GetInstanceResponse = self ._stub .GetInstance (req )
153164 return new_orchestration_state (req .instanceId , res )
154165
166+ def get_all_orchestration_states (self ,
167+ max_instance_count : Optional [int ] = None ,
168+ fetch_inputs_and_outputs : bool = False ) -> List [OrchestrationState ]:
169+ return self .get_orchestration_state_by (
170+ created_time_from = None ,
171+ created_time_to = None ,
172+ runtime_status = None ,
173+ max_instance_count = max_instance_count ,
174+ fetch_inputs_and_outputs = fetch_inputs_and_outputs
175+ )
176+
177+ def get_orchestration_state_by (self ,
178+ created_time_from : Optional [datetime ] = None ,
179+ created_time_to : Optional [datetime ] = None ,
180+ runtime_status : Optional [List [OrchestrationStatus ]] = None ,
181+ max_instance_count : Optional [int ] = None ,
182+ fetch_inputs_and_outputs : bool = False ,
183+ _continuation_token : Optional [pb2 .StringValue ] = None
184+ ) -> List [OrchestrationState ]:
185+ if max_instance_count is None :
186+ # DTS backend does not behave well with max_instance_count = None, so we set to max 32-bit signed value
187+ max_instance_count = (1 << 31 ) - 1
188+ req = pb .QueryInstancesRequest (
189+ query = pb .InstanceQuery (
190+ runtimeStatus = [status .value for status in runtime_status ] if runtime_status else None ,
191+ createdTimeFrom = helpers .new_timestamp (created_time_from ) if created_time_from else None ,
192+ createdTimeTo = helpers .new_timestamp (created_time_to ) if created_time_to else None ,
193+ maxInstanceCount = max_instance_count ,
194+ fetchInputsAndOutputs = fetch_inputs_and_outputs ,
195+ continuationToken = _continuation_token
196+ )
197+ )
198+ resp : pb .QueryInstancesResponse = self ._stub .QueryInstances (req )
199+ states = [parse_orchestration_state (res ) for res in resp .orchestrationState ]
200+ # Check the value for continuationToken - none or "0" indicates that there are no more results.
201+ if resp .continuationToken and resp .continuationToken .value and resp .continuationToken .value != "0" :
202+ self ._logger .info (f"Received continuation token with value { resp .continuationToken .value } , fetching next list of instances..." )
203+ states += self .get_orchestration_state_by (
204+ created_time_from ,
205+ created_time_to ,
206+ runtime_status ,
207+ max_instance_count ,
208+ fetch_inputs_and_outputs ,
209+ _continuation_token = resp .continuationToken
210+ )
211+ states = [state for state in states if state is not None ] # Filter out any None values
212+ return states
213+
155214 def wait_for_orchestration_start (self , instance_id : str , * ,
156215 fetch_payloads : bool = False ,
157216 timeout : int = 60 ) -> Optional [OrchestrationState ]:
@@ -199,7 +258,7 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *,
199258 req = pb .RaiseEventRequest (
200259 instanceId = instance_id ,
201260 name = event_name ,
202- input = wrappers_pb2 . StringValue ( value = shared .to_json (data )) if data else None )
261+ input = helpers . get_string_value ( shared .to_json (data )))
203262
204263 self ._logger .info (f"Raising event '{ event_name } ' for instance '{ instance_id } '." )
205264 self ._stub .RaiseEvent (req )
@@ -209,7 +268,7 @@ def terminate_orchestration(self, instance_id: str, *,
209268 recursive : bool = True ):
210269 req = pb .TerminateRequest (
211270 instanceId = instance_id ,
212- output = wrappers_pb2 . StringValue ( value = shared .to_json (output )) if output else None ,
271+ output = helpers . get_string_value ( shared .to_json (output )),
213272 recursive = recursive )
214273
215274 self ._logger .info (f"Terminating instance '{ instance_id } '." )
@@ -225,10 +284,27 @@ def resume_orchestration(self, instance_id: str):
225284 self ._logger .info (f"Resuming instance '{ instance_id } '." )
226285 self ._stub .ResumeInstance (req )
227286
228- def purge_orchestration (self , instance_id : str , recursive : bool = True ):
287+ def purge_orchestration (self , instance_id : str , recursive : bool = True ) -> PurgeInstancesResult :
229288 req = pb .PurgeInstancesRequest (instanceId = instance_id , recursive = recursive )
230289 self ._logger .info (f"Purging instance '{ instance_id } '." )
231- self ._stub .PurgeInstances (req )
290+ resp : pb .PurgeInstancesResponse = self ._stub .PurgeInstances (req )
291+ return PurgeInstancesResult (resp .deletedInstanceCount , resp .isComplete .value )
292+
293+ def purge_orchestrations_by (self ,
294+ created_time_from : Optional [datetime ] = None ,
295+ created_time_to : Optional [datetime ] = None ,
296+ runtime_status : Optional [List [OrchestrationStatus ]] = None ,
297+ recursive : bool = False ) -> PurgeInstancesResult :
298+ resp : pb .PurgeInstancesResponse = self ._stub .PurgeInstances (pb .PurgeInstancesRequest (
299+ instanceId = None ,
300+ purgeInstanceFilter = pb .PurgeInstanceFilter (
301+ createdTimeFrom = helpers .new_timestamp (created_time_from ) if created_time_from else None ,
302+ createdTimeTo = helpers .new_timestamp (created_time_to ) if created_time_to else None ,
303+ runtimeStatus = [status .value for status in runtime_status ] if runtime_status else None
304+ ),
305+ recursive = recursive
306+ ))
307+ return PurgeInstancesResult (resp .deletedInstanceCount , resp .isComplete .value )
232308
233309 def signal_entity (self ,
234310 entity_instance_id : EntityInstanceId ,
@@ -237,7 +313,7 @@ def signal_entity(self,
237313 req = pb .SignalEntityRequest (
238314 instanceId = str (entity_instance_id ),
239315 name = operation_name ,
240- input = wrappers_pb2 . StringValue ( value = shared .to_json (input )) if input else None ,
316+ input = helpers . get_string_value ( shared .to_json (input )),
241317 requestId = str (uuid .uuid4 ()),
242318 scheduledTime = None ,
243319 parentTraceContext = None ,
@@ -256,4 +332,53 @@ def get_entity(self,
256332 if not res .exists :
257333 return None
258334
259- return EntityMetadata .from_entity_response (res , include_state )
335+ return EntityMetadata .from_entity_metadata (res .entity , include_state )
336+
337+ def get_all_entities (self ,
338+ include_state : bool = True ,
339+ include_transient : bool = False ,
340+ page_size : Optional [int ] = None ) -> List [EntityMetadata ]:
341+ return self .get_entities_by (
342+ instance_id_starts_with = None ,
343+ last_modified_from = None ,
344+ last_modified_to = None ,
345+ include_state = include_state ,
346+ include_transient = include_transient ,
347+ page_size = page_size
348+ )
349+
350+ def get_entities_by (self ,
351+ instance_id_starts_with : Optional [str ] = None ,
352+ last_modified_from : Optional [datetime ] = None ,
353+ last_modified_to : Optional [datetime ] = None ,
354+ include_state : bool = True ,
355+ include_transient : bool = False ,
356+ page_size : Optional [int ] = None ,
357+ _continuation_token : Optional [pb2 .StringValue ] = None
358+ ) -> List [EntityMetadata ]:
359+ self ._logger .info (f"Getting entities" )
360+ query_request = pb .QueryEntitiesRequest (
361+ query = pb .EntityQuery (
362+ instanceIdStartsWith = helpers .get_string_value (instance_id_starts_with ),
363+ lastModifiedFrom = helpers .new_timestamp (last_modified_from ) if last_modified_from else None ,
364+ lastModifiedTo = helpers .new_timestamp (last_modified_to ) if last_modified_to else None ,
365+ includeState = include_state ,
366+ includeTransient = include_transient ,
367+ pageSize = helpers .get_int_value (page_size ),
368+ continuationToken = _continuation_token
369+ )
370+ )
371+ resp : pb .QueryEntitiesResponse = self ._stub .QueryEntities (query_request )
372+ entities = [EntityMetadata .from_entity_metadata (entity , query_request .query .includeState ) for entity in resp .entities ]
373+ if resp .continuationToken and resp .continuationToken .value != "0" :
374+ self ._logger .info (f"Received continuation token with value { resp .continuationToken .value } , fetching next page of entities..." )
375+ entities += self .get_entities_by (
376+ instance_id_starts_with = instance_id_starts_with ,
377+ last_modified_from = last_modified_from ,
378+ last_modified_to = last_modified_to ,
379+ include_state = include_state ,
380+ include_transient = include_transient ,
381+ page_size = page_size ,
382+ _continuation_token = resp .continuationToken
383+ )
384+ return entities
0 commit comments