Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
from airflow.models.xcom import XComModel
from airflow.sdk.definitions._internal.expandinput import NotFullyPopulated
from airflow.sdk.definitions.asset import Asset, AssetUniqueKey
from airflow.sdk.definitions.taskgroup import MappedTaskGroup
from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState

if TYPE_CHECKING:
Expand Down Expand Up @@ -288,20 +287,20 @@ def ti_run(
def _get_upstream_map_indexes(
task: Operator, ti_map_index: int, run_id: str, session: SessionDep
) -> Iterator[tuple[str, int | list[int] | None]]:
task_mapped_group = task.get_closest_mapped_task_group()
for upstream_task in task.upstream_list:
upstream_mapped_group = upstream_task.get_closest_mapped_task_group()
map_indexes: int | list[int] | None
if not isinstance(upstream_task.task_group, MappedTaskGroup):
if upstream_mapped_group is None:
# regular tasks or non-mapped task groups
map_indexes = None
elif task.task_group == upstream_task.task_group:
# tasks in the same mapped task group
# the task should use the map_index as the previous task in the same mapped task group
elif task_mapped_group == upstream_mapped_group:
# tasks in the same mapped task group hierarchy
map_indexes = ti_map_index
else:
# tasks not in the same mapped task group
# the upstream mapped task group should combine the return xcom as a list and return it
mapped_ti_count: int
upstream_mapped_group = upstream_task.task_group
try:
# for cases that does not need to resolve xcom
mapped_ti_count = upstream_mapped_group.get_parse_time_mapped_ti_count()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,81 @@ def task2():
upstream_map_indexes = response.json()["upstream_map_indexes"]
assert upstream_map_indexes == expected_upstream_map_indexes[(ti.task_id, ti.map_index)]

def test_nested_mapped_task_group_upstream_indexes(self, client, dag_maker):
"""
Test that upstream_map_indexes are correctly computed for tasks in nested mapped task groups.
"""
with dag_maker("test_nested_mapped_tg", serialized=True):

@task
def alter_input(inp: str) -> str:
return f"{inp}_Altered"

@task
def print_task(orig_input: str, altered_input: str) -> str:
return f"orig:{orig_input},altered:{altered_input}"

@task_group
def inner_task_group(orig_input: str) -> None:
altered_input = alter_input(orig_input)
print_task(orig_input, altered_input)

@task_group
def expandable_task_group(param: str) -> None:
inner_task_group(param)

expandable_task_group.expand(param=["One", "Two", "Three"])

dr = dag_maker.create_dagrun()

# Set all alter_input tasks to success so print_task can run
for ti in dr.get_task_instances():
if "alter_input" in ti.task_id and ti.map_index >= 0:
ti.state = State.SUCCESS
elif "print_task" in ti.task_id and ti.map_index >= 0:
ti.set_state(State.QUEUED)
dag_maker.session.flush()

# Expected upstream_map_indexes for each print_task instance
expected_upstream_map_indexes = {
("expandable_task_group.inner_task_group.print_task", 0): {
"expandable_task_group.inner_task_group.alter_input": 0
},
("expandable_task_group.inner_task_group.print_task", 1): {
"expandable_task_group.inner_task_group.alter_input": 1
},
("expandable_task_group.inner_task_group.print_task", 2): {
"expandable_task_group.inner_task_group.alter_input": 2
},
}

# Get only the expanded print_task instances (not the template)
print_task_tis = [
ti for ti in dr.get_task_instances() if "print_task" in ti.task_id and ti.map_index >= 0
]

# Test each print_task instance
for ti in print_task_tis:
response = client.patch(
f"/execution/task-instances/{ti.id}/run",
json={
"state": "running",
"hostname": "random-hostname",
"unixname": "random-unixname",
"pid": 100,
"start_date": "2024-09-30T12:00:00Z",
},
)

assert response.status_code == 200
upstream_map_indexes = response.json()["upstream_map_indexes"]
expected = expected_upstream_map_indexes[(ti.task_id, ti.map_index)]

assert upstream_map_indexes == expected, (
f"Task {ti.task_id}[{ti.map_index}] should have upstream_map_indexes {expected}, "
f"but got {upstream_map_indexes}"
)

def test_dynamic_task_mapping_with_xcom(self, client, dag_maker, create_task_instance, session, run_task):
"""
Test that the Task Instance upstream_map_indexes is correctly fetched when to running the Task Instances with xcom
Expand Down
Loading