11
11
import warnings
12
12
from abc import ABC , abstractmethod
13
13
from contextlib import contextmanager
14
- from dataclasses import dataclass
14
+ from dataclasses import dataclass , field
15
15
from datetime import datetime , timedelta , timezone
16
16
from enum import IntEnum
17
+ from functools import partial
17
18
from typing import (
18
19
Any ,
19
20
Awaitable ,
@@ -5159,6 +5160,19 @@ class _UnfinishedHandlersTest:
5159
5160
handler_type : Literal ["update" , "signal" ]
5160
5161
5161
5162
async def run_test (self ):
5163
+ # If we don't capture warnings then -- since the unfinished handler warning is converted to
5164
+ # an exception in the test suite -- we see WFT failures when we don't wait for handlers.
5165
+ handle : asyncio .Future [WorkflowHandle ] = asyncio .Future ()
5166
+ asyncio .create_task (
5167
+ self .get_workflow_result (wait_for_handlers = False , handle_future = handle )
5168
+ )
5169
+ await assert_eq_eventually (
5170
+ True ,
5171
+ partial (self .workflow_task_failed , workflow_id = (await handle ).id ),
5172
+ timeout = timedelta (seconds = 5 ),
5173
+ interval = timedelta (seconds = 1 ),
5174
+ )
5175
+
5162
5176
# The unfinished handler warning is issued by default,
5163
5177
handler_finished , warning = await self .get_workflow_result_and_warning (
5164
5178
wait_for_handlers = False ,
@@ -5182,45 +5196,71 @@ async def run_test(self):
5182
5196
)
5183
5197
assert not handler_finished and not warning
5184
5198
5199
+ async def workflow_task_failed (self , workflow_id : str ) -> bool :
5200
+ resp = await self .client .workflow_service .get_workflow_execution_history (
5201
+ GetWorkflowExecutionHistoryRequest (
5202
+ namespace = self .client .namespace ,
5203
+ execution = WorkflowExecution (workflow_id = workflow_id ),
5204
+ ),
5205
+ )
5206
+ for event in reversed (resp .history .events ):
5207
+ if event .event_type == EventType .EVENT_TYPE_WORKFLOW_TASK_FAILED :
5208
+ assert event .workflow_task_failed_event_attributes .failure .message .startswith (
5209
+ f"Workflow finished while { self .handler_type } handlers are still running"
5210
+ )
5211
+ return True
5212
+ return False
5213
+
5185
5214
async def get_workflow_result_and_warning (
5186
5215
self ,
5187
5216
wait_for_handlers : bool ,
5188
5217
unfinished_policy : Optional [workflow .HandlerUnfinishedPolicy ] = None ,
5189
5218
) -> Tuple [bool , bool ]:
5219
+ with pytest .WarningsRecorder () as warnings :
5220
+ wf_result = await self .get_workflow_result (
5221
+ wait_for_handlers , unfinished_policy
5222
+ )
5223
+ unfinished_handler_warning_emitted = any (
5224
+ issubclass (w .category , self .unfinished_handler_warning_cls )
5225
+ for w in warnings
5226
+ )
5227
+ return wf_result , unfinished_handler_warning_emitted
5228
+
5229
+ async def get_workflow_result (
5230
+ self ,
5231
+ wait_for_handlers : bool ,
5232
+ unfinished_policy : Optional [workflow .HandlerUnfinishedPolicy ] = None ,
5233
+ handle_future : Optional [asyncio .Future [WorkflowHandle ]] = None ,
5234
+ ) -> bool :
5190
5235
handle = await self .client .start_workflow (
5191
5236
UnfinishedHandlersWorkflow .run ,
5192
5237
arg = wait_for_handlers ,
5193
5238
id = f"wf-{ uuid .uuid4 ()} " ,
5194
5239
task_queue = self .worker .task_queue ,
5195
5240
)
5241
+ if handle_future :
5242
+ handle_future .set_result (handle )
5196
5243
handler_name = f"my_{ self .handler_type } "
5197
5244
if unfinished_policy :
5198
5245
handler_name += f"_{ unfinished_policy .name } "
5199
- with pytest .WarningsRecorder () as warnings :
5200
- if self .handler_type == "signal" :
5201
- await asyncio .gather (handle .signal (handler_name ))
5202
- else :
5203
- if not wait_for_handlers :
5204
- with pytest .raises (RPCError ) as err :
5205
- await asyncio .gather (
5206
- handle .execute_update (handler_name , id = "my-update" )
5207
- )
5208
- assert (
5209
- err .value .status == RPCStatusCode .NOT_FOUND
5210
- and "workflow execution already completed"
5211
- in str (err .value ).lower ()
5212
- )
5213
- else :
5246
+ if self .handler_type == "signal" :
5247
+ await asyncio .gather (handle .signal (handler_name ))
5248
+ else :
5249
+ if not wait_for_handlers :
5250
+ with pytest .raises (RPCError ) as err :
5214
5251
await asyncio .gather (
5215
5252
handle .execute_update (handler_name , id = "my-update" )
5216
5253
)
5254
+ assert (
5255
+ err .value .status == RPCStatusCode .NOT_FOUND
5256
+ and "workflow execution already completed" in str (err .value ).lower ()
5257
+ )
5258
+ else :
5259
+ await asyncio .gather (
5260
+ handle .execute_update (handler_name , id = "my-update" )
5261
+ )
5217
5262
5218
- wf_result = await handle .result ()
5219
- unfinished_handler_warning_emitted = any (
5220
- issubclass (w .category , self .unfinished_handler_warning_cls )
5221
- for w in warnings
5222
- )
5223
- return wf_result , unfinished_handler_warning_emitted
5263
+ return await handle .result ()
5224
5264
5225
5265
@property
5226
5266
def unfinished_handler_warning_cls (self ) -> Type :
0 commit comments