Skip to content

Commit 860dbda

Browse files
authored
Merge branch 'main' into nytian/restart
2 parents 1179a95 + c658a52 commit 860dbda

16 files changed

+1074
-39
lines changed

durabletask/client.py

Lines changed: 171 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
from dataclasses import dataclass
77
from datetime import datetime, timezone
88
from enum import Enum
9-
from typing import Any, Optional, Sequence, TypeVar, Union
9+
from typing import Any, List, Optional, Sequence, TypeVar, Union
1010

1111
import grpc
12-
from google.protobuf import wrappers_pb2
1312

1413
from durabletask.entities import EntityInstanceId
1514
from durabletask.entities.entity_metadata import EntityMetadata
@@ -57,6 +56,39 @@ def raise_if_failed(self):
5756
self.failure_details)
5857

5958

59+
@dataclass
60+
class OrchestrationQuery:
61+
created_time_from: Optional[datetime] = None
62+
created_time_to: Optional[datetime] = None
63+
runtime_status: Optional[List[OrchestrationStatus]] = None
64+
# Some backends don't respond well with max_instance_count = None, so we use the integer limit for non-paginated
65+
# results instead.
66+
max_instance_count: Optional[int] = (1 << 31) - 1
67+
fetch_inputs_and_outputs: bool = False
68+
69+
70+
@dataclass
71+
class EntityQuery:
72+
instance_id_starts_with: Optional[str] = None
73+
last_modified_from: Optional[datetime] = None
74+
last_modified_to: Optional[datetime] = None
75+
include_state: bool = True
76+
include_transient: bool = False
77+
page_size: Optional[int] = None
78+
79+
80+
@dataclass
81+
class PurgeInstancesResult:
82+
deleted_instance_count: int
83+
is_complete: bool
84+
85+
86+
@dataclass
87+
class CleanEntityStorageResult:
88+
empty_entities_removed: int
89+
orphaned_locks_released: int
90+
91+
6092
class OrchestrationFailedError(Exception):
6193
def __init__(self, message: str, failure_details: task.FailureDetails):
6294
super().__init__(message)
@@ -73,6 +105,12 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
73105

74106
state = res.orchestrationState
75107

108+
new_state = parse_orchestration_state(state)
109+
new_state.instance_id = instance_id # Override instance_id with the one from the request, to match old behavior
110+
return new_state
111+
112+
113+
def parse_orchestration_state(state: pb.OrchestrationState) -> OrchestrationState:
76114
failure_details = None
77115
if state.failureDetails.errorMessage != '' or state.failureDetails.errorType != '':
78116
failure_details = task.FailureDetails(
@@ -81,7 +119,7 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
81119
state.failureDetails.stackTrace.value if not helpers.is_empty(state.failureDetails.stackTrace) else None)
82120

