Skip to content

Commit b11fc14

Browse files
uranusjrephraimbuddy
authored andcommitted
Count mapped upstreams only if all are finished (#30641)
* Fix Pydantic TI handling in XComArg.resolve() * Count mapped upstreams only if all are finished An XComArg's get_task_map_length() should only return an integer when the *entire* task has finished. However, before this patch, it may attempt to count a mapped upstream even when some (or all!) of its expanded tis are still unfinished, resulting its downstream to be expanded prematurely. This patch adds an additional check before we count upstream results to ensure all the upstreams are actually finished. * Use SQL IN to find unfinished TI instead This needs a special workaround for a NULL quirk in SQL. (cherry picked from commit 5f2628d)
1 parent ed1af98 commit b11fc14

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

airflow/models/xcom_arg.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import inspect
2222
from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, Union, overload
2323

24-
from sqlalchemy import func
24+
from sqlalchemy import func, or_
2525
from sqlalchemy.orm import Session
2626

2727
from airflow.exceptions import AirflowException, XComNotFound
@@ -33,6 +33,7 @@
3333
from airflow.utils.mixins import ResolveMixin
3434
from airflow.utils.session import NEW_SESSION, provide_session
3535
from airflow.utils.setup_teardown import SetupTeardownContext
36+
from airflow.utils.state import State
3637
from airflow.utils.types import NOTSET, ArgNotSet
3738
from airflow.utils.xcom import XCOM_RETURN_KEY
3839

@@ -309,11 +310,26 @@ def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
309310
return super().zip(*others, fillvalue=fillvalue)
310311

311312
def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
313+
from airflow.models.taskinstance import TaskInstance
312314
from airflow.models.taskmap import TaskMap
313315
from airflow.models.xcom import XCom
314316

315317
task = self.operator
316318
if isinstance(task, MappedOperator):
319+
unfinished_ti_count_query = session.query(func.count(TaskInstance.map_index)).filter(
320+
TaskInstance.dag_id == task.dag_id,
321+
TaskInstance.run_id == run_id,
322+
TaskInstance.task_id == task.task_id,
323+
# Special NULL treatment is needed because 'state' can be NULL.
324+
# The "IN" part would produce "NULL NOT IN ..." and eventually
325+
# "NULl = NULL", which is a big no-no in SQL.
326+
or_(
327+
TaskInstance.state.is_(None),
328+
TaskInstance.state.in_(s.value for s in State.unfinished if s is not None),
329+
),
330+
)
331+
if unfinished_ti_count_query.scalar():
332+
return None # Not all of the expanded tis are done yet.
317333
query = session.query(func.count(XCom.map_index)).filter(
318334
XCom.dag_id == task.dag_id,
319335
XCom.run_id == run_id,
@@ -332,7 +348,11 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
332348

333349
@provide_session
334350
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
351+
from airflow.models.taskinstance import TaskInstance
352+
335353
ti = context["ti"]
354+
assert isinstance(ti, TaskInstance), "Wait for AIP-44 implementation to complete"
355+
336356
task_id = self.operator.task_id
337357
map_indexes = ti.get_relevant_upstream_map_indexes(
338358
self.operator,

tests/models/test_taskinstance.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3956,3 +3956,44 @@ def last_task():
39563956
middle_ti = dr.get_task_instance(task_id="middle_task", map_index=i)
39573957
assert middle_ti.state == State.SCHEDULED
39583958
assert "3 downstream tasks scheduled from follow-on schedule" in caplog.text
3959+
3960+
3961+
def test_mini_scheduler_not_skip_mapped_downstream_until_all_upstreams_finish(dag_maker, session):
3962+
with dag_maker(session=session):
3963+
3964+
@task
3965+
def generate() -> list[list[int]]:
3966+
return []
3967+
3968+
@task
3969+
def a_sum(numbers: list[int]) -> int:
3970+
return sum(numbers)
3971+
3972+
@task
3973+
def b_double(summed: int) -> int:
3974+
return summed * 2
3975+
3976+
@task
3977+
def c_gather(result) -> None:
3978+
pass
3979+
3980+
static = EmptyOperator(task_id="static")
3981+
3982+
summed = a_sum.expand(numbers=generate())
3983+
doubled = b_double.expand(summed=summed)
3984+
static >> c_gather(doubled)
3985+
3986+
dr: DagRun = dag_maker.create_dagrun()
3987+
tis = {(ti.task_id, ti.map_index): ti for ti in dr.task_instances}
3988+
3989+
static_ti = tis[("static", -1)]
3990+
static_ti.run(session=session)
3991+
static_ti.schedule_downstream_tasks(session=session)
3992+
# No tasks should be skipped yet!
3993+
assert not dr.get_task_instances([TaskInstanceState.SKIPPED], session=session)
3994+
3995+
generate_ti = tis[("generate", -1)]
3996+
generate_ti.run(session=session)
3997+
generate_ti.schedule_downstream_tasks(session=session)
3998+
# Now downstreams can be skipped.
3999+
assert dr.get_task_instances([TaskInstanceState.SKIPPED], session=session)

0 commit comments

Comments
 (0)