Skip to content

Commit 953c467

Browse files
committed
remove session.query from fab/src
1 parent 6c481b0 commit 953c467

File tree

1 file changed

+48
-30
lines changed
  • providers/fab/src/airflow/providers/fab/auth_manager/security_manager

1 file changed

+48
-30
lines changed

providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,8 @@ def reset_user_sessions(self, user: User) -> None:
545545
interface = self.appbuilder.get_app.session_interface
546546
session = interface.db.session
547547
user_session_model = interface.sql_session_model
548-
num_sessions = session.query(user_session_model).count()
548+
# num_sessions = session.query(user_session_model).count()
549+
num_sessions = session.scalars(select(func.count()).select_from(user_session_model)).one()
549550
if num_sessions > MAX_NUM_DATABASE_USER_SESSIONS:
550551
safe_username = escape(user.username)
551552
self._cli_safe_flash(
@@ -560,7 +561,8 @@ def reset_user_sessions(self, user: User) -> None:
560561
"warning",
561562
)
562563
else:
563-
for s in session.query(user_session_model):
564+
# for s in session.query(user_session_model):
565+
for s in session.scalars(user_session_model).all():
564566
session_details = interface.serializer.loads(want_bytes(s.data))
565567
if session_details.get("_user_id") == user.id:
566568
session.delete(s)
@@ -1274,10 +1276,11 @@ def find_role(self, name):
12741276
12751277
:param name: the role name
12761278
"""
1277-
return self.get_session.query(self.role_model).filter_by(name=name).one_or_none()
1279+
return self.get_session.execute(select(self.role_model).filter_by(name=name)).unique().one_or_none()
12781280

12791281
def get_all_roles(self):
1280-
return self.get_session.query(self.role_model).all()
1282+
# return self.get_session.query(self.role_model).all()
1283+
return self.get_session.scalars(select(self.role_model)).all()
12811284

12821285
def delete_role(self, role_name: str) -> None:
12831286
"""
@@ -1286,7 +1289,8 @@ def delete_role(self, role_name: str) -> None:
12861289
:param role_name: the name of a role in the ab_role table
12871290
"""
12881291
session = self.get_session
1289-
role = session.query(Role).filter(Role.name == role_name).first()
1292+
# role = session.query(Role).filter(Role.name == role_name).first()
1293+
role = session.execute(select(Role).where(Role.name == role_name)).first()
12901294
if role:
12911295
log.info("Deleting role '%s'", role_name)
12921296
session.delete(role)
@@ -1320,7 +1324,10 @@ def get_roles_from_keys(self, role_keys: list[str]) -> set[Role]:
13201324
return _roles
13211325

13221326
def get_public_role(self):
1323-
return self.get_session.query(self.role_model).filter_by(name=self.auth_role_public).one_or_none()
1327+
# return self.get_session.query(self.role_model).filter_by(name=self.auth_role_public).one_or_none()
1328+
return self.get_session.execute(
1329+
select(self.role_model).filter_by(name=self.auth_role_public)
1330+
).one_or_none()
13241331

13251332
"""
13261333
-----------
@@ -1377,7 +1384,8 @@ def get_user_by_id(self, pk):
13771384

13781385
def count_users(self):
13791386
"""Return the number of users in the database."""
1380-
return self.get_session.query(func.count(self.user_model.id)).scalar()
1387+
# return self.get_session.query(func.count(self.user_model.id)).scalar()
1388+
return self.get_session.execute(select(func.count(self.user_model.id))).scalar()
13811389

