Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions task-sdk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dynamic = ["version"]
description = "Python Task SDK for Apache Airflow DAG Authors"
readme = { file = "README.md", content-type = "text/markdown" }
license-files.globs = ["LICENSE"]
requires-python = ">=3.9, <3.13"
requires-python = ">=3.10, <3.13"

authors = [
{name="Apache Software Foundation", email="dev@airflow.apache.org"},
Expand All @@ -38,7 +38,6 @@ classifiers = [
"Intended Audience :: System Administrators",
"Framework :: Apache Airflow",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
Expand Down Expand Up @@ -169,14 +168,14 @@ enum-field-as-literal='one' # When a single enum member, make it output a `Liter
input-file-type='openapi'
output-model-type='pydantic_v2.BaseModel'
output-datetime-class='AwareDatetime'
target-python-version='3.9'
target-python-version='3.10'
use-annotated=true
use-default=true
use-double-quotes=true
use-schema-description=true # Desc becomes class doc comment
use-standard-collections=true # list[] not List[]
use-subclass-enum=true # enum, not union of Literals
use-union-operator=true # 3.9+annotations, not `Union[]`
use-union-operator=true # annotations, not `Union[]`
custom-formatters = ['datamodel_code_formatter']

url = 'http://0.0.0.0:8080/execution/openapi.json'
Expand Down
6 changes: 3 additions & 3 deletions task-sdk/src/airflow/sdk/bases/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import re
import textwrap
import warnings
from collections.abc import Collection, Iterator, Mapping, Sequence
from collections.abc import Callable, Collection, Iterator, Mapping, Sequence
from functools import cached_property, update_wrapper
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Protocol, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, cast, overload

import attr
import typing_extensions
Expand Down Expand Up @@ -424,7 +424,7 @@ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool =
)
if isinstance(kwargs, Sequence):
for item in kwargs:
if not isinstance(item, (XComArg, Mapping)):
if not isinstance(item, XComArg | Mapping):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
elif not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
Expand Down
14 changes: 7 additions & 7 deletions task-sdk/src/airflow/sdk/bases/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import hashlib
import time
import traceback
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any

from airflow.configuration import conf
from airflow.exceptions import (
Expand Down Expand Up @@ -143,7 +143,7 @@ def __init__(
def _coerce_poke_interval(poke_interval: float | timedelta) -> timedelta:
if isinstance(poke_interval, timedelta):
return poke_interval
if isinstance(poke_interval, (int, float)) and poke_interval >= 0:
if isinstance(poke_interval, int | float) and poke_interval >= 0:
return timedelta(seconds=poke_interval)
raise AirflowException(
"Operator arg `poke_interval` must be timedelta object or a non-negative number"
Expand All @@ -153,22 +153,22 @@ def _coerce_poke_interval(poke_interval: float | timedelta) -> timedelta:
def _coerce_timeout(timeout: float | timedelta) -> timedelta:
if isinstance(timeout, timedelta):
return timeout
if isinstance(timeout, (int, float)) and timeout >= 0:
if isinstance(timeout, int | float) and timeout >= 0:
return timedelta(seconds=timeout)
raise AirflowException("Operator arg `timeout` must be timedelta object or a non-negative number")

@staticmethod
def _coerce_max_wait(max_wait: float | timedelta | None) -> timedelta | None:
if max_wait is None or isinstance(max_wait, timedelta):
return max_wait
if isinstance(max_wait, (int, float)) and max_wait >= 0:
if isinstance(max_wait, int | float) and max_wait >= 0:
return timedelta(seconds=max_wait)
raise AirflowException("Operator arg `max_wait` must be timedelta object or a non-negative number")

def _validate_input_values(self) -> None:
if not isinstance(self.poke_interval, (int, float)) or self.poke_interval < 0:
if not isinstance(self.poke_interval, int | float) or self.poke_interval < 0:
raise AirflowException("The poke_interval must be a non-negative number")
if not isinstance(self.timeout, (int, float)) or self.timeout < 0:
if not isinstance(self.timeout, int | float) or self.timeout < 0:
raise AirflowException("The timeout must be a non-negative number")
if self.mode not in self.valid_modes:
raise AirflowException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]:
for key, child in _walk_group(dag.task_group):
if key == self.node_id:
continue
if not isinstance(child, (MappedOperator, MappedTaskGroup)):
if not isinstance(child, MappedOperator | MappedTaskGroup):
continue
if self.node_id in child.upstream_task_ids:
yield child
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,21 @@ def __str__(self) -> str:
def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]:
from airflow.sdk.definitions.xcom_arg import XComArg

return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str)
return isinstance(v, MappedArgument | XComArg | Mapping | Sequence) and not isinstance(v, str)


# To replace tedious isinstance() checks.
def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]:
from airflow.sdk.definitions.xcom_arg import XComArg

