Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ class TerminalStateNonSuccess(str, Enum):
FAILED = TerminalTIState.FAILED
SKIPPED = TerminalTIState.SKIPPED
REMOVED = TerminalTIState.REMOVED
FAIL_WITHOUT_RETRY = TerminalTIState.FAIL_WITHOUT_RETRY


class TITerminalStatePayload(StrictBaseModel):
Expand Down Expand Up @@ -157,6 +156,23 @@ class TIRescheduleStatePayload(StrictBaseModel):
end_date: UtcDateTime


class TIRetryStatePayload(StrictBaseModel):
"""Schema for updating TaskInstance to up_for_retry."""

state: Annotated[
Literal[IntermediateTIState.UP_FOR_RETRY],
# Specify a default in the schema, but not in code, so Pydantic marks it as required.
WithJsonSchema(
{
"type": "string",
"enum": [IntermediateTIState.UP_FOR_RETRY],
"default": IntermediateTIState.UP_FOR_RETRY,
}
),
]
end_date: UtcDateTime


class TISkippedDownstreamTasksStatePayload(StrictBaseModel):
"""Schema for updating downstream tasks to a skipped state."""

Expand Down Expand Up @@ -185,6 +201,8 @@ def ti_state_discriminator(v: dict[str, str] | StrictBaseModel) -> str:
return "deferred"
elif state == TIState.UP_FOR_RESCHEDULE:
return "up_for_reschedule"
elif state == TIState.UP_FOR_RETRY:
return "up_for_retry"
return "_other_"


Expand All @@ -197,6 +215,7 @@ def ti_state_discriminator(v: dict[str, str] | StrictBaseModel) -> str:
Annotated[TITargetStatePayload, Tag("_other_")],
Annotated[TIDeferredStatePayload, Tag("deferred")],
Annotated[TIRescheduleStatePayload, Tag("up_for_reschedule")],
Annotated[TIRetryStatePayload, Tag("up_for_retry")],
],
Discriminator(ti_state_discriminator),
]
Expand Down Expand Up @@ -276,6 +295,9 @@ class TIRunContext(BaseModel):
xcom_keys_to_clear: Annotated[list[str], Field(default_factory=list)]
"""List of Xcom keys that need to be cleared and purged on by the worker."""

should_retry: bool
"""If the ti encounters an error, whether it should enter retry or failed state."""


