Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions airflow/models/taskmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,29 @@ def leaves(self) -> Sequence["DependencyMixin"]:
raise NotImplementedError()

@abstractmethod
def set_upstream(self, other: Union["DependencyMixin", Sequence["DependencyMixin"]]):
def set_upstream(
self,
other: Union["DependencyMixin", Sequence["DependencyMixin"]],
edge_modifier: Optional["EdgeModifier"] = None,
):
"""Set a task or a task list to be directly upstream from the current task."""
raise NotImplementedError()

@abstractmethod
def set_downstream(self, other: Union["DependencyMixin", Sequence["DependencyMixin"]]):
def set_downstream(
self,
other: Union["DependencyMixin", Sequence["DependencyMixin"]],
edge_modifier: Optional["EdgeModifier"] = None,
):
"""Set a task or a task list to be directly downstream from the current task."""
raise NotImplementedError()

def update_relative(self, other: "DependencyMixin", upstream=True) -> None:
def update_relative(
self,
other: "DependencyMixin",
upstream=True,
edge_modifier: Optional["EdgeModifier"] = None,
) -> None:
"""
Update relationship information about another TaskMixin. Default is no-op.
Override if necessary.
Expand Down Expand Up @@ -163,7 +176,7 @@ def _set_relatives(

task_list: List[Operator] = []
for task_object in task_or_task_list:
task_object.update_relative(self, not upstream)
task_object.update_relative(self, not upstream, edge_modifier=edge_modifier)
relatives = task_object.leaves if upstream else task_object.roots
for task in relatives:
if not isinstance(task, (BaseOperator, MappedOperator)):
Expand Down
107 changes: 53 additions & 54 deletions airflow/utils/edgemodifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import TYPE_CHECKING, List, Optional, Sequence, Union
from typing import List, Optional, Sequence, Union

from airflow.models.taskmixin import DependencyMixin

if TYPE_CHECKING:
from airflow.models.baseoperator import BaseOperator
from airflow.models.taskmixin import DAGNode, DependencyMixin


class EdgeModifier(DependencyMixin):
Expand All @@ -42,8 +39,8 @@ class EdgeModifier(DependencyMixin):

def __init__(self, label: Optional[str] = None):
self.label = label
self._upstream: List["BaseOperator"] = []
self._downstream: List["BaseOperator"] = []
self._upstream: List["DependencyMixin"] = []
self._downstream: List["DependencyMixin"] = []

@property
def roots(self):
Expand All @@ -53,76 +50,78 @@ def roots(self):
def leaves(self):
return self._upstream

@staticmethod
def _make_list(item_or_list):
if not isinstance(item_or_list, Sequence):
return [item_or_list]
return item_or_list

def _save_nodes(
self,
nodes: Union["DependencyMixin", Sequence["DependencyMixin"]],
stream: List["DependencyMixin"],
):
from airflow.models.xcom_arg import XComArg
from airflow.utils.task_group import TaskGroup

for node in self._make_list(nodes):
if isinstance(node, (TaskGroup, XComArg)):
stream.append(node)
elif isinstance(node, DAGNode):
if node.task_group and not node.task_group.is_root:
stream.append(node.task_group)
else:
stream.append(node)
else:
raise TypeError(
f"Cannot use edge labels with {type(node).__name__}, "
f"only tasks, XComArg or TaskGroups"
)

def set_upstream(
self, task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], chain: bool = True
self,
other: Union["DependencyMixin", Sequence["DependencyMixin"]],
edge_modifier: Optional["EdgeModifier"] = None,
):
"""
Sets the given task/list onto the upstream attribute, and then checks if
we have both sides so we can resolve the relationship.

