Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -349,21 +349,40 @@ def _trigger_dag_af_2(self, context, run_id, parsed_logical_date):
return

def execute_complete(self, context: Context, event: tuple[str, dict[str, Any]]):
run_ids = event[1]["run_ids"]
"""
Handle task completion after returning from a deferral.

Args:
context: The Airflow context dictionary.
event: A tuple containing the class path of the trigger and the trigger event data.
"""
# Example event tuple content:
# (
# "airflow.providers.standard.triggers.external_task.DagStateTrigger",
# {
# 'dag_id': 'some_dag',
# 'states': ['success', 'failed'],
# 'poll_interval': 15,
# 'run_ids': ['manual__2025-11-19T17:49:20.907083+00:00'],
# 'execution_dates': [
# DateTime(2025, 11, 19, 17, 49, 20, 907083, tzinfo=Timezone('UTC'))
# ]
# }
# )
_, event_data = event
run_ids = event_data["run_ids"]
# Re-set as attribute after coming back from deferral - to be used by listeners.
# Just a safety check on length, we should always have single run_id here.
self.trigger_run_id = run_ids[0] if len(run_ids) == 1 else None
if AIRFLOW_V_3_0_PLUS:
self._trigger_dag_run_af_3_execute_complete(event=event)
self._trigger_dag_run_af_3_execute_complete(event_data=event_data)
else:
self._trigger_dag_run_af_2_execute_complete(event=event)
self._trigger_dag_run_af_2_execute_complete(event_data=event_data)

def _trigger_dag_run_af_3_execute_complete(self, event: tuple[str, dict[str, Any]]):
run_ids = event[1]["run_ids"]
event_data = event[1]
def _trigger_dag_run_af_3_execute_complete(self, event_data: dict[str, Any]):
failed_run_id_conditions = []

