Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
ConnectionHookMetaData,
StandardHookFields,
)
from airflow.sdk import Param
from airflow.serialization.definitions.param import SerializedParam

if TYPE_CHECKING:
from airflow.providers_manager import ConnectionFormWidgetInfo, HookInfo
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(
for v in validators:
if isinstance(v, HookMetaService.MockEnum):
enum = {"enum": v.allowed_values}
self.param = Param(
self.param = SerializedParam(
default=default,
title=label,
description=description or None,
Expand Down
3 changes: 2 additions & 1 deletion airflow-core/src/airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from sqlalchemy.orm import Session

from airflow import DAG
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.timetables.base import DataInterval

DAG_DETAIL_FIELDS = {*DAGResponse.model_fields, *DAGResponse.model_computed_fields}
Expand Down Expand Up @@ -656,7 +657,7 @@ def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> No
)
).all()

dot_graph = render_dag(dag, tis=list(tis))
dot_graph = render_dag(cast("SerializedDAG", dag), tis=list(tis))
print()
if filename:
_save_dot_to_file(dot_graph, filename)
Expand Down
3 changes: 2 additions & 1 deletion airflow-core/src/airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.models.base import ID_LEN, Base
from airflow.models.crypto import get_fernet
from airflow.sdk import SecretCache
from airflow.utils.helpers import prune_dict
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
Expand Down Expand Up @@ -519,6 +518,8 @@ def get_connection_from_secrets(cls, conn_id: str) -> Connection:

# check cache first
# enabled only if SecretCache.init() has been called first
from airflow.sdk import SecretCache

try:
uri = SecretCache.get_connection_uri(conn_id)
return Connection(conn_id=conn_id, uri=uri)
Expand Down
26 changes: 13 additions & 13 deletions airflow-core/src/airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
from collections import defaultdict
from collections.abc import Callable, Collection
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Union, cast
from typing import TYPE_CHECKING, Any, cast

import pendulum
import sqlalchemy_jsonfield
from dateutil.relativedelta import relativedelta
from sqlalchemy import (
Boolean,
Float,
Expand Down Expand Up @@ -61,7 +60,6 @@
from airflow.timetables.base import DataInterval, Timetable
from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable
from airflow.timetables.simple import AssetTriggeredTimetable, NullTimetable, OnceTimetable
from airflow.utils.context import Context
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime, mapped_column, with_row_locks
from airflow.utils.state import DagRunState
Expand All @@ -70,6 +68,9 @@
if TYPE_CHECKING:
from typing import TypeAlias

from dateutil.relativedelta import relativedelta

from airflow.sdk import Context
from airflow.serialization.definitions.assets import (
SerializedAsset,
SerializedAssetAlias,
Expand All @@ -78,21 +79,20 @@
from airflow.serialization.definitions.dag import SerializedDAG

UKey: TypeAlias = SerializedAssetUniqueKey
DagStateChangeCallback = Callable[[Context], None]
ScheduleInterval = None | str | timedelta | relativedelta

ScheduleArg = (
ScheduleInterval
| Timetable
| "SerializedAssetBase"
| Collection["SerializedAsset" | "SerializedAssetAlias"]
)

log = logging.getLogger(__name__)

TAG_MAX_LEN = 100

DagStateChangeCallback = Callable[[Context], None]
ScheduleInterval = None | str | timedelta | relativedelta

ScheduleArg = Union[
ScheduleInterval,
Timetable,
"SerializedAssetBase",
Collection[Union["SerializedAsset", "SerializedAssetAlias"]],
]


def infer_automated_data_interval(timetable: Timetable, logical_date: datetime) -> DataInterval:
"""
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@
from airflow.api_fastapi.execution_api.datamodels.asset import AssetProfile
from airflow.models.dag import DagModel
from airflow.models.dagrun import DagRun
from airflow.sdk import Context
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.serialization.definitions.mappedoperator import Operator
from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
from airflow.utils.context import Context


PAST_DEPENDS_MET = "past_depends_met"
Expand Down
7 changes: 6 additions & 1 deletion airflow-core/src/airflow/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from airflow.configuration import conf, ensure_secrets_loaded
from airflow.models.base import ID_LEN, Base
from airflow.models.crypto import get_fernet
from airflow.sdk import SecretCache
from airflow.secrets.metastore import MetastoreBackend
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, create_session, provide_session
Expand Down Expand Up @@ -238,6 +237,8 @@ def set(
)

# check if the secret exists in the custom secrets' backend.
from airflow.sdk import SecretCache

Variable.check_for_write_conflict(key=key)
if serialize_json:
stored_value = json.dumps(value, indent=2)
Expand Down Expand Up @@ -428,6 +429,8 @@ def delete(key: str, team_name: str | None = None, session: Session | None = Non
"Multi-team mode is not configured in the Airflow environment but the task trying to delete the variable belongs to a team"
)

from airflow.sdk import SecretCache

ctx: contextlib.AbstractContextManager
if session is not None:
ctx = contextlib.nullcontext(session)
Expand Down Expand Up @@ -494,6 +497,8 @@ def get_variable_from_secrets(key: str, team_name: str | None = None) -> str | N
:param team_name: Team name associated to the task trying to access the variable (if any)
:return: Variable Value
"""
from airflow.sdk import SecretCache

# Disable cache if the variable belongs to a team. We might enable it later
if not team_name:
# check cache first
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

from airflow.exceptions import AirflowException, NotMapped
from airflow.sdk import BaseOperator as TaskSDKBaseOperator
from airflow.sdk.definitions._internal.abstractoperator import DEFAULT_RETRY_DELAY_MULTIPLIER
from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator
from airflow.serialization.definitions.baseoperator import DEFAULT_OPERATOR_DEPS, SerializedBaseOperator
from airflow.serialization.definitions.node import DAGNode
Expand Down Expand Up @@ -288,10 +287,6 @@ def retry_exponential_backoff(self) -> float:
def max_retry_delay(self) -> datetime.timedelta | float | None:
return self._get_partial_kwargs_or_operator_default("max_retry_delay")

@property
def retry_delay_multiplier(self) -> float:
return float(self.partial_kwargs.get("retry_delay_multiplier", DEFAULT_RETRY_DELAY_MULTIPLIER))

@property
def weight_rule(self) -> PriorityWeightStrategy:
from airflow.serialization.definitions.baseoperator import SerializedBaseOperator
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from airflow._shared.timezones import timezone
from airflow.dag_processing.bundles.manager import DagBundlesManager
from airflow.exceptions import AirflowException
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
from airflow.utils import cli_action_loggers
from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler
from airflow.utils.platform import getuser, is_terminal_support_colors
Expand Down Expand Up @@ -274,6 +273,7 @@ def get_bagged_dag(bundle_names: list | None, dag_id: str, dagfile_path: str | N
dags folder.
"""
from airflow.dag_processing.dagbag import DagBag, sync_bag_to_db
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager

manager = DagBundlesManager()
for bundle_name in bundle_names or ():
Expand Down
44 changes: 15 additions & 29 deletions airflow-core/src/airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,22 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, cast
import warnings
from typing import Any

from sqlalchemy import select

from airflow.models.asset import AssetModel
from airflow.sdk import Asset, Context
from airflow.sdk import Asset
from airflow.sdk.execution_time.context import (
ConnectionAccessor as ConnectionAccessorSDK,
OutletEventAccessors as OutletEventAccessorsSDK,
VariableAccessor as VariableAccessorSDK,
)
from airflow.serialization.definitions.notset import NOTSET, is_arg_set
from airflow.utils.deprecation_tools import DeprecatedImportWarning
from airflow.utils.session import create_session

if TYPE_CHECKING:
from collections.abc import Container

# NOTE: Please keep this in sync with the following:
# * Context in task-sdk/src/airflow/sdk/definitions/context.py
# * Table in docs/apache-airflow/templates-ref.rst
Expand Down Expand Up @@ -141,30 +140,17 @@ def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset
return Asset(name=asset.name, uri=asset.uri, group=asset.group, extra=asset.extra)


def context_merge(context: Context, *args: Any, **kwargs: Any) -> None:
"""
Merge parameters into an existing context.

Like ``dict.update()`` , this take the same parameters, and updates
``context`` in-place.

This is implemented as a free function because the ``Context`` type is
"faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
functions.
def __getattr__(name: str):
if name in ("Context", "context_copy_partial", "context_merge"):
warnings.warn(
"Importing Context from airflow.utils.context is deprecated and will "
"be removed in the future. Please import it from airflow.sdk instead.",
DeprecatedImportWarning,
stacklevel=2,
)

:meta private:
"""
if not context:
context = Context()
import airflow.sdk.definitions.context as sdk

context.update(*args, **kwargs)
return getattr(sdk, name)


def context_copy_partial(source: Context, keys: Container[str]) -> Context:
"""
Create a context by copying items under selected keys in ``source``.

:meta private:
"""
new = {k: v for k, v in source.items() if k in keys}
return cast("Context", new)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
28 changes: 15 additions & 13 deletions airflow-core/src/airflow/utils/dag_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,23 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Any

from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
from airflow.serialization.definitions.baseoperator import SerializedBaseOperator
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.serialization.definitions.mappedoperator import SerializedMappedOperator
from airflow.serialization.definitions.taskgroup import SerializedTaskGroup

# Also support SDK types if possible.
try:
from airflow.sdk import TaskGroup
except ImportError:
TaskGroup = SerializedTaskGroup # type: ignore[misc]

if TYPE_CHECKING:
from airflow.sdk import DAG
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.serialization.definitions.mappedoperator import Operator
from airflow.serialization.definitions.node import DAGNode


def dag_edges(dag: DAG | SerializedDAG):
def dag_edges(dag: SerializedDAG):
"""
Create the list of edges needed to construct the Graph view.

Expand Down Expand Up @@ -62,9 +65,10 @@ def dag_edges(dag: DAG | SerializedDAG):

task_group_map = dag.task_group.get_task_group_dict()

def collect_edges(task_group):
def collect_edges(task_group: DAGNode) -> None:
"""Update edges_to_add and edges_to_skip according to TaskGroups."""
if isinstance(task_group, (AbstractOperator, SerializedBaseOperator, SerializedMappedOperator)):
child: DAGNode
if not isinstance(task_group, (TaskGroup, SerializedTaskGroup)):
return

for target_id in task_group.downstream_group_ids:
Expand Down Expand Up @@ -111,9 +115,7 @@ def collect_edges(task_group):
edges = set()
setup_teardown_edges = set()

# TODO (GH-52141): 'roots' in scheduler needs to return scheduler types
# instead, but currently it inherits SDK's DAG.
tasks_to_trace = cast("list[Operator]", dag.roots)
tasks_to_trace = dag.roots
while tasks_to_trace:
tasks_to_trace_next: list[Operator] = []
for task in tasks_to_trace:
Expand All @@ -130,7 +132,7 @@ def collect_edges(task_group):
# Build result dicts with the two ends of the edge, plus any extra metadata
# if we have it.
for source_id, target_id in sorted(edges.union(edges_to_add) - edges_to_skip):
record = {"source_id": source_id, "target_id": target_id}
record: dict[str, Any] = {"source_id": source_id, "target_id": target_id}
label = dag.get_edge_info(source_id, target_id).get("label")
if (source_id, target_id) in setup_teardown_edges:
record["is_setup_teardown"] = True
Expand Down
Loading
Loading