Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,6 @@ def say_hello_world(**context):
start_trigger_args: StartTriggerArgs | None = None
start_from_trigger: bool = False

on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None
on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None
on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None
on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None
Expand All @@ -349,7 +348,6 @@ def __init__(
self,
pre_execute=None,
post_execute=None,
on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
Expand All @@ -364,7 +362,6 @@ def __init__(
super().__init__(**kwargs)
self._pre_execute_hook = pre_execute
self._post_execute_hook = post_execute
self.on_execute_callback = on_execute_callback
self.on_failure_callback = on_failure_callback
self.on_success_callback = on_success_callback
self.on_skipped_callback = on_skipped_callback
Expand Down Expand Up @@ -393,7 +390,6 @@ def get_serialized_fields(cls):
return TaskSDKBaseOperator.get_serialized_fields() | {
"start_trigger_args",
"start_from_trigger",
"on_execute_callback",
"on_failure_callback",
"on_success_callback",
"on_retry_callback",
Expand Down
13 changes: 10 additions & 3 deletions task-sdk/src/airflow/sdk/definitions/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import db_safe_priority

C = TypeVar("C", bound=Callable)
T = TypeVar("T", bound=FunctionType)

if TYPE_CHECKING:
Expand Down Expand Up @@ -382,6 +383,12 @@ def wrapper(self, *args, **kwargs):
ExecutorSafeguard.test_mode = conf.getboolean("core", "unit_test_mode")


def _collect_callbacks(callbacks: C | Collection[C]) -> list[C]:
if isinstance(callbacks, Collection):
return list(callbacks)
return [callbacks]


class BaseOperatorMeta(abc.ABCMeta):
"""Metaclass of BaseOperator."""

Expand Down Expand Up @@ -805,7 +812,7 @@ def say_hello_world(**context):
pool: str = DEFAULT_POOL_NAME
pool_slots: int = DEFAULT_POOL_SLOTS
execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT
# on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None
on_execute_callback: Sequence[TaskStateChangeCallback] = ()
# on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None
# on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None
# on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None
Expand Down Expand Up @@ -959,7 +966,7 @@ def __init__(
pool_slots: int = DEFAULT_POOL_SLOTS,
sla: timedelta | None = None,
execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
# on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_execute_callback: TaskStateChangeCallback | Collection[TaskStateChangeCallback] = (),
# on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
# on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
# on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
Expand Down Expand Up @@ -1037,7 +1044,7 @@ def __init__(
self.execution_timeout = execution_timeout

# TODO:
# self.on_execute_callback = on_execute_callback
self.on_execute_callback = _collect_callbacks(on_execute_callback)
# self.on_failure_callback = on_failure_callback
# self.on_success_callback = on_success_callback
# self.on_retry_callback = on_retry_callback
Expand Down
10 changes: 8 additions & 2 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def run(
state = TerminalTIState.FAILED
return state, msg, error

result = _execute_task(context, ti)
result = _execute_task(context, ti, log)

_push_xcom_if_needed(result, ti, log)

Expand Down Expand Up @@ -789,7 +789,7 @@ def _handle_trigger_dag_run(
return msg, state


def _execute_task(context: Context, ti: RuntimeTaskInstance):
def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger):
"""Execute Task (optionally with a Timeout) and push Xcom results."""
from airflow.exceptions import AirflowTaskTimeout

Expand All @@ -807,6 +807,12 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance):
# Populate the context var so ExecutorSafeguard doesn't complain
ctx.run(ExecutorSafeguard.tracker.set, task)

for i, callback in enumerate(task.on_execute_callback):
try:
callback(context)
except Exception:
log.exception("Failed to run on-execute callback", index=i, callback=callback)

if task.execution_timeout:
# TODO: handle timeout in case of deferral
from airflow.utils.timeout import timeout
Expand Down
45 changes: 45 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1928,6 +1928,51 @@ def execute(self, context):
assert listener.error == error


@pytest.mark.usefixtures("mock_supervisor_comms")
class TestTaskRunnerCallsCallbacks:
def test_task_runner_calls_execute_callback(self, create_runtime_ti):
results = []

def custom_callback(context):
results.append("callback")

class CustomOperator(BaseOperator):
def execute(self, context):
results.append("execute")

task = CustomOperator(task_id="task", on_execute_callback=custom_callback)
runtime_ti = create_runtime_ti(dag_id="dag", task=task)
log = mock.MagicMock()
state, _, _ = run(runtime_ti, log)

assert state == TerminalTIState.SUCCESS
assert results == ["callback", "execute"]

def test_task_runner_not_fail_on_failed_execute_callback(self, create_runtime_ti):
results = []

def custom_callback_1(context):
results.append("callback 1")

def custom_callback_2(context):
raise Exception("sorry!")

class CustomOperator(BaseOperator):
def execute(self, context):
results.append("execute")

task = CustomOperator(task_id="task", on_execute_callback=[custom_callback_1, custom_callback_2])
runtime_ti = create_runtime_ti(dag_id="dag", task=task)
log = mock.MagicMock()
state, _, _ = run(runtime_ti, log)

assert state == TerminalTIState.SUCCESS
assert results == ["callback 1", "execute"]
assert log.exception.mock_calls == [
mock.call("Failed to run on-execute callback", index=1, callback=custom_callback_2),
]


class TestTriggerDagRunOperator:
"""Tests to verify various aspects of TriggerDagRunOperator"""

Expand Down
2 changes: 1 addition & 1 deletion tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,7 +1357,7 @@ def test_no_new_fields_added_to_base_operator(self):
"max_active_tis_per_dag": None,
"max_active_tis_per_dagrun": None,
"max_retry_delay": None,
"on_execute_callback": None,
"on_execute_callback": [],
"on_failure_fail_dagrun": False,
"on_failure_callback": None,
"on_retry_callback": None,
Expand Down