Skip to content

Commit cfdbb63

Browse files
committed
Move chain, chain_linear & cross_downstream to Task SDK
This functions are for DAG Authors to define relationship between multiple tasks in batch
1 parent bdfea80 commit cfdbb63

File tree

17 files changed

+417
-384
lines changed

17 files changed

+417
-384
lines changed

airflow/example_dags/example_asset_with_watchers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@
2121
from __future__ import annotations
2222

2323
from airflow.decorators import task
24-
from airflow.models.baseoperator import chain
2524
from airflow.models.dag import DAG
2625
from airflow.providers.standard.triggers.file import FileDeleteTrigger
27-
from airflow.sdk import Asset, AssetWatcher
26+
from airflow.sdk import Asset, AssetWatcher, chain
2827

2928
file_path = "/tmp/test"
3029

airflow/example_dags/example_bash_decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
from airflow.decorators import dag, task
2323
from airflow.exceptions import AirflowSkipException
24-
from airflow.models.baseoperator import chain
2524
from airflow.providers.standard.operators.empty import EmptyOperator
25+
from airflow.sdk import chain
2626
from airflow.utils.trigger_rule import TriggerRule
2727
from airflow.utils.weekday import WeekDay
2828

airflow/example_dags/example_complex.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323

2424
import pendulum
2525

26-
from airflow.models.baseoperator import chain
2726
from airflow.models.dag import DAG
2827
from airflow.providers.standard.operators.bash import BashOperator
28+
from airflow.sdk import chain
2929