class PrevSuccessfulDagRunResponse(BaseModel):
"""Schema for response with previous successful DagRun information for Task Template Context."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TIEnterRunningPayload,
TIHeartbeatInfo,
TIRescheduleStatePayload,
TIRetryStatePayload,
TIRunContext,
TIRuntimeCheckPayload,
TISkippedDownstreamTasksStatePayload,
Expand All @@ -50,7 +51,7 @@
from airflow.models.trigger import Trigger
from airflow.models.xcom import XComModel
from airflow.utils import timezone
from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
from airflow.utils.state import DagRunState, TaskInstanceState

router = VersionedAPIRouter(
dependencies=[
Expand Down Expand Up @@ -136,7 +137,7 @@ def ti_run(
ti_run_payload.pid,
):
log.info("Duplicate start request received from %s ", ti_run_payload.hostname)
elif previous_state != TaskInstanceState.QUEUED:
elif previous_state not in (TaskInstanceState.QUEUED, TaskInstanceState.RESTARTING):
log.warning(
"Can not start Task Instance ('%s') in invalid state: %s",
ti_id_str,
Expand Down Expand Up @@ -226,6 +227,7 @@ def ti_run(
variables=[],
connections=[],
xcom_keys_to_clear=xcom_keys,
should_retry=_is_eligible_to_retry(previous_state, ti.try_number, ti.max_tries),
)

# Only set if they are non-null
Expand Down Expand Up @@ -289,23 +291,18 @@ def ti_update_state(
query = update(TI).where(TI.id == ti_id_str).values(data)

if isinstance(ti_patch_payload, TITerminalStatePayload):
updated_state = ti_patch_payload.state
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(state=updated_state)
elif isinstance(ti_patch_payload, TIRetryStatePayload):
from airflow.models.taskinstance import uuid7
from airflow.models.taskinstancehistory import TaskInstanceHistory

ti = session.get(TI, ti_id_str)
TaskInstanceHistory.record_ti(ti, session=session)
ti.try_id = uuid7()
updated_state = ti_patch_payload.state
# if we get failed, we should attempt to retry, as it is a more
# normal state. Tasks with retries are more frequent than without retries.
if ti_patch_payload.state == TerminalTIState.FAIL_WITHOUT_RETRY:
updated_state = TaskInstanceState.FAILED
elif ti_patch_payload.state == TaskInstanceState.FAILED:
if _is_eligible_to_retry(previous_state, try_number, max_tries):
from airflow.models.taskinstance import uuid7
from airflow.models.taskinstancehistory import TaskInstanceHistory

ti = session.get(TI, ti_id_str)
TaskInstanceHistory.record_ti(ti, session=session)
ti.try_id = uuid7()
updated_state = TaskInstanceState.UP_FOR_RETRY
else:
updated_state = TaskInstanceState.FAILED
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(state=updated_state)
elif isinstance(ti_patch_payload, TISuccessStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
Expand Down
3 changes: 2 additions & 1 deletion airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,7 @@ def _get_template_context(
ti_context_from_server = TIRunContext(
dag_run=DagRunSDK.model_validate(dag_run, from_attributes=True),
max_tries=task_instance.max_tries,
should_retry=task_instance.is_eligible_to_retry(),
)
runtime_ti = task_instance.to_runtime_ti(context_from_server=ti_context_from_server)

Expand Down Expand Up @@ -3196,7 +3197,7 @@ def handle_failure(
fail_fast=fail_fast,
)

def is_eligible_to_retry(self):
def is_eligible_to_retry(self) -> bool:
"""Is task instance is eligible for retry."""
return _is_eligible_to_retry(task_instance=self)

Expand Down
1 change: 0 additions & 1 deletion airflow-core/src/airflow/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class TerminalTIState(str, Enum):
FAILED = "failed"
SKIPPED = "skipped" # A user can raise a AirflowSkipException from a task & it will be marked as skipped
REMOVED = "removed"
FAIL_WITHOUT_RETRY = "fail_without_retry"

def __str__(self) -> str:
return self.value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,22 @@ def setup_method(self):
def teardown_method(self):
clear_db_runs()

def test_ti_run_state_to_running(self, client, session, create_task_instance, time_machine):
@pytest.mark.parametrize(
"max_tries, should_retry",
[
pytest.param(0, False, id="max_retries=0"),
pytest.param(3, True, id="should_retry"),
],
)
def test_ti_run_state_to_running(
self,
client,
session,
create_task_instance,
time_machine,
max_tries,
should_retry,
):
"""
Test that the Task Instance state is updated to running when the Task Instance is in a state where it can be
marked as running.
Expand All @@ -131,6 +146,7 @@ def test_ti_run_state_to_running(self, client, session, create_task_instance, ti
session=session,
start_date=instant,
)
ti.max_tries = max_tries
session.commit()

