Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 0 additions & 30 deletions airflow-core/src/airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,19 +337,6 @@ def get_outlet_defs(self):
extended/overridden by subclasses.
"""

def pre_execute(self, context: Any):
"""Execute right before self.execute() is called."""
if self._pre_execute_hook is None:
return
from airflow.sdk.execution_time.callback_runner import create_executable_runner
from airflow.sdk.execution_time.context import context_get_outlet_events

create_executable_runner(
self._pre_execute_hook,
context_get_outlet_events(context),
logger=self.log,
).run(context)

def execute(self, context: Context) -> Any:
"""
Derive when creating an operator.
Expand All @@ -360,23 +347,6 @@ def execute(self, context: Context) -> Any:
"""
raise NotImplementedError()

def post_execute(self, context: Any, result: Any = None):
"""
Execute right after self.execute() is called.

It is passed the execution context and any results returned by the operator.
"""
if self._post_execute_hook is None:
return
from airflow.sdk.execution_time.callback_runner import create_executable_runner
from airflow.sdk.execution_time.context import context_get_outlet_events

create_executable_runner(
self._post_execute_hook,
context_get_outlet_events(context),
logger=self.log,
).run(context, result)

@provide_session
def clear(
self,
Expand Down
27 changes: 25 additions & 2 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2779,6 +2779,9 @@ def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, frozenset],

def _execute_task_with_callbacks(self, context: Context, test_mode: bool = False, *, session: Session):
"""Prepare Task for Execution."""
from airflow.sdk.execution_time.callback_runner import create_executable_runner
from airflow.sdk.execution_time.context import context_get_outlet_events

if TYPE_CHECKING:
assert self.task

Expand Down Expand Up @@ -2843,7 +2846,17 @@ def signal_handler(signum, frame):
)

# Run pre_execute callback
self.task.pre_execute(context=context)
if self.task._pre_execute_hook:
create_executable_runner(
self.task._pre_execute_hook,
context_get_outlet_events(context),
logger=self.log,
).run(context)
create_executable_runner(
self.task.pre_execute,
context_get_outlet_events(context),
logger=self.log,
).run(context)

# Run on_execute callback
self._run_execute_callback(context, self.task)
Expand Down Expand Up @@ -2877,7 +2890,17 @@ def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None)
self._rendered_map_index = _render_map_index(context, jinja_env=jinja_env)

# Run post_execute callback
self.task.post_execute(context=context, result=result)
if self.task._post_execute_hook:
create_executable_runner(
self.task._post_execute_hook,
context_get_outlet_events(context),
logger=self.log,
).run(context, result)
create_executable_runner(
self.task.post_execute,
context_get_outlet_events(context),
logger=self.log,
).run(context, result)

Stats.incr(f"operator_successes_{self.task.task_type}", tags=self.stats_tags)
# Same metric with tagging
Expand Down
17 changes: 0 additions & 17 deletions airflow-core/tests/unit/models/test_baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import copy
from collections import defaultdict
from datetime import datetime
from unittest import mock

import pytest

Expand Down Expand Up @@ -113,22 +112,6 @@ def test_baseoperator_with_task_id_less_than_250_chars(self):
except Exception as e:
pytest.fail(f"Exception raised: {e}")

def test_pre_execute_hook(self):
hook = mock.MagicMock()

op = BaseOperator(task_id="test_task", pre_execute=hook)
op_copy = op.prepare_for_execution()
op_copy.pre_execute({})
assert hook.called

def test_post_execute_hook(self):
hook = mock.MagicMock()

op = BaseOperator(task_id="test_task", post_execute=hook)
op_copy = op.prepare_for_execution()
op_copy.post_execute({})
assert hook.called

def test_task_naive_datetime(self):
naive_datetime = DEFAULT_DATE.replace(tzinfo=None)

Expand Down
4 changes: 4 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,8 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger):

if (pre_execute_hook := task._pre_execute_hook) is not None:
create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context)
if getattr(pre_execute_hook := task.pre_execute, "__func__", None) is not BaseOperator.pre_execute:
create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context)

_run_task_state_change_callbacks(task, "on_execute_callback", context, log)

Expand All @@ -993,6 +995,8 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger):

if (post_execute_hook := task._post_execute_hook) is not None:
create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result)
if getattr(post_execute_hook := task.post_execute, "__func__", None) is not BaseOperator.post_execute:
create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context)

return result

Expand Down