Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
5b5a855
remove session.query and update new style from airflow-models
Prab-27 Jun 27, 2025
68971f4
Update usage of XComModel.get_many method, as it now returns a Select…
Prab-27 Jun 27, 2025
bbe9f39
fix static checks
Prab-27 Jun 28, 2025
d5958d4
remove .with_entities from models-taskinstance.py
Prab-27 Jun 30, 2025
7aa757a
fix static checks - add session for session.execute in serialized_obj.py
Prab-27 Jul 1, 2025
1eaea10
fix static checks for usage of get_many method
Prab-27 Jul 3, 2025
184cf43
fix static checks
Prab-27 Jul 5, 2025
d18a371
execute select in routes-xcom.py
Prab-27 Jul 19, 2025
29f7803
get value using session.execute() and .mappings() in taskinstace-model
Prab-27 Jul 21, 2025
34c3fb9
use .mappings() in test and key value in xcom-models
Prab-27 Jul 31, 2025
2d29d2d
fix sttaic check value for xcom and test_backend
Prab-27 Aug 1, 2025
8d3584a
remove deprecated timezone imports
Prab-27 Aug 3, 2025
f6ed198
fix tests form test_backend and test_pod
Prab-27 Aug 3, 2025
c86f88e
execute select function from tests and fix a test for dxcom-deserials…
Prab-27 Aug 4, 2025
a47f1cc
remove .mappings()
Prab-27 Aug 5, 2025
5e0c6e0
fix tests
Prab-27 Aug 6, 2025
3a3cb3c
remove .mappings() from cncf/test_pod.py
Prab-27 Aug 6, 2025
8d3c0d4
fix some checks
Prab-27 Aug 7, 2025
4d1344a
add path to pre-commit hook
Prab-27 Aug 10, 2025
e581c1d
use / in path
Prab-27 Aug 10, 2025
6a80c90
use AIRFLOW_V_3_1_PLUS version in test_backend
Prab-27 Aug 12, 2025
548361b
add version and fix chnages
Prab-27 Aug 16, 2025
06e3297
Resolve conflicts
Prab-27 Aug 18, 2025
37254ee
remove session.query()
Prab-27 Aug 23, 2025
7a91031
fix tests
Prab-27 Aug 26, 2025
c4cf286
apply comment
Prab-27 Aug 28, 2025
bf1d52a
add session as attribute in DagRunWaiter class
Prab-27 Aug 29, 2025
663d6c9
resolve conflict
Prab-27 Aug 30, 2025
747c730
Remove debug print statements
Prab-27 Aug 30, 2025
0d7679a
adjust import statements and remove same code
Prab-27 Sep 1, 2025
6e61285
Customize the conditions
Prab-27 Sep 1, 2025
3fb04c5
rearrange tests from test_pod.py
Prab-27 Sep 2, 2025
d81f3b4
refactored conditional logic
Prab-27 Sep 2, 2025
d7d4bfc
remove unused session arguments from xcom_query function
Prab-27 Sep 2, 2025
2e68baa
remove blank lines and add select_from
Prab-27 Sep 2, 2025
432672c
fix static checks
Prab-27 Sep 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,8 @@ repos:
files: >
(?x)
^airflow-ctl.*\.py$|
^providers/fab.*\.py$|
^airflow-core/src/airflow/models/.*\.py$|
^providers/fab/.*\.py$|
^task_sdk.*\.py$
pass_filenames: true
- id: update-supported-versions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ def wait_dag_run_until_finished(
run_id=dag_run_id,
interval=interval,
result_task_ids=result_task_ids,
session=session,
)
return StreamingResponse(waiter.wait())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,13 @@ def get_xcom_entry(
task_ids=task_id,
dag_ids=dag_id,
map_indexes=map_index,
session=session,
limit=1,
)

# We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend.
# This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead
# retrieves the raw serialized value from the database.
result = xcom_query.limit(1).first()
result = session.scalars(xcom_query).first()

if result is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"XCom entry with key: `{xcom_key}` not found")
Expand Down Expand Up @@ -249,9 +248,8 @@ def create_xcom_entry(
dag_ids=dag_id,
run_id=dag_run_id,
map_indexes=request_body.map_index,
session=session,
)
result = already_existing_query.with_entities(XComModel.value).first()
result = session.execute(already_existing_query.with_only_columns(XComModel.value)).first()
if result:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import attrs
from sqlalchemy import select

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.models.dagrun import DagRun
from airflow.models.xcom import XCOM_RETURN_KEY, XComModel
from airflow.utils.session import create_session_async
Expand All @@ -43,6 +44,7 @@ class DagRunWaiter:
run_id: str
interval: float
result_task_ids: list[str] | None
session: SessionDep

