Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ async def dispatch(self, request: Request, call_next):
current_token = request.cookies.get(COOKIE_NAME_JWT_TOKEN)
try:
if current_token:
new_user = await self._refresh_user(current_token)
if new_user:
request.state.user = new_user
new_user, current_user = await self._refresh_user(current_token)
if user := (new_user or current_user):
request.state.user = user

response = await call_next(request)

Expand All @@ -67,9 +67,10 @@ async def dispatch(self, request: Request, call_next):
return response

@staticmethod
async def _refresh_user(current_token: str) -> BaseUser | None:
async def _refresh_user(current_token: str) -> tuple[BaseUser | None, BaseUser | None]:
try:
user = await resolve_user_from_token(current_token)
except HTTPException:
return None
return get_auth_manager().refresh_user(user=user)
return None, None

return get_auth_manager().refresh_user(user=user), user
1 change: 1 addition & 0 deletions devel-common/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ dependencies = [
"types-setuptools>=80.0.0.20250429",
"types-tabulate>=0.9.0.20240106",
"types-toml>=0.10.8.20240310",
"types-cachetools>=6.2.0.20251022",
]
"pytest" = [
# General pytest devel tools
Expand Down
1 change: 1 addition & 0 deletions providers/fab/docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ PIP package Version required
``jmespath`` ``>=0.7.0; python_version < "3.13"``
``werkzeug`` ``>=2.2,<4; python_version < "3.13"``
``wtforms`` ``>=3.0,<4; python_version < "3.13"``
``cachetools`` ``>=6.0; python_version < "3.13"``
``flask_limiter`` ``>3,!=3.13,<4``
========================================== ==========================================

Expand Down
8 changes: 8 additions & 0 deletions providers/fab/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,14 @@ config:
type: integer
example: ~
default: "1"
cache_ttl:
description: |
Number of seconds after which the user cache will expire to refetch updated user and
permissions.
version_added: 3.2.0
type: integer
example: ~
default: "30"

auth-managers:
- airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager
Expand Down
1 change: 1 addition & 0 deletions providers/fab/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ dependencies = [
"jmespath>=0.7.0; python_version < '3.13'",
"werkzeug>=2.2,<4; python_version < '3.13'",
"wtforms>=3.0,<4; python_version < '3.13'",
"cachetools>=6.0; python_version < '3.13'",

# https://github.com/dpgaspar/Flask-AppBuilder/blob/release/4.6.3/setup.py#L54C8-L54C26
# with an exclusion to account for https://github.com/alisaifee/flask-limiter/issues/479
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import TYPE_CHECKING, Any
from urllib.parse import urljoin

from cachetools import TTLCache, cachedmethod
from connexion import FlaskApi
from fastapi import FastAPI
from fastapi.middleware.wsgi import WSGIMiddleware
Expand Down Expand Up @@ -94,7 +95,7 @@
get_fab_action_from_method_map,
get_method_from_fab_action_map,
)
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.yaml import safe_load

if TYPE_CHECKING:
Expand Down Expand Up @@ -161,6 +162,7 @@
MenuItem.XCOMS: RESOURCE_XCOM,
}

CACHE_TTL = conf.getint("fab", "cache_ttl", fallback=30)

if AIRFLOW_V_3_1_PLUS:
from airflow.providers.fab.www.security.permissions import RESOURCE_HITL_DETAIL
Expand All @@ -176,6 +178,7 @@ class FabAuthManager(BaseAuthManager[User]):
This auth manager is responsible for providing a backward compatible user management experience to users.
"""

cache: TTLCache = TTLCache(maxsize=1024, ttl=CACHE_TTL)
appbuilder: AirflowAppBuilder | None = None

def init_flask_resources(self) -> None:
Expand Down Expand Up @@ -253,9 +256,13 @@ def get_user(self) -> User:

return current_user

@property
def session(self):
return self.appbuilder.session

@cachedmethod(lambda self: self.cache, key=lambda _, token: int(token["sub"]))
def deserialize_user(self, token: dict[str, Any]) -> User:
with create_session() as session:
return session.scalars(select(User).where(User.id == int(token["sub"]))).one()
return self.session.scalars(select(User).where(User.id == int(token["sub"]))).one()

def serialize_user(self, user: User) -> dict[str, Any]:
return {"sub": str(user.id)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,13 @@ def get_provider_info():
"example": None,
"default": "1",
},
"cache_ttl": {
"description": "Number of seconds after which the user cache will expire to refetch updated user and\npermissions.\n",
"version_added": "3.2.0",
"type": "integer",
"example": None,
"default": "30",
},
},
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import time
from contextlib import contextmanager, suppress
from itertools import chain
from typing import TYPE_CHECKING
Expand All @@ -34,6 +35,7 @@
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.utils.db import resetdb

from tests_common.test_utils.asserts import assert_queries_count
from tests_common.test_utils.config import conf_vars
from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user

Expand Down Expand Up @@ -197,9 +199,24 @@ def test_get_user_from_flask_g(self, mock_current_user, minimal_app_for_auth_api
with user_set(minimal_app_for_auth_api, flask_g_user):
assert auth_manager.get_user() == flask_g_user

@conf_vars({("fab", "cache_ttl"): "1"})
def test_deserialize_user(self, flask_app, auth_manager_with_appbuilder):
"""Test user objects are cached and that the cache expires after configured TTL."""
user = create_user(flask_app, "test")
result = auth_manager_with_appbuilder.deserialize_user({"sub": str(user.id)})
with assert_queries_count(2):
result = auth_manager_with_appbuilder.deserialize_user({"sub": str(user.id)})

assert user.get_id() == result.get_id()

with assert_queries_count(0):
result = auth_manager_with_appbuilder.deserialize_user({"sub": str(user.id)})

assert user.get_id() == result.get_id()

time.sleep(1)
with assert_queries_count(2):
result = auth_manager_with_appbuilder.deserialize_user({"sub": str(user.id)})

assert user.get_id() == result.get_id()

def test_serialize_user(self, flask_app, auth_manager_with_appbuilder):
Expand Down
Loading