Skip to content

Commit 19b74fd

Browse files
author
Andrew Godwin
authored
Add support for labelling DAG edges (#15142)
This adds support for putting human-readable labels on edges in the DAG between Tasks, as well as making the underlying framework for that generic enough that future metadata could be added if desired.
1 parent 56a0371 commit 19b74fd

File tree

14 files changed

+444
-17
lines changed

14 files changed

+444
-17
lines changed

airflow/models/baseoperator.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
6767
from airflow.utils import timezone
6868
from airflow.utils.decorators import apply_defaults
69+
from airflow.utils.edgemodifier import EdgeModifier
6970
from airflow.utils.helpers import validate_key
7071
from airflow.utils.log.logging_mixin import LoggingMixin
7172
from airflow.utils.operator_resources import Resources
@@ -1205,6 +1206,7 @@ def _set_relatives(
12051206
self,
12061207
task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]],
12071208
upstream: bool = False,
1209+
edge_modifier: Optional[EdgeModifier] = None,
12081210
) -> None:
12091211
"""Sets relatives for the task or task list."""
12101212
if not isinstance(task_or_task_list, Sequence):
@@ -1259,23 +1261,35 @@ def _set_relatives(
12591261
if upstream:
12601262
task.add_only_new(task.get_direct_relative_ids(upstream=False), self.task_id)
12611263
self.add_only_new(self._upstream_task_ids, task.task_id)
1264+
if edge_modifier:
1265+
edge_modifier.add_edge_info(self.dag, task.task_id, self.task_id)
12621266
else:
12631267
self.add_only_new(self._downstream_task_ids, task.task_id)
12641268
task.add_only_new(task.get_direct_relative_ids(upstream=True), self.task_id)
1269+
if edge_modifier:
1270+
edge_modifier.add_edge_info(self.dag, self.task_id, task.task_id)
12651271

1266-
def set_downstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]) -> None:
1272+
def set_downstream(
1273+
self,
1274+
task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]],
1275+
edge_modifier: Optional[EdgeModifier] = None,
1276+
) -> None:
12671277
"""
12681278
Set a task or a task list to be directly downstream from the current
12691279
task. Required by TaskMixin.
12701280
"""
1271-
self._set_relatives(task_or_task_list, upstream=False)
1281+
self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier)
12721282

1273-
def set_upstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]) -> None:
1283+
def set_upstream(
1284+
self,
1285+
task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]],
1286+
edge_modifier: Optional[EdgeModifier] = None,
1287+
) -> None:
12741288
"""
12751289
Set a task or a task list to be directly upstream from the current
12761290
task. Required by TaskMixin.
12771291
"""
1278-
self._set_relatives(task_or_task_list, upstream=True)
1292+
self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier)
12791293

12801294
@property
12811295
def output(self):

airflow/models/dag.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
from airflow.utils.session import provide_session
7474
from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, with_row_locks
7575
from airflow.utils.state import State
76-
from airflow.utils.types import DagRunType
76+
from airflow.utils.types import DagRunType, EdgeInfoType
7777

7878
if TYPE_CHECKING:
7979
from airflow.utils.task_group import TaskGroup
@@ -349,6 +349,11 @@ def __init__(
349349
self.on_success_callback = on_success_callback
350350
self.on_failure_callback = on_failure_callback
351351

352+
# Keeps track of any extra edge metadata (sparse; will not contain all
353+
# edges, so do not iterate over it for that). Outer key is upstream
354+
# task ID, inner key is downstream task ID.
355+
self.edge_info: Dict[str, Dict[str, EdgeInfoType]] = {}
356+
352357
# To keep it in parity with Serialized DAGs
353358
# and identify if DAG has on_*_callback without actually storing them in Serialized JSON
354359
self.has_on_success_callback = self.on_success_callback is not None
@@ -2050,6 +2055,24 @@ def get_serialized_fields(cls):
20502055
}
20512056
return cls.__serialized_fields
20522057

2058+
def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType:
2059+
"""
2060+
Returns edge information for the given pair of tasks if present, and
2061+
None if there is no information.
2062+
"""
2063+
# Note - older serialized DAGs may not have edge_info being a dict at all
2064+
if self.edge_info:
2065+
return self.edge_info.get(upstream_task_id, {}).get(downstream_task_id, {})
2066+
else:
2067+
return {}
2068+
2069+
def set_edge_info(self, upstream_task_id: str, downstream_task_id: str, info: EdgeInfoType):
2070+
"""
2071+
Sets the given edge information on the DAG. Note that this will overwrite,
2072+
rather than merge with, existing info.
2073+
"""
2074+
self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] = info
2075+
20532076

20542077
class DagTag(Base):
20552078
"""A tag name per dag, to allow quick filtering in the DAG view."""