async def _get_dag_run(self) -> DagRun:
async with create_session_async() as session:
Expand All @@ -55,7 +57,7 @@ def _serialize_xcoms(self) -> dict[str, Any]:
task_ids=self.result_task_ids,
dag_ids=self.dag_id,
)
xcom_query = xcom_query.order_by(XComModel.task_id, XComModel.map_index)
xcom_query = self.session.scalars(xcom_query.order_by(XComModel.task_id, XComModel.map_index)).all()

def _group_xcoms(g: Iterator[XComModel]) -> Any:
entries = list(g)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ async def xcom_query(
run_id: str,
task_id: str,
key: str,
session: SessionDep,
map_index: Annotated[int | None, Query()] = None,
) -> Select:
query = XComModel.get_many(
Expand All @@ -85,7 +84,6 @@ async def xcom_query(
task_ids=task_id,
dag_ids=dag_id,
map_indexes=map_index,
session=session,
)
return query

Expand Down Expand Up @@ -151,23 +149,22 @@ def get_xcom(
task_ids=task_id,
dag_ids=dag_id,
include_prior_dates=params.include_prior_dates,
session=session,
)
if params.offset is not None:
xcom_query = xcom_query.filter(XComModel.value.is_not(None)).order_by(None)
xcom_query = xcom_query.where(XComModel.value.is_not(None)).order_by(None)
if params.offset >= 0:
xcom_query = xcom_query.order_by(XComModel.map_index.asc()).offset(params.offset)
else:
xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - params.offset)
else:
xcom_query = xcom_query.filter(XComModel.map_index == params.map_index)
xcom_query = xcom_query.where(XComModel.map_index == params.map_index)

# We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend.
# This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead
# retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one`
# (which automatically deserializes using the backend), we avoid potential
# performance hits from retrieving large data files into the API server.
result = xcom_query.limit(1).first()
result = session.scalars(xcom_query).first()
if result is None:
if params.offset is None:
message = (
Expand Down Expand Up @@ -204,15 +201,14 @@ def get_mapped_xcom_by_index(
key=key,
task_ids=task_id,
dag_ids=dag_id,
session=session,
)
xcom_query = xcom_query.order_by(None)
if offset >= 0:
xcom_query = xcom_query.order_by(XComModel.map_index.asc()).offset(offset)
else:
xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - offset)

if (result := xcom_query.limit(1).first()) is None:
if (result := session.scalars(xcom_query).first()) is None:
message = (
f"XCom with {key=} {offset=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
)
Expand Down Expand Up @@ -250,7 +246,6 @@ def get_mapped_xcom_by_slice(
task_ids=task_id,
dag_ids=dag_id,
include_prior_dates=params.include_prior_dates,
session=session,
)
query = query.order_by(None)

