Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,18 @@ def on_celery_import_modules(*args, **kwargs):
# and deserialization for us
@app.task(name="execute_workload")
def execute_workload(input: str) -> None:
from celery.exceptions import Ignore
from pydantic import TypeAdapter

from airflow.configuration import conf
from airflow.executors import workloads
from airflow.sdk.execution_time.supervisor import supervise

try:
from airflow.sdk.exceptions import TaskAlreadyRunningError
except ImportError:
TaskAlreadyRunningError = None # type: ignore[misc,assignment]

decoder = TypeAdapter[workloads.All](workloads.All)
workload = decoder.validate_json(input)

Expand All @@ -159,15 +165,21 @@ def execute_workload(input: str) -> None:
base_url = f"http://localhost:8080{base_url}"
default_execution_api_server = f"{base_url.rstrip('/')}/execution/"

supervise(
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
ti=workload.ti, # type: ignore[arg-type]
dag_rel_path=workload.dag_rel_path,
bundle_info=workload.bundle_info,
token=workload.token,
server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server),
log_path=workload.log_path,
)
try:
supervise(
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
ti=workload.ti, # type: ignore[arg-type]
dag_rel_path=workload.dag_rel_path,
bundle_info=workload.bundle_info,
token=workload.token,
server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server),
log_path=workload.log_path,
)
except Exception as e:
if TaskAlreadyRunningError is not None and isinstance(e, TaskAlreadyRunningError):
log.info("[%s] Task already running elsewhere, ignoring redelivered message", celery_task_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a new feature or did something like this exist in Airflow 2.x?

raise Ignore()
raise


if not AIRFLOW_V_3_0_PLUS:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,3 +558,54 @@ def test_result_backend_sentinel_full_config():
result_backend_opts = default_celery.DEFAULT_CELERY_CONFIG["result_backend_transport_options"]
assert result_backend_opts["sentinel_kwargs"] == {"password": "redis_pass"}
assert result_backend_opts["master_name"] == "mymaster"


@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="execute_workload only exists in Airflow 3.0+")
def test_execute_workload_ignores_already_running_task():
"""Test that execute_workload raises Celery Ignore when task is already running."""
import importlib

from celery.exceptions import Ignore

try:
from airflow.sdk.exceptions import TaskAlreadyRunningError
except ImportError:
pytest.skip("TaskAlreadyRunningError not available in this Airflow version")

importlib.reload(celery_executor_utils)
execute_workload_unwrapped = celery_executor_utils.execute_workload.__wrapped__

mock_current_task = mock.MagicMock()
mock_current_task.request.id = "test-celery-task-id"
mock_app = mock.MagicMock()
mock_app.current_task = mock_current_task

with (
mock.patch("airflow.sdk.execution_time.supervisor.supervise") as mock_supervise,
mock.patch.object(celery_executor_utils, "app", mock_app),
):
mock_supervise.side_effect = TaskAlreadyRunningError("Task already running")

workload_json = """
{
"type": "ExecuteTask",
"token": "test-token",
"dag_rel_path": "test_dag.py",
"bundle_info": {"name": "test-bundle", "version": null},
"log_path": "test.log",
"ti": {
"id": "019bdec0-d353-7b68-abe0-5ac20fa75ad0",
"dag_version_id": "019bdead-fdcd-78ab-a9f2-aba3b80fded2",
"task_id": "test_task",
"dag_id": "test_dag",
"run_id": "test_run",
"try_number": 1,
"map_index": -1,
"pool_slots": 1,
"queue": "default",
"priority_weight": 1
}
}
"""
with pytest.raises(Ignore):
execute_workload_unwrapped(workload_json)
4 changes: 4 additions & 0 deletions task-sdk/src/airflow/sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,10 @@ class TaskNotFound(AirflowException):
"""Raise when a Task is not available in the system."""


class TaskAlreadyRunningError(AirflowException):
"""Raise when a task is already running (e.g., broker redelivered the message)."""


class FailFastDagInvalidTriggerRule(AirflowException):
"""Raise when a dag has 'fail_fast' enabled yet has a non-default trigger rule."""

