Skip to content

Commit cbfa578

Browse files
authored
Remove MappedOperator inheritance (#53696)
1 parent 5f29f4e commit cbfa578

File tree

27 files changed

+652
-281
lines changed

27 files changed

+652
-281
lines changed

airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
import structlog
2424

2525
from airflow.api_fastapi.common.parameters import state_priority
26+
from airflow.models.mappedoperator import MappedOperator
2627
from airflow.models.taskmap import TaskMap
27-
from airflow.sdk.definitions.mappedoperator import MappedOperator
2828
from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup, get_task_group_children_getter
2929
from airflow.serialization.serialized_objects import SerializedBaseOperator
3030

airflow-core/src/airflow/models/dagrun.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,11 @@
102102
from airflow.serialization.serialized_objects import SerializedBaseOperator
103103
from airflow.utils.types import ArgNotSet
104104

105-
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
106-
107105
CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI])
108-
109-
AttributeValueType = (
106+
AttributeValueType: TypeAlias = (
110107
str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float]
111108
)
109+
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
112110

113111
RUN_ID_REGEX = r"^(?:manual|scheduled|asset_triggered)__(?:\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00)$"
114112

@@ -1483,15 +1481,15 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
14831481
If the ti does not need expansion, either because the task is not
14841482
mapped, or has already been expanded, *None* is returned.
14851483
"""
1484+
from airflow.models.mappedoperator import is_mapped
1485+
14861486
if TYPE_CHECKING:
14871487
assert ti.task
14881488

14891489
if ti.map_index >= 0: # Already expanded, we're good.
14901490
return None
14911491

1492-
from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator
1493-
1494-
if isinstance(ti.task, TaskSDKMappedOperator):
1492+
if is_mapped(ti.task):
14951493
# If we get here, it could be that we are moving from non-mapped to mapped
14961494
# after task instance clearing or this ti is not yet expanded. Safe to clear
14971495
# the db references.
@@ -1510,7 +1508,7 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
15101508
revised_map_index_task_ids: set[str] = set()
15111509
for schedulable in itertools.chain(schedulable_tis, additional_tis):
15121510
if TYPE_CHECKING:
1513-
assert isinstance(schedulable.task, SerializedBaseOperator)
1511+
assert isinstance(schedulable.task, Operator)
15141512
old_state = schedulable.state
15151513
if not schedulable.are_dependencies_met(session=session, dep_context=dep_context):
15161514
old_states[schedulable.key] = old_state
@@ -1995,25 +1993,23 @@ def schedule_tis(
19951993
empty_ti_ids: list[str] = []
19961994
schedulable_ti_ids: list[str] = []
19971995
for ti in schedulable_tis:
1996+
task = ti.task
19981997
if TYPE_CHECKING:
1999-
assert isinstance(ti.task, SerializedBaseOperator)
1998+
assert isinstance(task, Operator)
20001999
if (
2001-
ti.task.inherits_from_empty_operator
2002-
and not ti.task.on_execute_callback
2003-
and not ti.task.on_success_callback
2004-
and not ti.task.outlets
2005-
and not ti.task.inlets
2000+
task.inherits_from_empty_operator
2001+
and not task.on_execute_callback
2002+
and not task.on_success_callback
2003+
and not task.outlets
2004+
and not task.inlets
20062005
):
20072006
empty_ti_ids.append(ti.id)
20082007
# check "start_trigger_args" to see whether the operator supports start execution from triggerer
20092008
# if so, we'll then check "start_from_trigger" to see whether this feature is turned on and defer
20102009
# this task.
20112010
# if not, we'll add this "ti" into "schedulable_ti_ids" and later execute it to run in the worker
2012-
elif ti.task.start_trigger_args is not None:
2013-
context = ti.get_template_context()
2014-
start_from_trigger = ti.task.expand_start_from_trigger(context=context, session=session)
2015-
2016-
if start_from_trigger:
2011+
elif task.start_trigger_args is not None:
2012+
if task.expand_start_from_trigger(context=ti.get_template_context()):
20172013
ti.start_date = timezone.utcnow()
20182014
if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
20192015
ti.try_number += 1

airflow-core/src/airflow/models/expandinput.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,11 @@
1919

2020
import functools
2121
import operator
22-
from collections.abc import Iterable, Sized
22+
from collections.abc import Iterable, Mapping, Sequence, Sized
2323
from typing import TYPE_CHECKING, Any, ClassVar
2424

2525
import attrs
2626

27-
if TYPE_CHECKING:
28-
from typing import TypeGuard
29-
30-
from sqlalchemy.orm import Session
31-
32-
from airflow.models.xcom_arg import SchedulerXComArg
33-
3427
from airflow.sdk.definitions._internal.expandinput import (
3528
DictOfListsExpandInput,
3629
ListOfDictsExpandInput,
@@ -41,6 +34,18 @@
4134
is_mappable,
4235
)
4336

37+
if TYPE_CHECKING:
38+
from typing import TypeAlias, TypeGuard
39+
40+
from sqlalchemy.orm import Session
41+
42+
from airflow.models.mappedoperator import MappedOperator
43+
from airflow.models.xcom_arg import SchedulerXComArg
44+
from airflow.serialization.serialized_objects import SerializedBaseOperator
45+
46+
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
47+
48+
4449
__all__ = [
4550
"DictOfListsExpandInput",
4651
"ListOfDictsExpandInput",
@@ -111,6 +116,26 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
111116
lengths = self._get_map_lengths(run_id, session=session)
112117
return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1)
113118

119+
def iter_references(self) -> Iterable[tuple[Operator, str]]:
120+
from airflow.models.referencemixin import ReferenceMixin
121+
122+
for x in self.value.values():
123+
if isinstance(x, ReferenceMixin):
124+
yield from x.iter_references()
125+
126+
127+
# To replace tedious isinstance() checks.
128+
def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]:
129+
from airflow.sdk.definitions.xcom_arg import XComArg
130+
131+
return not isinstance(v, (MappedArgument, XComArg))
132+
133+
134+
def _describe_type(value: Any) -> str:
135+
if value is None:
136+
return "None"
137+
return type(value).__name__
138+
114139

115140
@attrs.define
116141
class SchedulerListOfDictsExpandInput:
@@ -133,6 +158,16 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
133158
raise NotFullyPopulated({"expand_kwargs() argument"})
134159
return length
135160

161+
def iter_references(self) -> Iterable[tuple[Operator, str]]:
162+
from airflow.models.referencemixin import ReferenceMixin
163+
164+
if isinstance(self.value, ReferenceMixin):
165+
yield from self.value.iter_references()
166+
else:
167+
for x in self.value:
168+
if isinstance(x, ReferenceMixin):
169+
yield from x.iter_references()
170+
136171

137172
_EXPAND_INPUT_TYPES: dict[str, type[SchedulerExpandInput]] = {
138173
"dict-of-lists": SchedulerDictOfListsExpandInput,

0 commit comments

Comments
 (0)