Skip to content

Commit 540421c

Browse files
committed
PR Feedback
1 parent dd2548a commit 540421c

File tree

3 files changed

+139
-146
lines changed

3 files changed

+139
-146
lines changed

durabletask/client.py

Lines changed: 45 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,27 @@ def raise_if_failed(self):
5757
self.failure_details)
5858

5959

60+
@dataclass
61+
class OrchestrationQuery:
62+
created_time_from: Optional[datetime] = None
63+
created_time_to: Optional[datetime] = None
64+
runtime_status: Optional[List[OrchestrationStatus]] = None
65+
# Some backends don't respond well with max_instance_count = None, so we use the integer limit for non-paginated
66+
# results instead.
67+
max_instance_count: Optional[int] = (1 << 31) - 1
68+
fetch_inputs_and_outputs: bool = False
69+
70+
71+
@dataclass
72+
class EntityQuery:
73+
instance_id_starts_with: Optional[str] = None
74+
last_modified_from: Optional[datetime] = None
75+
last_modified_to: Optional[datetime] = None
76+
include_state: bool = True
77+
include_transient: bool = False
78+
page_size: Optional[int] = None
79+
80+
6081
@dataclass
6182
class PurgeInstancesResult:
6283
deleted_instance_count: int
@@ -170,46 +191,24 @@ def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = Tr
170191
return new_orchestration_state(req.instanceId, res)
171192

172193
def get_all_orchestration_states(self,
173-
max_instance_count: Optional[int] = None,
174-
fetch_inputs_and_outputs: bool = False) -> List[OrchestrationState]:
175-
return self.get_orchestration_state_by(
176-
created_time_from=None,
177-
created_time_to=None,
178-
runtime_status=None,
179-
max_instance_count=max_instance_count,
180-
fetch_inputs_and_outputs=fetch_inputs_and_outputs
181-
)
194+
orchestration_query: Optional[OrchestrationQuery] = None
195+
) -> List[OrchestrationState]:
196+
if orchestration_query is None:
197+
orchestration_query = OrchestrationQuery()
198+
_continuation_token = None
182199

183-
def get_orchestration_state_by(self,
184-
created_time_from: Optional[datetime] = None,
185-
created_time_to: Optional[datetime] = None,
186-
runtime_status: Optional[List[OrchestrationStatus]] = None,
187-
max_instance_count: Optional[int] = None,
188-
fetch_inputs_and_outputs: bool = False,
189-
_continuation_token: Optional[pb2.StringValue] = None
190-
) -> List[OrchestrationState]:
191-
if max_instance_count is None:
192-
# Some backends do not behave well with max_instance_count = None, so we set to max 32-bit signed value
193-
max_instance_count = (1 << 31) - 1
194-
195-
self._logger.info(f"Querying orchestration instances with filters - "
196-
f"created_time_from={created_time_from}, "
197-
f"created_time_to={created_time_to}, "
198-
f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, "
199-
f"max_instance_count={max_instance_count}, "
200-
f"fetch_inputs_and_outputs={fetch_inputs_and_outputs}, "
201-
f"continuation_token={_continuation_token.value if _continuation_token else None}")
200+
self._logger.info(f"Querying orchestration instances with query: {orchestration_query}")
202201

203202
states = []
204203

