Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,25 +1289,25 @@ def _handle_trigger_dag_run(
# be used when creating the extra link on the webserver.
ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id)

if drte.deferrable:
from airflow.providers.standard.triggers.external_task import DagStateTrigger

defer = TaskDeferred(
trigger=DagStateTrigger(
dag_id=drte.trigger_dag_id,
states=drte.allowed_states + drte.failed_states, # type: ignore[arg-type]
# Don't filter by execution_dates when run_ids is provided.
# run_id uniquely identifies a DAG run, and when reset_dag_run=True,
# drte.logical_date might be a newly calculated value that doesn't match
# the persisted logical_date in the database, causing the trigger to never find the run.
execution_dates=None,
run_ids=[drte.dag_run_id],
poll_interval=drte.poke_interval,
),
method_name="execute_complete",
)
return _defer_task(defer, ti, log)
if drte.wait_for_completion:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch @nathadfield

if drte.deferrable:
from airflow.providers.standard.triggers.external_task import DagStateTrigger

defer = TaskDeferred(
trigger=DagStateTrigger(
dag_id=drte.trigger_dag_id,
states=drte.allowed_states + drte.failed_states, # type: ignore[arg-type]
# Don't filter by execution_dates when run_ids is provided.
# run_id uniquely identifies a DAG run, and when reset_dag_run=True,
# drte.logical_date might be a newly calculated value that doesn't match
# the persisted logical_date in the database, causing the trigger to never find the run.
execution_dates=None,
run_ids=[drte.dag_run_id],
poll_interval=drte.poke_interval,
),
method_name="execute_complete",
)
return _defer_task(defer, ti, log)
while True:
log.info(
"Waiting for dag run to complete execution in allowed state.",
Expand Down Expand Up @@ -1343,6 +1343,14 @@ def _handle_trigger_dag_run(
dag_id=drte.trigger_dag_id,
state=comms_msg.state,
)
else:
# Fire-and-forget mode: wait_for_completion=False
if drte.deferrable:
log.info(
"Ignoring deferrable=True because wait_for_completion=False. "
"Task will complete immediately without waiting for the triggered DAG run.",
trigger_dag_id=drte.trigger_dag_id,
)

return _handle_current_task_success(context, ti)

Expand Down
4 changes: 2 additions & 2 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4038,7 +4038,7 @@ def test_handle_trigger_dag_run_wait_for_completion(
@pytest.mark.parametrize(
("allowed_states", "failed_states", "intermediate_state"),
[
([DagRunState.SUCCESS], None, TaskInstanceState.DEFERRED),
([DagRunState.SUCCESS], None, TaskInstanceState.SUCCESS),
],
)
def test_handle_trigger_dag_run_deferred(
Expand All @@ -4050,7 +4050,7 @@ def test_handle_trigger_dag_run_deferred(
mock_supervisor_comms,
):
"""
Test that TriggerDagRunOperator defers when the deferrable flag is set to True
Test that TriggerDagRunOperator does not defer when wait_for_completion=False
"""
from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator

Expand Down