83121
return OrchestrationState(
84-
instance_id,
122+
state.instanceId,
85123
state.name,
86124
OrchestrationStatus(state.orchestrationStatus),
87125
state.createdTimestamp.ToDatetime(),
@@ -93,7 +131,6 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
93131

94132

95133
class TaskHubGrpcClient:
96-
97134
def __init__(self, *,
98135
host_address: Optional[str] = None,
99136
metadata: Optional[list[tuple[str, str]]] = None,
@@ -136,7 +173,7 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu
136173
req = pb.CreateInstanceRequest(
137174
name=name,
138175
instanceId=instance_id if instance_id else uuid.uuid4().hex,
139-
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None,
176+
input=helpers.get_string_value(shared.to_json(input) if input is not None else None),
140177
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
141178
version=helpers.get_string_value(version if version else self.default_version),
142179
orchestrationIdReusePolicy=reuse_id_policy,
@@ -152,6 +189,42 @@ def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = Tr
152189
res: pb.GetInstanceResponse = self._stub.GetInstance(req)
153190
return new_orchestration_state(req.instanceId, res)
154191

192+
def get_all_orchestration_states(self,
193+
orchestration_query: Optional[OrchestrationQuery] = None
194+
) -> List[OrchestrationState]:
195+
if orchestration_query is None:
196+
orchestration_query = OrchestrationQuery()
197+
_continuation_token = None
198+
199+
self._logger.info(f"Querying orchestration instances with query: {orchestration_query}")
200+
201+
states = []
202+
203+
while True:
204+
req = pb.QueryInstancesRequest(
205+
query=pb.InstanceQuery(
206+
runtimeStatus=[status.value for status in orchestration_query.runtime_status] if orchestration_query.runtime_status else None,
207+
createdTimeFrom=helpers.new_timestamp(orchestration_query.created_time_from) if orchestration_query.created_time_from else None,
208+
createdTimeTo=helpers.new_timestamp(orchestration_query.created_time_to) if orchestration_query.created_time_to else None,
209+
maxInstanceCount=orchestration_query.max_instance_count,
210+
fetchInputsAndOutputs=orchestration_query.fetch_inputs_and_outputs,
211+
continuationToken=_continuation_token
212+
)
213+
)
214+
resp: pb.QueryInstancesResponse = self._stub.QueryInstances(req)
215+
states += [parse_orchestration_state(res) for res in resp.orchestrationState]
216+
# Check the value for continuationToken - none or "0" indicates that there are no more results.
217+
if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
218+
self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next list of instances...")
219+
if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
220+
self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
221+
break
222+
_continuation_token = resp.continuationToken
223+
else:
224+
break
225+
226+
return states
227+
155228
def wait_for_orchestration_start(self, instance_id: str, *,
156229
fetch_payloads: bool = False,
157230
timeout: int = 60) -> Optional[OrchestrationState]:
@@ -199,7 +272,8 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *,
199272
req = pb.RaiseEventRequest(
200273
instanceId=instance_id,
201274
name=event_name,
202-
input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None)
275+
input=helpers.get_string_value(shared.to_json(data) if data is not None else None)
276+
)
203277

204278
self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
205279
self._stub.RaiseEvent(req)
@@ -209,7 +283,7 @@ def terminate_orchestration(self, instance_id: str, *,
209283
recursive: bool = True):
210284
req = pb.TerminateRequest(
211285
instanceId=instance_id,
212-
output=wrappers_pb2.StringValue(value=shared.to_json(output)) if output else None,
286+
output=helpers.get_string_value(shared.to_json(output) if output is not None else None),
213287
recursive=recursive)
214288

215289
self._logger.info(f"Terminating instance '{instance_id}'.")
@@ -245,10 +319,31 @@ def restart_orchestration(self, instance_id: str, *,
245319
res: pb.RestartInstanceResponse = self._stub.RestartInstance(req)
246320
return res.instanceId
247321

248-
def purge_orchestration(self, instance_id: str, recursive: bool = True):
322+
def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult:
249323
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
250324
self._logger.info(f"Purging instance '{instance_id}'.")
251-
self._stub.PurgeInstances(req)
325+
resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req)
326+
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
327+
328+
def purge_orchestrations_by(self,
329+
created_time_from: Optional[datetime] = None,
330+
created_time_to: Optional[datetime] = None,
331+
runtime_status: Optional[List[OrchestrationStatus]] = None,
332+
recursive: bool = False) -> PurgeInstancesResult:
333+
self._logger.info("Purging orchestrations by filter: "
334+
f"created_time_from={created_time_from}, "
335+
f"created_time_to={created_time_to}, "
336+
f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, "
337+
f"recursive={recursive}")
338+
resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(pb.PurgeInstancesRequest(
339+
purgeInstanceFilter=pb.PurgeInstanceFilter(
340+
createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None,
341+
createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None,
342+
runtimeStatus=[status.value for status in runtime_status] if runtime_status else None
343+
),
344+
recursive=recursive
345+
))
346+
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
252347

253348
def signal_entity(self,
254349
entity_instance_id: EntityInstanceId,
@@ -257,7 +352,7 @@ def signal_entity(self,
257352
req = pb.SignalEntityRequest(
258353
instanceId=str(entity_instance_id),
259354
name=operation_name,
260-
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input else None,
355+
input=helpers.get_string_value(shared.to_json(input) if input is not None else None),
261356
requestId=str(uuid.uuid4()),
262357
scheduledTime=None,
263358
parentTraceContext=None,
@@ -276,4 +371,69 @@ def get_entity(self,
276371
if not res.exists:
277372
return None
278373

279-
return EntityMetadata.from_entity_response(res, include_state)
374+
return EntityMetadata.from_entity_metadata(res.entity, include_state)
375+
376+
def get_all_entities(self,
377+
entity_query: Optional[EntityQuery] = None) -> List[EntityMetadata]:
378+
if entity_query is None:
379+
entity_query = EntityQuery()
380+
_continuation_token = None
381+
382+
self._logger.info(f"Retrieving entities by filter: {entity_query}")
383+
384+
entities = []
385+
386+
while True:
387+
query_request = pb.QueryEntitiesRequest(
388+
query=pb.EntityQuery(
389+
instanceIdStartsWith=helpers.get_string_value(entity_query.instance_id_starts_with),
390+
lastModifiedFrom=helpers.new_timestamp(entity_query.last_modified_from) if entity_query.last_modified_from else None,
391+
lastModifiedTo=helpers.new_timestamp(entity_query.last_modified_to) if entity_query.last_modified_to else None,
392+
includeState=entity_query.include_state,
393+
includeTransient=entity_query.include_transient,
394+
pageSize=helpers.get_int_value(entity_query.page_size),
395+
continuationToken=_continuation_token
396+
)
397+
)
398+
resp: pb.QueryEntitiesResponse = self._stub.QueryEntities(query_request)
399+
entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities]
400+
if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
401+
self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next page of entities...")
402+
if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
403+
self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
404+
break
405+
_continuation_token = resp.continuationToken
406+
else:
407+
break
408+
return entities
409+
410+
def clean_entity_storage(self,
411+
remove_empty_entities: bool = True,
412+
release_orphaned_locks: bool = True
413+
) -> CleanEntityStorageResult:
414+
self._logger.info("Cleaning entity storage")
415+
416+
empty_entities_removed = 0
417+
orphaned_locks_released = 0
418+
_continuation_token = None
419+
420+
while True:
421+
req = pb.CleanEntityStorageRequest(
422+
removeEmptyEntities=remove_empty_entities,
423+
releaseOrphanedLocks=release_orphaned_locks,
424+
continuationToken=_continuation_token
425+
)
426+
resp: pb.CleanEntityStorageResponse = self._stub.CleanEntityStorage(req)
427+
empty_entities_removed += resp.emptyEntitiesRemoved
428+
orphaned_locks_released += resp.orphanedLocksReleased
429+
430+
if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
431+
self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, cleaning next page...")
432+
if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
433+
self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
434+
break
435+
_continuation_token = resp.continuationToken
436+
else:
437+
break
438+
439+
return CleanEntityStorageResult(empty_entities_removed, orphaned_locks_released)

durabletask/entities/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from durabletask.entities.entity_lock import EntityLock
99
from durabletask.entities.entity_context import EntityContext
1010
from durabletask.entities.entity_metadata import EntityMetadata
11+
from durabletask.entities.entity_operation_failed_exception import EntityOperationFailedException
1112

12-
__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext", "EntityMetadata"]
13+
__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext", "EntityMetadata",
14+
"EntityOperationFailedException"]
1315

1416
PACKAGE_NAME = "durabletask.entities"

durabletask/entities/entity_instance_id.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
class EntityInstanceId:
22
def __init__(self, entity: str, key: str):
3-
self.entity = entity
3+
EntityInstanceId.validate_entity_name(entity)
4+
EntityInstanceId.validate_key(key)
5+
self.entity = entity.lower()
46
self.key = key
57

68
def __str__(self) -> str:
@@ -35,8 +37,48 @@ def parse(entity_id: str) -> "EntityInstanceId":
3537
ValueError
3638
If the input string is not in the correct format.
3739
"""
40+
if not entity_id.startswith("@"):
41+
raise ValueError("Entity ID must start with '@'.")
3842
try:
3943
_, entity, key = entity_id.split("@", 2)
40-
return EntityInstanceId(entity=entity, key=key)
4144
except ValueError as ex:
42-
raise ValueError(f"Invalid entity ID format: {entity_id}", ex)
45+
raise ValueError(f"Invalid entity ID format: {entity_id}") from ex
46+
return EntityInstanceId(entity=entity, key=key)
47+
48+
@staticmethod
49+
def validate_entity_name(name: str) -> None:
50+
"""Validate that the entity name does not contain invalid characters.
51+
52+
Parameters
53+
----------
54+
name : str
55+
The entity name to validate.
56+
57+
Raises
58+
------
59+
ValueError
60+
If the name is not a valid entity name.
61+
"""
62+
if not name:
63+
raise ValueError("Entity name cannot be empty.")
64+
if "@" in name:
65+
raise ValueError("Entity name cannot contain '@' symbol.")
66+
67+
@staticmethod
68+
def validate_key(key: str) -> None:
69+
"""Validate that the entity key does not contain invalid characters.
70+
71+
Parameters
72+
----------
73+
key : str
74+
The entity key to validate.
75+
76+
Raises
77+
------
78+
ValueError
79+
If the key is not a valid entity key.
80+
"""
81+
if not key:
82+
raise ValueError("Entity key cannot be empty.")
83+
if "@" in key:
84+
raise ValueError("Entity key cannot contain '@' symbol.")

durabletask/entities/entity_metadata.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,22 @@ def __init__(self,
4444

4545
@staticmethod
4646
def from_entity_response(entity_response: pb.GetEntityResponse, includes_state: bool):
47+
return EntityMetadata.from_entity_metadata(entity_response.entity, includes_state)
48+
49+
@staticmethod
50+
def from_entity_metadata(entity: pb.EntityMetadata, includes_state: bool):
4751
try:
48-
entity_id = EntityInstanceId.parse(entity_response.entity.instanceId)
52+
entity_id = EntityInstanceId.parse(entity.instanceId)
4953
except ValueError:
5054
raise ValueError("Invalid entity instance ID in entity response.")
5155
entity_state = None
5256
if includes_state:
53-
entity_state = entity_response.entity.serializedState.value
57+
entity_state = entity.serializedState.value
5458
return EntityMetadata(
5559
id=entity_id,
56-
last_modified=entity_response.entity.lastModifiedTime.ToDatetime(timezone.utc),
57-
backlog_queue_size=entity_response.entity.backlogQueueSize,
58-
locked_by=entity_response.entity.lockedBy.value,
60+
last_modified=entity.lastModifiedTime.ToDatetime(timezone.utc),
61+
backlog_queue_size=entity.backlogQueueSize,
62+
locked_by=entity.lockedBy.value,
5963
includes_state=includes_state,
6064
state=entity_state
6165
)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from durabletask.internal.orchestrator_service_pb2 import TaskFailureDetails
2+
from durabletask.entities.entity_instance_id import EntityInstanceId
3+
4+
5+
class EntityOperationFailedException(Exception):
6+
"""Exception raised when an operation on an Entity Function fails."""
7+
8+
def __init__(self, entity_instance_id: EntityInstanceId, operation_name: str, failure_details: TaskFailureDetails) -> None:
9+
super().__init__()
10+
self.entity_instance_id = entity_instance_id
11+
self.operation_name = operation_name
12+
self.failure_details = failure_details
13+
14+
def __str__(self) -> str:
15+
return f"Operation '{self.operation_name}' on entity '{self.entity_instance_id}' failed with error: {self.failure_details.errorMessage}"

durabletask/internal/helpers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,13 @@ def get_string_value(val: Optional[str]) -> Optional[wrappers_pb2.StringValue]:
184184
return wrappers_pb2.StringValue(value=val)
185185

186186

187+
def get_int_value(val: Optional[int]) -> Optional[wrappers_pb2.Int32Value]:
188+
if val is None:
189+
return None
190+
else:
191+
return wrappers_pb2.Int32Value(value=val)
192+
193+
187194
def get_string_value_or_empty(val: Optional[str]) -> wrappers_pb2.StringValue:
188195
if val is None:
189196
return wrappers_pb2.StringValue(value="")

0 commit comments

Comments
 (0)