Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions providers/standard/src/airflow/providers/standard/triggers/hitl.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,26 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
"""Loop until the Human-in-the-loop response received or timeout reached."""
while True:
if self.timeout_datetime and self.timeout_datetime < utcnow():
# Fetch latest HITL detail before fallback
resp = await sync_to_async(get_hitl_detail_content_detail)(ti_id=self.ti_id)
if resp.response_received and resp.chosen_options:
# Response already received, yield success and exit
self.log.info(
"[HITL] responded_by=%s (id=%s) options=%s at %s (timeout fallback skipped)",
resp.responded_user_name,
resp.responded_user_id,
resp.chosen_options,
resp.response_at,
)
yield TriggerEvent(
HITLTriggerEventSuccessPayload(
chosen_options=resp.chosen_options,
params_input=resp.params_input or {},
timedout=False,
)
)
return

if self.defaults is None:
yield TriggerEvent(
HITLTriggerEventFailurePayload(
Expand Down
49 changes: 46 additions & 3 deletions providers/standard/tests/unit/standard/triggers/test_hitl.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ async def test_run_failed_due_to_timeout(self, mock_update, mock_supervisor_comm
)

gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
await asyncio.sleep(0.3)
trigger_task = asyncio.create_task(gen.__anext__())
event = await trigger_task
assert event == TriggerEvent(
HITLTriggerEventFailurePayload(
Expand Down Expand Up @@ -118,8 +118,8 @@ async def test_run_fallback_to_default_due_to_timeout(self, mock_update, mock_lo
)

gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
await asyncio.sleep(0.3)
trigger_task = asyncio.create_task(gen.__anext__())
event = await trigger_task

assert event == TriggerEvent(
Expand All @@ -130,6 +130,49 @@ async def test_run_fallback_to_default_due_to_timeout(self, mock_update, mock_lo
"[HITL] timeout reached before receiving response, fallback to default %s", ["1"]
)

@pytest.mark.db_test
@pytest.mark.asyncio
@mock.patch.object(HITLTrigger, "log")
@mock.patch("airflow.sdk.execution_time.hitl.update_hitl_detail_response")
async def test_run_should_check_response_in_timeout_handler(
self, mock_update, mock_log, mock_supervisor_comms
):
# action time only slightly before timeout
action_datetime = utcnow() + timedelta(seconds=0.1)
timeout_datetime = utcnow() + timedelta(seconds=0.1)

trigger = HITLTrigger(
defaults=["1"],
timeout_datetime=timeout_datetime,
poke_interval=5,
**default_trigger_args,
)
mock_supervisor_comms.send.return_value = HITLDetailResponse(
response_received=True,
responded_user_id="1",
responded_user_name="test",
response_at=action_datetime,
chosen_options=["2"],
params_input={},
)

gen = trigger.run()
await asyncio.sleep(0.3)
trigger_task = asyncio.create_task(gen.__anext__())
event = await trigger_task

assert event == TriggerEvent(
HITLTriggerEventSuccessPayload(chosen_options=["2"], params_input={}, timedout=False)
)

assert mock_log.info.call_args == mock.call(
"[HITL] responded_by=%s (id=%s) options=%s at %s (timeout fallback skipped)",
"test",
"1",
["2"],
action_datetime,
)

@pytest.mark.db_test
@pytest.mark.asyncio
@mock.patch.object(HITLTrigger, "log")
Expand All @@ -153,8 +196,8 @@ async def test_run(self, mock_update, mock_log, mock_supervisor_comms, time_mach
)

gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
await asyncio.sleep(0.3)
trigger_task = asyncio.create_task(gen.__anext__())
event = await trigger_task
assert event == TriggerEvent(
HITLTriggerEventSuccessPayload(
Expand Down