Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Annotated

from fastapi import Depends, HTTPException, status
from sqlalchemy import and_, select
from sqlalchemy import and_, or_, select

from airflow.api_fastapi.app import get_auth_manager
from airflow.api_fastapi.auth.managers.models.batch_apis import IsAuthorizedDagRequest
Expand Down Expand Up @@ -84,6 +84,11 @@ def get_import_error(
file_dag_ids = set(
session.scalars(select(DagModel.dag_id).where(DagModel.fileloc == error.filename)).all()
)

# No DAGs in the file (failed to parse), nothing to check permissions against
if not file_dag_ids:
return error

# Can the user read any DAGs in the file?
if not readable_dag_ids.intersection(file_dag_ids):
raise HTTPException(
Expand Down Expand Up @@ -129,7 +134,11 @@ def get_import_errors(
"""Get all import errors."""
auth_manager = get_auth_manager()
readable_dag_ids = auth_manager.get_authorized_dag_ids(method="GET", user=user)
# Build a cte that fetches dag_ids for each file location

# Subquery for files that have any DAGs
files_with_any_dags = select(DagModel.relative_fileloc).distinct().subquery()

# CTE for DAGs the user can read
visible_files_cte = (
select(DagModel.relative_fileloc, DagModel.dag_id, DagModel.bundle_name)
.where(DagModel.dag_id.in_(readable_dag_ids))
Expand All @@ -140,13 +149,23 @@ def get_import_errors(
# Each returned row will be a tuple: (ParseImportError, dag_id)
import_errors_stmt = (
select(ParseImportError, visible_files_cte.c.dag_id)
.join(
.outerjoin(
files_with_any_dags,
ParseImportError.filename == files_with_any_dags.c.relative_fileloc,
)
.outerjoin(
visible_files_cte,
and_(
ParseImportError.filename == visible_files_cte.c.relative_fileloc,
ParseImportError.bundle_name == visible_files_cte.c.bundle_name,
),
)
.where(
or_(
files_with_any_dags.c.relative_fileloc.is_(None),
visible_files_cte.c.dag_id.isnot(None),
)
)
.order_by(ParseImportError.id)
)

Expand All @@ -164,8 +183,14 @@ def get_import_errors(
)

import_errors = []
for import_error, file_dag_ids in import_errors_result:
dag_ids = [dag_id for _, dag_id in file_dag_ids]
for import_error, file_dag_ids_iter in import_errors_result:
dag_ids = [dag_id for _, dag_id in file_dag_ids_iter if dag_id is not None]

# No DAGs in the file, nothing to check permissions against
if not dag_ids:
import_errors.append(import_error)
continue

dag_id_to_team = DagModel.get_dag_id_to_team_name_mapping(dag_ids, session=session)
# Check if user has read access to all the DAGs defined in the file
requests: Sequence[IsAuthorizedDagRequest] = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def test_should_raises_403_unauthorized(self, unauthorized_test_client, import_e
response = unauthorized_test_client.get(f"/importErrors/{import_error_id}")
assert response.status_code == 403

@pytest.mark.usefixtures("not_permitted_dag_model")
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
def test_should_raises_403_unauthorized__user_can_not_read_any_dags_in_file(
self, mock_get_auth_manager, test_client, import_errors
Expand Down Expand Up @@ -273,6 +274,23 @@ def test_get_import_error__user_dont_have_read_permission_to_read_all_dags_in_fi
"bundle_name": BUNDLE_NAME,
}

@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
def test_get_import_error__no_dag_in_dagmodel(self, mock_get_auth_manager, test_client, import_errors):
"""Test import error is returned when no DAG exists in DagModel."""
import_error_id = import_errors[0].id
set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager, set())

response = test_client.get(f"/importErrors/{import_error_id}")

assert response.status_code == 200
assert response.json() == {
"import_error_id": import_error_id,
"timestamp": from_datetime_to_zulu_without_ms(TIMESTAMP1),
"filename": FILENAME1,
"stack_trace": STACKTRACE1,
"bundle_name": BUNDLE_NAME,
}


class TestGetImportErrors:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -405,7 +423,6 @@ def test_should_raises_403_unauthorized(self, unauthorized_test_client):
),
],
)
@pytest.mark.usefixtures("permitted_dag_model")
@mock.patch.object(DagModel, "get_dag_id_to_team_name_mapping")
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
def test_user_can_not_read_all_dags_in_file(
Expand All @@ -416,19 +433,19 @@ def test_user_can_not_read_all_dags_in_file(
team,
batch_is_authorized_dag_return_value,
expected_stack_trace,
permitted_dag_model,
permitted_dag_model_all,
import_errors,
):
mock_get_dag_id_to_team_name_mapping.return_value = {permitted_dag_model.dag_id: team}
dag_id1 = "dag_id1"
mock_get_dag_id_to_team_name_mapping.return_value = {dag_id1: team}
mock_get_authorized_dag_ids = set_mock_auth_manager__get_authorized_dag_ids(
mock_get_auth_manager, {permitted_dag_model.dag_id}
mock_get_auth_manager, {dag_id1}
)
mock_batch_is_authorized_dag = set_mock_auth_manager__batch_is_authorized_dag(
mock_get_auth_manager, batch_is_authorized_dag_return_value
)
# Act
with assert_queries_count(3):
response = test_client.get("/importErrors")
response = test_client.get("/importErrors")
# Assert
mock_get_authorized_dag_ids.assert_called_once_with(method="GET", user=mock.ANY)
assert response.status_code == 200
Expand All @@ -449,20 +466,20 @@ def test_user_can_not_read_all_dags_in_file(
[
{
"method": "GET",
"details": DagDetails(id=permitted_dag_model.dag_id, team_name=team),
"details": DagDetails(id=dag_id1, team_name=team),
}
],
user=mock.ANY,
)

@pytest.mark.usefixtures("permitted_dag_model")
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
def test_bundle_name_join_condition_for_import_errors(
self, mock_get_auth_manager, test_client, permitted_dag_model, import_errors, session
self, mock_get_auth_manager, test_client, permitted_dag_model_all, import_errors, session
):
"""Test that the bundle_name join condition works correctly."""
dag_id1 = "dag_id1"
mock_get_authorized_dag_ids = set_mock_auth_manager__get_authorized_dag_ids(
mock_get_auth_manager, {permitted_dag_model.dag_id}
mock_get_auth_manager, {dag_id1}
)
set_mock_auth_manager__batch_is_authorized_dag(mock_get_auth_manager, True)

Expand All @@ -479,10 +496,11 @@ def test_bundle_name_join_condition_for_import_errors(
assert response_json["import_errors"][0]["filename"] == FILENAME1

# Now test that removing the bundle_name from the DagModel causes the import error to not be returned
permitted_dag_model.bundle_name = "another_bundle_name"
session.add(DagBundleModel(name="another_bundle_name"))
session.flush()
session.merge(permitted_dag_model)
dag_model1 = session.get(DagModel, dag_id1)
dag_model1.bundle_name = "another_bundle_name"
session.merge(dag_model1)
session.commit()

response2 = test_client.get("/importErrors")
Expand All @@ -492,3 +510,18 @@ def test_bundle_name_join_condition_for_import_errors(
response_json2 = response2.json()
assert response_json2["total_entries"] == 0
assert response_json2["import_errors"] == []

@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
def test_get_import_errors__no_dag_in_dagmodel(self, mock_get_auth_manager, test_client, import_errors):
"""Test import errors are returned when no DAG exists in DagModel."""
set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager, set())

response = test_client.get("/importErrors")

assert response.status_code == 200
response_json = response.json()
assert response_json["total_entries"] == 3
filenames = [error["filename"] for error in response_json["import_errors"]]
assert FILENAME1 in filenames
assert FILENAME2 in filenames
assert FILENAME3 in filenames
Loading