3030
with DAG(
3131
dag_id="example_complex",

airflow/example_dags/example_short_circuit_decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import pendulum
2222

2323
from airflow.decorators import dag, task
24-
from airflow.models.baseoperator import chain
2524
from airflow.providers.standard.operators.empty import EmptyOperator
25+
from airflow.sdk import chain
2626
from airflow.utils.trigger_rule import TriggerRule
2727

2828

airflow/example_dags/example_short_circuit_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121

2222
import pendulum
2323

24-
from airflow.models.baseoperator import chain
2524
from airflow.models.dag import DAG
2625
from airflow.providers.standard.operators.empty import EmptyOperator
2726
from airflow.providers.standard.operators.python import ShortCircuitOperator
27+
from airflow.sdk import chain
2828
from airflow.utils.trigger_rule import TriggerRule
2929

3030
with DAG(

airflow/models/baseoperator.py

Lines changed: 6 additions & 267 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import functools
2727
import logging
2828
import operator
29-
from collections.abc import Collection, Iterable, Iterator, Sequence
29+
from collections.abc import Collection, Iterable, Iterator
3030
from datetime import datetime, timedelta
3131
from functools import singledispatchmethod
3232
from types import FunctionType
@@ -54,14 +54,16 @@
5454
NotMapped,
5555
)
5656
from airflow.models.taskinstance import TaskInstance, clear_task_instances
57-
from airflow.models.taskmixin import DependencyMixin
5857
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator as TaskSDKAbstractOperator
5958
from airflow.sdk.definitions.baseoperator import (
60-
get_merged_defaults as get_merged_defaults, # Re-export for compat
59+
# Re-export for compat
60+
chain as chain,
61+
chain_linear as chain_linear,
62+
cross_downstream as cross_downstream,
63+
get_merged_defaults as get_merged_defaults,
6164
)
6265
from airflow.sdk.definitions.context import Context
6366
from airflow.sdk.definitions.dag import BaseOperator as TaskSDKBaseOperator
64-
from airflow.sdk.definitions.edges import EdgeModifier as TaskSDKEdgeModifier
6567
from airflow.sdk.definitions.mappedoperator import MappedOperator
6668
from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
6769
from airflow.serialization.enums import DagAttributeTypes
@@ -72,7 +74,6 @@
7274
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
7375
from airflow.utils import timezone
7476
from airflow.utils.context import context_get_outlet_events
75-
from airflow.utils.edgemodifier import EdgeModifier
7677
from airflow.utils.operator_helpers import ExecutionCallableRunner
7778
from airflow.utils.operator_resources import Resources
7879
from airflow.utils.session import NEW_SESSION, provide_session
@@ -811,265 +812,3 @@ def iter_mapped_task_group_lengths(group) -> Iterator[int]:
811812
group = group.parent_group
812813

813814
return functools.reduce(operator.mul, iter_mapped_task_group_lengths(group))
814-
815-
816-
def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
817-
r"""
818-
Given a number of tasks, builds a dependency chain.
819-
820-
This function accepts values of BaseOperator (aka tasks), EdgeModifiers (aka Labels), XComArg, TaskGroups,
821-
or lists containing any mix of these types (or a mix in the same list). If you want to chain between two
822-
lists you must ensure they have the same length.
823-
824-
Using classic operators/sensors:
825-
826-
.. code-block:: python
827-
828-
chain(t1, [t2, t3], [t4, t5], t6)
829-
830-
is equivalent to::
831-
832-
/ -> t2 -> t4 \
833-
t1 -> t6
834-
\ -> t3 -> t5 /
835-
836-
.. code-block:: python
837-
838-
t1.set_downstream(t2)
839-
t1.set_downstream(t3)
840-
t2.set_downstream(t4)
841-
t3.set_downstream(t5)
842-
t4.set_downstream(t6)
843-
t5.set_downstream(t6)
844-
845-
Using task-decorated functions aka XComArgs:
846-
847-
.. code-block:: python
848-
849-
chain(x1(), [x2(), x3()], [x4(), x5()], x6())
850-
851-
is equivalent to::
852-
853-
/ -> x2 -> x4 \
854-
x1 -> x6
855-
\ -> x3 -> x5 /
856-
857-
.. code-block:: python
858-
859-
x1 = x1()
860-
x2 = x2()
861-
x3 = x3()
862-
x4 = x4()
863-
x5 = x5()
864-
x6 = x6()
865-
x1.set_downstream(x2)
866-
x1.set_downstream(x3)
867-
x2.set_downstream(x4)
868-
x3.set_downstream(x5)
869-
x4.set_downstream(x6)
870-
x5.set_downstream(x6)
871-
872-
Using TaskGroups:
873-
874-
.. code-block:: python
875-
876-
chain(t1, task_group1, task_group2, t2)
877-
878-
t1.set_downstream(task_group1)
879-
task_group1.set_downstream(task_group2)
880-
task_group2.set_downstream(t2)
881-
882-
883-
It is also possible to mix between classic operator/sensor, EdgeModifiers, XComArg, and TaskGroups:
884-
885-
.. code-block:: python
886-
887-
chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], task_group1, x3())
888-
889-
is equivalent to::
890-
891-
/ "branch one" -> x1 \
892-
t1 -> task_group1 -> x3
893-
\ "branch two" -> x2 /
894-
895-
.. code-block:: python
896-
897-
x1 = x1()
898-
x2 = x2()
899-
x3 = x3()
900-
label1 = Label("branch one")
901-
label2 = Label("branch two")
902-
t1.set_downstream(label1)
903-
label1.set_downstream(x1)
904-
t2.set_downstream(label2)
905-
label2.set_downstream(x2)
906-
x1.set_downstream(task_group1)
907-
x2.set_downstream(task_group1)
908-
task_group1.set_downstream(x3)
909-
910-
# or
911-
912-
x1 = x1()
913-
x2 = x2()
914-
x3 = x3()
915-
t1.set_downstream(x1, edge_modifier=Label("branch one"))
916-
t1.set_downstream(x2, edge_modifier=Label("branch two"))
917-
x1.set_downstream(task_group1)
918-
x2.set_downstream(task_group1)
919-
task_group1.set_downstream(x3)
920-
921-
922-
:param tasks: Individual and/or list of tasks, EdgeModifiers, XComArgs, or TaskGroups to set dependencies
923-
"""
924-
for up_task, down_task in zip(tasks, tasks[1:]):
925-
if isinstance(up_task, DependencyMixin):
926-
up_task.set_downstream(down_task)
927-
continue
928-
if isinstance(down_task, DependencyMixin):
929-
down_task.set_upstream(up_task)
930-
continue
931-
if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence):
932-
raise TypeError(f"Chain not supported between instances of {type(up_task)} and {type(down_task)}")
933-
up_task_list = up_task
934-
down_task_list = down_task
935-
if len(up_task_list) != len(down_task_list):
936-
raise AirflowException(
937-
f"Chain not supported for different length Iterable. "
938-
f"Got {len(up_task_list)} and {len(down_task_list)}."
939-
)
940-
for up_t, down_t in zip(up_task_list, down_task_list):
941-
up_t.set_downstream(down_t)
942-
943-
944-
def cross_downstream(
945-
from_tasks: Sequence[DependencyMixin],
946-
to_tasks: DependencyMixin | Sequence[DependencyMixin],
947-
):
948-
r"""
949-
Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks.
950-
951-
Using classic operators/sensors:
952-
953-
.. code-block:: python
954-
955-
cross_downstream(from_tasks=[t1, t2, t3], to_tasks=[t4, t5, t6])
956-
957-
is equivalent to::
958-
959-
t1 ---> t4
960-
\ /
961-
t2 -X -> t5
962-
/ \
963-
t3 ---> t6
964-
965-
.. code-block:: python
966-
967-
t1.set_downstream(t4)
968-
t1.set_downstream(t5)
969-
t1.set_downstream(t6)
970-
t2.set_downstream(t4)
971-
t2.set_downstream(t5)
972-
t2.set_downstream(t6)
973-
t3.set_downstream(t4)
974-
t3.set_downstream(t5)
975-
t3.set_downstream(t6)
976-
977-
Using task-decorated functions aka XComArgs:
978-
979-
.. code-block:: python
980-
981-
cross_downstream(from_tasks=[x1(), x2(), x3()], to_tasks=[x4(), x5(), x6()])
982-
983-
is equivalent to::
984-
985-
x1 ---> x4
986-
\ /
987-
x2 -X -> x5
988-
/ \
989-
x3 ---> x6
990-
991-
.. code-block:: python
992-
993-
x1 = x1()
994-
x2 = x2()
995-
x3 = x3()
996-
x4 = x4()
997-
x5 = x5()
998-
x6 = x6()
999-
x1.set_downstream(x4)
1000-
x1.set_downstream(x5)
1001-
x1.set_downstream(x6)
1002-
x2.set_downstream(x4)
1003-
x2.set_downstream(x5)
1004-
x2.set_downstream(x6)
1005-
x3.set_downstream(x4)
1006-
x3.set_downstream(x5)
1007-
x3.set_downstream(x6)
1008-
1009-
It is also possible to mix between classic operator/sensor and XComArg tasks:
1010-
1011-
.. code-block:: python
1012-
1013-
cross_downstream(from_tasks=[t1, x2(), t3], to_tasks=[x1(), t2, x3()])
1014-
1015-
is equivalent to::
1016-
1017-
t1 ---> x1
1018-
\ /
1019-
x2 -X -> t2
1020-
/ \
1021-
t3 ---> x3
1022-
1023-
.. code-block:: python
1024-
1025-
x1 = x1()
1026-
x2 = x2()
1027-
x3 = x3()
1028-
t1.set_downstream(x1)
1029-
t1.set_downstream(t2)
1030-
t1.set_downstream(x3)
1031-
x2.set_downstream(x1)
1032-
x2.set_downstream(t2)
1033-
x2.set_downstream(x3)
1034-
t3.set_downstream(x1)
1035-
t3.set_downstream(t2)
1036-
t3.set_downstream(x3)
1037-
1038-
:param from_tasks: List of tasks or XComArgs to start from.
1039-
:param to_tasks: List of tasks or XComArgs to set as downstream dependencies.
1040-
"""
1041-
for task in from_tasks:
1042-
task.set_downstream(to_tasks)
1043-
1044-
1045-
def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]):
1046-
"""
1047-
Simplify task dependency definition.
1048-
1049-
E.g.: suppose you want precedence like so::
1050-
1051-
╭─op2─╮ ╭─op4─╮
1052-
op1─┤ ├─├─op5─┤─op7
1053-
╰-op3─╯ ╰-op6─╯
1054-
1055-
Then you can accomplish like so::
1056-
1057-
chain_linear(op1, [op2, op3], [op4, op5, op6], op7)
1058-
1059-
:param elements: a list of operators / lists of operators
1060-
"""
1061-
if not elements:
1062-
raise ValueError("No tasks provided; nothing to do.")
1063-
prev_elem = None
1064-
deps_set = False
1065-
for curr_elem in elements:
1066-
if isinstance(curr_elem, (EdgeModifier, TaskSDKEdgeModifier)):
1067-
raise ValueError("Labels are not supported by chain_linear")
1068-
if prev_elem is not None:
1069-
for task in prev_elem:
1070-
task >> curr_elem
1071-
if not deps_set:
1072-
deps_set = True
1073-
prev_elem = [curr_elem] if isinstance(curr_elem, DependencyMixin) else curr_elem
1074-
if not deps_set:
1075-
raise ValueError("No dependencies were set. Did you forget to expand with `*`?")

airflow/utils/edgemodifier.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,5 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19-
from typing import TYPE_CHECKING
20-
21-
import airflow.sdk
22-
23-
if TYPE_CHECKING:
24-
from airflow.typing_compat import TypeAlias
25-
26-
EdgeModifier: TypeAlias = airflow.sdk.definitions.edges.EdgeModifier
27-
28-
29-
# Factory functions
30-
def Label(label: str):
31-
"""Create an EdgeModifier that sets a human-readable label on the edge."""
32-
return EdgeModifier(label=label)
19+
# Re-export for compat
20+
from airflow.sdk.definitions.edges import Label as Label

dev/perf/dags/elastic_dag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
from datetime import datetime, timedelta
2323
from enum import Enum
2424

25-
from airflow.models.baseoperator import chain
2625
from airflow.models.dag import DAG
2726
from airflow.providers.standard.operators.bash import BashOperator
27+
from airflow.sdk import chain
2828

2929
# DAG File used in performance tests. Its shape can be configured by environment variables.
3030
RE_TIME_DELTA = re.compile(

0 commit comments

Comments
 (0)