return not isinstance(v, (MappedArgument, XComArg))
return not isinstance(v, MappedArgument | XComArg)


# To replace tedious isinstance() checks.
def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]:
from airflow.sdk.definitions.xcom_arg import XComArg

return isinstance(v, (MappedArgument, XComArg))
return isinstance(v, MappedArgument | XComArg)


@attrs.define(kw_only=True)
Expand Down
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/definitions/_internal/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _set_relatives(
task_object.update_relative(self, not upstream, edge_modifier=edge_modifier)
relatives = task_object.leaves if upstream else task_object.roots
for task in relatives:
if not isinstance(task, (BaseOperator, MappedOperator)):
if not isinstance(task, BaseOperator | MappedOperator):
raise TypeError(
f"Relationships can only be set between Operators; received {task.__class__.__name__}"
)
Expand Down
5 changes: 3 additions & 2 deletions task-sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
import os
import urllib.parse
import warnings
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, Union, overload
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, ClassVar, Literal, overload

import attrs

Expand Down Expand Up @@ -117,7 +118,7 @@ def to_asset_alias(self) -> AssetAlias:
return AssetAlias(name=self.name)


BaseAssetUniqueKey = Union[AssetUniqueKey, AssetAliasUniqueKey]
BaseAssetUniqueKey = AssetUniqueKey | AssetAliasUniqueKey


def normalize_noop(parts: SplitResult) -> SplitResult:
Expand Down
17 changes: 5 additions & 12 deletions task-sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@
import sys
import weakref
from collections import abc
from collections.abc import Collection, Iterable, MutableSet
from collections.abc import Callable, Collection, Iterable, MutableSet
from datetime import datetime, timedelta
from inspect import signature
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Union,
cast,
overload,
)
Expand Down Expand Up @@ -93,7 +91,7 @@
DagStateChangeCallback = Callable[[Context], None]
ScheduleInterval = None | str | timedelta | relativedelta

ScheduleArg = Union[ScheduleInterval, Timetable, BaseAsset, Collection[BaseAsset]]
ScheduleArg = ScheduleInterval | Timetable | BaseAsset | Collection[BaseAsset]


