Skip to content

Commit ecd703d

Browse files
authored
Add Client.count_workflows (#510)
Fixes #294
1 parent 50c2033 commit ecd703d

File tree

3 files changed

+170
-1
lines changed

3 files changed

+170
-1
lines changed

temporalio/client.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,30 @@ def list_workflows(
819819
)
820820
)
821821

822+
async def count_workflows(
823+
self,
824+
query: Optional[str] = None,
825+
rpc_metadata: Mapping[str, str] = {},
826+
rpc_timeout: Optional[timedelta] = None,
827+
) -> WorkflowExecutionCount:
828+
"""Count workflows.
829+
830+
Args:
831+
query: A Temporal visibility filter. See Temporal documentation
832+
concerning visibility list filters.
833+
rpc_metadata: Headers used on each RPC call. Keys here override
834+
client-level RPC metadata keys.
835+
rpc_timeout: Optional RPC deadline to set for each RPC call.
836+
837+
Returns:
838+
Count of workflows.
839+
"""
840+
return await self._impl.count_workflows(
841+
CountWorkflowsInput(
842+
query=query, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout
843+
)
844+
)
845+
822846
@overload
823847
def get_async_activity_handle(
824848
self, *, workflow_id: str, run_id: Optional[str], activity_id: str
@@ -2310,6 +2334,57 @@ class WorkflowExecutionStatus(IntEnum):
23102334
)
23112335

23122336

