Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RELEASE_NOTES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ Bug Fixes
- Fix CLI export to handle stdout without file descriptors (#50328)
- Fix ``DagProcessor`` stats log to show the correct parse duration (#50316)
- Fix OpenAPI schema for ``get_log`` API (#50547)
- Remove ``logical_date`` check when validating inlets and outlets (#51464)
- Guard ``ti`` update state and set task to fail if exception encountered (#51295)

Miscellaneous
"""""""""""""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import json
from collections import defaultdict
from collections.abc import Iterator
from datetime import datetime
from typing import TYPE_CHECKING, Annotated, Any
from uuid import UUID

Expand Down Expand Up @@ -55,7 +56,7 @@
TITerminalStatePayload,
)
from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep
from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException, TaskNotFound
from airflow.exceptions import TaskNotFound
from airflow.models.asset import AssetActive
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun as DR
Expand All @@ -70,6 +71,8 @@
from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState

if TYPE_CHECKING:
from sqlalchemy.sql.dml import Update

from airflow.sdk.types import Operator


Expand Down Expand Up @@ -381,43 +384,74 @@ def ti_update_state(

# We exclude_unset to avoid updating fields that are not set in the payload
data = ti_patch_payload.model_dump(exclude={"task_outlets", "outlet_events"}, exclude_unset=True)

query = update(TI).where(TI.id == ti_id_str).values(data)

if isinstance(ti_patch_payload, TITerminalStatePayload):
updated_state = ti_patch_payload.state
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(state=updated_state)
try:
query, updated_state = _create_ti_state_update_query_and_update_state(
ti_patch_payload=ti_patch_payload,
ti_id_str=ti_id_str,
session=session,
query=query,
updated_state=updated_state,
dag_id=dag_id,
dag_bag=dag_bag,
)
except Exception:
# Set a task to failed in case any unexpected exception happened during task state update
log.exception("Error updating Task Instance state to %s. Set the task to failed", updated_state)
ti = session.get(TI, ti_id_str)
query = TI.duration_expression_update(datetime.now(tz=timezone.utc), query, session.bind)
query = query.values(state=TaskInstanceState.FAILED)
_handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, session=session, dag_bag=dag_bag)

if updated_state == TerminalTIState.FAILED:
ti = session.get(TI, ti_id_str)
ser_dag = dag_bag.get_dag(dag_id)
if ser_dag and getattr(ser_dag, "fail_fast", False):
task_dict = getattr(ser_dag, "task_dict")
task_teardown_map = {k: v.is_teardown for k, v in task_dict.items()}
_stop_remaining_tasks(task_instance=ti, task_teardown_map=task_teardown_map, session=session)

elif isinstance(ti_patch_payload, TIRetryStatePayload):
# TODO: Replace this with FastAPI's Custom Exception handling:
# https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
try:
result = session.execute(query)
log.info("Task instance state updated", new_state=updated_state, rows_affected=result.rowcount)
except SQLAlchemyError as e:
log.error("Error updating Task Instance state", error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred"
)


def _handle_fail_fast_for_dag(ti: TI, dag_id: str, session: SessionDep, dag_bag: DagBagDep) -> None:
ser_dag = dag_bag.get_dag(dag_id)
if ser_dag and getattr(ser_dag, "fail_fast", False):
task_dict = getattr(ser_dag, "task_dict")
task_teardown_map = {k: v.is_teardown for k, v in task_dict.items()}
_stop_remaining_tasks(task_instance=ti, task_teardown_map=task_teardown_map, session=session)


def _create_ti_state_update_query_and_update_state(
*,
ti_patch_payload: TIStateUpdate,
ti_id_str: str,
query: Update,
updated_state,
session: SessionDep,
dag_bag: DagBagDep,
dag_id: str,
) -> tuple[Update, TaskInstanceState]:
if isinstance(ti_patch_payload, (TITerminalStatePayload, TIRetryStatePayload, TISuccessStatePayload)):
ti = session.get(TI, ti_id_str)
updated_state = ti_patch_payload.state
ti.prepare_db_for_next_try(session)
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(state=updated_state)
elif isinstance(ti_patch_payload, TISuccessStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
updated_state = ti_patch_payload.state
task_instance = session.get(TI, ti_id_str)
try:

if updated_state == TerminalTIState.FAILED:
# This is the only case needs extra handling for TITerminalStatePayload
_handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, session=session, dag_bag=dag_bag)
elif isinstance(ti_patch_payload, TIRetryStatePayload):
ti.prepare_db_for_next_try(session)
elif isinstance(ti_patch_payload, TISuccessStatePayload):
TI.register_asset_changes_in_db(
task_instance,
ti,
ti_patch_payload.task_outlets, # type: ignore
ti_patch_payload.outlet_events,
session,
)
except AirflowInactiveAssetInInletOrOutletException as err:
log.error("Asset registration failed due to conflicting asset: %s", err)

query = query.values(state=updated_state)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
Expand Down Expand Up @@ -468,14 +502,17 @@ def ti_update_state(
# As documented in https://dev.mysql.com/doc/refman/5.7/en/datetime.html.
_MYSQL_TIMESTAMP_MAX = timezone.datetime(2038, 1, 19, 3, 14, 7)
if ti_patch_payload.reschedule_date > _MYSQL_TIMESTAMP_MAX:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail={
"reason": "invalid_reschedule_date",
"message": f"Cannot reschedule to {ti_patch_payload.reschedule_date.isoformat()} "
f"since it is over MySQL's TIMESTAMP storage limit.",
},
# Set a task to failed in case any unexpected exception happened during task state update
log.exception(
"Error updating Task Instance state to %s. Set the task to failed", updated_state
)
data = ti_patch_payload.model_dump(exclude={"reschedule_date"}, exclude_unset=True)
query = update(TI).where(TI.id == ti_id_str).values(data)
query = TI.duration_expression_update(datetime.now(tz=timezone.utc), query, session.bind)
query = query.values(state=TaskInstanceState.FAILED)
ti = session.get(TI, ti_id_str)
_handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, session=session, dag_bag=dag_bag)
return query, updated_state

task_instance = session.get(TI, ti_id_str)
actual_start_date = timezone.utcnow()
Expand All @@ -494,16 +531,10 @@ def ti_update_state(
# clear the next_method and next_kwargs so that none of the retries pick them up
query = query.values(state=TaskInstanceState.UP_FOR_RESCHEDULE, next_method=None, next_kwargs=None)
updated_state = TaskInstanceState.UP_FOR_RESCHEDULE
# TODO: Replace this with FastAPI's Custom Exception handling:
# https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
try:
result = session.execute(query)
log.info("Task instance state updated", new_state=updated_state, rows_affected=result.rowcount)
except SQLAlchemyError as e:
log.error("Error updating Task Instance state", error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred"
)
else:
raise ValueError(f"Unexpected Payload Type {type(ti_patch_payload)}")

return query, updated_state


@ti_id_router.patch(
Expand Down Expand Up @@ -868,7 +899,7 @@ def validate_inlets_and_outlets(
bind_contextvars(ti_id=ti_id_str)

ti = session.scalar(select(TI).where(TI.id == ti_id_str))
if not ti or not ti.logical_date:
if not ti:
log.error("Task Instance not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,40 @@ def test_ti_update_state_to_success_with_asset_events(
assert event[0].asset == AssetModel(name="my-task", uri="s3://bucket/my-task", extra={})
assert event[0].extra == expected_extra

def test_ti_update_state_to_failed_with_inactive_asset(self, client, session, create_task_instance):
# inactive
asset = AssetModel(
id=1,
name="my-task-2",
uri="s3://bucket/my-task",
group="asset",
extra={},
)
session.add(asset)

ti = create_task_instance(
task_id="test_ti_update_state_to_success_with_asset_events",
start_date=DEFAULT_START_DATE,
state=State.RUNNING,
)
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": "success",
"end_date": DEFAULT_END_DATE.isoformat(),
"task_outlets": [{"name": "my-task-2", "uri": "s3://bucket/my-task", "type": "Asset"}],
"outlet_events": [],
},
)

assert response.status_code == 204
session.expire_all()

ti = session.get(TaskInstance, ti.id)
assert ti.state == State.FAILED

@pytest.mark.parametrize(
"outlet_events, expected_extra",
[
Expand Down Expand Up @@ -976,8 +1010,13 @@ def test_ti_update_state_reschedule_mysql_limit(
},
)

assert response.status_code == 422
assert response.json()["detail"]["reason"] == "invalid_reschedule_date"
assert response.status_code == 204
assert response.text == ""

session.expire_all()

ti = session.get(TaskInstance, ti.id)
assert ti.state == State.FAILED

def test_ti_update_state_handle_retry(self, client, session, create_task_instance):
ti = create_task_instance(
Expand Down Expand Up @@ -2142,7 +2181,14 @@ def add_one(x):


class TestInvactiveInletsAndOutlets:
def test_ti_inactive_inlets_and_outlets(self, client, dag_maker):
@pytest.mark.parametrize(
"logical_date",
[
datetime(2025, 6, 6, tzinfo=timezone.utc),
None,
],
)
def test_ti_inactive_inlets_and_outlets(self, logical_date, client, dag_maker):
"""Test the inactive assets in inlets and outlets can be found."""
with dag_maker("test_inlets_and_outlets"):
EmptyOperator(
Expand All @@ -2154,7 +2200,7 @@ def test_ti_inactive_inlets_and_outlets(self, client, dag_maker):
],
)

dr = dag_maker.create_dagrun()
dr = dag_maker.create_dagrun(logical_date=logical_date)

task1_ti = dr.get_task_instance("task1")
response = client.get(f"/execution/task-instances/{task1_ti.id}/validate-inlets-and-outlets")
Expand All @@ -2175,7 +2221,14 @@ def test_ti_inactive_inlets_and_outlets(self, client, dag_maker):
for asset in expected_inactive_assets:
assert asset in inactive_assets

def test_ti_inactive_inlets_and_outlets_without_inactive_assets(self, client, dag_maker):
@pytest.mark.parametrize(
"logical_date",
[
datetime(2025, 6, 6, tzinfo=timezone.utc),
None,
],
)
def test_ti_inactive_inlets_and_outlets_without_inactive_assets(self, logical_date, client, dag_maker):
"""Test the task without inactive assets in its inlets or outlets returns empty list."""
with dag_maker("test_inlets_and_outlets_inactive"):
EmptyOperator(
Expand All @@ -2184,7 +2237,7 @@ def test_ti_inactive_inlets_and_outlets_without_inactive_assets(self, client, da
outlets=[Asset(name="outlet-name", uri="uri")],
)

dr = dag_maker.create_dagrun()
dr = dag_maker.create_dagrun(logical_date=logical_date)

task1_ti = dr.get_task_instance("inactive_task1")
response = client.get(f"/execution/task-instances/{task1_ti.id}/validate-inlets-and-outlets")
Expand Down
7 changes: 7 additions & 0 deletions dev/breeze/src/airflow_breeze/utils/path_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def search_upwards_for_airflow_root_path(start_from: Path) -> Path | None:
airflow_candidate_init_py = directory / "airflow-core" / "src" / "airflow" / "__init__.py"
if airflow_candidate_init_py.exists() and "airflow" in airflow_candidate_init_py.read_text().lower():
return directory
airflow_2_candidate_init_py = directory / "airflow" / "__init__.py"
if (
airflow_2_candidate_init_py.exists()
and "airflow" in airflow_2_candidate_init_py.read_text().lower()
and directory.parent.name != "src"
):
return directory
directory = directory.parent
return None

Expand Down
2 changes: 1 addition & 1 deletion providers/google/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ PIP package Version required
``python-slugify`` ``>=7.0.0``
``PyOpenSSL`` ``>=23.0.0``
``sqlalchemy-bigquery`` ``>=1.2.1``
``sqlalchemy-spanner`` ``>=1.6.2``
``sqlalchemy-spanner`` ``>=1.6.2,!=1.12.0``
``tenacity`` ``>=8.1.0``
``immutabledict`` ``>=4.2.0``
``types-protobuf`` ``!=5.29.1.20250402``
Expand Down
2 changes: 1 addition & 1 deletion providers/google/docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ PIP package Version required
``python-slugify`` ``>=7.0.0``
``PyOpenSSL`` ``>=23.0.0``
``sqlalchemy-bigquery`` ``>=1.2.1``
``sqlalchemy-spanner`` ``>=1.6.2``
``sqlalchemy-spanner`` ``>=1.6.2,!=1.12.0``
``tenacity`` ``>=8.1.0``
``immutabledict`` ``>=4.2.0``
``types-protobuf`` ``!=5.29.1.20250402``
Expand Down
2 changes: 1 addition & 1 deletion providers/google/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ dependencies = [
"python-slugify>=7.0.0",
"PyOpenSSL>=23.0.0",
"sqlalchemy-bigquery>=1.2.1",
"sqlalchemy-spanner>=1.6.2",
"sqlalchemy-spanner>=1.6.2,!=1.12.0",
"tenacity>=8.1.0",
"immutabledict>=4.2.0",
# types-protobuf 5.29.1.20250402 is a partial stub package, leading to mypy complaining
Expand Down
4 changes: 2 additions & 2 deletions reproducible_build.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
release-notes-hash: bf463650df899ba29246022076670c40
source-date-epoch: 1748962238
release-notes-hash: edb6987ad849473a219f71b63e369800
source-date-epoch: 1749202198
3 changes: 2 additions & 1 deletion scripts/ci/testing/run_unit_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export COLOR_RED=$'\e[31m'
export COLOR_BLUE=$'\e[34m'
export COLOR_YELLOW=$'\e[33m'
export COLOR_RESET=$'\e[0m'
export COLOR_GREEN=$'\e[32m'

if [[ ! "$#" -eq 2 ]]; then
echo "${COLOR_RED}You must provide 2 arguments: Group, Scope!.${COLOR_RESET}"
Expand Down Expand Up @@ -107,7 +108,7 @@ function providers_tests() {
echo
exit "${RESULT}"
fi
echo "${COLOR_GREEB}Providers tests completed successfully${COLOR_RESET}"
echo "${COLOR_GREEN}Providers tests completed successfully${COLOR_RESET}"
}


Expand Down
Loading