Skip to content

Commit 94bf238

Browse files
committed
Refactor test
1 parent 62d7c03 commit 94bf238

File tree

1 file changed

+62
-22
lines changed

1 file changed

+62
-22
lines changed

tests/worker/test_workflow.py

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
import warnings
1212
from abc import ABC, abstractmethod
1313
from contextlib import contextmanager
14-
from dataclasses import dataclass
14+
from dataclasses import dataclass, field
1515
from datetime import datetime, timedelta, timezone
1616
from enum import IntEnum
17+
from functools import partial
1718
from typing import (
1819
Any,
1920
Awaitable,
@@ -5159,6 +5160,19 @@ class _UnfinishedHandlersTest:
51595160
handler_type: Literal["update", "signal"]
51605161

51615162
async def run_test(self):
5163+
# If we don't capture warnings then -- since the unfinished handler warning is converted to
5164+
# an exception in the test suite -- we see WFT failures when we don't wait for handlers.
5165+
handle: asyncio.Future[WorkflowHandle] = asyncio.Future()
5166+
asyncio.create_task(
5167+
self.get_workflow_result(wait_for_handlers=False, handle_future=handle)
5168+
)
5169+
await assert_eq_eventually(
5170+
True,
5171+
partial(self.workflow_task_failed, workflow_id=(await handle).id),
5172+
timeout=timedelta(seconds=5),
5173+
interval=timedelta(seconds=1),
5174+
)
5175+
51625176
# The unfinished handler warning is issued by default,
51635177
handler_finished, warning = await self.get_workflow_result_and_warning(
51645178
wait_for_handlers=False,
@@ -5182,45 +5196,71 @@ async def run_test(self):
51825196
)
51835197
assert not handler_finished and not warning
51845198

5199+
async def workflow_task_failed(self, workflow_id: str) -> bool:
5200+
resp = await self.client.workflow_service.get_workflow_execution_history(
5201+
GetWorkflowExecutionHistoryRequest(
5202+
namespace=self.client.namespace,
5203+
execution=WorkflowExecution(workflow_id=workflow_id),
5204+
),
5205+
)
5206+
for event in reversed(resp.history.events):
5207+
if event.event_type == EventType.EVENT_TYPE_WORKFLOW_TASK_FAILED:
5208+
assert event.workflow_task_failed_event_attributes.failure.message.startswith(
5209+
f"Workflow finished while {self.handler_type} handlers are still running"
5210+
)
5211+
return True
5212+
return False
5213+
51855214
async def get_workflow_result_and_warning(
51865215
self,
51875216
wait_for_handlers: bool,
51885217
unfinished_policy: Optional[workflow.HandlerUnfinishedPolicy] = None,
51895218
) -> Tuple[bool, bool]:
5219+
with pytest.WarningsRecorder() as warnings:
5220+
wf_result = await self.get_workflow_result(
5221+
wait_for_handlers, unfinished_policy
5222+
)
5223+
unfinished_handler_warning_emitted = any(
5224+
issubclass(w.category, self.unfinished_handler_warning_cls)
5225+
for w in warnings
5226+
)
5227+
return wf_result, unfinished_handler_warning_emitted
5228+
5229+
async def get_workflow_result(
5230+
self,
5231+
wait_for_handlers: bool,
5232+
unfinished_policy: Optional[workflow.HandlerUnfinishedPolicy] = None,
5233+
handle_future: Optional[asyncio.Future[WorkflowHandle]] = None,
5234+
) -> bool:
51905235
handle = await self.client.start_workflow(
51915236
UnfinishedHandlersWorkflow.run,
51925237
arg=wait_for_handlers,
51935238
id=f"wf-{uuid.uuid4()}",
51945239
task_queue=self.worker.task_queue,
51955240
)
5241+
if handle_future:
5242+
handle_future.set_result(handle)
51965243
handler_name = f"my_{self.handler_type}"
51975244
if unfinished_policy:
51985245
handler_name += f"_{unfinished_policy.name}"
5199-
with pytest.WarningsRecorder() as warnings:
5200-
if self.handler_type == "signal":
5201-
await asyncio.gather(handle.signal(handler_name))
5202-
else:
5203-
if not wait_for_handlers:
5204-
with pytest.raises(RPCError) as err:
5205-
await asyncio.gather(
5206-
handle.execute_update(handler_name, id="my-update")
5207-
)
5208-
assert (
5209-
err.value.status == RPCStatusCode.NOT_FOUND
5210-
and "workflow execution already completed"
5211-
in str(err.value).lower()
5212-
)
5213-
else:
5246+
if self.handler_type == "signal":
5247+
await asyncio.gather(handle.signal(handler_name))
5248+
else:
5249+
if not wait_for_handlers:
5250+
with pytest.raises(RPCError) as err:
52145251
await asyncio.gather(
52155252
handle.execute_update(handler_name, id="my-update")
52165253
)
5254+
assert (
5255+
err.value.status == RPCStatusCode.NOT_FOUND
5256+
and "workflow execution already completed" in str(err.value).lower()
5257+
)
5258+
else:
5259+
await asyncio.gather(
5260+
handle.execute_update(handler_name, id="my-update")
5261+
)
52175262

5218-
wf_result = await handle.result()
5219-
unfinished_handler_warning_emitted = any(
5220-
issubclass(w.category, self.unfinished_handler_warning_cls)
5221-
for w in warnings
5222-
)
5223-
return wf_result, unfinished_handler_warning_emitted
5263+
return await handle.result()
52245264

52255265
@property
52265266
def unfinished_handler_warning_cls(self) -> Type:

0 commit comments

Comments
 (0)