From 4b16d306eebc1f13cc3f24b2b314f2944c52416a Mon Sep 17 00:00:00 2001 From: Josh Fell Date: Fri, 8 Dec 2023 23:32:11 -0500 Subject: [PATCH] Refactor ExternalDagLink to not create ad hoc TaskInstances This operator link creates an ad hoc TaskInstance to build its output URL. Generally we moved away from this approach starting with #21285. This PR updates the operator link to use `TaskInstance.get_task_instance()` instead. --- airflow/models/renderedtifields.py | 6 ++- airflow/sensors/external_task.py | 25 ++++++++++-- tests/sensors/test_external_task_sensor.py | 47 ++++++++++++++-------- 3 files changed, 55 insertions(+), 23 deletions(-) diff --git a/airflow/models/renderedtifields.py b/airflow/models/renderedtifields.py index 9107d88795d4a..3880a85091877 100644 --- a/airflow/models/renderedtifields.py +++ b/airflow/models/renderedtifields.py @@ -46,7 +46,7 @@ from sqlalchemy.orm import Session from sqlalchemy.sql import FromClause - from airflow.models.taskinstance import TaskInstance + from airflow.models.taskinstance import TaskInstance, TaskInstancePydantic class RenderedTaskInstanceFields(Base): @@ -139,7 +139,9 @@ def _redact(self): @classmethod @provide_session - def get_templated_fields(cls, ti: TaskInstance, session: Session = NEW_SESSION) -> dict | None: + def get_templated_fields( + cls, ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION + ) -> dict | None: """ Get templated field for a TaskInstance from the RenderedTaskInstanceFields table. diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py index 79734d25a3cce..5a3353d916890 100644 --- a/airflow/sensors/external_task.py +++ b/airflow/sensors/external_task.py @@ -45,6 +45,8 @@ if TYPE_CHECKING: from sqlalchemy.orm import Query, Session + from airflow.models.baseoperator import BaseOperator + from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.context import Context @@ -57,10 +59,25 @@ class ExternalDagLink(BaseOperatorLink): name = "External DAG" - def get_link(self, operator, dttm): - ti = TaskInstance(task=operator, execution_date=dttm) - operator.render_template_fields(ti.get_template_context()) - query = {"dag_id": operator.external_dag_id, "execution_date": dttm.isoformat()} + def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str: + from airflow.models.renderedtifields import RenderedTaskInstanceFields + + ti = TaskInstance.get_task_instance( + dag_id=ti_key.dag_id, run_id=ti_key.run_id, task_id=ti_key.task_id, map_index=ti_key.map_index + ) + + if TYPE_CHECKING: + assert ti is not None + + template_fields = RenderedTaskInstanceFields.get_templated_fields(ti) + external_dag_id = ( + template_fields["external_dag_id"] if template_fields else operator.external_dag_id # type: ignore[attr-defined] + ) + query = { + "dag_id": external_dag_id, + "execution_date": ti.execution_date.isoformat(), # type: ignore[union-attr] + } + return build_airflow_url_with_query(query) diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index 215ce47d56625..bdbf7b1dc5aa3 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -1052,28 +1052,41 @@ def test_external_task_sensor_check_zipped_dag_existence(dag_zip_maker): op._check_for_existence(session) -def test_external_task_sensor_templated(dag_maker, app): - with dag_maker(): - ExternalTaskSensor( - task_id="templated_task", - external_dag_id="dag_{{ ds }}", - external_task_id="task_{{ ds }}", - ) - - dagrun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE) - (instance,) = dagrun.task_instances - instance.render_templates() +@pytest.mark.parametrize( + argnames=["external_dag_id", "external_task_id", "expected_external_dag_id", "expected_external_task_id"], + argvalues=[ + ("dag_test", "task_test", "dag_test", "task_test"), + ("dag_{{ ds }}", "task_{{ ds }}", f"dag_{DEFAULT_DATE.date()}", f"task_{DEFAULT_DATE.date()}"), + ], + ids=["not_templated", "templated"], +) +def test_external_task_sensor_extra_link( + external_dag_id, + external_task_id, + expected_external_dag_id, + expected_external_task_id, + create_task_instance_of_operator, + app, +): + ti = create_task_instance_of_operator( + ExternalTaskSensor, + dag_id="external_task_sensor_extra_links_dag", + execution_date=DEFAULT_DATE, + task_id="external_task_sensor_extra_links_task", + external_dag_id=external_dag_id, + external_task_id=external_task_id, + ) + ti.render_templates() - assert instance.task.external_dag_id == f"dag_{DEFAULT_DATE.date()}" - assert instance.task.external_task_id == f"task_{DEFAULT_DATE.date()}" - assert instance.task.external_task_ids == [f"task_{DEFAULT_DATE.date()}"] + assert ti.task.external_dag_id == expected_external_dag_id + assert ti.task.external_task_id == expected_external_task_id + assert ti.task.external_task_ids == [expected_external_task_id] - # Verify that the operator link uses the rendered value of ``external_dag_id``. app.config["SERVER_NAME"] = "" with app.app_context(): - url = instance.task.get_extra_links(instance, "External DAG") + url = ti.task.get_extra_links(ti, "External DAG") - assert f"/dags/dag_{DEFAULT_DATE.date()}/grid" in url + assert f"/dags/{expected_external_dag_id}/grid" in url class TestExternalTaskMarker: