Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent using trigger_rule=TriggerRule.ALWAYS in a task-generated mapping within bare tasks #44751

Merged
merged 1 commit into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,12 @@ def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]):
super()._validate_arg_names(func, kwargs)

def expand(self, **map_kwargs: OperatorExpandArgument) -> XComArg:
if self.kwargs.get("trigger_rule") == TriggerRule.ALWAYS and any(
[isinstance(expanded, XComArg) for expanded in map_kwargs.values()]
):
raise ValueError(
"Task-generated mapping within a task using 'expand' is not allowed with trigger rule 'always'."
)
if not map_kwargs:
raise TypeError("no arguments to expand against")
self._validate_arg_names("expand", map_kwargs)
Expand All @@ -411,6 +417,21 @@ def expand(self, **map_kwargs: OperatorExpandArgument) -> XComArg:
return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)

def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg:
if (
self.kwargs.get("trigger_rule") == TriggerRule.ALWAYS
and not isinstance(kwargs, XComArg)
and any(
[
isinstance(v, XComArg)
for kwarg in kwargs
if not isinstance(kwarg, XComArg)
for v in kwarg.values()
]
)
):
raise ValueError(
"Task-generated mapping within a task using 'expand_kwargs' is not allowed with trigger rule 'always'."
)
if isinstance(kwargs, Sequence):
for item in kwargs:
if not isinstance(item, (XComArg, Mapping)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,6 @@ The grid view also provides visibility into your mapped tasks in the details pan

Although we show a "reduce" task here (``sum_it``) you don't have to have one, the mapped tasks will still be executed even if they have no downstream tasks.

.. warning:: ``TriggerRule.ALWAYS`` cannot be utilized in expanded tasks

Assigning ``trigger_rule=TriggerRule.ALWAYS`` in expanded tasks is forbidden, as expanded parameters will be undefined with the task's immediate execution.
This is enforced at the time of the DAG parsing, and will raise an error if you try to use it.

Task-generated Mapping
----------------------
Expand All @@ -113,6 +109,12 @@ The above examples we've shown could all be achieved with a ``for`` loop in the

The ``make_list`` task runs as a normal task and must return a list or dict (see `What data types can be expanded?`_), and then the ``consumer`` task will be called four times, once with each value in the return of ``make_list``.

.. warning:: Task-generated mapping cannot be utilized with ``TriggerRule.ALWAYS``

Assigning ``trigger_rule=TriggerRule.ALWAYS`` in task-generated mapping is not allowed, as expanded parameters are undefined with the task's immediate execution.
This is enforced at the time of the DAG parsing, for both tasks and mapped tasks groups, and will raise an error if you try to use it.
In the recent example, setting ``trigger_rule=TriggerRule.ALWAYS`` in the ``consumer`` task will raise an error since ``make_list`` is a task-generated mapping.

Repeated mapping
----------------

Expand Down
1 change: 1 addition & 0 deletions newsfragments/44751.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
``TriggerRule.ALWAYS`` cannot be utilized within a task-generated mapping, either in bare tasks (fixed in this PR) or mapped task groups (fixed in PR #44368). The issue with doing so, is that the task is immediately executed without waiting for the upstreams's mapping results, which certainly leads to failure of the task. This fix avoids it by raising an exception when it is detected during DAG parsing.
4 changes: 3 additions & 1 deletion task_sdk/src/airflow/sdk/definitions/taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,9 @@ def __iter__(self):

for child in self.children.values():
if isinstance(child, AbstractOperator) and child.trigger_rule == TriggerRule.ALWAYS:
raise ValueError("Tasks in a mapped task group cannot have trigger_rule set to 'ALWAYS'")
raise ValueError(
"Task-generated mapping within a mapped task group is not allowed with trigger rule 'always'"
)
yield from self._iter_child(child)

def iter_mapped_dependencies(self) -> Iterator[DAGNode]:
Expand Down
38 changes: 38 additions & 0 deletions tests/decorators/test_mapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,41 @@ def f(x: int, y: int) -> int:
xcoms.add(ti.xcom_pull(session=session, task_ids=ti.task_id, map_indexes=ti.map_index))

assert xcoms == {11, 12, 13}


@pytest.mark.db_test
def test_fail_task_generated_mapping_with_trigger_rule_always__exapnd(dag_maker, session):
with DAG(dag_id="d", schedule=None, start_date=DEFAULT_DATE):

@task
def get_input():
return ["world", "moon"]

@task(trigger_rule="always")
def hello(input):
print(f"Hello, {input}")

with pytest.raises(
ValueError,
match="Task-generated mapping within a task using 'expand' is not allowed with trigger rule 'always'",
):
hello.expand(input=get_input())


@pytest.mark.db_test
def test_fail_task_generated_mapping_with_trigger_rule_always__exapnd_kwargs(dag_maker, session):
with DAG(dag_id="d", schedule=None, start_date=DEFAULT_DATE):

@task
def get_input():
return ["world", "moon"]

@task(trigger_rule="always")
def hello(input, input2):
print(f"Hello, {input}, {input2}")

with pytest.raises(
ValueError,
match="Task-generated mapping within a task using 'expand_kwargs' is not allowed with trigger rule 'always'",
):
hello.expand_kwargs([{"input": get_input(), "input2": get_input()}])
5 changes: 3 additions & 2 deletions tests/decorators/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def tg():


@pytest.mark.db_test
def test_expand_fail_trigger_rule_always(dag_maker, session):
def test_fail_task_generated_mapping_with_trigger_rule_always(dag_maker, session):
@dag(schedule=None, start_date=pendulum.datetime(2022, 1, 1))
def pipeline():
@task
Expand All @@ -151,7 +151,8 @@ def tg(param):
t1(param)

with pytest.raises(
ValueError, match="Tasks in a mapped task group cannot have trigger_rule set to 'ALWAYS'"
ValueError,
match="Task-generated mapping within a mapped task group is not allowed with trigger rule 'always'",
):
tg.expand(param=get_param())

Expand Down