Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
TIEnterRunningPayload,
TIHeartbeatInfo,
TIRescheduleStatePayload,
TIRetryStatePayload,
TIRunContext,
TISkippedDownstreamTasksStatePayload,
TISuccessStatePayload,
Expand Down Expand Up @@ -152,6 +153,11 @@ def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: datetime):
body = TITerminalStatePayload(end_date=when, state=TerminalStateNonSuccess(state))
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def retry(self, id: uuid.UUID, end_date: datetime):
"""Tell the API server that this TI has failed and reached a up_for_retry state."""
body = TIRetryStatePayload(end_date=end_date)
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def succeed(self, id: uuid.UUID, when: datetime, task_outlets, outlet_events):
"""Tell the API server that this TI has succeeded."""
body = TISuccessStatePayload(end_date=when, task_outlets=task_outlets, outlet_events=outlet_events)
Expand Down
9 changes: 9 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
RetryTask,
SetRenderedFields,
SetXCom,
SkipDownstreamTasks,
Expand Down Expand Up @@ -125,6 +126,7 @@
STATES_SENT_DIRECTLY = [
IntermediateTIState.DEFERRED,
IntermediateTIState.UP_FOR_RESCHEDULE,
IntermediateTIState.UP_FOR_RETRY,
TerminalTIState.SUCCESS,
]

Expand Down Expand Up @@ -913,6 +915,13 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
task_outlets=msg.task_outlets,
outlet_events=msg.outlet_events,
)
elif isinstance(msg, RetryTask):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
self.client.task_instances.retry(
id=self.id,
end_date=msg.end_date,
)
elif isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
if isinstance(conn, ConnectionResponse):
Expand Down
16 changes: 16 additions & 0 deletions task-sdk/tests/task_sdk/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,22 @@ def handle_request(request: httpx.Request) -> httpx.Response:
)
client.task_instances.reschedule(ti_id, msg)

def test_task_instance_up_for_retry(self):
ti_id = uuid6.uuid7()

def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti_id}/state":
actual_body = json.loads(request.read())
assert actual_body["state"] == "up_for_retry"
assert actual_body["end_date"] == "2024-10-31T12:00:00Z"
return httpx.Response(
status_code=204,
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

client = make_client(transport=httpx.MockTransport(handle_request))
client.task_instances.retry(ti_id, end_date=timezone.parse("2024-10-31T12:00:00Z"))

@pytest.mark.parametrize(
"rendered_fields",
[
Expand Down
10 changes: 10 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
RetryTask,
SetRenderedFields,
SetXCom,
SucceedTask,
Expand Down Expand Up @@ -1165,6 +1166,15 @@ def watched_subprocess(self, mocker):
"",
id="patch_task_instance_to_skipped",
),
pytest.param(
RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
b"",
"task_instances.retry",
(),
{"id": TI_ID, "end_date": timezone.parse("2024-10-31T12:00:00Z")},
"",
id="up_for_retry",
),
pytest.param(
SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}),
b"",
Expand Down
Loading