Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import warnings
from contextlib import suppress
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any
Expand All @@ -30,7 +31,7 @@
from flask import Blueprint, current_app, g
from flask_appbuilder.const import AUTH_LDAP
from sqlalchemy import select
from sqlalchemy.exc import NoResultFound
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.orm import Session, joinedload

from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
Expand Down Expand Up @@ -290,6 +291,12 @@ def deserialize_user(self, token: dict[str, Any]) -> User:
return self.session.scalars(select(User).where(User.id == int(token["sub"]))).one()
except NoResultFound:
raise ValueError(f"User with id {token['sub']} not found")
except SQLAlchemyError:
# Discard the poisoned scoped session so the next request gets a
# fresh connection from the pool instead of a PendingRollbackError.
with suppress(Exception):
self.session.remove()
raise

def serialize_user(self, user: User) -> dict[str, Any]:
return {"sub": str(user.id)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import pytest
from flask import g
from flask_appbuilder.const import AUTH_DB, AUTH_LDAP
from sqlalchemy.exc import OperationalError, PendingRollbackError

from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
from airflow.api_fastapi.common.types import MenuItem
Expand Down Expand Up @@ -959,6 +960,67 @@ def test_resetdb(
mock_init.assert_called_once()


@pytest.mark.db_test
class TestDeserializeUserSessionCleanup:
"""Test that deserialize_user cleans up the FAB scoped session on database errors.

Problem:
When the database connection drops (e.g., PostgreSQL's
``idle_in_transaction_session_timeout`` fires), the underlying connection
becomes invalid. SQLAlchemy raises ``OperationalError`` on the first request
that hits the dead connection. The scoped session then enters an invalid
state. Any subsequent request that reuses the same thread-local session
raises ``PendingRollbackError`` — permanently breaking the API server until
it is restarted.
"""

@staticmethod
def _patched_session(auth_manager, mock_session):
"""Replace the ``session`` property on *auth_manager* with *mock_session*."""
return mock.patch.object(
type(auth_manager), "session", new_callable=mock.PropertyMock, return_value=mock_session
)

@pytest.mark.parametrize(
"raised_exc",
[
OperationalError("server closed the connection unexpectedly", None, Exception()),
PendingRollbackError(
"Can't reconnect until invalid transaction is rolled back. "
"Please rollback() fully before proceeding"
),
],
ids=["operational_error", "pending_rollback_error"],
)
def test_db_error_calls_session_remove(self, auth_manager_with_appbuilder, raised_exc):
"""session.remove() is called on SQLAlchemy errors so the next request recovers."""
mock_session = MagicMock(spec=["scalars", "remove"])
mock_session.scalars.side_effect = raised_exc
auth_manager_with_appbuilder.cache.pop(99997, None)

with self._patched_session(auth_manager_with_appbuilder, mock_session):
with pytest.raises(type(raised_exc)):
auth_manager_with_appbuilder.deserialize_user({"sub": "99997"})

mock_session.remove.assert_called_once()

def test_db_error_propagates_when_session_remove_raises(self, auth_manager_with_appbuilder):
"""The original SQLAlchemyError propagates even if session.remove() itself raises."""
# Arrange — session.scalars raises the original DB error;
# session.remove raises a secondary error that must be suppressed.
original_exc = OperationalError("connection dropped", None, Exception())
mock_session = MagicMock(spec=["scalars", "remove"])
mock_session.scalars.side_effect = original_exc
mock_session.remove.side_effect = AttributeError("appbuilder gone")
auth_manager_with_appbuilder.cache.pop(99997, None)

with self._patched_session(auth_manager_with_appbuilder, mock_session):
with pytest.raises(OperationalError):
auth_manager_with_appbuilder.deserialize_user({"sub": "99997"})

mock_session.remove.assert_called_once()


class TestFabAuthManagerSessionCleanup:
"""Test session cleanup middleware in FAB auth manager FastAPI app.

Expand Down