Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion airflow-core/src/airflow/api_fastapi/common/db/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import TYPE_CHECKING

from sqlalchemy import func, select
from sqlalchemy.orm import selectinload

from airflow.api_fastapi.common.db.common import (
apply_filters_to_select,
Expand All @@ -33,7 +34,7 @@


def generate_dag_with_latest_run_query(max_run_filters: list[BaseParam], order_by: SortParam) -> Select:
query = select(DagModel)
query = select(DagModel).options(selectinload(DagModel.tags))

max_run_id_query = ( # ordering by id will not always be "latest run", but it's a simplifying assumption
select(DagRun.dag_id, func.max(DagRun.id).label("max_dag_run_id"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType

from tests_common.test_utils.asserts import count_queries
from tests_common.test_utils.db import (
clear_db_assets,
clear_db_connections,
Expand Down Expand Up @@ -524,6 +525,71 @@ def test_get_dags_filter_has_import_errors(self, session, test_client, filter_va
assert body["total_entries"] == 1
assert [dag["dag_id"] for dag in body["dags"]] == expected_ids

def test_get_dags_no_n_plus_one_queries(self, session, test_client):
"""Test that fetching DAGs with tags doesn't trigger n+1 queries."""
num_dags = 5
for i in range(num_dags):
dag_id = f"test_dag_queries_{i}"
dag_model = DagModel(
dag_id=dag_id,
bundle_name="dag_maker",
fileloc=f"/tmp/{dag_id}.py",
is_stale=False,
)
session.add(dag_model)
session.flush()

for j in range(3):
tag = DagTag(name=f"tag_{i}_{j}", dag_id=dag_id)
session.add(tag)

session.commit()
session.expire_all()

with count_queries() as result:
response = test_client.get("/dags", params={"limit": 10})

assert response.status_code == 200
body = response.json()
dags_with_our_prefix = [d for d in body["dags"] if d["dag_id"].startswith("test_dag_queries_")]
assert len(dags_with_our_prefix) == num_dags
for dag in dags_with_our_prefix:
assert len(dag["tags"]) == 3

first_query_count = sum(result.values())

# Add more DAGs and verify query count doesn't scale linearly
for i in range(num_dags, num_dags + 3):
dag_id = f"test_dag_queries_{i}"
dag_model = DagModel(
dag_id=dag_id,
bundle_name="dag_maker",
fileloc=f"/tmp/{dag_id}.py",
is_stale=False,
)
session.add(dag_model)
session.flush()

for j in range(3):
tag = DagTag(name=f"tag_{i}_{j}", dag_id=dag_id)
session.add(tag)

session.commit()
session.expire_all()

with count_queries() as result2:
response = test_client.get("/dags", params={"limit": 15})

assert response.status_code == 200
second_query_count = sum(result2.values())

# With n+1, adding 3 DAGs would add ~3 tag queries
# Without n+1, query count should remain nearly identical
assert second_query_count - first_query_count < 3, (
f"Added 3 DAGs but query count increased by {second_query_count - first_query_count} "
f"({first_query_count} → {second_query_count}), suggesting n+1 queries for tags"
)


class TestPatchDag(TestDagEndpoint):
"""Unit tests for Patch DAG."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
from sqlalchemy.orm import Session

from airflow.models import DagRun
from airflow.models.dag import DagModel, DagTag
from airflow.models.dag_favorite import DagFavorite
from airflow.models.hitl import HITLDetail
from airflow.sdk.timezone import utcnow
from airflow.utils.session import provide_session
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType

from tests_common.test_utils.asserts import count_queries
from unit.api_fastapi.core_api.routes.public.test_dags import (
DAG1_ID,
DAG2_ID,
Expand Down Expand Up @@ -231,6 +233,71 @@ def test_should_response_403(self, unauthorized_test_client):
response = unauthorized_test_client.get("/dags", params={})
assert response.status_code == 403

def test_get_dags_no_n_plus_one_queries(self, session, test_client):
"""Test that fetching DAGs with tags doesn't trigger n+1 queries."""
num_dags = 5
for i in range(num_dags):
dag_id = f"test_dag_queries_ui_{i}"
dag_model = DagModel(
dag_id=dag_id,
bundle_name="dag_maker",
fileloc=f"/tmp/{dag_id}.py",
is_stale=False,
)
session.add(dag_model)
session.flush()

for j in range(3):
tag = DagTag(name=f"tag_ui_{i}_{j}", dag_id=dag_id)
session.add(tag)

session.commit()
session.expire_all()

with count_queries() as result:
response = test_client.get("/dags", params={"limit": 10})

assert response.status_code == 200
body = response.json()
dags_with_our_prefix = [d for d in body["dags"] if d["dag_id"].startswith("test_dag_queries_ui_")]
assert len(dags_with_our_prefix) == num_dags
for dag in dags_with_our_prefix:
assert len(dag["tags"]) == 3

first_query_count = sum(result.values())

# Add more DAGs and verify query count doesn't scale linearly
for i in range(num_dags, num_dags + 3):
dag_id = f"test_dag_queries_ui_{i}"
dag_model = DagModel(
dag_id=dag_id,
bundle_name="dag_maker",
fileloc=f"/tmp/{dag_id}.py",
is_stale=False,
)
session.add(dag_model)
session.flush()

for j in range(3):
tag = DagTag(name=f"tag_ui_{i}_{j}", dag_id=dag_id)
session.add(tag)

session.commit()
session.expire_all()

with count_queries() as result2:
response = test_client.get("/dags", params={"limit": 15})

assert response.status_code == 200
second_query_count = sum(result2.values())

# With n+1, adding 3 DAGs would add ~3 tag queries
# Without n+1, query count should remain nearly identical
assert second_query_count - first_query_count < 3, (
f"Added 3 DAGs but query count increased by {second_query_count - first_query_count} "
f"({first_query_count} → {second_query_count}), suggesting n+1 queries for tags"
)

@pytest.mark.usefixtures("configure_git_connection_for_dag_bundle")
def test_latest_run_should_return_200(self, test_client):
response = test_client.get(f"/dags/{DAG1_ID}/latest_run")
Expand Down
Loading