205204
while True:
206205
req = pb.QueryInstancesRequest(
207206
query=pb.InstanceQuery(
208-
runtimeStatus=[status.value for status in runtime_status] if runtime_status else None,
209-
createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None,
210-
createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None,
211-
maxInstanceCount=max_instance_count,
212-
fetchInputsAndOutputs=fetch_inputs_and_outputs,
207+
runtimeStatus=[status.value for status in orchestration_query.runtime_status] if orchestration_query.runtime_status else None,
208+
createdTimeFrom=helpers.new_timestamp(orchestration_query.created_time_from) if orchestration_query.created_time_from else None,
209+
createdTimeTo=helpers.new_timestamp(orchestration_query.created_time_to) if orchestration_query.created_time_to else None,
210+
maxInstanceCount=orchestration_query.max_instance_count,
211+
fetchInputsAndOutputs=orchestration_query.fetch_inputs_and_outputs,
213212
continuationToken=_continuation_token
214213
)
215214
)
@@ -318,7 +317,6 @@ def purge_orchestrations_by(self,
318317
f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, "
319318
f"recursive={recursive}")
320319
resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(pb.PurgeInstancesRequest(
321-
instanceId=None,
322320
purgeInstanceFilter=pb.PurgeInstanceFilter(
323321
createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None,
324322
createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None,
@@ -357,46 +355,24 @@ def get_entity(self,
357355
return EntityMetadata.from_entity_metadata(res.entity, include_state)
358356

359357
def get_all_entities(self,
360-
include_state: bool = True,
361-
include_transient: bool = False,
362-
page_size: Optional[int] = None) -> List[EntityMetadata]:
363-
return self.get_entities_by(
364-
instance_id_starts_with=None,
365-
last_modified_from=None,
366-
last_modified_to=None,
367-
include_state=include_state,
368-
include_transient=include_transient,
369-
page_size=page_size
370-
)
358+
entity_query: Optional[EntityQuery] = None) -> List[EntityMetadata]:
359+
if entity_query is None:
360+
entity_query = EntityQuery()
361+
_continuation_token = None
371362

372-
def get_entities_by(self,
373-
instance_id_starts_with: Optional[str] = None,
374-
last_modified_from: Optional[datetime] = None,
375-
last_modified_to: Optional[datetime] = None,
376-
include_state: bool = True,
377-
include_transient: bool = False,
378-
page_size: Optional[int] = None,
379-
_continuation_token: Optional[pb2.StringValue] = None
380-
) -> List[EntityMetadata]:
381-
self._logger.info(f"Retrieving entities by filter: "
382-
f"instance_id_starts_with={instance_id_starts_with}, "
383-
f"last_modified_from={last_modified_from}, "
384-
f"last_modified_to={last_modified_to}, "
385-
f"include_state={include_state}, "
386-
f"include_transient={include_transient}, "
387-
f"page_size={page_size}")
363+
self._logger.info(f"Retrieving entities by filter: {entity_query}")
388364

389365
entities = []
390366

391367
while True:
392368
query_request = pb.QueryEntitiesRequest(
393369
query=pb.EntityQuery(
394-
instanceIdStartsWith=helpers.get_string_value(instance_id_starts_with),
395-
lastModifiedFrom=helpers.new_timestamp(last_modified_from) if last_modified_from else None,
396-
lastModifiedTo=helpers.new_timestamp(last_modified_to) if last_modified_to else None,
397-
includeState=include_state,
398-
includeTransient=include_transient,
399-
pageSize=helpers.get_int_value(page_size),
370+
instanceIdStartsWith=helpers.get_string_value(entity_query.instance_id_starts_with),
371+
lastModifiedFrom=helpers.new_timestamp(entity_query.last_modified_from) if entity_query.last_modified_from else None,
372+
lastModifiedTo=helpers.new_timestamp(entity_query.last_modified_to) if entity_query.last_modified_to else None,
373+
includeState=entity_query.include_state,
374+
includeTransient=entity_query.include_transient,
375+
pageSize=helpers.get_int_value(entity_query.page_size),
400376
continuationToken=_continuation_token
401377
)
402378
)
@@ -414,13 +390,13 @@ def get_entities_by(self,
414390

415391
def clean_entity_storage(self,
416392
remove_empty_entities: bool = True,
417-
release_orphaned_locks: bool = True,
418-
_continuation_token: Optional[pb2.StringValue] = None
393+
release_orphaned_locks: bool = True
419394
) -> CleanEntityStorageResult:
420395
self._logger.info("Cleaning entity storage")
421396

422397
empty_entities_removed = 0
423398
orphaned_locks_released = 0
399+
_continuation_token = None
424400

425401
while True:
426402
req = pb.CleanEntityStorageRequest(

tests/durabletask-azuremanaged/test_dts_batch_actions.py

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
import asyncio
23
import logging
34
import os
45
import time
@@ -37,7 +38,9 @@ def test_get_all_orchestration_states():
3738
c.wait_for_orchestration_completion(id, timeout=30)
3839

3940
all_orchestrations = c.get_all_orchestration_states()
40-
all_orchestrations_with_state = c.get_all_orchestration_states(fetch_inputs_and_outputs=True)
41+
query = client.OrchestrationQuery()
42+
query.fetch_inputs_and_outputs = True
43+
all_orchestrations_with_state = c.get_all_orchestration_states(query)
4144
this_orch = c.get_orchestration_state(id)
4245

4346
assert this_orch is not None
@@ -84,16 +87,16 @@ def test_get_orchestration_state_by_status():
8487
pass # Expected failure
8588

8689
# Query by completed status
87-
completed_orchestrations = c.get_orchestration_state_by(
88-
runtime_status=[client.OrchestrationStatus.COMPLETED],
89-
fetch_inputs_and_outputs=True
90-
)
90+
query = client.OrchestrationQuery()
91+
query.runtime_status = [client.OrchestrationStatus.COMPLETED]
92+
query.fetch_inputs_and_outputs = True
93+
completed_orchestrations = c.get_all_orchestration_states(query)
9194

9295
# Query by failed status
93-
failed_orchestrations = c.get_orchestration_state_by(
94-
runtime_status=[client.OrchestrationStatus.FAILED],
95-
fetch_inputs_and_outputs=True
96-
)
96+
query = client.OrchestrationQuery()
97+
query.runtime_status = [client.OrchestrationStatus.FAILED]
98+
query.fetch_inputs_and_outputs = True
99+
failed_orchestrations = c.get_all_orchestration_states(query)
97100

98101
assert len([o for o in completed_orchestrations if o.instance_id == completed_id]) == 1
99102
completed_orch = [o for o in completed_orchestrations if o.instance_id == completed_id][0]
@@ -125,17 +128,20 @@ def test_get_orchestration_state_by_time_range():
125128
after_creation = datetime.now(timezone.utc) + timedelta(seconds=5)
126129

127130
# Query by time range
128-
orchestrations_in_range = c.get_orchestration_state_by(
131+
query = client.OrchestrationQuery(
129132
created_time_from=before_creation,
130133
created_time_to=after_creation,
131134
fetch_inputs_and_outputs=True
132135
)
136+
orchestrations_in_range = c.get_all_orchestration_states(query)
133137

134138
# Query outside time range
135-
orchestrations_outside_range = c.get_orchestration_state_by(
139+
query = client.OrchestrationQuery(
136140
created_time_from=after_creation,
137-
created_time_to=after_creation + timedelta(hours=1)
141+
created_time_to=after_creation + timedelta(hours=1),
142+
fetch_inputs_and_outputs=True
138143
)
144+
orchestrations_outside_range = c.get_all_orchestration_states(query)
139145

140146
assert len([o for o in orchestrations_in_range if o.instance_id == id]) == 1
141147
assert len([o for o in orchestrations_outside_range if o.instance_id == id]) == 0
@@ -171,7 +177,8 @@ def emit(self, record):
171177
c.wait_for_orchestration_completion(id, timeout=30)
172178

173179
# Query with max_instance_count=2
174-
orchestrations = c.get_orchestration_state_by(max_instance_count=2)
180+
query = client.OrchestrationQuery(max_instance_count=2)
181+
orchestrations = c.get_all_orchestration_states(query)
175182

176183
# Should return more than 2 instances since we created at least 3
177184
assert len(orchestrations) > 2
@@ -303,16 +310,18 @@ def counter_entity(ctx: entities.EntityContext, input):
303310
# Create entity
304311
entity_id = entities.EntityInstanceId("counter_entity", "testCounter1")
305312
c.signal_entity(entity_id, "add", 5)
306-
time.sleep(2) # Wait for signal to be processed
313+
asyncio.run(asyncio.sleep(2)) # Wait for signal to be processed
307314

308315
# Get all entities without state
309-
all_entities = c.get_all_entities(include_state=False)
316+
query = client.EntityQuery(include_state=False)
317+
all_entities = c.get_all_entities(query)
310318
assert len([e for e in all_entities if e.id == entity_id]) == 1
311319
entity_without_state = [e for e in all_entities if e.id == entity_id][0]
312320
assert entity_without_state.get_state(int) is None
313321

314322
# Get all entities with state
315-
all_entities_with_state = c.get_all_entities(include_state=True)
323+
query = client.EntityQuery(include_state=True)
324+
all_entities_with_state = c.get_all_entities(query)
316325
assert len([e for e in all_entities_with_state if e.id == entity_id]) == 1
317326
entity_with_state = [e for e in all_entities_with_state if e.id == entity_id][0]
318327
assert entity_with_state.get_state(int) == 5
@@ -337,18 +346,20 @@ def counter_entity(ctx: entities.EntityContext, input):
337346

338347
c.signal_entity(entity_id_1, "set", 10)
339348
c.signal_entity(entity_id_2, "set", 20)
340-
time.sleep(2) # Wait for signals to be processed
349+
asyncio.run(asyncio.sleep(2)) # Wait for signals to be processed
341350

342351
# Query by prefix
343-
entities_prefix1 = c.get_entities_by(
352+
query = client.EntityQuery(
344353
instance_id_starts_with="@counter_entity@prefix1",
345354
include_state=True
346355
)
356+
entities_prefix1 = c.get_all_entities(query)
347357

348-
entities_prefix2 = c.get_entities_by(
358+
query = client.EntityQuery(
349359
instance_id_starts_with="@counter_entity@prefix2",
350360
include_state=True
351361
)
362+
entities_prefix2 = c.get_all_entities(query)
352363

353364
assert len([e for e in entities_prefix1 if e.id == entity_id_1]) == 1
354365
assert len([e for e in entities_prefix1 if e.id == entity_id_2]) == 0
@@ -376,22 +387,24 @@ def simple_entity(ctx: entities.EntityContext, input):
376387
# Create entity
377388
entity_id = entities.EntityInstanceId("simple_entity", "timeTestEntity")
378389
c.signal_entity(entity_id, "set", "test_value")
379-
time.sleep(2) # Wait for signal to be processed
390+
asyncio.run(asyncio.sleep(2)) # Wait for signal to be processed
380391

381392
after_creation = datetime.now(timezone.utc) + timedelta(seconds=5)
382393

383394
# Query by time range
384-
entities_in_range = c.get_entities_by(
395+
query = client.EntityQuery(
385396
last_modified_from=before_creation,
386397
last_modified_to=after_creation,
387398
include_state=True
388399
)
400+
entities_in_range = c.get_all_entities(query)
389401

390402
# Query outside time range
391-
entities_outside_range = c.get_entities_by(
403+
query = client.EntityQuery(
392404
last_modified_from=after_creation,
393405
last_modified_to=after_creation + timedelta(hours=1)
394406
)
407+
entities_outside_range = c.get_all_entities(query)
395408

396409
assert len([e for e in entities_in_range if e.id == entity_id]) == 1
397410
assert len([e for e in entities_outside_range if e.id == entity_id]) == 0
@@ -412,13 +425,14 @@ class EmptyEntity(entities.DurableEntity):
412425
# Create an entity and then delete its state to make it empty
413426
entity_id = entities.EntityInstanceId("EmptyEntity", "toClean")
414427
c.signal_entity(entity_id, "delete")
415-
time.sleep(2) # Wait for signal to be processed
428+
asyncio.run(asyncio.sleep(2)) # Wait for signal to be processed
416429

417430
# Clean entity storage
418431
result = c.clean_entity_storage(
419432
remove_empty_entities=True,
420433
release_orphaned_locks=True
421434
)
422435

423-
# Verify clean result - we expect at least the entity we just deleted to be removed
424-
assert result.empty_entities_removed >= 0
436+
# Verify clean result - DTS backend always returns 0, as it has its own mechanism for entity state purge
437+
assert result.empty_entities_removed == 0
438+
assert result.orphaned_locks_released == 0

0 commit comments

Comments
 (0)