Skip to content

Commit 4494708

Browse files
committed
Initial implementation
1 parent 3a3c0c4 commit 4494708

File tree

5 files changed

+245
-17
lines changed

5 files changed

+245
-17
lines changed

durabletask/client.py

Lines changed: 137 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import logging
55
import uuid
66
from dataclasses import dataclass
7-
from datetime import datetime, timezone
7+
from datetime import datetime, timedelta, timezone
88
from enum import Enum
9-
from typing import Any, Optional, Sequence, TypeVar, Union
9+
from typing import Any, List, Optional, Sequence, TypeVar, Union
1010

1111
import grpc
12-
from google.protobuf import wrappers_pb2
12+
from google.protobuf import wrappers_pb2 as pb2
1313

1414
from durabletask.entities import EntityInstanceId
1515
from durabletask.entities.entity_metadata import EntityMetadata
@@ -57,6 +57,12 @@ def raise_if_failed(self):
5757
self.failure_details)
5858

5959

60+
class PurgeInstancesResult:
61+
def __init__(self, deleted_instance_count: int, is_complete: bool):
62+
self.deleted_instance_count = deleted_instance_count
63+
self.is_complete = is_complete
64+
65+
6066
class OrchestrationFailedError(Exception):
6167
def __init__(self, message: str, failure_details: task.FailureDetails):
6268
super().__init__(message)
@@ -73,6 +79,12 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
7379

7480
state = res.orchestrationState
7581

82+
new_state = parse_orchestration_state(state)
83+
new_state.instance_id = instance_id # Override instance_id with the one from the request, to match old behavior
84+
return new_state
85+
86+
87+
def parse_orchestration_state(state: pb.OrchestrationState) -> OrchestrationState:
7688
failure_details = None
7789
if state.failureDetails.errorMessage != '' or state.failureDetails.errorType != '':
7890
failure_details = task.FailureDetails(
@@ -81,7 +93,7 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
8193
state.failureDetails.stackTrace.value if not helpers.is_empty(state.failureDetails.stackTrace) else None)
8294