2337+
@dataclass
2338+
class WorkflowExecutionCount:
2339+
"""Representation of a count from a count workflows call."""
2340+
2341+
count: int
2342+
"""Approximate number of workflows matching the original query.
2343+
2344+
If the query had a group-by clause, this is simply the sum of all the counts
2345+
in py:attr:`groups`.
2346+
"""
2347+
2348+
groups: Sequence[WorkflowExecutionCountAggregationGroup]
2349+
"""Groups if the query had a group-by clause, or empty if not."""
2350+
2351+
@staticmethod
2352+
def _from_raw(
2353+
raw: temporalio.api.workflowservice.v1.CountWorkflowExecutionsResponse,
2354+
) -> WorkflowExecutionCount:
2355+
return WorkflowExecutionCount(
2356+
count=raw.count,
2357+
groups=[
2358+
WorkflowExecutionCountAggregationGroup._from_raw(g) for g in raw.groups
2359+
],
2360+
)
2361+
2362+
2363+
@dataclass
2364+
class WorkflowExecutionCountAggregationGroup:
2365+
"""Aggregation group if the workflow count query had a group-by clause."""
2366+
2367+
count: int
2368+
"""Approximate number of workflows matching the original query for this
2369+
group.
2370+
"""
2371+
2372+
group_values: Sequence[temporalio.common.SearchAttributeValue]
2373+
"""Search attribute values for this group."""
2374+
2375+
@staticmethod
2376+
def _from_raw(
2377+
raw: temporalio.api.workflowservice.v1.CountWorkflowExecutionsResponse.AggregationGroup,
2378+
) -> WorkflowExecutionCountAggregationGroup:
2379+
return WorkflowExecutionCountAggregationGroup(
2380+
count=raw.count,
2381+
group_values=[
2382+
temporalio.converter._decode_search_attribute_value(v)
2383+
for v in raw.group_values
2384+
],
2385+
)
2386+
2387+
23132388
class WorkflowExecutionAsyncIterator:
23142389
"""Asynchronous iterator for :py:class:`WorkflowExecution` values.
23152390
@@ -4373,6 +4448,15 @@ class ListWorkflowsInput:
43734448
rpc_timeout: Optional[timedelta]
43744449

43754450

4451+
@dataclass
4452+
class CountWorkflowsInput:
4453+
"""Input for :py:meth:`OutboundInterceptor.count_workflows`."""
4454+
4455+
query: Optional[str]
4456+
rpc_metadata: Mapping[str, str]
4457+
rpc_timeout: Optional[timedelta]
4458+
4459+
43764460
@dataclass
43774461
class QueryWorkflowInput:
43784462
"""Input for :py:meth:`OutboundInterceptor.query_workflow`."""
@@ -4669,6 +4753,12 @@ def list_workflows(
46694753
"""Called for every :py:meth:`Client.list_workflows` call."""
46704754
return self.next.list_workflows(input)
46714755

4756+
async def count_workflows(
4757+
self, input: CountWorkflowsInput
4758+
) -> WorkflowExecutionCount:
4759+
"""Called for every :py:meth:`Client.count_workflows` call."""
4760+
return await self.next.count_workflows(input)
4761+
46724762
async def query_workflow(self, input: QueryWorkflowInput) -> Any:
46734763
"""Called for every :py:meth:`WorkflowHandle.query` call."""
46744764
return await self.next.query_workflow(input)
@@ -4928,6 +5018,21 @@ def list_workflows(
49285018
) -> WorkflowExecutionAsyncIterator:
49295019
return WorkflowExecutionAsyncIterator(self._client, input)
49305020

5021+
async def count_workflows(
5022+
self, input: CountWorkflowsInput
5023+
) -> WorkflowExecutionCount:
5024+
return WorkflowExecutionCount._from_raw(
5025+
await self._client.workflow_service.count_workflow_executions(
5026+
temporalio.api.workflowservice.v1.CountWorkflowExecutionsRequest(
5027+
namespace=self._client.namespace,
5028+
query=input.query or "",
5029+
),
5030+
retry=True,
5031+
metadata=input.rpc_metadata,
5032+
timeout=input.rpc_timeout,
5033+
)
5034+
)
5035+
49315036
async def query_workflow(self, input: QueryWorkflowInput) -> Any:
49325037
req = temporalio.api.workflowservice.v1.QueryWorkflowRequest(
49335038
namespace=self._client.namespace,

temporalio/converter.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,6 +1341,15 @@ def decode_typed_search_attributes(
13411341
return temporalio.common.TypedSearchAttributes(pairs)
13421342

13431343

1344+
def _decode_search_attribute_value(
1345+
payload: temporalio.api.common.v1.Payload,
1346+
) -> temporalio.common.SearchAttributeValue:
1347+
val = default().payload_converter.from_payload(payload)
1348+
if isinstance(val, str) and payload.metadata.get("type") == b"Datetime":
1349+
val = _get_iso_datetime_parser()(val)
1350+
return val # type: ignore
1351+
1352+
13441353
def value_to_type(
13451354
hint: Type,
13461355
value: Any,

tests/test_client.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import uuid
55
from datetime import datetime, timedelta, timezone
6-
from typing import Any, List, Optional, Tuple
6+
from typing import Any, List, Optional, Tuple, cast
77

88
import pytest
99
from google.protobuf import json_format
@@ -60,6 +60,8 @@
6060
TaskReachabilityType,
6161
TerminateWorkflowInput,
6262
WorkflowContinuedAsNewError,
63+
WorkflowExecutionCount,
64+
WorkflowExecutionCountAggregationGroup,
6365
WorkflowExecutionStatus,
6466
WorkflowFailureError,
6567
WorkflowHandle,
@@ -569,6 +571,59 @@ async def test_list_workflows_and_fetch_history(
569571
assert actual_id_and_input == expected_id_and_input
570572

571573

574+
@workflow.defn
575+
class CountableWorkflow:
576+
@workflow.run
577+
async def run(self, wait_forever: bool) -> None:
578+
await workflow.wait_condition(lambda: not wait_forever)
579+
580+
581+
async def test_count_workflows(client: Client, env: WorkflowEnvironment):
582+
if env.supports_time_skipping:
583+
pytest.skip("Java test server doesn't support newer workflow listing")
584+
585+
# 3 workflows that complete, 2 that don't
586+
async with new_worker(client, CountableWorkflow) as worker:
587+
for _ in range(3):
588+
await client.execute_workflow(
589+
CountableWorkflow.run,
590+
False,
591+
id=f"id-{uuid.uuid4()}",
592+
task_queue=worker.task_queue,
593+
)
594+
for _ in range(2):
595+
await client.start_workflow(
596+
CountableWorkflow.run,
597+
True,
598+
id=f"id-{uuid.uuid4()}",
599+
task_queue=worker.task_queue,
600+
)
601+
602+
async def fetch_count() -> WorkflowExecutionCount:
603+
resp = await client.count_workflows(
604+
f"TaskQueue = '{worker.task_queue}' GROUP BY ExecutionStatus"
605+
)
606+
cast(List[WorkflowExecutionCountAggregationGroup], resp.groups).sort(
607+
key=lambda g: g.count
608+
)
609+
return resp
610+
611+
await assert_eq_eventually(
612+
WorkflowExecutionCount(
613+
count=5,
614+
groups=[
615+
WorkflowExecutionCountAggregationGroup(
616+
count=2, group_values=["Running"]
617+
),
618+
WorkflowExecutionCountAggregationGroup(
619+
count=3, group_values=["Completed"]
620+
),
621+
],
622+
),
623+
fetch_count,
624+
)
625+
626+
572627
def test_history_from_json():
573628
# Take proto, make JSON, convert to dict, alter some enums, confirm that it
574629
# alters the enums back and matches original history

0 commit comments

Comments
 (0)