_DAG_HASH_ATTRS = frozenset(
Expand Down Expand Up @@ -124,7 +122,7 @@ def _create_timetable(interval: ScheduleInterval, timezone: Timezone | FixedTime
return OnceTimetable()
if interval == "@continuous":
return ContinuousTimetable()
if isinstance(interval, (timedelta, relativedelta)):
if isinstance(interval, timedelta | relativedelta):
if airflow_conf.getboolean("scheduler", "create_cron_data_intervals"):
return DeltaDataIntervalTimetable(interval)
return DeltaTriggerTimetable(interval)
Expand Down Expand Up @@ -809,7 +807,7 @@ def partial_subset(
direct_upstreams: list[Operator] = []
if include_direct_upstream:
for t in itertools.chain(matched_tasks, also_include):
upstream = (u for u in t.upstream_list if isinstance(u, (BaseOperator, MappedOperator)))
upstream = (u for u in t.upstream_list if isinstance(u, BaseOperator | MappedOperator))
direct_upstreams.extend(upstream)

# Make sure to not recursively deepcopy the dag or task_group while copying the task.
Expand Down Expand Up @@ -1284,12 +1282,7 @@ def _run_inline_trigger(trigger):
import asyncio

async def _run_inline_trigger_main():
# We can replace it with `return await anext(trigger.run(), default=None)`
# when we drop support for Python 3.9
try:
return await trigger.run().__anext__()
except StopAsyncIteration:
return None
return await anext(trigger.run(), None)

return asyncio.run(_run_inline_trigger_main())

Expand Down
3 changes: 2 additions & 1 deletion task-sdk/src/airflow/sdk/definitions/deadline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from __future__ import annotations

import logging
from collections.abc import Callable
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING

from airflow.models.deadline import ReferenceModels
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from typing import Callable
from collections.abc import Callable

from airflow.providers_manager import ProvidersManager
from airflow.sdk.bases.decorator import TaskDecorator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
# documentation for more details.
from __future__ import annotations

from collections.abc import Collection, Container, Iterable, Mapping
from collections.abc import Callable, Collection, Container, Iterable, Mapping
from datetime import timedelta
from typing import Any, Callable, TypeVar, overload
from typing import Any, TypeVar, overload

from docker.types import Mount
from kubernetes.client import models as k8s
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
# under the License.
from __future__ import annotations

from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar

from airflow.exceptions import AirflowSkipException
from airflow.sdk.bases.decorator import Task, _TaskDecorator

if TYPE_CHECKING:
from typing_extensions import TypeAlias
from typing import TypeAlias

from airflow.sdk.bases.operator import TaskPreExecuteHook
from airflow.sdk.definitions.context import Context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from __future__ import annotations

import types
from typing import TYPE_CHECKING, Callable
from collections.abc import Callable
from typing import TYPE_CHECKING

from airflow.exceptions import AirflowException
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
import functools
import inspect
import warnings
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, TypeVar, overload
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, overload

import attr

Expand Down Expand Up @@ -144,7 +144,7 @@ def expand(self, **kwargs: OperatorExpandArgument) -> DAGNode:
def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument) -> DAGNode:
if isinstance(kwargs, Sequence):
for item in kwargs:
if not isinstance(item, (XComArg, Mapping)):
if not isinstance(item, XComArg | Mapping):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
elif not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
Expand Down
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/definitions/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _save_nodes(
from airflow.sdk.definitions.xcom_arg import XComArg

for node in self._make_list(nodes):
if isinstance(node, (TaskGroup, XComArg, DAGNode)):
if isinstance(node, TaskGroup | XComArg | DAGNode):
stream.append(node)
else:
raise TypeError(
Expand Down
20 changes: 8 additions & 12 deletions task-sdk/src/airflow/sdk/definitions/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
import copy
import warnings
from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Union
from typing import TYPE_CHECKING, Any, ClassVar

import attrs
import methodtools

from airflow.models.abstractoperator import TaskStateChangeCallback
from airflow.sdk.definitions._internal.abstractoperator import (
DEFAULT_EXECUTOR,
DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
Expand Down Expand Up @@ -60,9 +61,6 @@
import jinja2 # Slow import.
import pendulum

from airflow.models.abstractoperator import (
TaskStateChangeCallback,
)
from airflow.models.expandinput import (
OperatorExpandArgument,
OperatorExpandKwargsArgument,
Expand All @@ -73,7 +71,6 @@
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.param import ParamsDict
from airflow.sdk.definitions.xcom_arg import XComArg
from airflow.sdk.types import Operator
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.triggers.base import StartTriggerArgs
from airflow.typing_compat import TypeGuard
Expand All @@ -82,9 +79,8 @@
from airflow.utils.task_group import TaskGroup
from airflow.utils.trigger_rule import TriggerRule

TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, list[TaskStateChangeCallback]]

ValidationSource = Union[Literal["expand"], Literal["partial"]]
TaskStateChangeCallbackAttrType = TaskStateChangeCallback | list[TaskStateChangeCallback] | None
ValidationSource = Literal["expand"] | Literal["partial"]


def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None:
Expand Down Expand Up @@ -144,9 +140,9 @@ def is_mappable_value(value: Any) -> TypeGuard[Collection]:

:meta private:
"""
if not isinstance(value, (Sequence, dict)):
if not isinstance(value, Sequence | dict):
return False
if isinstance(value, (bytearray, bytes, str)):
if isinstance(value, bytearray | bytes | str):
return False
return True

Expand Down Expand Up @@ -196,7 +192,7 @@ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool =

if isinstance(kwargs, Sequence):
for item in kwargs:
if not isinstance(item, (XComArg, Mapping)):
if not isinstance(item, XComArg | Mapping):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
elif not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
Expand Down Expand Up @@ -786,7 +782,7 @@ def prepare_for_execution(self) -> MappedOperator:
# we don't need to create a copy of the MappedOperator here.
return self

def iter_mapped_dependencies(self) -> Iterator[Operator]:
def iter_mapped_dependencies(self) -> Iterator[AbstractOperator]:
"""Upstream dependencies that provide XComs used by this task for task mapping."""
from airflow.sdk.definitions.xcom_arg import XComArg

Expand Down
Loading