Expand Down Expand Up @@ -309,7 +304,7 @@ def get_mapped_xcom_by_slice(
else:
query = query.slice(-stop, -start)

values = [row.value for row in query.with_entities(XComModel.value)]
values = [row.value for row in session.execute(query.with_only_columns(XComModel.value)).all()]
if step != 1:
values = values[::step]
return XComSequenceSliceResponse(values)
Expand Down
9 changes: 4 additions & 5 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,13 +790,12 @@ def fetch_task_instances(

def _check_last_n_dagruns_failed(self, dag_id, max_consecutive_failed_dag_runs, session):
"""Check if last N dags failed."""
dag_runs = (
session.query(DagRun)
.filter(DagRun.dag_id == dag_id)
dag_runs = session.scalars(
select(DagRun)
.where(DagRun.dag_id == dag_id)
.order_by(DagRun.logical_date.desc())
.limit(max_consecutive_failed_dag_runs)
.all()
)
).all()
""" Marking dag as paused, if needed"""
to_be_paused = len(dag_runs) >= max_consecutive_failed_dag_runs and all(
dag_run.state == DagRunState.FAILED for dag_run in dag_runs
Expand Down
7 changes: 4 additions & 3 deletions airflow-core/src/airflow/models/deadline.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,10 @@ def prune_deadlines(cls, *, session: Session, conditions: dict[Column, Any]) ->

try:
# Get deadlines which match the provided conditions and their associated DagRuns.
deadline_dagrun_pairs = (
session.query(Deadline, DagRun).join(DagRun).filter(and_(*filter_conditions)).all()
)
deadline_dagrun_pairs = session.execute(
select(Deadline, DagRun).join(DagRun).where(and_(*filter_conditions))
).all()

except AttributeError as e:
logger.exception("Error resolving deadlines: %s", e)
raise
Expand Down
8 changes: 3 additions & 5 deletions airflow-core/src/airflow/models/serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,8 @@ def get_latest_serialized_dags(
"""
# Subquery to get the latest serdag per dag_id
latest_serdag_subquery = (
session.query(cls.dag_id, func.max(cls.created_at).label("created_at"))
.filter(cls.dag_id.in_(dag_ids))
select(cls.dag_id, func.max(cls.created_at).label("created_at"))
.where(cls.dag_id.in_(dag_ids))
.group_by(cls.dag_id)
.subquery()
)
Expand All @@ -501,9 +501,7 @@ def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str, SerializedDA
:returns: a dict of DAGs read from database
"""
latest_serialized_dag_subquery = (
session.query(cls.dag_id, func.max(cls.created_at).label("max_created"))
.group_by(cls.dag_id)
.subquery()
select(cls.dag_id, func.max(cls.created_at).label("max_created")).group_by(cls.dag_id).subquery()
)
serialized_dags = session.scalars(
select(cls).join(
Expand Down
49 changes: 29 additions & 20 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,16 +264,14 @@ def clear_task_instances(
for instance in tis:
run_ids_by_dag_id[instance.dag_id].add(instance.run_id)

drs = (
session.query(DagRun)
.filter(
drs = session.scalars(
select(DagRun).where(
or_(
and_(DagRun.dag_id == dag_id, DagRun.run_id.in_(run_ids))
for dag_id, run_ids in run_ids_by_dag_id.items()
)
)
.all()
)
).all()
dag_run_state = DagRunState(dag_run_state) # Validate the state value.
for dr in drs:
if dr.state in State.finished_dr_states:
Expand Down Expand Up @@ -659,7 +657,7 @@ def get_task_instance(
session: Session = NEW_SESSION,
) -> TaskInstance | None:
query = (
session.query(TaskInstance)
select(TaskInstance)
.options(lazyload(TaskInstance.dag_run)) # lazy load dag run to avoid locking it
.filter_by(
dag_id=dag_id,
Expand All @@ -672,9 +670,9 @@ def get_task_instance(
if lock_for_update:
for attempt in run_with_db_retries(logger=cls.logger()):
with attempt:
return query.with_for_update().one_or_none()
return session.execute(query.with_for_update()).scalar_one_or_none()
else:
return query.one_or_none()
return session.execute(query).scalar_one_or_none()

return None

Expand Down Expand Up @@ -824,13 +822,13 @@ def are_dependents_done(self, session: Session = NEW_SESSION) -> bool:
if not task.downstream_task_ids:
return True

ti = session.query(func.count(TaskInstance.task_id)).filter(
ti = select(func.count(TaskInstance.task_id)).where(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id.in_(task.downstream_task_ids),
TaskInstance.run_id == self.run_id,
TaskInstance.state.in_((TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS)),
)
count = ti[0][0]
count = session.scalar(ti)
return count == len(task.downstream_task_ids)

@provide_session
Expand Down Expand Up @@ -1005,7 +1003,9 @@ def ready_for_retry(self) -> bool:
def _get_dagrun(dag_id, run_id, session) -> DagRun:
from airflow.models.dagrun import DagRun # Avoid circular import

dr = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == run_id).one()
dr = session.execute(
select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)
).scalar_one()
return dr

@provide_session
Expand Down Expand Up @@ -1947,7 +1947,6 @@ def xcom_pull(
task_ids=task_ids,
map_indexes=map_indexes,
include_prior_dates=include_prior_dates,
session=session,
)

# NOTE: Since we're only fetching the value field and not the whole
Expand All @@ -1956,8 +1955,14 @@ def xcom_pull(

# We are only pulling one single task.
if (task_ids is None or isinstance(task_ids, str)) and not isinstance(map_indexes, Iterable):
first = query.with_entities(
XComModel.run_id, XComModel.task_id, XComModel.dag_id, XComModel.map_index, XComModel.value
first = session.execute(
query.with_only_columns(
XComModel.run_id,
XComModel.task_id,
XComModel.dag_id,
XComModel.map_index,
XComModel.value,
)
).first()
if first is None: # No matching XCom at all.
return default
Expand Down Expand Up @@ -1998,16 +2003,20 @@ def xcom_pull(
def get_num_running_task_instances(self, session: Session, same_dagrun: bool = False) -> int:
"""Return Number of running TIs from the DB."""
# .count() is inefficient
num_running_task_instances_query = session.query(func.count()).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
TaskInstance.state == TaskInstanceState.RUNNING,
num_running_task_instances_query = (
select(func.count())
.select_from(TaskInstance)
.where(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
TaskInstance.state == TaskInstanceState.RUNNING,
)
)
if same_dagrun:
num_running_task_instances_query = num_running_task_instances_query.filter(
num_running_task_instances_query = num_running_task_instances_query.where(
TaskInstance.run_id == self.run_id
)
return num_running_task_instances_query.scalar()
return session.scalar(num_running_task_instances_query)

@staticmethod
def filter_for_tis(tis: Iterable[TaskInstance | TaskInstanceKey]) -> BooleanClauseList | None:
Expand Down
Loading