Expand Down
22 changes: 20 additions & 2 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
XComSequenceIndexResponse,
)
from airflow.sdk.configuration import conf
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.exceptions import ErrorType, TaskAlreadyRunningError
from airflow.sdk.execution_time import comms
from airflow.sdk.execution_time.comms import (
AssetEventsResult,
Expand Down Expand Up @@ -1007,9 +1007,23 @@ def _on_child_started(
ti_context = self.client.task_instances.start(ti.id, self.pid, start_date)
self._should_retry = ti_context.should_retry
self._last_successful_heartbeat = time.monotonic()
except Exception:
except Exception as e:
# On any error kill that subprocess!
self.kill(signal.SIGKILL)

# Handle broker redelivery: task already running on another worker
if isinstance(e, ServerResponseError) and e.response.status_code == 409:
# FastAPI wraps HTTPException detail in {"detail": {...}}
detail = e.detail
if isinstance(detail, dict) and "detail" in detail:
detail = detail["detail"]
if (
isinstance(detail, dict)
and detail.get("reason") == "invalid_state"
and detail.get("previous_state") == "running"
):
Comment on lines +1018 to +1024
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is rather fragile and not a good idea to write things this way.

log.warning("Task already running, likely broker redelivery", task_instance_id=str(ti.id))
raise TaskAlreadyRunningError(f"Task {ti.id} is already running") from e
raise

msg = StartupDetails.model_construct(
Expand Down Expand Up @@ -2088,6 +2102,10 @@ def supervise(
final_state=process.final_state,
)
return exit_code
except TaskAlreadyRunningError:
# Let the executor handle this (e.g., Celery will ignore it)
log.info("Exiting due to broker redelivery", task_instance_id=str(ti.id))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"broker" as a concept only exist for Celery, so this log message doesn't make sense for other executors.

raise
finally:
if log_path and log_file_descriptor:
log_file_descriptor.close()
Expand Down
49 changes: 35 additions & 14 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,38 +720,59 @@ def test_supervise_handles_deferred_task(
} in captured_logs

def test_supervisor_handles_already_running_task(self):
"""Test that Supervisor prevents starting a Task Instance that is already running."""
"""Test that Supervisor raises TaskAlreadyRunningError for already running tasks."""
from airflow.sdk.exceptions import TaskAlreadyRunningError

ti = TaskInstance(
id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7()
)

# Mock API Server response indicating the TI is already running
# The API Server would return a 409 Conflict status code if the TI is not
# in a "queued" state.
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti.id}/run":
return httpx.Response(
409,
json={
"reason": "invalid_state",
"message": "TI was not in a state where it could be marked as running",
"previous_state": "running",
"detail": {
"reason": "invalid_state",
"message": "TI was not in a state where it could be marked as running",
"previous_state": "running",
}
},
)

return httpx.Response(status_code=204)

client = make_client(transport=httpx.MockTransport(handle_request))

with pytest.raises(ServerResponseError, match="Server returned error") as err:
with pytest.raises(TaskAlreadyRunningError, match="already running"):
ActivitySubprocess.start(dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, what=ti, client=client)

assert err.value.response.status_code == 409
assert err.value.detail == {
"reason": "invalid_state",
"message": "TI was not in a state where it could be marked as running",
"previous_state": "running",
}
@pytest.mark.parametrize("previous_state", ["failed", "success", "skipped"])
def test_supervisor_raises_error_for_other_invalid_states(self, previous_state):
"""Test that Supervisor raises ServerResponseError for non-running invalid states."""
ti = TaskInstance(
id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7()
)

def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti.id}/run":
return httpx.Response(
409,
json={
"detail": {
"reason": "invalid_state",
"message": "TI was not in a state where it could be marked as running",
"previous_state": previous_state,
}
},
)

return httpx.Response(status_code=204)

client = make_client(transport=httpx.MockTransport(handle_request))

with pytest.raises(ServerResponseError):
ActivitySubprocess.start(dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, what=ti, client=client)

@pytest.mark.parametrize("captured_logs", [logging.ERROR], indirect=True, ids=["log_level=error"])
def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch, mocker, make_ti_context_dict):
Expand Down