Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 67 additions & 2 deletions airflow-core/src/airflow/callbacks/callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,24 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Annotated, Literal
from collections.abc import Mapping
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast

from pydantic import BaseModel, Field
import structlog
from pydantic import BaseModel, Field, model_validator
from sqlalchemy import inspect as sa_inspect
from sqlalchemy.exc import NoInspectionAvailable
from sqlalchemy.orm.attributes import set_committed_value
from sqlalchemy.orm.exc import DetachedInstanceError

from airflow.api_fastapi.execution_api.datamodels import taskinstance as ti_datamodel # noqa: TC001
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.typing_compat import Self

log = structlog.get_logger(logger_name=__name__)


class BaseCallbackRequest(BaseModel):
"""
Expand Down Expand Up @@ -95,6 +103,63 @@ class DagRunContext(BaseModel):
dag_run: ti_datamodel.DagRun | None = None
last_ti: ti_datamodel.TaskInstance | None = None

@model_validator(mode="before")
@classmethod
def _sanitize_consumed_asset_events(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
if (dag_run := values.get("dag_run")) is None:
return values

# DagRunContext may receive non-ORM dag_run objects (e.g. datamodels).
# Only apply this validator to ORM-mapped instances.
try:
sa_inspect(dag_run)
except NoInspectionAvailable:
return values

# Relationship access may raise DetachedInstanceError; on that path, reload DagRun
# from the DB to avoid crashing the scheduler.
try:
events = dag_run.consumed_asset_events
set_committed_value(
dag_run,
"consumed_asset_events",
list(events) if events is not None else [],
)
except DetachedInstanceError:
log.warning(
"DagRunContext encountered DetachedInstanceError while accessing "
"consumed_asset_events; reloading DagRun from DB."
)
from sqlalchemy import select
from sqlalchemy.orm import selectinload

from airflow.models.asset import AssetEvent
from airflow.models.dagrun import DagRun
from airflow.utils.session import create_session

# Defensive guardrail: reload DagRun with eager-loaded relationships on
# DetachedInstanceError to recover state without adding DB I/O to the hot path.
with create_session() as session:
dag_run_reloaded = session.scalar(
select(DagRun)
.where(DagRun.id == dag_run.id)
.options(
selectinload(DagRun.consumed_asset_events).selectinload(AssetEvent.asset),
selectinload(DagRun.consumed_asset_events).selectinload(AssetEvent.source_aliases),
)
)

# DagRun exists; reload is expected to succeed.
dag_run_reloaded = cast("DagRun", dag_run_reloaded)
reloaded_events = dag_run_reloaded.consumed_asset_events

# Install DB-backed relationship state on the detached instance.
set_committed_value(
dag_run, "consumed_asset_events", list(reloaded_events) if reloaded_events is not None else []
)

return values


class DagCallbackRequest(BaseCallbackRequest):
"""A Class with information about the success/failure DAG callback to be executed."""
Expand Down
62 changes: 62 additions & 0 deletions airflow-core/tests/unit/callbacks/test_callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
EmailRequest,
TaskCallbackRequest,
)
from airflow.models import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.serialization.definitions.baseoperator import SerializedBaseOperator
from airflow.utils.state import State, TaskInstanceState
Expand Down Expand Up @@ -197,6 +198,67 @@ def test_dagrun_context_serialization(self):
assert deserialized.dag_run.dag_id == context.dag_run.dag_id
assert deserialized.last_ti.task_id == context.last_ti.task_id

def test_dagrun_context_detached_consumed_asset_events(self, session):
"""
DagRunContext should not fail if a detached DagRun raises
DetachedInstanceError when accessing consumed_asset_events.
"""
# Create a real ORM DagRun.
current_time = timezone.utcnow()
dag_run = DagRun(
dag_id="test_dag",
run_id="test_run_detached",
logical_date=current_time,
state="running",
run_type="manual",
)

# Forcefully detached it to replicate failure mode.
session.add(dag_run)
session.commit()
session.expunge(dag_run)

# Validation for consumed_asset_events occurs on creation of DagRunContext.
context = DagRunContext(dag_run=dag_run, last_ti=None)

# Access should be safe and not raise DetachedInstanceError.
events = context.dag_run.consumed_asset_events

# Relationship should be normalized to a safe iterable.
assert events is not None
assert isinstance(events, list)

def test_dagrun_context_attached_consumed_asset_events(self, session):
"""
DagRunContext should safely normalize consumed_asset_events
when the DagRun is attached to a session.
"""
current_time = timezone.utcnow()
dag_run = DagRun(
dag_id="test_dag",
run_id="test_run_attached",
logical_date=current_time,
state="running",
run_type="manual",
)

# Do not detach
session.add(dag_run)
session.flush()

# Construct context while DagRun is still attached.
context = DagRunContext(
dag_run=dag_run,
last_ti=None,
)

# Access should be safe and not raise DetachedInstanceError.
events = context.dag_run.consumed_asset_events

# Relationship should be normalized to a safe iterable.
assert events is not None
assert isinstance(events, list)


class TestDagCallbackRequestWithContext:
def test_dag_callback_request_with_context_from_server(self):
Expand Down