Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions airflow-core/src/airflow/listeners/spec/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,28 @@ def on_task_instance_failed(
error: None | str | BaseException,
):
"""Execute when task state changes to FAIL. previous_state can be None."""


@hookspec
def on_task_instance_skipped(
previous_state: TaskInstanceState | None,
task_instance: RuntimeTaskInstance | TaskInstance,
):
"""
Execute when a task instance skips itself during execution.

This hook is called only when a task has started execution and then
intentionally skips itself (e.g., by raising AirflowSkipException).

Note: This function will NOT cover tasks that were skipped by scheduler, before execution began, such as:
- Skips due to trigger rules (e.g., upstream failures)
- Skips from operators like BranchPythonOperator, ShortCircuitOperator, or similar mechanisms
- Any other situation in which the scheduler decides not to schedule a task for execution

For comprehensive tracking of skipped tasks, use DAG-level listeners
(on_dag_run_success/on_dag_run_failed) which may have access to all task states.

:param previous_state: Previous state of the task instance (can be None)
:param task_instance: The task instance object (RuntimeTaskInstance when called
from task execution context, TaskInstance when called from API server)
"""
6 changes: 6 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1456,6 +1456,12 @@ def finalize(
log.exception("error calling listener")
elif state == TaskInstanceState.SKIPPED:
_run_task_state_change_callbacks(task, "on_skipped_callback", context, log)
try:
get_listener_manager().hook.on_task_instance_skipped(
previous_state=TaskInstanceState.RUNNING, task_instance=ti
)
except Exception:
log.exception("error calling listener")
elif state == TaskInstanceState.UP_FOR_RETRY:
_run_task_state_change_callbacks(task, "on_retry_callback", context, log)
try:
Expand Down
35 changes: 35 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2996,6 +2996,10 @@ def on_task_instance_failed(self, previous_state, task_instance, error):
self.state.append(TaskInstanceState.FAILED)
self.error = error

@hookimpl
def on_task_instance_skipped(self, previous_state, task_instance):
self.state.append(TaskInstanceState.SKIPPED)

@hookimpl
def before_stopping(self, component):
self.component = component
Expand Down Expand Up @@ -3146,6 +3150,37 @@ def execute(self, context):
assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.FAILED]
assert listener.error == error

def test_task_runner_calls_listeners_skipped(self, mocked_parse, mock_supervisor_comms):
listener = self.CustomListener()
get_listener_manager().add_listener(listener)

class CustomOperator(BaseOperator):
def execute(self, context):
raise AirflowSkipException("Task intentionally skipped")

task = CustomOperator(
task_id="test_task_runner_calls_listeners_skipped", do_xcom_push=True, multiple_outputs=True
)
dag = get_inline_dag(dag_id="test_dag", task=task)
ti = TaskInstance(
id=uuid7(),
task_id=task.task_id,
dag_id=dag.dag_id,
run_id="test_run",
try_number=1,
dag_version_id=uuid7(),
)

runtime_ti = RuntimeTaskInstance.model_construct(
**ti.model_dump(exclude_unset=True), task=task, start_date=timezone.utcnow()
)
log = mock.MagicMock()
context = runtime_ti.get_template_context()
state, _, _ = run(runtime_ti, context, log)
finalize(runtime_ti, state, context, log)

assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SKIPPED]

def test_listener_access_outlet_event_on_running_and_success(self, mocked_parse, mock_supervisor_comms):
"""Test listener can access outlet events through invoking get_template_context() while task running and success"""
listener = self.CustomOutletEventsListener()
Expand Down