Providing this also provides << via DependencyMixin.
"""
from airflow.models.baseoperator import BaseOperator

# Ensure we have a list, even if it's just one item
if isinstance(task_or_task_list, DependencyMixin):
task_or_task_list = [task_or_task_list]
# Unfurl it into actual operators
operators: List[BaseOperator] = []
for task in task_or_task_list:
for root in task.roots:
if not isinstance(root, BaseOperator):
raise TypeError(f"Cannot use edge labels with {type(root).__name__}, only operators")
operators.append(root)
# For each already-declared downstream, pair off with each new upstream
# item and store the edge info.
for operator in operators:
for downstream in self._downstream:
self.add_edge_info(operator.dag, operator.task_id, downstream.task_id)
if chain:
operator.set_downstream(downstream)
# Add the new tasks to our list of ones we've seen
self._upstream.extend(operators)
self._save_nodes(other, self._upstream)
for node in self._downstream:
node.set_upstream(other, edge_modifier=self)

def set_downstream(
self, task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], chain: bool = True
self,
other: Union["DependencyMixin", Sequence["DependencyMixin"]],
edge_modifier: Optional["EdgeModifier"] = None,
):
"""
Sets the given task/list onto the downstream attribute, and then checks if
we have both sides so we can resolve the relationship.

Providing this also provides >> via DependencyMixin.
"""
from airflow.models.baseoperator import BaseOperator

# Ensure we have a list, even if it's just one item
if isinstance(task_or_task_list, DependencyMixin):
task_or_task_list = [task_or_task_list]
# Unfurl it into actual operators
operators: List[BaseOperator] = []
for task in task_or_task_list:
for leaf in task.leaves:
if not isinstance(leaf, BaseOperator):
raise TypeError(f"Cannot use edge labels with {type(leaf).__name__}, only operators")
operators.append(leaf)
# Pair them off with existing
for operator in operators:
for upstream in self._upstream:
self.add_edge_info(upstream.dag, upstream.task_id, operator.task_id)
if chain:
upstream.set_downstream(operator)
# Add the new tasks to our list of ones we've seen
self._downstream.extend(operators)

def update_relative(self, other: DependencyMixin, upstream: bool = True) -> None:
self._save_nodes(other, self._downstream)
for node in self._upstream:
node.set_downstream(other, edge_modifier=self)

def update_relative(
self,
other: "DependencyMixin",
upstream: bool = True,
edge_modifier: Optional["EdgeModifier"] = None,
) -> None:
"""
Called if we're not the "main" side of a relationship; we still run the
same logic, though.
"""
if upstream:
self.set_upstream(other, chain=False)
self.set_upstream(other)
else:
self.set_downstream(other, chain=False)
self.set_downstream(other)

def add_edge_info(self, dag, upstream_id: str, downstream_id: str):
"""
Expand Down
17 changes: 15 additions & 2 deletions airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,12 @@ def label(self) -> Optional[str]:
"""group_id excluding parent's group_id used as the node label in UI."""
return self._group_id

def update_relative(self, other: DependencyMixin, upstream=True) -> None:
def update_relative(
self,
other: DependencyMixin,
upstream=True,
edge_modifier: Optional["EdgeModifier"] = None,
) -> None:
"""
Overrides TaskMixin.update_relative.

Expand All @@ -256,10 +261,18 @@ def update_relative(self, other: DependencyMixin, upstream=True) -> None:
f"or operators; received {task.__class__.__name__}"
)

# Do not set a relationship between a TaskGroup and a Label's roots
if self == task:
continue

if upstream:
self.upstream_task_ids.add(task.node_id)
if edge_modifier:
edge_modifier.add_edge_info(self.dag, task.node_id, self.upstream_join_id)
else:
self.downstream_task_ids.add(task.node_id)
if edge_modifier:
edge_modifier.add_edge_info(self.dag, self.downstream_join_id, task.node_id)

def _set_relatives(
self,
Expand All @@ -282,7 +295,7 @@ def _set_relatives(
task_or_task_list = [task_or_task_list]

for task_like in task_or_task_list:
self.update_relative(task_like, upstream)
self.update_relative(task_like, upstream, edge_modifier=edge_modifier)

def __enter__(self) -> "TaskGroup":
TaskGroupContext.push_context_managed_task_group(self)
Expand Down
Loading