Skip to content

Commit 3a3c0c4

Browse files
authored
Address various entity-related bugs (#109)
* Address various entity-related bugs
1 parent 0e09bb8 commit 3a3c0c4

File tree

11 files changed

+377
-23
lines changed

11 files changed

+377
-23
lines changed

durabletask/entities/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from durabletask.entities.entity_lock import EntityLock
99
from durabletask.entities.entity_context import EntityContext
1010
from durabletask.entities.entity_metadata import EntityMetadata
11+
from durabletask.entities.entity_operation_failed_exception import EntityOperationFailedException
1112

12-
__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext", "EntityMetadata"]
13+
__all__ = ["EntityInstanceId", "DurableEntity", "EntityLock", "EntityContext", "EntityMetadata",
14+
"EntityOperationFailedException"]
1315

1416
PACKAGE_NAME = "durabletask.entities"

durabletask/entities/entity_instance_id.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
class EntityInstanceId:
22
def __init__(self, entity: str, key: str):
3-
self.entity = entity
3+
EntityInstanceId.validate_entity_name(entity)
4+
EntityInstanceId.validate_key(key)
5+
self.entity = entity.lower()
46
self.key = key
57

68
def __str__(self) -> str:
@@ -35,8 +37,48 @@ def parse(entity_id: str) -> "EntityInstanceId":
3537
ValueError
3638
If the input string is not in the correct format.
3739
"""
40+
if not entity_id.startswith("@"):
41+
raise ValueError("Entity ID must start with '@'.")
3842
try:
3943
_, entity, key = entity_id.split("@", 2)
40-
return EntityInstanceId(entity=entity, key=key)
4144
except ValueError as ex:
42-
raise ValueError(f"Invalid entity ID format: {entity_id}", ex)
45+
raise ValueError(f"Invalid entity ID format: {entity_id}") from ex
46+
return EntityInstanceId(entity=entity, key=key)
47+
48+
@staticmethod
49+
def validate_entity_name(name: str) -> None:
50+
"""Validate that the entity name does not contain invalid characters.
51+
52+
Parameters
53+
----------
54+
name : str
55+
The entity name to validate.
56+
57+
Raises
58+
------
59+
ValueError
60+
If the name is not a valid entity name.
61+
"""
62+
if not name:
63+
raise ValueError("Entity name cannot be empty.")
64+
if "@" in name:
65+
raise ValueError("Entity name cannot contain '@' symbol.")
66+
67+
@staticmethod
68+
def validate_key(key: str) -> None:
69+
"""Validate that the entity key does not contain invalid characters.
70+
71+
Parameters
72+
----------
73+
key : str
74+
The entity key to validate.
75+
76+
Raises
77+
------
78+
ValueError
79+
If the key is not a valid entity key.
80+
"""
81+
if not key:
82+
raise ValueError("Entity key cannot be empty.")
83+
if "@" in key:
84+
raise ValueError("Entity key cannot contain '@' symbol.")
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from durabletask.internal.orchestrator_service_pb2 import TaskFailureDetails
2+
from durabletask.entities.entity_instance_id import EntityInstanceId
3+
4+
5+
class EntityOperationFailedException(Exception):
6+
"""Exception raised when an operation on an Entity Function fails."""
7+
8+
def __init__(self, entity_instance_id: EntityInstanceId, operation_name: str, failure_details: TaskFailureDetails) -> None:
9+
super().__init__()
10+
self.entity_instance_id = entity_instance_id
11+
self.operation_name = operation_name
12+
self.failure_details = failure_details
13+
14+
def __str__(self) -> str:
15+
return f"Operation '{self.operation_name}' on entity '{self.entity_instance_id}' failed with error: {self.failure_details.errorMessage}"
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import Any
2+
3+
4+
class JsonEncodeOutputException(Exception):
5+
"""Custom exception type used to indicate that an orchestration result could not be JSON-encoded."""
6+
7+
def __init__(self, problem_object: Any):
8+
super().__init__()
9+
self.problem_object = problem_object
10+
11+
def __str__(self) -> str:
12+
return f"The orchestration result could not be encoded. Object details: {self.problem_object}"

durabletask/worker.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
import grpc
2020
from google.protobuf import empty_pb2
2121

22+
from durabletask.entities.entity_operation_failed_exception import EntityOperationFailedException
2223
from durabletask.internal import helpers
2324
from durabletask.internal.entity_state_shim import StateShim
2425
from durabletask.internal.helpers import new_timestamp
2526
from durabletask.entities import DurableEntity, EntityLock, EntityInstanceId, EntityContext
27+
from durabletask.internal.json_encode_output_exception import JsonEncodeOutputException
2628
from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext
2729
from durabletask.internal.proto_task_hub_sidecar_service_stub import ProtoTaskHubSidecarServiceStub
2830
import durabletask.internal.helpers as ph
@@ -141,14 +143,12 @@ class _Registry:
141143
orchestrators: dict[str, task.Orchestrator]
142144
activities: dict[str, task.Activity]
143145
entities: dict[str, task.Entity]
144-
entity_instances: dict[str, DurableEntity]
145146
versioning: Optional[VersioningOptions] = None
146147

147148
def __init__(self):
148149
self.orchestrators = {}
149150
self.activities = {}
150151
self.entities = {}
151-
self.entity_instances = {}
152152

153153
def add_orchestrator(self, fn: task.Orchestrator[TInput, TOutput]) -> str:
154154
if fn is None:
@@ -199,8 +199,8 @@ def add_entity(self, fn: task.Entity, name: Optional[str] = None) -> str:
199199
return name
200200

201201
def add_named_entity(self, name: str, fn: task.Entity) -> None:
202-
if not name:
203-
raise ValueError("A non-empty entity name is required.")
202+
name = name.lower()
203+
EntityInstanceId.validate_entity_name(name)
204204
if name in self.entities:
205205
raise ValueError(f"A '{name}' entity already exists.")
206206

@@ -829,7 +829,7 @@ def __init__(self, instance_id: str, registry: _Registry):
829829
self._pending_actions: dict[int, pb.OrchestratorAction] = {}
830830
self._pending_tasks: dict[int, task.CompletableTask] = {}
831831
# Maps entity ID to task ID
832-
self._entity_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
832+
self._entity_task_id_map: dict[str, tuple[EntityInstanceId, str, int]] = {}
833833
self._entity_lock_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
834834
# Maps criticalSectionId to task ID
835835
self._entity_lock_id_map: dict[str, int] = {}
@@ -902,7 +902,13 @@ def set_complete(
902902
self._result = result
903903
result_json: Optional[str] = None
904904
if result is not None:
905-
result_json = result if is_result_encoded else shared.to_json(result)
905+
try:
906+
result_json = result if is_result_encoded else shared.to_json(result)
907+
except (ValueError, TypeError):
908+
self._is_complete = False
909+
self._result = None
910+
self.set_failed(JsonEncodeOutputException(result))
911+
return
906912
action = ph.new_complete_orchestration_action(
907913
self.next_sequence_number(), status, result_json
908914
)
@@ -1606,7 +1612,7 @@ def process_event(
16061612
raise TypeError("Unexpected sub-orchestration task type")
16071613
elif event.HasField("eventRaised"):
16081614
if event.eventRaised.name in ctx._entity_task_id_map:
1609-
entity_id, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None))
1615+
entity_id, operation, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None, None))
16101616
self._handle_entity_event_raised(ctx, event, entity_id, task_id, False)
16111617
elif event.eventRaised.name in ctx._entity_lock_task_id_map:
16121618
entity_id, task_id = ctx._entity_lock_task_id_map.get(event.eventRaised.name, (None, None))
@@ -1680,9 +1686,10 @@ def process_event(
16801686
)
16811687
try:
16821688
entity_id = EntityInstanceId.parse(event.entityOperationCalled.targetInstanceId.value)
1689+
operation = event.entityOperationCalled.operation
16831690
except ValueError:
16841691
raise RuntimeError(f"Could not parse entity ID from targetInstanceId '{event.entityOperationCalled.targetInstanceId.value}'")
1685-
ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, entity_call_id)
1692+
ctx._entity_task_id_map[event.entityOperationCalled.requestId] = (entity_id, operation, entity_call_id)
16861693
elif event.HasField("entityOperationSignaled"):
16871694
# This history event confirms that the entity signal was successfully scheduled.
16881695
# Remove the entityOperationSignaled event from the pending action list so we don't schedule it
@@ -1743,7 +1750,7 @@ def process_event(
17431750
ctx.resume()
17441751
elif event.HasField("entityOperationCompleted"):
17451752
request_id = event.entityOperationCompleted.requestId
1746-
entity_id, task_id = ctx._entity_task_id_map.pop(request_id, (None, None))
1753+
entity_id, operation, task_id = ctx._entity_task_id_map.pop(request_id, (None, None, None))
17471754
if not entity_id:
17481755
raise RuntimeError(f"Could not parse entity ID from request ID '{request_id}'")
17491756
if not task_id:
@@ -1762,10 +1769,29 @@ def process_event(
17621769
entity_task.complete(result)
17631770
ctx.resume()
17641771
elif event.HasField("entityOperationFailed"):
1765-
if not ctx.is_replaying:
1766-
self._logger.info(f"{ctx.instance_id}: Entity operation failed.")
1767-
self._logger.info(f"Data: {json.dumps(event.entityOperationFailed)}")
1768-
pass
1772+
request_id = event.entityOperationFailed.requestId
1773+
entity_id, operation, task_id = ctx._entity_task_id_map.pop(request_id, (None, None, None))
1774+
if not entity_id:
1775+
raise RuntimeError(f"Could not parse entity ID from request ID '{request_id}'")
1776+
if operation is None:
1777+
raise RuntimeError(f"Could not parse operation name from request ID '{request_id}'")
1778+
if not task_id:
1779+
raise RuntimeError(f"Could not find matching task ID for entity operation with request ID '{request_id}'")
1780+
entity_task = ctx._pending_tasks.pop(task_id, None)
1781+
if not entity_task:
1782+
if not ctx.is_replaying:
1783+
self._logger.warning(
1784+
f"{ctx.instance_id}: Ignoring unexpected entityOperationFailed event with request ID = {request_id}."
1785+
)
1786+
return
1787+
failure = EntityOperationFailedException(
1788+
entity_id,
1789+
operation,
1790+
event.entityOperationFailed.failureDetails
1791+
)
1792+
ctx._entity_context.recover_lock_after_call(entity_id)
1793+
entity_task.fail(str(failure), failure)
1794+
ctx.resume()
17691795
elif event.HasField("orchestratorCompleted"):
17701796
# Added in Functions only (for some reason) and does not affect orchestrator flow
17711797
pass
@@ -1777,7 +1803,7 @@ def process_event(
17771803
if action and action.HasField("sendEntityMessage"):
17781804
if action.sendEntityMessage.HasField("entityOperationCalled"):
17791805
entity_id, event_id = self._parse_entity_event_sent_input(event)
1780-
ctx._entity_task_id_map[event_id] = (entity_id, event.eventId)
1806+
ctx._entity_task_id_map[event_id] = (entity_id, action.sendEntityMessage.entityOperationCalled.operation, event.eventId)
17811807
elif action.sendEntityMessage.HasField("entityLockRequested"):
17821808
entity_id, event_id = self._parse_entity_event_sent_input(event)
17831809
ctx._entity_lock_task_id_map[event_id] = (entity_id, event.eventId)
@@ -1937,11 +1963,7 @@ def execute(
19371963
ctx = EntityContext(orchestration_id, operation, state, entity_id)
19381964

19391965
if isinstance(fn, type) and issubclass(fn, DurableEntity):
1940-
if self._registry.entity_instances.get(str(entity_id), None):
1941-
entity_instance = self._registry.entity_instances[str(entity_id)]
1942-
else:
1943-
entity_instance = fn()
1944-
self._registry.entity_instances[str(entity_id)] = entity_instance
1966+
entity_instance = fn()
19451967
if not hasattr(entity_instance, operation):
19461968
raise AttributeError(f"Entity '{entity_id}' does not have operation '{operation}'")
19471969
method = getattr(entity_instance, operation)

tests/durabletask-azuremanaged/entities/__init__.py

Whitespace-only changes.

tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py renamed to tests/durabletask-azuremanaged/entities/test_dts_class_based_entities_e2e.py

File renamed without changes.

0 commit comments

Comments
 (0)