Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import json
from collections import defaultdict
from collections.abc import Iterator
from datetime import datetime
from typing import TYPE_CHECKING, Annotated, Any
from uuid import UUID

Expand Down Expand Up @@ -55,7 +56,7 @@
TITerminalStatePayload,
)
from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep
from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException, TaskNotFound
from airflow.exceptions import TaskNotFound
from airflow.models.asset import AssetActive
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun as DR
Expand All @@ -70,6 +71,8 @@
from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState

if TYPE_CHECKING:
from sqlalchemy.sql.dml import Update

from airflow.sdk.types import Operator


Expand Down Expand Up @@ -381,43 +384,74 @@ def ti_update_state(

# We exclude_unset to avoid updating fields that are not set in the payload
data = ti_patch_payload.model_dump(exclude={"task_outlets", "outlet_events"}, exclude_unset=True)

query = update(TI).where(TI.id == ti_id_str).values(data)

if isinstance(ti_patch_payload, TITerminalStatePayload):
updated_state = ti_patch_payload.state
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(state=updated_state)
try:
query, updated_state = _create_ti_state_update_query_and_update_state(
ti_patch_payload=ti_patch_payload,
ti_id_str=ti_id_str,
session=session,
query=query,
updated_state=updated_state,
dag_id=dag_id,
dag_bag=dag_bag,
)
except Exception:
# Set a task to failed in case any unexpected exception happened during task state update
log.exception("Error updating Task Instance state to %s. Set the task to failed", updated_state)
ti = session.get(TI, ti_id_str)
query = TI.duration_expression_update(datetime.now(tz=timezone.utc), query, session.bind)
query = query.values(state=TaskInstanceState.FAILED)
_handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, session=session, dag_bag=dag_bag)

if updated_state == TerminalTIState.FAILED:
ti = session.get(TI, ti_id_str)
ser_dag = dag_bag.get_dag(dag_id)
if ser_dag and getattr(ser_dag, "fail_fast", False):
task_dict = getattr(ser_dag, "task_dict")
task_teardown_map = {k: v.is_teardown for k, v in task_dict.items()}
_stop_remaining_tasks(task_instance=ti, task_teardown_map=task_teardown_map, session=session)

elif isinstance(ti_patch_payload, TIRetryStatePayload):
# TODO: Replace this with FastAPI's Custom Exception handling:
# https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
try:
result = session.execute(query)
log.info("Task instance state updated", new_state=updated_state, rows_affected=result.rowcount)
except SQLAlchemyError as e:
log.error("Error updating Task Instance state", error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred"
)


def _handle_fail_fast_for_dag(ti: TI, dag_id: str, session: SessionDep, dag_bag: DagBagDep) -> None:
ser_dag = dag_bag.get_dag(dag_id)
if ser_dag and getattr(ser_dag, "fail_fast", False):
task_dict = getattr(ser_dag, "task_dict")
task_teardown_map = {k: v.is_teardown for k, v in task_dict.items()}
_stop_remaining_tasks(task_instance=ti, task_teardown_map=task_teardown_map, session=session)


def _create_ti_state_update_query_and_update_state(
*,
ti_patch_payload: TIStateUpdate,
ti_id_str: str,
query: Update,
updated_state,
session: SessionDep,
dag_bag: DagBagDep,
dag_id: str,
) -> tuple[Update, TaskInstanceState]:
if isinstance(ti_patch_payload, (TITerminalStatePayload, TIRetryStatePayload, TISuccessStatePayload)):
ti = session.get(TI, ti_id_str)
updated_state = ti_patch_payload.state
ti.prepare_db_for_next_try(session)
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(state=updated_state)
elif isinstance(ti_patch_payload, TISuccessStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
updated_state = ti_patch_payload.state
task_instance = session.get(TI, ti_id_str)
try:

if updated_state == TerminalTIState.FAILED:
# This is the only case needs extra handling for TITerminalStatePayload
_handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, session=session, dag_bag=dag_bag)
elif isinstance(ti_patch_payload, TIRetryStatePayload):
ti.prepare_db_for_next_try(session)
elif isinstance(ti_patch_payload, TISuccessStatePayload):
TI.register_asset_changes_in_db(
task_instance,
ti,
ti_patch_payload.task_outlets, # type: ignore
ti_patch_payload.outlet_events,
session,
)
except AirflowInactiveAssetInInletOrOutletException as err:
log.error("Asset registration failed due to conflicting asset: %s", err)

