Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions airflow-core/src/airflow/api_fastapi/common/db/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,30 @@
from sqlalchemy.sql import Select


def generate_dag_with_latest_run_query(max_run_filters: list[BaseParam], order_by: SortParam) -> Select:
def generate_dag_with_latest_run_query(
max_run_filters: list[BaseParam], order_by: SortParam, *, dag_ids: set[str] | None = None
) -> Select:
"""
Generate a query to fetch DAGs with their latest run.

:param max_run_filters: List of filters to apply to the latest run
:param order_by: Sort parameter for ordering results
:param dag_ids: Optional set of DAG IDs to limit the query to. When provided, both the main
DAG query and the subquery for finding the latest runs will be filtered to
only these DAG IDs, improving performance when users have limited DAG access.
:return: SQLAlchemy Select statement
"""
query = select(DagModel).options(selectinload(DagModel.tags))

max_run_id_query = ( # ordering by id will not always be "latest run", but it's a simplifying assumption
select(DagRun.dag_id, func.max(DagRun.id).label("max_dag_run_id"))
.group_by(DagRun.dag_id)
.subquery(name="mrq")
)
# Filter main query by dag_ids if provided
if dag_ids is not None:
query = query.where(DagModel.dag_id.in_(dag_ids or set()))

# Also filter the subquery for finding latest runs
max_run_id_query_stmt = select(DagRun.dag_id, func.max(DagRun.id).label("max_dag_run_id"))
if dag_ids is not None:
max_run_id_query_stmt = max_run_id_query_stmt.where(DagRun.dag_id.in_(dag_ids or set()))
max_run_id_query = max_run_id_query_stmt.group_by(DagRun.dag_id).subquery(name="mrq")

has_max_run_filter = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def get_dags(
last_dag_run_state,
],
order_by=order_by,
dag_ids=readable_dags_filter.value,
)

dags_select, total_entries = paginated_select(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def get_dags(
last_dag_run_state,
],
order_by=order_by,
dag_ids=readable_dags_filter.value,
)

dags_select, total_entries = paginated_select(
Expand Down
69 changes: 69 additions & 0 deletions airflow-core/tests/unit/api_fastapi/common/db/test_dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,72 @@ def test_queued_runs_with_null_start_date_are_properly_joined(
"This suggests the WHERE start_date IS NOT NULL condition is excluding it."
)
assert running_dagrun_state is not None, "Running DAG should have DagRun state joined"

@pytest.mark.usefixtures("testing_dag_bundle")
def test_filters_by_dag_ids_when_provided(self, session):
"""
Verify that when dag_ids is provided, only those DAGs and their runs are queried.

This is a performance optimization: both the main DAG query and the DagRun subquery
should only process accessible DAGs when the user has limited access.
"""
dag_ids = ["dag_accessible_1", "dag_accessible_2", "dag_inaccessible_3"]

for dag_id in dag_ids:
dag_model = DagModel(
dag_id=dag_id,
bundle_name="testing",
is_stale=False,
is_paused=False,
fileloc=f"/tmp/{dag_id}.py",
)
session.add(dag_model)
session.flush()

# Create 2 runs for each DAG
for run_idx in range(2):
dagrun = DagRun(
dag_id=dag_id,
run_id=f"manual__{run_idx}",
run_type="manual",
logical_date=datetime(2024, 1, 1 + run_idx, tzinfo=timezone.utc),
state=DagRunState.SUCCESS,
start_date=datetime(2024, 1, 1 + run_idx, 1, tzinfo=timezone.utc),
)
session.add(dagrun)
session.commit()

# User has access to only 2 DAGs
accessible_dag_ids = {"dag_accessible_1", "dag_accessible_2"}

# Query with dag_ids filter
query_filtered = generate_dag_with_latest_run_query(
max_run_filters=[],
order_by=SortParam(allowed_attrs=["last_run_state"], model=DagModel).set_value(
["last_run_state"]
),
dag_ids=accessible_dag_ids,
)

# Query without dag_ids filter
query_unfiltered = generate_dag_with_latest_run_query(
max_run_filters=[],
order_by=SortParam(allowed_attrs=["last_run_state"], model=DagModel).set_value(
["last_run_state"]
),
)

result_filtered = session.execute(query_filtered.add_columns(DagRun.state)).fetchall()
result_unfiltered = session.execute(query_unfiltered.add_columns(DagRun.state)).fetchall()

# Filtered query should only return accessible DAGs
filtered_dag_ids = {row[0].dag_id for row in result_filtered}
assert filtered_dag_ids == accessible_dag_ids

# Unfiltered query returns all DAGs
unfiltered_dag_ids = {row[0].dag_id for row in result_unfiltered}
assert unfiltered_dag_ids == set(dag_ids)

# All accessible DAGs should have DagRun info
filtered_dags_with_runs = {row[0].dag_id for row in result_filtered if row[1] is not None}
assert filtered_dags_with_runs == accessible_dag_ids
Loading