Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/decorators/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
if TYPE_CHECKING:
from typing_extensions import TypeAlias

from airflow.models.baseoperator import TaskPreExecuteHook
from airflow.sdk.definitions.baseoperator import TaskPreExecuteHook
from airflow.sdk.definitions.context import Context

BoolConditionFunc: TypeAlias = Callable[[Context], bool]
Expand Down
38 changes: 11 additions & 27 deletions airflow-core/src/airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,7 @@
from collections.abc import Collection, Iterable, Iterator
from datetime import datetime, timedelta
from functools import singledispatchmethod
from types import FunctionType
from typing import (
TYPE_CHECKING,
Any,
Callable,
TypeVar,
)
from typing import TYPE_CHECKING, Any

import methodtools
import pendulum
Expand All @@ -62,7 +56,6 @@
cross_downstream as cross_downstream,
get_merged_defaults as get_merged_defaults,
)
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.dag import BaseOperator as TaskSDKBaseOperator
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
Expand All @@ -73,8 +66,6 @@
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.utils import timezone
from airflow.utils.context import context_get_outlet_events
from airflow.utils.operator_helpers import ExecutionCallableRunner
from airflow.utils.operator_resources import Resources
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState
Expand All @@ -86,16 +77,11 @@

from airflow.models.dag import DAG as SchedulerDAG
from airflow.models.operator import Operator
from airflow.sdk import BaseOperatorLink
from airflow.sdk import BaseOperatorLink, Context
from airflow.sdk.definitions._internal.node import DAGNode
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.triggers.base import StartTriggerArgs

TaskPreExecuteHook = Callable[[Context], None]
TaskPostExecuteHook = Callable[[Context, Any], None]

T = TypeVar("T", bound=FunctionType)

logger = logging.getLogger("airflow.models.baseoperator.BaseOperator")


Expand Down Expand Up @@ -338,20 +324,12 @@ def say_hello_world(**context):
start_trigger_args: StartTriggerArgs | None = None
start_from_trigger: bool = False

def __init__(
self,
pre_execute=None,
post_execute=None,
**kwargs,
):
def __init__(self, **kwargs):
if start_date := kwargs.get("start_date", None):
kwargs["start_date"] = timezone.convert_to_utc(start_date)

if end_date := kwargs.get("end_date", None):
kwargs["end_date"] = timezone.convert_to_utc(end_date)
super().__init__(**kwargs)
self._pre_execute_hook = pre_execute
self._post_execute_hook = post_execute