response = client.patch(
Expand Down Expand Up @@ -160,7 +176,8 @@ def test_ti_run_state_to_running(self, client, session, create_task_instance, ti
"conf": {},
},
"task_reschedule_count": 0,
"max_tries": 0,
"max_tries": max_tries,
"should_retry": should_retry,
"variables": [],
"connections": [],
"xcom_keys_to_clear": [],
Expand Down Expand Up @@ -210,7 +227,7 @@ def test_next_kwargs_still_encoded(self, client, session, create_task_instance,
time_machine.move_to(instant, tick=False)

ti = create_task_instance(
task_id="test_ti_run_state_to_running",
task_id="test_next_kwargs_still_encoded",
state=State.QUEUED,
session=session,
start_date=instant,
Expand Down Expand Up @@ -238,6 +255,7 @@ def test_next_kwargs_still_encoded(self, client, session, create_task_instance,
"dag_run": mock.ANY,
"task_reschedule_count": 0,
"max_tries": 0,
"should_retry": False,
"variables": [],
"connections": [],
"xcom_keys_to_clear": [],
Expand All @@ -248,7 +266,10 @@ def test_next_kwargs_still_encoded(self, client, session, create_task_instance,
},
}

@pytest.mark.parametrize("initial_ti_state", [s for s in TaskInstanceState if s != State.QUEUED])
@pytest.mark.parametrize(
"initial_ti_state",
[s for s in TaskInstanceState if s not in (TaskInstanceState.QUEUED, TaskInstanceState.RESTARTING)],
)
def test_ti_run_state_conflict_if_not_queued(
self, client, session, create_task_instance, initial_ti_state
):
Expand Down Expand Up @@ -691,67 +712,17 @@ def test_ti_update_state_to_reschedule(self, client, session, create_task_instan
assert trs[0].task_instance.map_index == -1
assert trs[0].duration == 129600

@pytest.mark.parametrize(
("retries", "expected_state"),
[
(0, State.FAILED),
(None, State.FAILED),
(3, State.UP_FOR_RETRY),
],
)
def test_ti_update_state_to_failed_with_retries(
self, client, session, create_task_instance, retries, expected_state
):
def test_ti_update_state_handle_retry(self, client, session, create_task_instance):
ti = create_task_instance(
task_id="test_ti_update_state_to_retry",
state=State.RUNNING,
)

if retries is not None:
ti.max_tries = retries
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": TerminalTIState.FAILED,
"end_date": DEFAULT_END_DATE.isoformat(),
},
)

assert response.status_code == 204
assert response.text == ""

session.expire_all()

ti = session.get(TaskInstance, ti.id)
assert ti.state == expected_state
assert ti.next_method is None
assert ti.next_kwargs is None

tih = session.query(TaskInstanceHistory).where(
TaskInstanceHistory.task_id == ti.task_id, TaskInstanceHistory.task_instance_id == ti.id
)
tih_count = tih.count()
assert tih_count == (1 if retries else 0)
if retries:
tih = tih.one()
assert tih.try_id
assert tih.try_id != ti.try_id

def test_ti_update_state_when_ti_is_restarting(self, client, session, create_task_instance):
ti = create_task_instance(
task_id="test_ti_update_state_when_ti_is_restarting",
state=State.RUNNING,
)
# update state to restarting
ti.state = State.RESTARTING
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": TerminalTIState.FAILED,
"state": State.UP_FOR_RETRY,
"end_date": DEFAULT_END_DATE.isoformat(),
},
)
Expand All @@ -762,43 +733,19 @@ def test_ti_update_state_when_ti_is_restarting(self, client, session, create_tas
session.expire_all()

ti = session.get(TaskInstance, ti.id)
# restarting is always retried
assert ti.state == State.UP_FOR_RETRY
assert ti.next_method is None
assert ti.next_kwargs is None

def test_ti_update_state_when_ti_has_higher_tries_than_retries(
self, client, session, create_task_instance
):
ti = create_task_instance(
task_id="test_ti_update_state_when_ti_has_higher_tries_than_retries",
state=State.RUNNING,
)
# two maximum tries defined, but third try going on
ti.max_tries = 2
ti.try_number = 3
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": TerminalTIState.FAILED,
"end_date": DEFAULT_END_DATE.isoformat(),
},
tih = (
session.query(TaskInstanceHistory)
.where(TaskInstanceHistory.task_id == ti.task_id, TaskInstanceHistory.task_instance_id == ti.id)
.one()
)
assert tih.try_id
assert tih.try_id != ti.try_id

assert response.status_code == 204
assert response.text == ""

session.expire_all()

ti = session.get(TaskInstance, ti.id)
# all retries exhausted, marking as failed
assert ti.state == State.FAILED
assert ti.next_method is None
assert ti.next_kwargs is None

def test_ti_update_state_to_failed_without_retry_table_check(self, client, session, create_task_instance):
def test_ti_update_state_to_failed_table_check(self, client, session, create_task_instance):
# we just want to fail in this test, no need to retry
ti = create_task_instance(
task_id="test_ti_update_state_to_failed_table_check",
Expand All @@ -810,7 +757,7 @@ def test_ti_update_state_to_failed_without_retry_table_check(self, client, sessi
response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": TerminalTIState.FAIL_WITHOUT_RETRY,
"state": TerminalTIState.FAILED,
"end_date": DEFAULT_END_DATE.isoformat(),
},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def _make_context(
),
task_reschedule_count=task_reschedule_count,
max_tries=0,
should_retry=False,
)

return _make_context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,7 @@ def mock_task_id(dag_id, task_id, try_number, logical_date, map_index):
),
task_reschedule_count=0,
max_tries=1,
should_retry=False,
),
start_date=dt.datetime(2023, 1, 1, 13, 1, 1),
)
Expand Down
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ def noop_handler(request: httpx.Request) -> httpx.Response:
"run_after": "2021-01-01T00:00:00Z",
},
"max_tries": 0,
"should_retry": False,
},
)
return httpx.Response(200, json={"text": "Hello, world!"})
Expand Down
Loading
Loading