Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,33 @@ class TaskInstance(StrictBaseModel):
hostname: str | None = None


class AssetReferenceAssetEventDagRun(StrictBaseModel):
"""Schema for AssetModel used in AssetEventDagRunReference."""

name: str
uri: str
extra: dict


class AssetAliasReferenceAssetEventDagRun(StrictBaseModel):
"""Schema for AssetAliasModel used in AssetEventDagRunReference."""

name: str


class AssetEventDagRunReference(StrictBaseModel):
"""Schema for AssetEvent model used in DagRun."""

asset: AssetReferenceAssetEventDagRun
extra: dict
source_task_id: str | None
source_dag_id: str | None
source_run_id: str | None
source_map_index: int | None
source_aliases: list[AssetAliasReferenceAssetEventDagRun]
timestamp: UtcDateTime


class DagRun(StrictBaseModel):
"""Schema for DagRun model with minimal required fields needed for Runtime."""

Expand All @@ -261,6 +288,7 @@ class DagRun(StrictBaseModel):
clear_number: int = 0
run_type: DagRunType
conf: Annotated[dict[str, Any], Field(default_factory=dict)]
consumed_asset_events: list[AssetEventDagRunReference]


class TIRunContext(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pydantic import JsonValue
from sqlalchemy import func, tuple_, update
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.orm import joinedload
from sqlalchemy.sql import select

from airflow.api_fastapi.common.db.common import SessionDep
Expand Down Expand Up @@ -173,21 +174,15 @@ def ti_run(
result = session.execute(query)
log.info("TI %s state updated: %s row(s) affected", ti_id_str, result.rowcount)

dr = session.execute(
select(
DR.run_id,
DR.dag_id,
DR.data_interval_start,
DR.data_interval_end,
DR.run_after,
DR.start_date,
DR.end_date,
DR.clear_number,
DR.run_type,
DR.conf,
DR.logical_date,
).filter_by(dag_id=ti.dag_id, run_id=ti.run_id)
).one_or_none()
dr = (
session.scalars(
select(DR)
.filter_by(dag_id=ti.dag_id, run_id=ti.run_id)
.options(joinedload(DR.consumed_asset_events))
)
.unique()
.one_or_none()
)

if not dr:
raise ValueError(f"DagRun with dag_id={ti.dag_id} and run_id={ti.run_id} not found.")
Expand Down Expand Up @@ -236,8 +231,8 @@ def ti_run(
context.next_kwargs = ti.next_kwargs

return context
except SQLAlchemyError as e:
log.error("Error marking Task Instance state as running: %s", e)
except SQLAlchemyError:
log.exception("Error marking Task Instance state as running")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred"
)
Expand Down
12 changes: 6 additions & 6 deletions airflow-core/src/airflow/models/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,16 +739,16 @@ class AssetEvent(Base):
)

@property
def uri(self):
return self.asset.uri
def name(self) -> str:
return self.asset.name

@property
def group(self):
return self.asset.group
def uri(self) -> str:
return self.asset.uri

@property
def name(self):
return self.asset.name
def group(self) -> str:
return self.asset.group

def __repr__(self) -> str:
args = []
Expand Down
31 changes: 20 additions & 11 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,10 +883,28 @@ def _get_template_context(
assert task_instance.task
assert task
assert task.dag
assert session

dag_run = task_instance.get_dagrun(session)
validated_params = process_params(dag, task, dag_run.conf, suppress_exception=ignore_param_exceptions)
def _get_dagrun(session: Session) -> DagRun:
dag_run = task_instance.get_dagrun(session)
if dag_run in session:
return dag_run
# The dag_run may not be attached to the session anymore since the
# code base is over-zealous with use of session.expunge_all().
# Re-attach it if the relation is not loaded so we can load it when needed.
info = inspect(dag_run)
if info.attrs.consumed_asset_events.loaded_value is not NO_VALUE:
return dag_run
# If dag_run is not flushed to db at all (e.g. CLI commands using
# in-memory objects for ad-hoc operations), just set the value manually.
if not info.has_identity:
dag_run.consumed_asset_events = []
return dag_run
return session.merge(dag_run, load=False)

dag_run = _get_dagrun(session)

validated_params = process_params(dag, task, dag_run.conf, suppress_exception=ignore_param_exceptions)
ti_context_from_server = TIRunContext(
dag_run=DagRunSDK.model_validate(dag_run, from_attributes=True),
max_tries=task_instance.max_tries,
Expand Down Expand Up @@ -916,15 +934,6 @@ def get_prev_end_date_success() -> pendulum.DateTime | None:
return timezone.coerce_datetime(_get_previous_dagrun_success().end_date)

def get_triggering_events() -> dict[str, list[AssetEvent]]:
if TYPE_CHECKING:
assert session is not None

# The dag_run may not be attached to the session anymore since the
# code base is over-zealous with use of session.expunge_all().
# Re-attach it if we get called.
nonlocal dag_run
if dag_run not in session:
dag_run = session.merge(dag_run, load=False)
asset_events = dag_run.consumed_asset_events
triggering_events: dict[str, list[AssetEvent]] = defaultdict(list)
for event in asset_events:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_ti_run_state_to_running(
"end_date": None,
"run_type": "manual",
"conf": {},
"consumed_asset_events": [],
},
"task_reschedule_count": 0,
"max_tries": max_tries,
Expand Down
1 change: 1 addition & 0 deletions devel-common/src/tests_common/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,6 +2078,7 @@ def _create_task_instance(
run_type=run_type, # type: ignore
run_after=run_after, # type: ignore
conf=conf,
consumed_asset_events=[],
),
task_reschedule_count=task_reschedule_count,
max_tries=task_retries if max_tries is None else max_tries,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,18 @@
from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS

if TYPE_CHECKING:
from collections.abc import Sequence

try:
from airflow.sdk.api.datamodels._generated import TIRunContext
from airflow.sdk.api.datamodels._generated import AssetEventDagRunReference, TIRunContext
from airflow.sdk.definitions.context import Context

except ImportError:
# TODO: Remove once provider drops support for Airflow 2
# TIRunContext is only used in Airflow 3 tests
from airflow.utils.context import Context

TIRunContext = Any # type: ignore[misc, assignment]
AssetEventDagRunReference = TIRunContext = Any # type: ignore[misc, assignment]


if AIRFLOW_V_2_10_PLUS:
Expand Down Expand Up @@ -438,6 +440,7 @@ def __call__(
run_type: str = ...,
task_reschedule_count: int = ...,
conf: dict[str, Any] | None = ...,
consumed_asset_events: Sequence[AssetEventDagRunReference] = ...,
) -> TIRunContext: ...


Expand All @@ -459,6 +462,7 @@ def _make_context(
run_type: str = "manual",
task_reschedule_count: int = 0,
conf=None,
consumed_asset_events: Sequence[AssetEventDagRunReference] = (),
) -> TIRunContext:
return TIRunContext(
dag_run=DagRun(
Expand All @@ -472,6 +476,7 @@ def _make_context(
run_type=run_type, # type: ignore
run_after=run_after, # type: ignore
conf=conf, # type: ignore
consumed_asset_events=list(consumed_asset_events),
),
task_reschedule_count=task_reschedule_count,
max_tries=0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,7 @@ def mock_task_id(dag_id, task_id, try_number, logical_date, map_index):
run_type=DagRunType.MANUAL,
run_after=timezone.datetime(2023, 1, 3, 13, 1, 1),
conf=None,
consumed_asset_events=[],
),
task_reschedule_count=0,
max_tries=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,9 @@ def extract_keys_from_dict(node: ast.Dict) -> typing.Iterator[str]:
raise ValueError("'context' is not assigned a dictionary literal")
yield from extract_keys_from_dict(context_assignment.value)

# Handle keys added conditionally in `if x := self._ti_context_from_server`
# Handle keys added conditionally in `if from_server`
for stmt in fn_get_template_context.body:
if (
isinstance(stmt, ast.If)
and isinstance(stmt.test, ast.NamedExpr)
and isinstance(stmt.test.value, ast.Attribute)
and stmt.test.value.attr == "_ti_context_from_server"
):
if isinstance(stmt, ast.If) and isinstance(stmt.test, ast.Name) and stmt.test.id == "from_server":
for sub_stmt in stmt.body:
# Get keys from `context_from_server` assignment
if (
Expand Down
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def noop_handler(request: httpx.Request) -> httpx.Response:
"start_date": "2021-01-01T00:00:00Z",
"run_type": DagRunType.MANUAL,
"run_after": "2021-01-01T00:00:00Z",
"consumed_asset_events": [],
},
"max_tries": 0,
"should_retry": False,
Expand Down
43 changes: 43 additions & 0 deletions task-sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@
API_VERSION: Final[str] = "2025-03-26"


class AssetAliasReferenceAssetEventDagRun(BaseModel):
"""
Schema for AssetAliasModel used in AssetEventDagRunReference.
"""

model_config = ConfigDict(
extra="forbid",
)
name: Annotated[str, Field(title="Name")]


class AssetProfile(BaseModel):
"""
Profile of an asset-like object.
Expand All @@ -52,6 +63,19 @@ class AssetProfile(BaseModel):
type: Annotated[str, Field(title="Type")]


class AssetReferenceAssetEventDagRun(BaseModel):
"""
Schema for AssetModel used in AssetEventDagRunReference.
"""

model_config = ConfigDict(
extra="forbid",
)
name: Annotated[str, Field(title="Name")]
uri: Annotated[str, Field(title="Uri")]
extra: Annotated[dict[str, Any], Field(title="Extra")]


class AssetResponse(BaseModel):
"""
Asset schema for responses with fields that are needed for Runtime.
Expand Down Expand Up @@ -354,6 +378,24 @@ class TerminalTIState(str, Enum):
REMOVED = "removed"


class AssetEventDagRunReference(BaseModel):
"""
Schema for AssetEvent model used in DagRun.
"""

model_config = ConfigDict(
extra="forbid",
)
asset: AssetReferenceAssetEventDagRun
extra: Annotated[dict[str, Any], Field(title="Extra")]
source_task_id: Annotated[str | None, Field(title="Source Task Id")] = None
source_dag_id: Annotated[str | None, Field(title="Source Dag Id")] = None
source_run_id: Annotated[str | None, Field(title="Source Run Id")] = None
source_map_index: Annotated[int | None, Field(title="Source Map Index")] = None
source_aliases: Annotated[list[AssetAliasReferenceAssetEventDagRun], Field(title="Source Aliases")]
timestamp: Annotated[AwareDatetime, Field(title="Timestamp")]


class AssetEventResponse(BaseModel):
"""
Asset event schema with fields that are needed for Runtime.
Expand Down Expand Up @@ -397,6 +439,7 @@ class DagRun(BaseModel):
clear_number: Annotated[int | None, Field(title="Clear Number")] = 0
run_type: DagRunType
conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None
consumed_asset_events: Annotated[list[AssetEventDagRunReference], Field(title="Consumed Asset Events")]


class HTTPValidationError(BaseModel):
Expand Down
Loading
Loading