Skip to content

Commit

Permalink
Refactor ExternalDagLink to not create ad hoc TaskInstances
Browse files Browse the repository at this point in the history
This operator link creates an ad hoc TaskInstance to build its output URL. Generally we moved away from this approach starting with #21285.

This PR updates the operator link to use `TaskInstance.get_task_instance()` instead.
  • Loading branch information
josh-fell committed Dec 9, 2023
1 parent 9c0b0cd commit 4b16d30
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 23 deletions.
6 changes: 4 additions & 2 deletions airflow/models/renderedtifields.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import FromClause

from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstance import TaskInstance, TaskInstancePydantic


class RenderedTaskInstanceFields(Base):
Expand Down Expand Up @@ -139,7 +139,9 @@ def _redact(self):

@classmethod
@provide_session
def get_templated_fields(cls, ti: TaskInstance, session: Session = NEW_SESSION) -> dict | None:
def get_templated_fields(
cls, ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION
) -> dict | None:
"""
Get templated field for a TaskInstance from the RenderedTaskInstanceFields table.
Expand Down
25 changes: 21 additions & 4 deletions airflow/sensors/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
if TYPE_CHECKING:
from sqlalchemy.orm import Query, Session

from airflow.models.baseoperator import BaseOperator
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context


Expand All @@ -57,10 +59,25 @@ class ExternalDagLink(BaseOperatorLink):

name = "External DAG"

def get_link(self, operator, dttm):
ti = TaskInstance(task=operator, execution_date=dttm)
operator.render_template_fields(ti.get_template_context())
query = {"dag_id": operator.external_dag_id, "execution_date": dttm.isoformat()}
def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
from airflow.models.renderedtifields import RenderedTaskInstanceFields

ti = TaskInstance.get_task_instance(
dag_id=ti_key.dag_id, run_id=ti_key.run_id, task_id=ti_key.task_id, map_index=ti_key.map_index
)

if TYPE_CHECKING:
assert ti is not None

template_fields = RenderedTaskInstanceFields.get_templated_fields(ti)
external_dag_id = (
template_fields["external_dag_id"] if template_fields else operator.external_dag_id # type: ignore[attr-defined]
)
query = {
"dag_id": external_dag_id,
"execution_date": ti.execution_date.isoformat(), # type: ignore[union-attr]
}

return build_airflow_url_with_query(query)


Expand Down
47 changes: 30 additions & 17 deletions tests/sensors/test_external_task_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,28 +1052,41 @@ def test_external_task_sensor_check_zipped_dag_existence(dag_zip_maker):
op._check_for_existence(session)


def test_external_task_sensor_templated(dag_maker, app):
with dag_maker():
ExternalTaskSensor(
task_id="templated_task",
external_dag_id="dag_{{ ds }}",
external_task_id="task_{{ ds }}",
)

dagrun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE)
(instance,) = dagrun.task_instances
instance.render_templates()
@pytest.mark.parametrize(
argnames=["external_dag_id", "external_task_id", "expected_external_dag_id", "expected_external_task_id"],
argvalues=[
("dag_test", "task_test", "dag_test", "task_test"),
("dag_{{ ds }}", "task_{{ ds }}", f"dag_{DEFAULT_DATE.date()}", f"task_{DEFAULT_DATE.date()}"),
],
ids=["not_templated", "templated"],
)
def test_external_task_sensor_extra_link(
external_dag_id,
external_task_id,
expected_external_dag_id,
expected_external_task_id,
create_task_instance_of_operator,
app,
):
ti = create_task_instance_of_operator(
ExternalTaskSensor,
dag_id="external_task_sensor_extra_links_dag",
execution_date=DEFAULT_DATE,
task_id="external_task_sensor_extra_links_task",
external_dag_id=external_dag_id,
external_task_id=external_task_id,
)
ti.render_templates()

assert instance.task.external_dag_id == f"dag_{DEFAULT_DATE.date()}"
assert instance.task.external_task_id == f"task_{DEFAULT_DATE.date()}"
assert instance.task.external_task_ids == [f"task_{DEFAULT_DATE.date()}"]
assert ti.task.external_dag_id == expected_external_dag_id
assert ti.task.external_task_id == expected_external_task_id
assert ti.task.external_task_ids == [expected_external_task_id]

# Verify that the operator link uses the rendered value of ``external_dag_id``.
app.config["SERVER_NAME"] = ""
with app.app_context():
url = instance.task.get_extra_links(instance, "External DAG")
url = ti.task.get_extra_links(ti, "External DAG")

assert f"/dags/dag_{DEFAULT_DATE.date()}/grid" in url
assert f"/dags/{expected_external_dag_id}/grid" in url


class TestExternalTaskMarker:
Expand Down

0 comments on commit 4b16d30

Please sign in to comment.