8395
return OrchestrationState(
84-
instance_id,
96+
state.instanceId,
8597
state.name,
8698
OrchestrationStatus(state.orchestrationStatus),
8799
state.createdTimestamp.ToDatetime(),
@@ -93,7 +105,6 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
93105

94106

95107
class TaskHubGrpcClient:
96-
97108
def __init__(self, *,
98109
host_address: Optional[str] = None,
99110
metadata: Optional[list[tuple[str, str]]] = None,
@@ -136,7 +147,7 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu
136147
req = pb.CreateInstanceRequest(
137148
name=name,
138149
instanceId=instance_id if instance_id else uuid.uuid4().hex,
139-
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None,
150+
input=helpers.get_string_value(shared.to_json(input)),
140151
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
141152
version=helpers.get_string_value(version if version else self.default_version),
142153
orchestrationIdReusePolicy=reuse_id_policy,
@@ -152,6 +163,54 @@ def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = Tr
152163
res: pb.GetInstanceResponse = self._stub.GetInstance(req)
153164
return new_orchestration_state(req.instanceId, res)
154165

166+
def get_all_orchestration_states(self,
167+
max_instance_count: Optional[int] = None,
168+
fetch_inputs_and_outputs: bool = False) -> List[OrchestrationState]:
169+
return self.get_orchestration_state_by(
170+
created_time_from=None,
171+
created_time_to=None,
172+
runtime_status=None,
173+
max_instance_count=max_instance_count,
174+
fetch_inputs_and_outputs=fetch_inputs_and_outputs
175+
)
176+
177+
def get_orchestration_state_by(self,
178+
created_time_from: Optional[datetime] = None,
179+
created_time_to: Optional[datetime] = None,
180+
runtime_status: Optional[List[OrchestrationStatus]] = None,
181+
max_instance_count: Optional[int] = None,
182+
fetch_inputs_and_outputs: bool = False,
183+
_continuation_token: Optional[pb2.StringValue] = None
184+
) -> List[OrchestrationState]:
185+
if max_instance_count is None:
186+
# DTS backend does not behave well with max_instance_count = None, so we set to max 32-bit signed value
187+
max_instance_count = (1 << 31) - 1
188+
req = pb.QueryInstancesRequest(
189+
query=pb.InstanceQuery(
190+
runtimeStatus=[status.value for status in runtime_status] if runtime_status else None,
191+
createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None,
192+
createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None,
193+
maxInstanceCount=max_instance_count,
194+
fetchInputsAndOutputs=fetch_inputs_and_outputs,
195+
continuationToken=_continuation_token
196+
)
197+
)
198+
resp: pb.QueryInstancesResponse = self._stub.QueryInstances(req)
199+
states = [parse_orchestration_state(res) for res in resp.orchestrationState]
200+
# Check the value for continuationToken - none or "0" indicates that there are no more results.
201+
if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
202+
self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next list of instances...")
203+
states += self.get_orchestration_state_by(
204+
created_time_from,
205+
created_time_to,
206+
runtime_status,
207+
max_instance_count,
208+
fetch_inputs_and_outputs,
209+
_continuation_token=resp.continuationToken
210+
)
211+
states = [state for state in states if state is not None] # Filter out any None values
212+
return states
213+
155214
def wait_for_orchestration_start(self, instance_id: str, *,
156215
fetch_payloads: bool = False,
157216
timeout: int = 60) -> Optional[OrchestrationState]:
@@ -199,7 +258,7 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *,
199258
req = pb.RaiseEventRequest(
200259
instanceId=instance_id,
201260
name=event_name,
202-
input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None)
261+
input=helpers.get_string_value(shared.to_json(data)))
203262

204263
self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
205264
self._stub.RaiseEvent(req)
@@ -209,7 +268,7 @@ def terminate_orchestration(self, instance_id: str, *,
209268
recursive: bool = True):
210269
req = pb.TerminateRequest(
211270
instanceId=instance_id,
212-
output=wrappers_pb2.StringValue(value=shared.to_json(output)) if output else None,
271+
output=helpers.get_string_value(shared.to_json(output)),
213272
recursive=recursive)
214273

215274
self._logger.info(f"Terminating instance '{instance_id}'.")
@@ -225,10 +284,27 @@ def resume_orchestration(self, instance_id: str):
225284
self._logger.info(f"Resuming instance '{instance_id}'.")
226285
self._stub.ResumeInstance(req)
227286

228-
def purge_orchestration(self, instance_id: str, recursive: bool = True):
287+
def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult:
229288
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
230289
self._logger.info(f"Purging instance '{instance_id}'.")
231-
self._stub.PurgeInstances(req)
290+
resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req)
291+
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
292+
293+
def purge_orchestrations_by(self,
294+
created_time_from: Optional[datetime] = None,
295+
created_time_to: Optional[datetime] = None,
296+
runtime_status: Optional[List[OrchestrationStatus]] = None,
297+
recursive: bool = False) -> PurgeInstancesResult:
298+
resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(pb.PurgeInstancesRequest(
299+
instanceId=None,
300+
purgeInstanceFilter=pb.PurgeInstanceFilter(
301+
createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None,
302+
createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None,
303+
runtimeStatus=[status.value for status in runtime_status] if runtime_status else None
304+
),
305+
recursive=recursive
306+
))
307+
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
232308

233309
def signal_entity(self,
234310
entity_instance_id: EntityInstanceId,
@@ -237,7 +313,7 @@ def signal_entity(self,
237313
req = pb.SignalEntityRequest(
238314
instanceId=str(entity_instance_id),
239315
name=operation_name,
240-
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input else None,
316+
input=helpers.get_string_value(shared.to_json(input)),
241317
requestId=str(uuid.uuid4()),
242318
scheduledTime=None,
243319
parentTraceContext=None,
@@ -256,4 +332,53 @@ def get_entity(self,
256332
if not res.exists:
257333
return None
258334

259-
return EntityMetadata.from_entity_response(res, include_state)
335+
return EntityMetadata.from_entity_metadata(res.entity, include_state)
336+
337+
def get_all_entities(self,
338+
include_state: bool = True,
339+
include_transient: bool = False,
340+
page_size: Optional[int] = None) -> List[EntityMetadata]:
341+
return self.get_entities_by(
342+
instance_id_starts_with=None,
343+
last_modified_from=None,
344+
last_modified_to=None,
345+
include_state=include_state,
346+
include_transient=include_transient,
347+
page_size=page_size
348+
)
349+
350+
def get_entities_by(self,
351+
instance_id_starts_with: Optional[str] = None,
352+
last_modified_from: Optional[datetime] = None,
353+
last_modified_to: Optional[datetime] = None,
354+
include_state: bool = True,
355+
include_transient: bool = False,
356+
page_size: Optional[int] = None,
357+
_continuation_token: Optional[pb2.StringValue] = None
358+
) -> List[EntityMetadata]:
359+
self._logger.info(f"Getting entities")
360+
query_request = pb.QueryEntitiesRequest(
361+
query=pb.EntityQuery(
362+
instanceIdStartsWith=helpers.get_string_value(instance_id_starts_with),
363+
lastModifiedFrom=helpers.new_timestamp(last_modified_from) if last_modified_from else None,
364+
lastModifiedTo=helpers.new_timestamp(last_modified_to) if last_modified_to else None,
365+
includeState=include_state,
366+
includeTransient=include_transient,
367+
pageSize=helpers.get_int_value(page_size),
368+
continuationToken=_continuation_token
369+
)
370+
)
371+
resp: pb.QueryEntitiesResponse = self._stub.QueryEntities(query_request)
372+
entities = [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities]
373+
if resp.continuationToken and resp.continuationToken.value != "0":
374+
self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next page of entities...")
375+
entities += self.get_entities_by(
376+
instance_id_starts_with=instance_id_starts_with,
377+
last_modified_from=last_modified_from,
378+
last_modified_to=last_modified_to,
379+
include_state=include_state,
380+
include_transient=include_transient,
381+
page_size=page_size,
382+
_continuation_token=resp.continuationToken
383+
)
384+
return entities

durabletask/entities/entity_metadata.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from datetime import datetime, timezone
22
from typing import Any, Optional, Type, TypeVar, Union, overload
3+
from warnings import deprecated
34
from durabletask.entities.entity_instance_id import EntityInstanceId
45

56
import durabletask.internal.orchestrator_service_pb2 as pb
@@ -42,20 +43,25 @@ def __init__(self,
4243
self.includes_state = includes_state
4344
self._state = state
4445

46+
@deprecated("This method is deprecated. Use 'from_entity_metadata' instead.")
4547
@staticmethod
4648
def from_entity_response(entity_response: pb.GetEntityResponse, includes_state: bool):
49+
return EntityMetadata.from_entity_metadata(entity_response.entity, includes_state)
50+
51+
@staticmethod
52+
def from_entity_metadata(entity: pb.EntityMetadata, includes_state: bool):
4753
try:
48-
entity_id = EntityInstanceId.parse(entity_response.entity.instanceId)
54+
entity_id = EntityInstanceId.parse(entity.instanceId)
4955
except ValueError:
5056
raise ValueError("Invalid entity instance ID in entity response.")
5157
entity_state = None
5258
if includes_state:
53-
entity_state = entity_response.entity.serializedState.value
59+
entity_state = entity.serializedState.value
5460
return EntityMetadata(
5561
id=entity_id,
56-
last_modified=entity_response.entity.lastModifiedTime.ToDatetime(timezone.utc),
57-
backlog_queue_size=entity_response.entity.backlogQueueSize,
58-
locked_by=entity_response.entity.lockedBy.value,
62+
last_modified=entity.lastModifiedTime.ToDatetime(timezone.utc),
63+
backlog_queue_size=entity.backlogQueueSize,
64+
locked_by=entity.lockedBy.value,
5965
includes_state=includes_state,
6066
state=entity_state
6167
)

durabletask/internal/helpers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,13 @@ def get_string_value(val: Optional[str]) -> Optional[wrappers_pb2.StringValue]:
184184
return wrappers_pb2.StringValue(value=val)
185185

186186

187+
def get_int_value(val: Optional[int]) -> Optional[wrappers_pb2.Int32Value]:
188+
if val is None:
189+
return None
190+
else:
191+
return wrappers_pb2.Int32Value(value=val)
192+
193+
187194
def get_string_value_or_empty(val: Optional[str]) -> wrappers_pb2.StringValue:
188195
if val is None:
189196
return wrappers_pb2.StringValue(value="")
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
2+
import os
3+
4+
import pytest
5+
from durabletask import client, task
6+
from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker
7+
from durabletask.azuremanaged.client import DurableTaskSchedulerClient
8+
9+
# Read the environment variables
10+
taskhub_name = os.getenv("TASKHUB", "default")
11+
endpoint = os.getenv("ENDPOINT", "http://localhost:8080")
12+
13+
pytestmark = pytest.mark.dts
14+
15+
16+
def empty_orchestrator(ctx: task.OrchestrationContext, _):
17+
return "Complete"
18+
19+
20+
def test_get_all_orchestration_states():
21+
# Start a worker, which will connect to the sidecar in a background thread
22+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
23+
taskhub=taskhub_name, token_credential=None) as w:
24+
w.add_orchestrator(empty_orchestrator)
25+
w.start()
26+
27+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
28+
taskhub=taskhub_name, token_credential=None)
29+
id = c.schedule_new_orchestration(empty_orchestrator, input="Hello")
30+
c.wait_for_orchestration_completion(id, timeout=30)
31+
32+
all_orchestrations = c.get_all_orchestration_states()
33+
all_orchestrations_with_state = c.get_all_orchestration_states(fetch_inputs_and_outputs=True)
34+
this_orch = c.get_orchestration_state(id)
35+
36+
assert this_orch is not None
37+
assert this_orch.instance_id == id
38+
39+
assert all_orchestrations is not None
40+
assert len(all_orchestrations) > 1
41+
print(f"Received {len(all_orchestrations)} orchestrations")
42+
assert len([o for o in all_orchestrations if o.instance_id == id]) == 1
43+
orchestration_state = [o for o in all_orchestrations if o.instance_id == id][0]
44+
assert orchestration_state.runtime_status == client.OrchestrationStatus.COMPLETED
45+
assert orchestration_state.serialized_input is None
46+
assert orchestration_state.serialized_output is None
47+
assert orchestration_state.failure_details is None
48+
49+
assert all_orchestrations_with_state is not None
50+
assert len(all_orchestrations_with_state) > 1
51+
print(f"Received {len(all_orchestrations_with_state)} orchestrations")
52+
assert len([o for o in all_orchestrations_with_state if o.instance_id == id]) == 1
53+
orchestration_state = [o for o in all_orchestrations_with_state if o.instance_id == id][0]
54+
assert orchestration_state.runtime_status == client.OrchestrationStatus.COMPLETED
55+
assert orchestration_state.serialized_input == '"Hello"'
56+
assert orchestration_state.serialized_output == '"Complete"'
57+
assert orchestration_state.failure_details is None
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
2+
from durabletask import client, task, worker
3+
4+
5+
def empty_orchestrator(ctx: task.OrchestrationContext, _):
6+
return "Complete"
7+
8+
9+
def test_get_all_orchestration_states():
10+
# Start a worker, which will connect to the sidecar in a background thread
11+
with worker.TaskHubGrpcWorker() as w:
12+
w.add_orchestrator(empty_orchestrator)
13+
w.start()
14+
15+
c = client.TaskHubGrpcClient()
16+
id = c.schedule_new_orchestration(empty_orchestrator, input="Hello")
17+
c.wait_for_orchestration_completion(id, timeout=30)
18+
19+
all_orchestrations = c.get_all_orchestration_states()
20+
this_orch = c.get_orchestration_state(id)
21+
22+
assert this_orch is not None
23+
assert this_orch.instance_id == id
24+
25+
assert all_orchestrations is not None
26+
assert len(all_orchestrations) > 1
27+
print(f"Received {len(all_orchestrations)} orchestrations")
28+
assert len([o for o in all_orchestrations if o.instance_id == id]) == 1
29+
orchestration_state = [o for o in all_orchestrations if o.instance_id == id][0]
30+
assert orchestration_state.runtime_status == client.OrchestrationStatus.COMPLETED
31+
assert orchestration_state.serialized_input == '"Hello"'
32+
assert orchestration_state.serialized_output == '"Complete"'
33+
assert orchestration_state.failure_details is None

0 commit comments

Comments
 (0)