Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1662,7 +1662,7 @@ repos:
^airflow-core/src/airflow/operators/subdag\.py$|
^airflow-core/src/airflow/plugins_manager\.py$|
^airflow-core/src/airflow/providers_manager\.py$|
^airflow-core/src/airflow/serialization/dag\.py$|
^airflow-core/src/airflow/serialization/definitions/[_a-z]+\.py$|
^airflow-core/src/airflow/serialization/enums\.py$|
^airflow-core/src/airflow/serialization/helpers\.py$|
^airflow-core/src/airflow/serialization/serialized_objects\.py$|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from airflow.api_fastapi.core_api.services.ui.task_group import get_task_group_children_getter
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskmap import TaskMap
from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
from airflow.serialization.serialized_objects import SerializedBaseOperator

log = structlog.get_logger(logger_name=__name__)
Expand Down Expand Up @@ -78,8 +78,8 @@ def _get_aggs_for_node(detail):


def _find_aggregates(
node: TaskGroup | MappedTaskGroup | SerializedBaseOperator | TaskMap,
parent_node: TaskGroup | MappedTaskGroup | SerializedBaseOperator | TaskMap | None,
node: SerializedTaskGroup | SerializedBaseOperator | TaskMap,
parent_node: SerializedTaskGroup | SerializedBaseOperator | TaskMap | None,
ti_details: dict[str, list],
) -> Iterable[dict]:
"""Recursively fill the Task Group Map."""
Expand All @@ -98,7 +98,7 @@ def _find_aggregates(
}

return
if isinstance(node, TaskGroup):
if isinstance(node, SerializedTaskGroup):
children = []
for child in get_task_group_children_getter()(node):
for child_node in _find_aggregates(node=child, parent_node=node, ti_details=ti_details):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
from operator import methodcaller

from airflow.configuration import conf
from airflow.models.mappedoperator import MappedOperator
from airflow.sdk.definitions.taskgroup import MappedTaskGroup
from airflow.models.mappedoperator import MappedOperator, is_mapped
from airflow.serialization.serialized_objects import SerializedBaseOperator


Expand All @@ -51,14 +50,14 @@ def task_group_to_dict(task_item_or_group, parent_group_is_mapped=False):
node_operator["setup_teardown_type"] = "setup"
elif task.is_teardown:
node_operator["setup_teardown_type"] = "teardown"
if isinstance(task, MappedOperator) or parent_group_is_mapped:
if is_mapped(task) or parent_group_is_mapped:
node_operator["is_mapped"] = True
return node_operator

task_group = task_item_or_group
is_mapped = isinstance(task_group, MappedTaskGroup)
mapped = is_mapped(task_group)
children = [
task_group_to_dict(child, parent_group_is_mapped=parent_group_is_mapped or is_mapped)
task_group_to_dict(child, parent_group_is_mapped=parent_group_is_mapped or mapped)
for child in get_task_group_children_getter()(task_group)
]

Expand All @@ -74,7 +73,7 @@ def task_group_to_dict(task_item_or_group, parent_group_is_mapped=False):
"id": task_group.group_id,
"label": task_group.label,
"tooltip": task_group.tooltip,
"is_mapped": is_mapped,
"is_mapped": mapped,
"children": children,
"type": "task",
}
Expand All @@ -83,9 +82,9 @@ def task_group_to_dict(task_item_or_group, parent_group_is_mapped=False):
def task_group_to_dict_grid(task_item_or_group, parent_group_is_mapped=False):
"""Create a nested dict representation of this TaskGroup and its children used to construct the Grid."""
if isinstance(task := task_item_or_group, (MappedOperator, SerializedBaseOperator)):
is_mapped = None
if task.is_mapped or parent_group_is_mapped:
is_mapped = True
mapped = None
if parent_group_is_mapped or is_mapped(task):
mapped = True
setup_teardown_type = None
if task.is_setup is True:
setup_teardown_type = "setup"
Expand All @@ -94,22 +93,22 @@ def task_group_to_dict_grid(task_item_or_group, parent_group_is_mapped=False):
return {
"id": task.task_id,
"label": task.label,
"is_mapped": is_mapped,
"is_mapped": mapped,
"children": None,
"setup_teardown_type": setup_teardown_type,
}