13821390
def add_register_user(self, username, first_name, last_name, email, password="", hashed_password=""):
13831391
"""
@@ -1409,22 +1417,32 @@ def find_user(self, username=None, email=None):
14091417
if username:
14101418
try:
14111419
if self.auth_username_ci:
1412-
return (
1420+
""" return (
14131421
self.get_session.query(self.user_model)
14141422
.filter(func.lower(self.user_model.username) == func.lower(username))
14151423
.one_or_none()
1416-
)
1417-
return (
1424+
) """
1425+
return self.get_session.execute(
1426+
select(self.user_model).where(
1427+
func.lower(self.user_model.username) == func.lower(username)
1428+
)
1429+
).one_or_none()
1430+
""" return (
14181431
self.get_session.query(self.user_model)
14191432
.filter(func.lower(self.user_model.username) == func.lower(username))
14201433
.one_or_none()
1421-
)
1434+
) """
1435+
return self.get_session.execute(
1436+
select(self.user_model).where(
1437+
func.lower(self.user_model.username) == func.lower(username)
1438+
)
1439+
).one_or_none()
14221440
except MultipleResultsFound:
14231441
log.error("Multiple results found for user %s", username)
14241442
return None
14251443
elif email:
14261444
try:
1427-
return self.get_session.query(self.user_model).filter_by(email=email).one_or_none()
1445+
return self.get_session.execute(select(self.user_model).filter_by(email=email)).one_or_none()
14281446
except MultipleResultsFound:
14291447
log.error("Multiple results found for user with email %s", email)
14301448
return None
@@ -1456,7 +1474,7 @@ def del_register_user(self, register_user):
14561474
return False
14571475

14581476
def get_all_users(self):
1459-
return self.get_session.query(self.user_model).all()
1477+
return self.get_session.scalars(select(self.user_model)).all()
14601478

14611479
def update_user_auth_stat(self, user, success=True):
14621480
"""
@@ -1496,7 +1514,7 @@ def get_action(self, name: str) -> Action:
14961514
14971515
:param name: name
14981516
"""
1499-
return self.get_session.query(self.action_model).filter_by(name=name).one_or_none()
1517+
return self.get_session.execute(select(self.action_model).filter_by(name=name)).one_or_none()
15001518

15011519
def create_action(self, name):
15021520
"""
@@ -1529,11 +1547,9 @@ def delete_action(self, name: str) -> bool:
15291547
log.warning(const.LOGMSG_WAR_SEC_DEL_PERMISSION, name)
15301548
return False
15311549
try:
1532-
perms = (
1533-
self.get_session.query(self.permission_model)
1534-
.filter(self.permission_model.action == action)
1535-
.all()
1536-
)
1550+
perms = self.get_session.scalars(
1551+
select(self.permission_model).where(self.permission_model.action == action)
1552+
).all()
15371553
if perms:
15381554
log.warning(const.LOGMSG_WAR_SEC_DEL_PERM_PVM, action, perms)
15391555
return False
@@ -1557,7 +1573,7 @@ def get_resource(self, name: str) -> Resource | None:
15571573
15581574
:param name: Name of resource
15591575
"""
1560-
return self.get_session.query(self.resource_model).filter_by(name=name).one_or_none()
1576+
return self.get_session.execute(select(self.resource_model).filter_by(name=name)).one_or_none()
15611577

15621578
def create_resource(self, name) -> Resource | None:
15631579
"""
@@ -1598,11 +1614,9 @@ def get_permission(
15981614
action = self.get_action(action_name)
15991615
resource = self.get_resource(resource_name)
16001616
if action and resource:
1601-
return (
1602-
self.get_session.query(self.permission_model)
1603-
.filter_by(action=action, resource=resource)
1604-
.one_or_none()
1605-
)
1617+
return self.get_session.execute(
1618+
select(self.permission_model).filter_by(action=action, resource=resource)
1619+
).one_or_none()
16061620
return None
16071621

16081622
def get_resource_permissions(self, resource: Resource) -> Permission:
@@ -1611,7 +1625,9 @@ def get_resource_permissions(self, resource: Resource) -> Permission:
16111625
16121626
:param resource: Object representing a single resource.
16131627
"""
1614-
return self.get_session.query(self.permission_model).filter_by(resource_id=resource.id).all()
1628+
return self.get_session.scalars(
1629+
select(self.permission_model).filter_by(resource_id=resource.id)
1630+
).all()
16151631

16161632
def create_permission(self, action_name, resource_name) -> Permission | None:
16171633
"""
@@ -1658,9 +1674,9 @@ def delete_permission(self, action_name: str, resource_name: str) -> None:
16581674
perm = self.get_permission(action_name, resource_name)
16591675
if not perm:
16601676
return
1661-
roles = (
1662-
self.get_session.query(self.role_model).filter(self.role_model.permissions.contains(perm)).first()
1663-
)
1677+
roles = self.get_session.execute(
1678+
select(self.role_model).where(self.role_model.permissions.contains(perm))
1679+
).first()
16641680
if roles:
16651681
log.warning(const.LOGMSG_WAR_SEC_DEL_PERMVIEW, resource_name, action_name, roles)
16661682
return
@@ -1669,7 +1685,9 @@ def delete_permission(self, action_name: str, resource_name: str) -> None:
16691685
self.get_session.delete(perm)
16701686
self.get_session.commit()
16711687
# if no more permission on permission view, delete permission
1672-
if not self.get_session.query(self.permission_model).filter_by(action=perm.action).all():
1688+
if not self.get_session.scalars(
1689+
select(self.permission_model).filter_by(action=perm.action)
1690+
).all():
16731691
self.delete_action(perm.action.name)
16741692
log.info(const.LOGMSG_INF_SEC_DEL_PERMVIEW, action_name, resource_name)
16751693
except Exception as e:

0 commit comments

Comments
 (0)