Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,6 @@ class MappedTaskGroup(TaskGroup):
def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._expand_input = expand_input
for op, _ in expand_input.iter_references():
self.set_upstream(op)

def iter_mapped_dependencies(self) -> Iterator[Operator]:
"""Upstream dependencies that provide XComs used by this mapped task group."""
Expand Down Expand Up @@ -620,6 +618,11 @@ def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
(g._expand_input.get_total_map_length(run_id, session=session) for g in groups),
)

def __exit__(self, exc_type, exc_val, exc_tb):
for op, _ in self._expand_input.iter_references():
self.set_upstream(op)
super().__exit__(exc_type, exc_val, exc_tb)


class TaskGroupContext:
"""TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager."""
Expand Down
46 changes: 46 additions & 0 deletions tests/decorators/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,52 @@ def tg(a, b):
assert saved == {"a": 1, "b": MappedArgument(input=tg._expand_input, key="b")}


def test_task_group_expand_kwargs_with_upstream(dag_maker, session, caplog):
with dag_maker() as dag:

@dag.task
def t1():
return [{"a": 1}, {"a": 2}]

@task_group("tg1")
def tg1(a, b):
@dag.task()
def t2():
return [a, b]

t2()

tg1.expand_kwargs(t1())

dr = dag_maker.create_dagrun()
dr.task_instance_scheduling_decisions()
assert "Cannot expand" not in caplog.text
assert "missing upstream values: ['expand_kwargs() argument']" not in caplog.text


def test_task_group_expand_with_upstream(dag_maker, session, caplog):
with dag_maker() as dag:

@dag.task
def t1():
return [1, 2, 3]

@task_group("tg1")
def tg1(a, b):
@dag.task()
def t2():
return [a, b]

t2()

tg1.partial(a=1).expand(b=t1())

dr = dag_maker.create_dagrun()
dr.task_instance_scheduling_decisions()
assert "Cannot expand" not in caplog.text
assert "missing upstream values: ['b']" not in caplog.text


def test_override_dag_default_args():
@dag(
dag_id="test_dag",
Expand Down