# Defines the operator level extra links
operator_extra_links: Collection[BaseOperatorLink] = ()
Expand Down Expand Up @@ -411,7 +389,10 @@ def pre_execute(self, context: Any):
"""Execute right before self.execute() is called."""
if self._pre_execute_hook is None:
return
ExecutionCallableRunner(
from airflow.sdk.execution_time.callback_runner import create_executable_runner
from airflow.sdk.execution_time.context import context_get_outlet_events

create_executable_runner(
self._pre_execute_hook,
context_get_outlet_events(context),
logger=self.log,
Expand All @@ -436,7 +417,10 @@ def post_execute(self, context: Any, result: Any = None):
"""
if self._post_execute_hook is None:
return
ExecutionCallableRunner(
from airflow.sdk.execution_time.callback_runner import create_executable_runner
from airflow.sdk.execution_time.context import context_get_outlet_events

create_executable_runner(
self._post_execute_hook,
context_get_outlet_events(context),
logger=self.log,
Expand Down
6 changes: 3 additions & 3 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@
from airflow.utils.helpers import prune_dict, render_template_to_string
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.operator_helpers import ExecutionCallableRunner
from airflow.utils.platform import getuser
from airflow.utils.retries import run_with_db_retries
from airflow.utils.session import NEW_SESSION, create_session, provide_session
Expand Down Expand Up @@ -640,13 +639,14 @@ def _execute_task(task_instance: TaskInstance, context: Context, task_orig: Oper
)

def _execute_callable(context: Context, **execute_callable_kwargs):
from airflow.utils.context import context_get_outlet_events
from airflow.sdk.execution_time.callback_runner import create_executable_runner
from airflow.sdk.execution_time.context import context_get_outlet_events

try:
# Print a marker for log grouping of details before task execution
log.info("::endgroup::")

return ExecutionCallableRunner(
return create_executable_runner(
execute_callable,
context_get_outlet_events(context),
logger=log,
Expand Down
8 changes: 0 additions & 8 deletions airflow-core/src/airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@

if TYPE_CHECKING:
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.types import OutletEventAccessorsProtocol

# NOTE: Please keep this in sync with the following:
# * Context in task-sdk/src/airflow/sdk/definitions/context.py
Expand Down Expand Up @@ -176,10 +175,3 @@ def context_copy_partial(source: Context, keys: Container[str]) -> Context:
"""
new = {k: v for k, v in source.items() if k in keys}
return cast(Context, new)


def context_get_outlet_events(context: Context) -> OutletEventAccessorsProtocol:
try:
return context["outlet_events"]
except KeyError:
return OutletEventAccessors()
73 changes: 1 addition & 72 deletions airflow-core/src/airflow/utils/operator_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,9 @@
from __future__ import annotations

import inspect
import logging
from collections.abc import Collection, Mapping
from typing import TYPE_CHECKING, Any, Callable, Protocol, TypeVar
from typing import Any, Callable, TypeVar

from airflow.typing_compat import ParamSpec
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from airflow.sdk.types import OutletEventAccessorsProtocol

P = ParamSpec("P")
R = TypeVar("R")


Expand Down Expand Up @@ -58,7 +50,6 @@ def determine(
args: Collection[Any],
kwargs: Mapping[str, Any],
) -> KeywordParameters:
import inspect
import itertools

signature = inspect.signature(func)
Expand Down Expand Up @@ -119,65 +110,3 @@ def kwargs_func(*args, **kwargs):
return func(*args, **kwargs)

return kwargs_func


class _ExecutionCallableRunner(Protocol):
@staticmethod
def run(*args, **kwargs): ...


def ExecutionCallableRunner(
func: Callable[P, R],
outlet_events: OutletEventAccessorsProtocol,
*,
logger: logging.Logger,
) -> _ExecutionCallableRunner:
"""
Run an execution callable against a task context and given arguments.

If the callable is a simple function, this simply calls it with the supplied
arguments (including the context). If the callable is a generator function,
the generator is exhausted here, with the yielded values getting fed back
into the task context automatically for execution.

This convoluted implementation of inner class with closure is so *all*
arguments passed to ``run()`` can be forwarded to the wrapped function. This
is particularly important for the argument "self", which some use cases
need to receive. This is not possible if this is implemented as a normal
class, where "self" needs to point to the ExecutionCallableRunner object.

The function name violates PEP 8 due to backward compatibility. This was
implemented as a class previously.

:meta private:
"""

class _ExecutionCallableRunnerImpl:
@staticmethod
def run(*args: P.args, **kwargs: P.kwargs) -> R:
from airflow.sdk.definitions.asset.metadata import Metadata

if not inspect.isgeneratorfunction(func):
return func(*args, **kwargs)

result: Any = NOTSET

def _run():
nonlocal result
result = yield from func(*args, **kwargs)

for metadata in _run():
if isinstance(metadata, Metadata):
outlet_events[metadata.asset].extra.update(metadata.extra)

if metadata.alias:
outlet_events[metadata.alias].add(metadata.asset, extra=metadata.extra)

continue
logger.warning("Ignoring unknown data of %r received from task", type(metadata))
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Full yielded value: %r", metadata)

return result

return _ExecutionCallableRunnerImpl
Original file line number Diff line number Diff line change
Expand Up @@ -1350,6 +1350,8 @@ def test_no_new_fields_added_to_base_operator(self):
assert fields == {
"_logger_name": None,
"_needs_expansion": None,
"_post_execute_hook": None,
"_pre_execute_hook": None,
"_task_display_name": None,
"allow_nested_operators": True,
"depends_on_past": False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@

from pendulum.datetime import DateTime

from airflow.sdk.execution_time.callback_runner import ExecutionCallableRunner
from airflow.sdk.execution_time.context import OutletEventAccessorsProtocol

try:
from airflow.sdk.definitions.context import Context
except ImportError:
# TODO: Remove once provider drops support for Airflow 2
except ImportError: # TODO: Remove once provider drops support for Airflow 2
from airflow.utils.context import Context

_SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"]
Expand Down Expand Up @@ -190,14 +192,22 @@ def execute(self, context: Context) -> Any:
context_merge(context, self.op_kwargs, templates_dict=self.templates_dict)
self.op_kwargs = self.determine_kwargs(context)

if AIRFLOW_V_3_0_PLUS:
from airflow.utils.context import context_get_outlet_events
# This needs to be lazy because subclasses may implement execute_callable
# by running a separate process that can't use the eager result.
def __prepare_execution() -> tuple[ExecutionCallableRunner, OutletEventAccessorsProtocol] | None:
if AIRFLOW_V_3_0_PLUS:
from airflow.sdk.execution_time.callback_runner import create_executable_runner
from airflow.sdk.execution_time.context import context_get_outlet_events

return create_executable_runner, context_get_outlet_events(context)
if AIRFLOW_V_2_10_PLUS:
from airflow.utils.context import context_get_outlet_events # type: ignore
from airflow.utils.operator_helpers import ExecutionCallableRunner # type: ignore

self._asset_events = context_get_outlet_events(context)
elif AIRFLOW_V_2_10_PLUS:
from airflow.utils.context import context_get_outlet_events
return ExecutionCallableRunner, context_get_outlet_events(context)
return None

self._dataset_events = context_get_outlet_events(context)
self.__prepare_execution = __prepare_execution

return_value = self.execute_callable()
if self.show_return_value_in_logs:
Expand All @@ -210,19 +220,18 @@ def execute(self, context: Context) -> Any:
def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
return KeywordParameters.determine(self.python_callable, self.op_args, context).unpacking()

__prepare_execution: Callable[[], tuple[ExecutionCallableRunner, OutletEventAccessorsProtocol] | None]

def execute_callable(self) -> Any:
"""
Call the python callable with the given arguments.

:return: the return value of the call.
"""
try:
from airflow.utils.operator_helpers import ExecutionCallableRunner
except ImportError:
# Handle Pre Airflow 2.10 case where ExecutionCallableRunner was not available
if (execution_preparation := self.__prepare_execution()) is None:
return self.python_callable(*self.op_args, **self.op_kwargs)
asset_events = self._asset_events if AIRFLOW_V_3_0_PLUS else self._dataset_events
runner = ExecutionCallableRunner(self.python_callable, asset_events, logger=self.log)
create_execution_runner, asset_events = execution_preparation
runner = create_execution_runner(self.python_callable, asset_events, logger=self.log)
return runner.run(*self.op_args, **self.op_kwargs)


Expand Down
16 changes: 9 additions & 7 deletions task-sdk/src/airflow/sdk/definitions/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@
from airflow.typing_compat import Self
from airflow.utils.operator_resources import Resources

TaskPreExecuteHook = Callable[[Context], None]
TaskPostExecuteHook = Callable[[Context, Any], None]

__all__ = [
"BaseOperator",
"chain",
Expand Down Expand Up @@ -822,8 +825,8 @@ def say_hello_world(**context):
on_success_callback: Sequence[TaskStateChangeCallback] = ()
on_retry_callback: Sequence[TaskStateChangeCallback] = ()
on_skipped_callback: Sequence[TaskStateChangeCallback] = ()
# pre_execute: TaskPreExecuteHook | None = None
# post_execute: TaskPostExecuteHook | None = None
_pre_execute_hook: TaskPreExecuteHook | None = None
_post_execute_hook: TaskPostExecuteHook | None = None
trigger_rule: TriggerRule = DEFAULT_TRIGGER_RULE
resources: dict[str, Any] | None = None
run_as_user: str | None = None
Expand Down Expand Up @@ -981,8 +984,8 @@ def __init__(
on_success_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None,
on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_skipped_callback: None | TaskStateChangeCallback | Collection[TaskStateChangeCallback] = None,
# pre_execute: TaskPreExecuteHook | None = None,
# post_execute: TaskPostExecuteHook | None = None,
pre_execute: TaskPreExecuteHook | None = None,
post_execute: TaskPostExecuteHook | None = None,
trigger_rule: str = DEFAULT_TRIGGER_RULE,
resources: dict[str, Any] | None = None,
run_as_user: str | None = None,
Expand Down Expand Up @@ -1053,14 +1056,13 @@ def __init__(
)
self.execution_timeout = execution_timeout

# TODO:
self.on_execute_callback = _collect_callbacks(on_execute_callback)
self.on_failure_callback = _collect_callbacks(on_failure_callback)
self.on_success_callback = _collect_callbacks(on_success_callback)
self.on_retry_callback = _collect_callbacks(on_retry_callback)
self.on_skipped_callback = _collect_callbacks(on_skipped_callback)
# self._pre_execute_hook = pre_execute
# self._post_execute_hook = post_execute
self._pre_execute_hook = pre_execute
self._post_execute_hook = post_execute

if start_date:
self.start_date = timezone.convert_to_utc(start_date)
Expand Down
Loading