task_group = task_item_or_group
task_group_sort = get_task_group_children_getter()
is_mapped_group = isinstance(task_group, MappedTaskGroup)
mapped = is_mapped(task_group)
children = [
task_group_to_dict_grid(x, parent_group_is_mapped=parent_group_is_mapped or is_mapped_group)
task_group_to_dict_grid(x, parent_group_is_mapped=parent_group_is_mapped or mapped)
for x in task_group_sort(task_group)
]

return {
"id": task_group.group_id,
"label": task_group.label,
"is_mapped": is_mapped_group or None,
"is_mapped": mapped or None,
"children": children or None,
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
import pendulum

from airflow.providers.standard.operators.bash import BashOperator
from airflow.sdk import DAG
from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.sdk import DAG, TaskGroup

with DAG(
dag_id="example_setup_teardown",
Expand Down
3 changes: 1 addition & 2 deletions airflow-core/src/airflow/example_dags/example_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@

from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk import DAG
from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.sdk import DAG, TaskGroup

# [START howto_task_group]
with DAG(
Expand Down
4 changes: 1 addition & 3 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1676,9 +1676,7 @@ def task_filter(task: Operator) -> bool:

# Create the missing tasks, including mapped tasks
tis_to_create = self._create_tasks(
# TODO (GH-52141): task_dict in scheduler should contain scheduler
# types instead, but currently it inherits SDK's DAG.
(task for task in cast("Iterable[Operator]", dag.task_dict.values()) if task_filter(task)),
(task for task in dag.task_dict.values() if task_filter(task)),
task_creator,
session=session,
)
Expand Down
31 changes: 21 additions & 10 deletions airflow-core/src/airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import functools
import operator
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, TypeGuard
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, TypeGuard, overload

import attrs
import methodtools
Expand All @@ -31,7 +31,7 @@
from airflow.sdk import BaseOperator as TaskSDKBaseOperator
from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator
from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
from airflow.serialization.definitions.taskgroup import SerializedMappedTaskGroup, SerializedTaskGroup
from airflow.serialization.enums import DagAttributeTypes
from airflow.serialization.serialized_objects import DEFAULT_OPERATOR_DEPS, SerializedBaseOperator
from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy
Expand All @@ -57,8 +57,16 @@
log = structlog.get_logger(__name__)


def is_mapped(task: Operator) -> TypeGuard[MappedOperator]:
return task.is_mapped
@overload
def is_mapped(obj: Operator) -> TypeGuard[MappedOperator]: ...


@overload
def is_mapped(obj: SerializedTaskGroup) -> TypeGuard[SerializedMappedTaskGroup]: ...


def is_mapped(obj: Operator | SerializedTaskGroup) -> TypeGuard[MappedOperator | SerializedMappedTaskGroup]:
return obj.is_mapped


@attrs.define(
Expand Down Expand Up @@ -100,8 +108,11 @@ class MappedOperator(DAGNode):
start_from_trigger: bool = False
_needs_expansion: bool = True

dag: SerializedDAG = attrs.field(init=False)
task_group: TaskGroup = attrs.field(init=False)
# TODO (GH-52141): These should contain serialized containers, but currently
# this class inherits from an SDK one.
dag: SerializedDAG = attrs.field(init=False) # type: ignore[assignment]
task_group: SerializedTaskGroup = attrs.field(init=False) # type: ignore[assignment]

start_date: pendulum.DateTime | None = attrs.field(init=False, default=None)
end_date: pendulum.DateTime | None = attrs.field(init=False, default=None)
upstream_task_ids: set[str] = attrs.field(factory=set, init=False)
Expand Down Expand Up @@ -388,7 +399,7 @@ def _get_specified_expand_input(self) -> SchedulerExpandInput:
return getattr(self, self._expand_input_attr)

# TODO (GH-52141): Copied from sdk. Find a better place for this to live in.
def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
def iter_mapped_task_groups(self) -> Iterator[SerializedMappedTaskGroup]:
"""
Return mapped task groups this task belongs to.

Expand All @@ -401,7 +412,7 @@ def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
yield from group.iter_mapped_task_groups()

# TODO (GH-52141): Copied from sdk. Find a better place for this to live in.
def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
def get_closest_mapped_task_group(self) -> SerializedMappedTaskGroup | None:
"""
Get the mapped task group "closest" to this task in the DAG.

Expand Down Expand Up @@ -504,7 +515,7 @@ def _(task: MappedOperator | TaskSDKMappedOperator, run_id: str, *, session: Ses


@get_mapped_ti_count.register
def _(group: TaskGroup, run_id: str, *, session: Session) -> int:
def _(group: SerializedTaskGroup, run_id: str, *, session: Session) -> int:
"""
Return the number of instances a task in this group should be mapped to at run time.

Expand All @@ -523,7 +534,7 @@ def _(group: TaskGroup, run_id: str, *, session: Session) -> int:

def iter_mapped_task_group_lengths(group) -> Iterator[int]:
while group is not None:
if isinstance(group, MappedTaskGroup):
if isinstance(group, SerializedMappedTaskGroup):
exp_input = group._expand_input
# TODO (GH-52141): 'group' here should be scheduler-bound and returns scheduler expand input.
if not hasattr(exp_input, "get_total_map_length"):
Expand Down
22 changes: 9 additions & 13 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@
from airflow.sdk import DAG
from airflow.sdk.api.datamodels._generated import AssetProfile
from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, AssetUriRef
from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
from airflow.sdk.types import RuntimeTaskInstanceProtocol
from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.utils.context import Context

Expand Down Expand Up @@ -1534,12 +1534,9 @@ def run(
assert original_task is not None
assert original_task.dag is not None

serialized_task = SerializedDAG.deserialize_dag(
SerializedDAG.serialize_dag(original_task.dag)
).task_dict[original_task.task_id]
# TODO (GH-52141): task_dict in scheduler should contain scheduler
# types instead, but currently it inherits SDK's DAG.
self.task = cast("Operator", serialized_task)
self.task = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(original_task.dag)).task_dict[
original_task.task_id
]
res = self.check_and_change_state_before_execution(
verbose=verbose,
ignore_all_deps=ignore_all_deps,
Expand Down Expand Up @@ -2286,7 +2283,7 @@ def duration_expression_update(
)


def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None:
def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> SerializedTaskGroup | None:
"""Given two operators, find their innermost common mapped task group."""
if node1.dag is None or node2.dag is None or node1.dag_id != node2.dag_id:
return None
Expand All @@ -2295,16 +2292,15 @@ def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> Mapp
return next(common_groups, None)


def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool:
def _is_further_mapped_inside(operator: Operator, container: SerializedTaskGroup) -> bool:
"""Whether given operator is *further* mapped inside a task group."""
from airflow.models.mappedoperator import MappedOperator
from airflow.sdk.definitions.taskgroup import MappedTaskGroup
from airflow.models.mappedoperator import is_mapped

if isinstance(operator, MappedOperator):
if is_mapped(operator):
return True
task_group = operator.task_group
while task_group is not None and task_group.group_id != container.group_id:
if isinstance(task_group, MappedTaskGroup):
if is_mapped(task_group):
return True
task_group = task_group.parent_group
return False
Expand Down
17 changes: 17 additions & 0 deletions airflow-core/src/airflow/serialization/definitions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Loading
Loading