Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions airflow/example_dags/example_dynamic_task_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Example DAG demonstrating the usage of dynamic task mapping."""
from __future__ import annotations

from datetime import datetime

from airflow import DAG
from airflow.decorators import task

with DAG(dag_id="example_dynamic_task_mapping", start_date=datetime(2022, 3, 4)) as dag:

@task
def add_one(x: int):
return x + 1

@task
def sum_it(values):
total = sum(values)
print(f"Total was {total}")

added_values = add_one.expand(x=[1, 2, 3])
sum_it(added_values)
1 change: 0 additions & 1 deletion docker_tests/test_docker_compose_quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from unittest import mock

import requests

from docker_tests.command_utils import run_command
from docker_tests.constants import SOURCE_ROOT
from docker_tests.docker_tests_utils import docker_image
Expand Down
23 changes: 3 additions & 20 deletions docs/apache-airflow/concepts/dynamic-task-mapping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,10 @@ Simple mapping

In its simplest form you can map over a list defined directly in your DAG file using the ``expand()`` function instead of calling your task directly.

.. code-block:: python

from datetime import datetime

from airflow import DAG
from airflow.decorators import task


with DAG(dag_id="simple_mapping", start_date=datetime(2022, 3, 4)) as dag:

@task
def add_one(x: int):
return x + 1

@task
def sum_it(values):
total = sum(values)
print(f"Total was {total}")
If you want to see a simple usage of Dynamic Task Mapping, you can look below:

added_values = add_one.expand(x=[1, 2, 3])
sum_it(added_values)
.. exampleinclude:: /../../airflow/example_dags/example_dynamic_task_mapping.py
:language: python

This will show ``Total was 9`` in the task logs when executed.

Expand Down
5 changes: 2 additions & 3 deletions docs/build_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
from itertools import filterfalse, tee
from typing import Callable, Iterable, NamedTuple, TypeVar

from rich.console import Console
from tabulate import tabulate

from docs.exts.docs_build import dev_index_generator, lint_checks
from docs.exts.docs_build.code_utils import CONSOLE_WIDTH, PROVIDER_INIT_FILE
from docs.exts.docs_build.docs_builder import DOCS_DIR, AirflowDocsBuilder, get_available_packages
Expand All @@ -36,6 +33,8 @@
from docs.exts.docs_build.github_action_utils import with_group
from docs.exts.docs_build.package_filter import process_package_filters
from docs.exts.docs_build.spelling_checks import SpellingError, display_spelling_error_summary
from rich.console import Console
from tabulate import tabulate

TEXT_RED = "\033[31m"
TEXT_RESET = "\033[0m"
Expand Down
2 changes: 1 addition & 1 deletion docs/exts/docs_build/spelling_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from functools import total_ordering
from typing import NamedTuple

from docs.exts.docs_build.code_utils import CONSOLE_WIDTH
from rich.console import Console

from airflow.utils.code_utils import prepare_code_snippet
from docs.exts.docs_build.code_utils import CONSOLE_WIDTH

CURRENT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__)))
DOCS_DIR = os.path.abspath(os.path.join(CURRENT_DIR, os.pardir, os.pardir))
Expand Down
77 changes: 55 additions & 22 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pathlib import Path
from unittest import mock

import attr
import pendulum
import pytest
from dateutil.relativedelta import FR, relativedelta
Expand All @@ -42,6 +43,7 @@
from airflow.kubernetes.pod_generator import PodGenerator
from airflow.models import DAG, Connection, DagBag, Operator
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.expandinput import EXPAND_INPUT_EMPTY
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import Param, ParamsDict
from airflow.models.xcom import XCOM_RETURN_KEY, XCom
Expand Down Expand Up @@ -534,32 +536,47 @@ def validate_deserialized_task(
serialized_task,
task,
):
"""Verify non-airflow operators are casted to BaseOperator."""
assert isinstance(serialized_task, SerializedBaseOperator)
"""Verify non-Airflow operators are casted to BaseOperator or MappedOperator."""
assert not isinstance(task, SerializedBaseOperator)
assert isinstance(task, BaseOperator)
assert isinstance(task, (BaseOperator, MappedOperator))

# Every task should have a task_group property -- even if it's the DAG's root task group
assert serialized_task.task_group

fields_to_check = task.get_serialized_fields() - {
# Checked separately
"_task_type",
"_operator_name",
"subdag",
# Type is excluded, so don't check it
"_log",
# List vs tuple. Check separately
"template_ext",
"template_fields",
# We store the string, real dag has the actual code
"on_failure_callback",
"on_success_callback",
"on_retry_callback",
# Checked separately
"resources",
"params",
}
if isinstance(task, BaseOperator):
assert isinstance(serialized_task, SerializedBaseOperator)
fields_to_check = task.get_serialized_fields() - {
# Checked separately
"_task_type",
"_operator_name",
"subdag",
# Type is excluded, so don't check it
"_log",
# List vs tuple. Check separately
"template_ext",
"template_fields",
# We store the string, real dag has the actual code
"on_failure_callback",
"on_success_callback",
"on_retry_callback",
# Checked separately
"resources",
}
else: # Promised to be mapped by the assert above.
assert isinstance(serialized_task, MappedOperator)
fields_to_check = {f.name for f in attr.fields(MappedOperator)}
fields_to_check -= {
# Matching logic in BaseOperator.get_serialized_fields().
"dag",
"task_group",
# List vs tuple. Check separately.
"operator_extra_links",
"template_ext",
"template_fields",
# Checked separately.
"operator_class",
"partial_kwargs",
}

assert serialized_task.task_type == task.task_type

Expand All @@ -580,9 +597,25 @@ def validate_deserialized_task(
assert serialized_task.resources == task.resources

# Ugly hack as some operators override params var in their init
if isinstance(task.params, ParamsDict):
if isinstance(task.params, ParamsDict) and isinstance(serialized_task.params, ParamsDict):
assert serialized_task.params.dump() == task.params.dump()

if isinstance(task, MappedOperator):
# MappedOperator.operator_class holds a backup of the serialized
# data; checking its entirety basically duplicates this validation
# function, so we just do some satiny checks.
serialized_task.operator_class["_task_type"] == type(task).__name__
serialized_task.operator_class["_operator_name"] == task._operator_name

# Serialization cleans up default values in partial_kwargs, this
# adds them back to both sides.
default_partial_kwargs = (
BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY, strict=False).partial_kwargs
)
serialized_partial_kwargs = {**default_partial_kwargs, **serialized_task.partial_kwargs}
original_partial_kwargs = {**default_partial_kwargs, **task.partial_kwargs}
assert serialized_partial_kwargs == original_partial_kwargs

# Check that for Deserialized task, task.subdag is None for all other Operators
# except for the SubDagOperator where task.subdag is an instance of DAG object
if task.task_type == "SubDagOperator":
Expand Down