Skip to content

Commit 91c4e2c

Browse files
committed
Fix TriggerDagRunOperator extra_link when trigger_dag_id is templated
1 parent d43052e commit 91c4e2c

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

providers/src/airflow/providers/standard/operators/trigger_dagrun.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,28 @@ class TriggerDagRunLink(BaseOperatorLink):
6767
name = "Triggered DAG"
6868

6969
def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
70+
from airflow.models.renderedtifields import RenderedTaskInstanceFields
71+
from airflow.models.taskinstance import TaskInstance
72+
73+
ti = TaskInstance.get_task_instance(
74+
dag_id=ti_key.dag_id, run_id=ti_key.run_id, task_id=ti_key.task_id, map_index=ti_key.map_index
75+
)
76+
if TYPE_CHECKING:
77+
assert ti is not None
78+
79+
template_fields = RenderedTaskInstanceFields.get_templated_fields(ti)
80+
untemplated_trigger_dag_id = cast(TriggerDagRunOperator, operator).trigger_dag_id
81+
if template_fields:
82+
trigger_dag_id = template_fields.get("trigger_dag_id", untemplated_trigger_dag_id)
83+
else:
84+
trigger_dag_id = untemplated_trigger_dag_id
85+
7086
# Fetch the correct dag_run_id for the triggerED dag which is
7187
# stored in xcom during execution of the triggerING task.
7288
triggered_dag_run_id = XCom.get_value(ti_key=ti_key, key=XCOM_RUN_ID)
89+
7390
query = {
74-
"dag_id": cast(TriggerDagRunOperator, operator).trigger_dag_id,
91+
"dag_id": trigger_dag_id,
7592
"dag_run_id": triggered_dag_run_id,
7693
}
7794
return build_airflow_url_with_query(query)

tests/operators/test_trigger_dagrun.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,27 @@ def test_trigger_dagrun_with_templated_logical_date(self, dag_maker):
271271
assert triggered_dag_run.logical_date == DEFAULT_DATE
272272
self.assert_extra_link(triggered_dag_run, task, session)
273273

274+
def test_trigger_dagrun_with_templated_trigger_dag_id(self, dag_maker):
275+
"""Test TriggerDagRunOperator with templated trigger dag id."""
276+
with dag_maker(
277+
TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True
278+
) as dag:
279+
task = TriggerDagRunOperator(
280+
task_id="__".join(["test_trigger_dagrun_with_templated_trigger_dag_id", TRIGGERED_DAG_ID]),
281+
trigger_dag_id="{{ ti.task_id.rsplit('.', 1)[-1].split('__')[-1] }}",
282+
)
283+
self.re_sync_triggered_dag_to_db(dag, dag_maker)
284+
dag_maker.create_dagrun()
285+
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
286+
287+
with create_session() as session:
288+
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
289+
assert len(dagruns) == 1
290+
triggered_dag_run = dagruns[0]
291+
assert triggered_dag_run.external_trigger
292+
assert triggered_dag_run.dag_id == TRIGGERED_DAG_ID
293+
self.assert_extra_link(triggered_dag_run, task, session)
294+
274295
def test_trigger_dagrun_operator_conf(self, dag_maker):
275296
"""Test passing conf to the triggered DagRun."""
276297
with dag_maker(

0 commit comments

Comments
 (0)