airflow/models/xcom_arg.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from typing import Any, Dict, List, Sequence, Union
18+
from typing import Any, Dict, List, Optional, Sequence, Union
1919

2020
from airflow.exceptions import AirflowException
2121
from airflow.models.baseoperator import BaseOperator # pylint: disable=R0401
2222
from airflow.models.taskmixin import TaskMixin
2323
from airflow.models.xcom import XCOM_RETURN_KEY
24+
from airflow.utils.edgemodifier import EdgeModifier
2425

2526

2627
class XComArg(TaskMixin):
@@ -111,13 +112,21 @@ def key(self) -> str:
111112
"""Returns keys of this XComArg"""
112113
return self._key
113114

114-
def set_upstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]):
115+
def set_upstream(
116+
self,
117+
task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]],
118+
edge_modifier: Optional[EdgeModifier] = None,
119+
):
115120
"""Proxy to underlying operator set_upstream method. Required by TaskMixin."""
116-
self.operator.set_upstream(task_or_task_list)
121+
self.operator.set_upstream(task_or_task_list, edge_modifier)
117122

118-
def set_downstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]):
123+
def set_downstream(
124+
self,
125+
task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]],
126+
edge_modifier: Optional[EdgeModifier] = None,
127+
):
119128
"""Proxy to underlying operator set_downstream method. Required by TaskMixin."""
120-
self.operator.set_downstream(task_or_task_list)
129+
self.operator.set_downstream(task_or_task_list, edge_modifier)
121130

