Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Improve Supervisor and Task Instance State Validation #44405

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,21 @@
from typing import TYPE_CHECKING
from unittest.mock import MagicMock

import httpx
import pytest
import structlog
from uuid6 import uuid7

from airflow.sdk.api import client as sdk_client
from airflow.sdk.api.client import ServerResponseError
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity
from airflow.sdk.execution_time.comms import ConnectionResult, GetConnection, GetVariable, VariableResult
from airflow.sdk.execution_time.supervisor import WatchedSubprocess, supervise
from airflow.utils import timezone as tz

from task_sdk.tests.api.test_client import make_client

if TYPE_CHECKING:
import kgb

Expand Down Expand Up @@ -66,7 +70,7 @@ def subprocess_main():
print("I'm a short message")
sys.stdout.write("Message ")
print("stderr message", file=sys.stderr)
# We need a short sleep for the main process to process things. I worry this timining will be
# We need a short sleep for the main process to process things. I worry this timing will be
# fragile, but I can't think of a better way. This lets the stdout be read (partial line) and the
# stderr full line be read
sleep(0.1)
Expand Down Expand Up @@ -258,6 +262,38 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine):
"timestamp": "2024-11-07T12:34:56.078901Z",
} in captured_logs

def test_supervisor_handles_already_running_task(self):
"""Test that Supervisor prevents starting a Task Instance that is already running."""
ti = TaskInstance(id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1)

# Mock API Server response indicating the TI is already running
# The API Server would return a 409 Conflict status code if the TI is not
# in a "queued" state.
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti.id}/state":
return httpx.Response(
409,
json={
"reason": "invalid_state",
"message": "TI was not in a state where it could be marked as running",
"previous_state": "running",
},
)

return httpx.Response(status_code=204)

client = make_client(transport=httpx.MockTransport(handle_request))

with pytest.raises(ServerResponseError, match="Server returned error") as err:
WatchedSubprocess.start(path=os.devnull, ti=ti, client=client)

assert err.value.response.status_code == 409
assert err.value.detail == {
"reason": "invalid_state",
"message": "TI was not in a state where it could be marked as running",
"previous_state": "running",
}


class TestHandleRequest:
@pytest.fixture
Expand Down
13 changes: 8 additions & 5 deletions tests/api_fastapi/execution_api/routes/test_task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.state import State, TaskInstanceState

from tests_common.test_utils.db import clear_db_runs

Expand Down Expand Up @@ -79,14 +79,17 @@ def test_ti_update_state_to_running(self, client, session, create_task_instance)
assert ti.pid == 100
assert ti.start_date.isoformat() == "2024-10-31T12:00:00+00:00"

def test_ti_update_state_conflict_if_not_queued(self, client, session, create_task_instance):
@pytest.mark.parametrize("initial_ti_state", [s for s in TaskInstanceState if s != State.QUEUED])
def test_ti_update_state_conflict_if_not_queued(
self, client, session, create_task_instance, initial_ti_state
):
"""
Test that a 409 error is returned when the Task Instance is not in a state where it can be marked as
running. In this case, the Task Instance is first in NONE state so it cannot be marked as running.
"""
ti = create_task_instance(
task_id="test_ti_update_state_conflict_if_not_queued",
state=State.NONE,
state=initial_ti_state,
)
session.commit()

Expand All @@ -105,12 +108,12 @@ def test_ti_update_state_conflict_if_not_queued(self, client, session, create_ta
assert response.json() == {
"detail": {
"message": "TI was not in a state where it could be marked as running",
"previous_state": State.NONE,
"previous_state": initial_ti_state,
"reason": "invalid_state",
}
}

assert session.scalar(select(TaskInstance.state).where(TaskInstance.id == ti.id)) == State.NONE
assert session.scalar(select(TaskInstance.state).where(TaskInstance.id == ti.id)) == initial_ti_state

@pytest.mark.parametrize(
("state", "end_date", "expected_state"),
Expand Down