for run_id in run_ids:
for run_id in event_data["run_ids"]:
state = event_data.get(run_id)
if state in self.failed_states:
failed_run_id_conditions.append(run_id)
Expand All @@ -387,10 +406,10 @@ def _trigger_dag_run_af_3_execute_complete(self, event: tuple[str, dict[str, Any

@provide_session
def _trigger_dag_run_af_2_execute_complete(
self, event: tuple[str, dict[str, Any]], session: Session = NEW_SESSION
self, event_data: dict[str, Any], session: Session = NEW_SESSION
):
# This logical_date is parsed from the return trigger event
provided_logical_date = event[1]["execution_dates"][0]
provided_logical_date = event_data["execution_dates"][0]
try:
# Note: here execution fails on database isolation mode. Needs structural changes for AIP-72
dag_run = session.execute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,23 +226,26 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]:
elif self.execution_dates:
runs_ids_or_dates = len(self.execution_dates)

cls_path, data = self.serialize()

if AIRFLOW_V_3_0_PLUS:
data = await self.validate_count_dags_af_3(runs_ids_or_dates_len=runs_ids_or_dates)
yield TriggerEvent(data)
data.update( # update with {run_id: run_state} dict
await self.validate_count_dags_af_3(runs_ids_or_dates_len=runs_ids_or_dates)
)
yield TriggerEvent((cls_path, data))
return
else:
while True:
num_dags = await self.count_dags()
if num_dags == runs_ids_or_dates:
yield TriggerEvent(self.serialize())
yield TriggerEvent((cls_path, data))
return
await asyncio.sleep(self.poll_interval)

async def validate_count_dags_af_3(self, runs_ids_or_dates_len: int = 0) -> dict[str, typing.Any]:
async def validate_count_dags_af_3(self, runs_ids_or_dates_len: int = 0) -> dict[str, str]:
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance

cls_path, data = self.serialize()

run_states: dict[str, str] = {} # {run_id: run_state}
while True:
num_dags = await sync_to_async(RuntimeTaskInstance.get_dr_count)(
dag_id=self.dag_id,
Expand All @@ -257,8 +260,8 @@ async def validate_count_dags_af_3(self, runs_ids_or_dates_len: int = 0) -> dict
dag_id=self.dag_id,
run_id=run_id,
)
data[run_id] = state
return data
run_states[run_id] = state
return run_states
await asyncio.sleep(self.poll_interval)

if not AIRFLOW_V_3_0_PLUS:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,20 @@ def test_trigger_dag_run_execute_complete_re_set_run_id_attribute(self):

assert operator.trigger_run_id == "run_id_1"

def test_trigger_dag_run_execute_complete_fails_with_dict_as_input_type(self):
operator = TriggerDagRunOperator(
task_id="test_task",
trigger_dag_id=TRIGGERED_DAG_ID,
wait_for_completion=True,
poke_interval=10,
failed_states=[],
)

with pytest.raises(ValueError, match="too many values to unpack"):
operator.execute_complete(
{}, {"dag_id": "dag_id", "run_ids": ["run_id_1"], "poll_interval": 15, "run_id_1": "success"}
)

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3")
def test_trigger_dag_run_with_fail_when_dag_is_paused_should_fail(self):
with pytest.raises(
Expand Down
102 changes: 102 additions & 0 deletions providers/standard/tests/unit/standard/triggers/test_external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,108 @@ async def test_dag_state_trigger_af_3(self, mock_get_dag_run_count, session):
# Prevents error when task is destroyed while in "pending" state
asyncio.get_event_loop().stop()

@pytest.mark.db_test
@pytest.mark.asyncio
@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Airflow 2 had a different implementation")
@mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_dr_count")
@mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_dagrun_state")
async def test_dag_state_trigger_af_3_return_type(
self, mock_get_dagrun_state, mock_get_dag_run_count, session
):
"""
Assert that the DagStateTrigger returns a tuple with classpath and event_data.
"""
mock_get_dag_run_count.return_value = 1
mock_get_dagrun_state.return_value = DagRunState.SUCCESS
dag = DAG(f"{self.DAG_ID}_return_type", schedule=None, start_date=timezone.datetime(2022, 1, 1))

dag_run = DagRun(
dag_id=dag.dag_id,
run_type="manual",
run_id="external_task_run_id",
logical_date=timezone.datetime(2022, 1, 1),
)
dag_run.state = DagRunState.SUCCESS
session.add(dag_run)
session.commit()

trigger = DagStateTrigger(
dag_id=dag.dag_id,
states=self.STATES,
run_ids=["external_task_run_id"],
poll_interval=0.2,
execution_dates=[timezone.datetime(2022, 1, 1)],
)

task = asyncio.create_task(trigger.run().__anext__())
await asyncio.sleep(0.5)
assert task.done() is True

result = task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == (
"airflow.providers.standard.triggers.external_task.DagStateTrigger",
{
"dag_id": "test_dag_state_trigger_return_type",
"execution_dates": [
timezone.datetime(2022, 1, 1, 0, 0, tzinfo=timezone.utc),
],
"external_task_run_id": DagRunState.SUCCESS,
"poll_interval": 0.2,
"run_ids": ["external_task_run_id"],
"states": ["success", "fail"],
},
)
asyncio.get_event_loop().stop()

@pytest.mark.db_test
@pytest.mark.asyncio
@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test only AF2 implementation.")
async def test_dag_state_trigger_af_2_return_type(self, session):
"""
Assert that the DagStateTrigger returns a tuple with classpath and event_data.
"""
dag = DAG(f"{self.DAG_ID}_return_type", schedule=None, start_date=timezone.datetime(2022, 1, 1))

dag_run = DagRun(
dag_id=dag.dag_id,
run_type="manual",
run_id="external_task_run_id",
execution_date=timezone.datetime(2022, 1, 1),
)
dag_run.state = DagRunState.SUCCESS
session.add(dag_run)
session.commit()

trigger = DagStateTrigger(
dag_id=dag.dag_id,
states=self.STATES,
run_ids=["external_task_run_id"],
poll_interval=0.2,
execution_dates=[timezone.datetime(2022, 1, 1)],
)

task = asyncio.create_task(trigger.run().__anext__())
await asyncio.sleep(0.5)
assert task.done() is True

result = task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == (
"airflow.providers.standard.triggers.external_task.DagStateTrigger",
{
"dag_id": "test_dag_state_trigger_return_type",
"execution_dates": [
timezone.datetime(2022, 1, 1, 0, 0, tzinfo=timezone.utc),
],
# 'external_task_run_id': DagRunState.SUCCESS, # This is only appended in AF3
"poll_interval": 0.2,
"run_ids": ["external_task_run_id"],
"states": ["success", "fail"],
},
)
asyncio.get_event_loop().stop()

def test_serialization(self):
"""Asserts that the DagStateTrigger correctly serializes its arguments and classpath."""
trigger = DagStateTrigger(
Expand Down
Loading