122131
def resolve(self, context: Dict) -> Any:
123132
"""

airflow/serialization/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,4 @@ class DagAttributeTypes(str, Enum):
4646
TUPLE = 'tuple'
4747
POD = 'k8s.V1Pod'
4848
TASK_GROUP = 'taskgroup'
49+
EDGE_INFO = 'edgeinfo'

airflow/serialization/schema.json

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@
102102
"_task_group": {"anyOf": [
103103
{ "type": "null" },
104104
{ "$ref": "#/definitions/task_group" }
105-
]}
105+
]},
106+
"edge_info": { "$ref": "#/definitions/edge_info" }
106107
},
107108
"required": [
108109
"_dag_id",
@@ -217,6 +218,21 @@
217218
}
218219
},
219220
"additionalProperties": false
221+
},
222+
"edge_info": {
223+
"$comment": "Metadata about DAG edges",
224+
"type": "object",
225+
"additionalProperties": {
226+
"type": "object",
227+
"additionalProperties": {
228+
"type": "object",
229+
"properties": {
230+
"label": { "type": "string" }
231+
},
232+
"required": ["label"],
233+
"additionalProperties": false
234+
}
235+
}
220236
}
221237
},
222238

airflow/serialization/serialized_objects.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,9 @@ def serialize_dag(cls, dag: DAG) -> dict:
647647
serialize_dag["tasks"] = [cls._serialize(task) for _, task in dag.task_dict.items()]
648648
serialize_dag['_task_group'] = SerializedTaskGroup.serialize_task_group(dag.task_group)
649649

650+
# Edge info in the JSON exactly matches our internal structure
651+
serialize_dag["edge_info"] = dag.edge_info
652+
650653
# has_on_*_callback are only stored if the value is True, as the default is False
651654
if dag.has_on_success_callback:
652655
serialize_dag['has_on_success_callback'] = True
@@ -678,6 +681,9 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG':
678681
v = cls._deserialize_timedelta(v)
679682
elif k.endswith("_date"):
680683
v = cls._deserialize_datetime(v)
684+
elif k == "edge_info":
685+
# Value structure matches exactly
686+
pass
681687
elif k in cls._decorated_fields:
682688
v = cls._deserialize(v)
683689
# else use v as it is

airflow/utils/dot_renderer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ def render_dag(dag: DAG, tis: Optional[List[TaskInstance]] = None) -> graphviz.D
156156
_draw_nodes(dag.task_group, dot, states_by_task_id)
157157

158158
for edge in dag_edges(dag):
159-
dot.edge(edge["source_id"], edge["target_id"])
159+
# Gets an optional label for the edge; this will be None if none is specified.
160+
label = dag.get_edge_info(edge["source_id"], edge["target_id"]).get("label")
161+
# Add the edge to the graph with optional label
162+
# (we can just use the maybe-None label variable directly)
163+
dot.edge(edge["source_id"], edge["target_id"], label)
160164

161165
return dot

airflow/utils/edgemodifier.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from typing import Sequence, Union
19+
20+
from airflow.models.taskmixin import TaskMixin
21+
22+
23+
class EdgeModifier(TaskMixin):
24+
"""
25+
Class that represents edge information to be added between two
26+
tasks/operators. Has shorthand factory functions, like Label("hooray").
27+
28+
Current implementation supports
29+
t1 >> Label("Success route") >> t2
30+
t2 << Label("Success route") << t2
31+
32+
Note that due to the potential for use in either direction, this waits
33+
to make the actual connection between both sides until both are declared,
34+
and will do so progressively if multiple ups/downs are added.
35+
36+
This and EdgeInfo are related - an EdgeModifier is the Python object you
37+
use to add information to (potentially multiple) edges, and EdgeInfo
38+
is the representation of the information for one specific edge.
39+
"""
40+
41+
def __init__(self, label: str = None):
42+
self.label = label
43+
self._upstream = []
44+
self._downstream = []
45+
46+
@property
47+
def roots(self):
48+
"""Should return list of root operator List["BaseOperator"]"""
49+
return self._downstream
50+
51+
@property
52+
def leaves(self):
53+
"""Should return list of leaf operator List["BaseOperator"]"""
54+
return self._upstream
55+
56+
def set_upstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]], chain: bool = True):
57+
"""
58+
Sets the given task/list onto the upstream attribute, and then checks if
59+
we have both sides so we can resolve the relationship.
60+
61+
Providing this also provides << via TaskMixin.
62+
"""
63+
# Ensure we have a list, even if it's just one item
64+
if not isinstance(task_or_task_list, list):
65+
task_or_task_list = [task_or_task_list]
66+
# Unfurl it into actual operators
67+
operators = []
68+
for task in task_or_task_list:
69+
operators.extend(task.roots)
70+
# For each already-declared downstream, pair off with each new upstream
71+
# item and store the edge info.
72+
for operator in operators:
73+
for downstream in self._downstream:
74+
self.add_edge_info(operator.dag, operator.task_id, downstream.task_id)
75+
if chain:
76+
operator.set_downstream(downstream)
77+
# Add the new tasks to our list of ones we've seen
78+
self._upstream.extend(operators)
79+
80+
def set_downstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]], chain: bool = True):
81+
"""
82+
Sets the given task/list onto the downstream attribute, and then checks if
83+
we have both sides so we can resolve the relationship.
84+
85+
Providing this also provides >> via TaskMixin.
86+
"""
87+
# Ensure we have a list, even if it's just one item
88+
if not isinstance(task_or_task_list, list):
89+
task_or_task_list = [task_or_task_list]
90+
# Unfurl it into actual operators
91+
operators = []
92+
for task in task_or_task_list:
93+
operators.extend(task.leaves)
94+
# Pair them off with existing
95+
for operator in operators:
96+
for upstream in self._upstream:
97+
self.add_edge_info(upstream.dag, upstream.task_id, operator.task_id)
98+
if chain:
99+
upstream.set_downstream(operator)
100+
# Add the new tasks to our list of ones we've seen
101+
self._downstream.extend(operators)
102+
103+
def update_relative(self, other: "TaskMixin", upstream: bool = True) -> None:
104+
"""
105+
Called if we're not the "main" side of a relationship; we still run the
106+
same logic, though.
107+
"""
108+
if upstream:
109+
self.set_upstream(other, chain=False)
110+
else:
111+
self.set_downstream(other, chain=False)
112+
113+
def add_edge_info(self, dag, upstream_id: str, downstream_id: str):
114+
"""
115+
Adds or updates task info on the DAG for this specific pair of tasks.
116+
117+
Called either from our relationship trigger methods above, or directly
118+
by set_upstream/set_downstream in operators.
119+
"""
120+
dag.set_edge_info(upstream_id, downstream_id, {"label": self.label})
121+
122+
123+
# Factory functions
124+
def Label(label: str): # pylint: disable=C0103
125+
"""Creates an EdgeModifier that sets a human-readable label on the edge."""
126+
return EdgeModifier(label=label)

airflow/utils/types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import enum
18+
from typing import Optional
19+
20+
from airflow.typing_compat import TypedDict
1821

1922

2023
class DagRunType(str, enum.Enum):
@@ -34,3 +37,12 @@ def from_run_id(run_id: str) -> "DagRunType":
3437
if run_id and run_id.startswith(f"{run_type.value}__"):
3538
return run_type
3639
return DagRunType.MANUAL
40+
41+
42+
class EdgeInfoType(TypedDict):
43+
"""
44+
Represents extra metadata that the DAG can store about an edge,
45+
usually generated from an EdgeModifier.
46+
"""
47+
48+
label: Optional[str]

airflow/www/templates/airflow/graph.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@
686686
g.setEdge(source_id, target_id, {
687687
curve: d3.curveBasis,
688688
arrowheadClass: 'arrowhead',
689+
label: edge.label
689690
});
690691
})
691692

0 commit comments

Comments
 (0)