Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions airflow/models/renderedtifields.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import FromClause

from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.sdk.types import Operator


Expand Down Expand Up @@ -155,7 +155,11 @@ def _redact(self):

@classmethod
@provide_session
def get_templated_fields(cls, ti: TaskInstance, session: Session = NEW_SESSION) -> dict | None:
def get_templated_fields(
cls,
ti: TaskInstance | TaskInstanceKey,
session: Session = NEW_SESSION,
) -> dict | None:
"""
Get templated field for a TaskInstance from the RenderedTaskInstanceFields table.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import json
import time
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any

from sqlalchemy import select
from sqlalchemy.orm.exc import NoResultFound
Expand Down Expand Up @@ -73,29 +73,20 @@ class TriggerDagRunLink(BaseOperatorLink):

def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
from airflow.models.renderedtifields import RenderedTaskInstanceFields
from airflow.models.taskinstance import TaskInstance

ti = TaskInstance.get_task_instance(
dag_id=ti_key.dag_id, run_id=ti_key.run_id, task_id=ti_key.task_id, map_index=ti_key.map_index
)
if TYPE_CHECKING:
assert ti is not None
assert isinstance(operator, TriggerDagRunOperator)

template_fields = RenderedTaskInstanceFields.get_templated_fields(ti)
untemplated_trigger_dag_id = cast(TriggerDagRunOperator, operator).trigger_dag_id
if template_fields:
trigger_dag_id = template_fields.get("trigger_dag_id", untemplated_trigger_dag_id)
if template_fields := RenderedTaskInstanceFields.get_templated_fields(ti_key):
trigger_dag_id: str = template_fields.get("trigger_dag_id", operator.trigger_dag_id)
else:
trigger_dag_id = untemplated_trigger_dag_id
trigger_dag_id = operator.trigger_dag_id

# Fetch the correct dag_run_id for the triggerED dag which is
# stored in xcom during execution of the triggerING task.
triggered_dag_run_id = XCom.get_value(ti_key=ti_key, key=XCOM_RUN_ID)

query = {
"dag_id": trigger_dag_id,
"dag_run_id": triggered_dag_run_id,
}
query = {"dag_id": trigger_dag_id, "dag_run_id": triggered_dag_run_id}
return build_airflow_url_with_query(query)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.dag import DagModel
from airflow.models.dagbag import DagBag
from airflow.models.taskinstance import TaskInstance
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.triggers.external_task import WorkflowTrigger
from airflow.providers.standard.utils.sensor_helper import _get_count, _get_external_task_group_task_ids
Expand Down Expand Up @@ -63,22 +62,15 @@ class ExternalDagLink(BaseOperatorLink):
def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
from airflow.models.renderedtifields import RenderedTaskInstanceFields

ti = TaskInstance.get_task_instance(
dag_id=ti_key.dag_id, run_id=ti_key.run_id, task_id=ti_key.task_id, map_index=ti_key.map_index
)

if TYPE_CHECKING:
assert ti is not None
assert isinstance(operator, (ExternalTaskMarker, ExternalTaskSensor))

template_fields = RenderedTaskInstanceFields.get_templated_fields(ti)
external_dag_id = (
template_fields["external_dag_id"] if template_fields else operator.external_dag_id # type: ignore[attr-defined]
)
query = {
"dag_id": external_dag_id,
"logical_date": ti.logical_date.isoformat(), # type: ignore[union-attr]
}
if template_fields := RenderedTaskInstanceFields.get_templated_fields(ti_key):
external_dag_id: str = template_fields.get("external_dag_id", operator.external_dag_id)
else:
external_dag_id = operator.external_dag_id

query = {"dag_id": external_dag_id, "run_id": ti_key.run_id}
return build_airflow_url_with_query(query)


Expand Down