query = query.values(state=updated_state)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
Expand Down Expand Up @@ -468,14 +502,17 @@ def ti_update_state(
# As documented in https://dev.mysql.com/doc/refman/5.7/en/datetime.html.
_MYSQL_TIMESTAMP_MAX = timezone.datetime(2038, 1, 19, 3, 14, 7)
if ti_patch_payload.reschedule_date > _MYSQL_TIMESTAMP_MAX:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail={
"reason": "invalid_reschedule_date",
"message": f"Cannot reschedule to {ti_patch_payload.reschedule_date.isoformat()} "
f"since it is over MySQL's TIMESTAMP storage limit.",
},
# Set a task to failed in case any unexpected exception happened during task state update
log.exception(
"Error updating Task Instance state to %s. Set the task to failed", updated_state
)
data = ti_patch_payload.model_dump(exclude={"reschedule_date"}, exclude_unset=True)
query = update(TI).where(TI.id == ti_id_str).values(data)
query = TI.duration_expression_update(datetime.now(tz=timezone.utc), query, session.bind)
query = query.values(state=TaskInstanceState.FAILED)
ti = session.get(TI, ti_id_str)
_handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, session=session, dag_bag=dag_bag)
return query, updated_state

task_instance = session.get(TI, ti_id_str)
actual_start_date = timezone.utcnow()
Expand All @@ -494,16 +531,10 @@ def ti_update_state(
# clear the next_method and next_kwargs so that none of the retries pick them up
query = query.values(state=TaskInstanceState.UP_FOR_RESCHEDULE, next_method=None, next_kwargs=None)
updated_state = TaskInstanceState.UP_FOR_RESCHEDULE
# TODO: Replace this with FastAPI's Custom Exception handling:
# https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
try:
result = session.execute(query)
log.info("Task instance state updated", new_state=updated_state, rows_affected=result.rowcount)
except SQLAlchemyError as e:
log.error("Error updating Task Instance state", error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred"
)
else:
raise ValueError(f"Unexpected Payload Type {type(ti_patch_payload)}")

return query, updated_state


@ti_id_router.patch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,40 @@ def test_ti_update_state_to_success_with_asset_events(
assert event[0].asset == AssetModel(name="my-task", uri="s3://bucket/my-task", extra={})
assert event[0].extra == expected_extra

def test_ti_update_state_to_failed_with_inactive_asset(self, client, session, create_task_instance):
# inactive
asset = AssetModel(
id=1,
name="my-task-2",
uri="s3://bucket/my-task",
group="asset",
extra={},
)
session.add(asset)

ti = create_task_instance(
task_id="test_ti_update_state_to_success_with_asset_events",
start_date=DEFAULT_START_DATE,
state=State.RUNNING,
)
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": "success",
"end_date": DEFAULT_END_DATE.isoformat(),
"task_outlets": [{"name": "my-task-2", "uri": "s3://bucket/my-task", "type": "Asset"}],
"outlet_events": [],
},
)

assert response.status_code == 204
session.expire_all()

ti = session.get(TaskInstance, ti.id)
assert ti.state == State.FAILED

@pytest.mark.parametrize(
"outlet_events, expected_extra",
[
Expand Down Expand Up @@ -976,8 +1010,13 @@ def test_ti_update_state_reschedule_mysql_limit(
},
)

assert response.status_code == 422
assert response.json()["detail"]["reason"] == "invalid_reschedule_date"
assert response.status_code == 204
assert response.text == ""

session.expire_all()

ti = session.get(TaskInstance, ti.id)
assert ti.state == State.FAILED

def test_ti_update_state_handle_retry(self, client, session, create_task_instance):
ti = create_task_instance(
Expand Down
8 changes: 4 additions & 4 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,6 @@ def succeed(self, id: uuid.UUID, when: datetime, task_outlets, outlet_events, re
)
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def heartbeat(self, id: uuid.UUID, pid: int):
body = TIHeartbeatInfo(pid=pid, hostname=get_hostname())
self.client.put(f"task-instances/{id}/heartbeat", content=body.model_dump_json())

def defer(self, id: uuid.UUID, msg):
"""Tell the API server that this TI has been deferred."""
body = TIDeferredStatePayload(**msg.model_dump(exclude_unset=True, exclude={"type"}))
Expand All @@ -192,6 +188,10 @@ def reschedule(self, id: uuid.UUID, msg: RescheduleTask):
# Create a reschedule state payload from msg
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def heartbeat(self, id: uuid.UUID, pid: int):
body = TIHeartbeatInfo(pid=pid, hostname=get_hostname())
self.client.put(f"task-instances/{id}/heartbeat", content=body.model_dump_json())

def skip_downstream_tasks(self, id: uuid.UUID, msg: SkipDownstreamTasks):
"""Tell the API server to skip the downstream tasks of this TI."""
body = TISkippedDownstreamTasksStatePayload(tasks=msg.tasks)
Expand Down