Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2028,6 +2028,14 @@ def tg2(inp):
# and "ti_count == ancestor_ti_count" does not work, since the further
# expansion may be of length 1.
if not _is_further_mapped_inside(relative, common_ancestor):
placeholder_index = resolve_placeholder_map_index(
task=task, relative=relative, map_index=ancestor_map_index, run_id=run_id, session=session
)
# Handle cases where an upstream mapped placeholder (map_index = -1) has already
# been expanded and replaced by its successor (map_index = 0) at evaluation time.
if placeholder_index is not None:
return placeholder_index

return ancestor_map_index

# Otherwise we need a partial aggregation for values from selected task
Expand Down Expand Up @@ -2102,6 +2110,54 @@ def _visit_relevant_relatives_for_mapped(mapped_tasks: Iterable[tuple[str, int]]
return visited


def resolve_placeholder_map_index(
*,
task: Operator,
relative: Operator,
map_index: int,
run_id: str,
session: Session,
) -> int | None:
"""
Resolve the correct map_index for upstream dependency evaluation.

This handles the transition from map_index = -1 (pre-expansion placeholder)
to map_index = 0 (post-expansion placeholder successor).

Returns:
- 0 if the placeholder has transitioned from -1 to 0
- None if no override should be applied
"""
if map_index != -1:
return None

rows = session.execute(
select(TaskInstance.task_id, TaskInstance.map_index).where(
TaskInstance.dag_id == relative.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.task_id.in_([task.task_id, relative.task_id]),
TaskInstance.map_index.in_([-1, 0]),
)
).all()

task_to_map_indexes: dict[str, list[int]] = defaultdict(list)
for task_id, mi in rows:
task_to_map_indexes[task_id].append(mi)

# We only rewrite when:
# 1) the current task is still using the placeholder (-1)
# 2) the upstream placeholder (-1) no longer exists
# 3) the post-expansion placeholder (0) does exist
if (
-1 in task_to_map_indexes.get(task.task_id, [])
and -1 not in task_to_map_indexes.get(relative.task_id, [])
and 0 in task_to_map_indexes.get(relative.task_id, [])
):
return 0

return None


class TaskInstanceNote(Base):
"""For storage of arbitrary notes concerning the task instance."""

Expand Down
89 changes: 89 additions & 0 deletions airflow-core/tests/unit/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2981,6 +2981,95 @@ def g(v):
assert result == expected


def test_downstream_placeholder_handles_upstream_post_expansion(dag_maker, session):
"""
Test dynamic task mapping behavior when an upstream placeholder task
(map_index = -1) has been replaced by the first expanded task
(map_index = 0).

This verifies that trigger rule evaluation correctly resolves relevant
upstream map indexes both when referencing the original placeholder
and when referencing the first expanded task instance.
"""

with dag_maker(session=session) as dag:

@task
def get_mapping_source():
return ["one", "two", "three"]

@task
def mapped_task(x):
output = f"{x}"
return output

@task_group(prefix_group_id=False)
def the_task_group(x):
start = MockOperator(task_id="start")
upstream = mapped_task(x)

# Plain downstream inside task group (no mapping source).
downstream = MockOperator(task_id="downstream")

start >> upstream >> downstream

mapping_source = get_mapping_source()
mapped_tg = the_task_group.expand(x=mapping_source)

mapping_source >> mapped_tg

# Create DAG run and execute prerequisites.
dr = dag_maker.create_dagrun()

dag_maker.run_ti("get_mapping_source", map_index=-1, dag_run=dr, session=session)

# Force expansion of the upstream mapped task.
upstream_task = dag.get_task("mapped_task")
_, max_index = TaskMap.expand_mapped_task(
upstream_task,
dr.run_id,
session=session,
)
expanded_ti_count = max_index + 1

downstream_task = dag.get_task("downstream")

# Grab the downstream placeholder TI.
downstream_ti = dr.get_task_instance(task_id="downstream", map_index=-1, session=session)
downstream_ti.refresh_from_task(downstream_task)

result = downstream_ti.get_relevant_upstream_map_indexes(
upstream=upstream_task,
ti_count=expanded_ti_count,
session=session,
)

assert result == 0

# Now do the same for downstream expanded (map_index = 0) to ensure existing behavior is not broken.
# Force expansion of the downstream mapped task.
_, max_index = TaskMap.expand_mapped_task(
downstream_task,
dr.run_id,
session=session,
)
expanded_ti_count = max_index + 1

# Grab the first expanded downstream task. Behavior is the same for all cases where map_index >= 0.
downstream_ti = dr.get_task_instance(task_id="downstream", map_index=0, session=session)
downstream_ti.refresh_from_task(downstream_task)

result = downstream_ti.get_relevant_upstream_map_indexes(
upstream=upstream_task,
ti_count=expanded_ti_count,
session=session,
)

# Verify behavior remains unchanged once the downstream task itself
# has expanded (map_index >= 0).
assert result == 0


def test_find_relevant_relatives_with_non_mapped_task_as_tuple(dag_maker, session):
"""Test that specifying a non-mapped task as a tuple doesn't raise NotMapped exception."""
# t1 -> t2 (non-mapped) -> t3
Expand Down