Skip to content

Commit 97a6934

Browse files
fix: refactor user query logic and add general approval_status filter
1 parent 27d9431 commit 97a6934

File tree

1 file changed

+96
-135
lines changed

1 file changed

+96
-135
lines changed

routers/admin.py

Lines changed: 96 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -136,22 +136,6 @@ def strip_reason(cls, value: str | None) -> str | None:
136136
def _membership_response() -> dict[str, object]:
137137
return {"status": "ok", "updated": True}
138138

139-
def _get_allowed_resource_subqueries(admin_roles: list[str]):
140-
"""
141-
Return subqueries for platform/group IDs the admin has access to.
142-
"""
143-
allowed_platforms_subquery = (
144-
select(Platform.id)
145-
.join(Platform.admin_roles)
146-
.where(Auth0Role.name.in_(admin_roles))
147-
)
148-
allowed_groups_subquery = (
149-
select(BiocommonsGroup.group_id)
150-
.join(BiocommonsGroup.admin_roles)
151-
.where(Auth0Role.name.in_(admin_roles))
152-
)
153-
return allowed_platforms_subquery, allowed_groups_subquery
154-
155139

156140
def _approve_platform_membership(
157141
*,
@@ -297,6 +281,10 @@ class UserQueryParams(BaseModel):
297281
298282
Each field listed here must have a {field}_query method defined.
299283
"""
284+
approval_status: ApprovalStatusEnum | None = Field(
285+
None,
286+
description="Filter by approval status across platforms and groups",
287+
)
300288
platform: PlatformEnum | None = Field(None, description="Filter by platform")
301289
platform_approval_status: ApprovalStatusEnum | None = Field(None, description="Filter by platform approval status")
302290
group: GroupEnum | None = Field(None, description="Filter by group")
@@ -307,6 +295,8 @@ class UserQueryParams(BaseModel):
307295
description="Filter users by group ('tsi',) or platform ('galaxy', 'bpa_data_portal')"
308296
)
309297
search: str | None = Field(None, description="Search users by username or email")
298+
_allowed_platforms_subquery: SelectOfScalar[Platform] | None = None
299+
_allowed_groups_subquery: SelectOfScalar[BiocommonsGroup] | None = None
310300

311301
def _fields(self):
312302
return (name for name in self.__pydantic_fields__.keys()
@@ -319,6 +309,11 @@ def model_post_init(self, context: Any) -> None:
319309
for field_name in self._fields():
320310
if not hasattr(self, f"{field_name}_query"):
321311
raise NotImplementedError(f"Missing query method for field '{field_name}'")
312+
if self.approval_status and (self.platform_approval_status or self.group_approval_status):
313+
raise HTTPException(
314+
status_code=400,
315+
detail="approval_status cannot be used with platform_approval_status or group_approval_status",
316+
)
322317

323318
def get_base_query(self):
324319
"""
@@ -328,12 +323,36 @@ def get_base_query(self):
328323
select(BiocommonsUser)
329324
)
330325

326+
def _set_allowed_resource_subqueries(self, admin_roles: list[str]) -> None:
327+
"""
328+
Cache allowed platform/group subqueries for reuse within this instance.
329+
"""
330+
allowed_platforms_subquery, allowed_groups_subquery = self.get_allowed_resource_subqueries(admin_roles)
331+
self._allowed_platforms_subquery = allowed_platforms_subquery
332+
self._allowed_groups_subquery = allowed_groups_subquery
333+
334+
def get_allowed_resource_subqueries(self, admin_roles: list[str]):
335+
"""
336+
Return subqueries for platform/group IDs the admin has access to.
337+
"""
338+
allowed_platforms_subquery = (
339+
select(Platform.id)
340+
.join(Platform.admin_roles)
341+
.where(Auth0Role.name.in_(admin_roles))
342+
)
343+
allowed_groups_subquery = (
344+
select(BiocommonsGroup.group_id)
345+
.join(BiocommonsGroup.admin_roles)
346+
.where(Auth0Role.name.in_(admin_roles))
347+
)
348+
return allowed_platforms_subquery, allowed_groups_subquery
349+
331350
def get_admin_permissions_query(self, admin_roles: list[str]):
332351
"""
333352
Get the query for only returning users the admin has permission to view/manage,
334353
based on group/platform roles
335354
"""
336-
allowed_platforms_subquery, allowed_groups_subquery = _get_allowed_resource_subqueries(admin_roles)
355+
allowed_platforms_subquery, allowed_groups_subquery = self.get_allowed_resource_subqueries(admin_roles)
337356
platform_access_condition = BiocommonsUser.id.in_(
338357
select(PlatformMembership.user_id).where(
339358
PlatformMembership.platform_id.in_(allowed_platforms_subquery)
@@ -350,17 +369,18 @@ def get_complete_query(self, admin_roles: list[str], pagination: PaginationParam
350369
"""
351370
Return a full user query, with permissions from admin roles applied
352371
"""
372+
self._set_allowed_resource_subqueries(admin_roles)
353373
return (
354374
self.get_base_query()
355375
.where(
356376
self.get_admin_permissions_query(admin_roles),
357-
*self.get_query_conditions())
377+
*self.get_query_conditions(admin_roles))
358378
.distinct()
359379
.offset(pagination.start_index)
360380
.limit(pagination.per_page)
361381
)
362382

363-
def get_query_conditions(self):
383+
def get_query_conditions(self, admin_roles: list[str] | None = None):
364384
"""
365385
Returns a list of SQLAlchemy queries for the filters that have been set.
366386
The queries can be passed to where().
@@ -373,7 +393,10 @@ def get_query_conditions(self):
373393
if field_value is not None:
374394
method_name = f"{field_name}_query"
375395
query_method = getattr(self, method_name)
376-
condition = query_method()
396+
try:
397+
condition = query_method(admin_roles) if admin_roles is not None else query_method()
398+
except TypeError:
399+
condition = query_method()
377400
# conditions may be None for interacting queries like platform
378401
# and platform_approval_status
379402
if condition is not None:
@@ -421,6 +444,10 @@ def platform_approval_status_query(self):
421444
platform_status_query = select(PlatformMembership.user_id).where(
422445
PlatformMembership.approval_status == self.platform_approval_status
423446
)
447+
if self._allowed_platforms_subquery is not None:
448+
platform_status_query = platform_status_query.where(
449+
PlatformMembership.platform_id.in_(self._allowed_platforms_subquery)
450+
)
424451
return BiocommonsUser.id.in_(platform_status_query)
425452

426453
def group_query(self):
@@ -435,6 +462,27 @@ def group_approval_status_query(self):
435462
)
436463
return BiocommonsUser.id.in_(group_status_query)
437464

465+
def approval_status_query(self, admin_roles: list[str] | None = None):
466+
"""
467+
Filter by approval status across platforms and groups.
468+
"""
469+
if self._allowed_platforms_subquery is None or self._allowed_groups_subquery is None:
470+
if admin_roles is None:
471+
raise ValueError("Allowed resource subqueries must be set before calling approval_status_query")
472+
self._set_allowed_resource_subqueries(admin_roles)
473+
474+
platform_status_query = select(PlatformMembership.user_id).where(
475+
PlatformMembership.platform_id.in_(self._allowed_platforms_subquery),
476+
PlatformMembership.approval_status == self.approval_status,
477+
)
478+
group_status_query = select(GroupMembership.user_id).where(
479+
GroupMembership.approval_status == self.approval_status,
480+
)
481+
return or_(
482+
BiocommonsUser.id.in_(platform_status_query),
483+
BiocommonsUser.id.in_(group_status_query),
484+
)
485+
438486
def email_verified_query(self):
439487
return BiocommonsUser.email_verified.is_(self.email_verified)
440488

@@ -485,96 +533,6 @@ def get_filtered_user_query(
485533
return user_query.get_complete_query(admin_roles, pagination)
486534

487535

488-
def _count_users_with_membership_status(
489-
*,
490-
db_session: Session,
491-
admin_roles: list[str],
492-
base_params: dict[str, object],
493-
status: ApprovalStatusEnum,
494-
) -> int:
495-
"""
496-
Count distinct users who have either a platform or group membership
497-
with the given status, limited to resources the admin can manage.
498-
"""
499-
# Remove status filters so we can apply OR logic below
500-
params_data = {**base_params, "platform_approval_status": None, "group_approval_status": None}
501-
params = UserQueryParams(**params_data)
502-
allowed_platforms_subquery, allowed_groups_subquery = _get_allowed_resource_subqueries(admin_roles)
503-
504-
base_conditions = [
505-
params.get_admin_permissions_query(admin_roles),
506-
*params.get_query_conditions(),
507-
]
508-
509-
platform_status_query = select(PlatformMembership.user_id).where(
510-
PlatformMembership.platform_id.in_(allowed_platforms_subquery),
511-
PlatformMembership.approval_status == status,
512-
)
513-
group_status_query = select(GroupMembership.user_id).where(
514-
GroupMembership.group_id.in_(allowed_groups_subquery),
515-
GroupMembership.approval_status == status,
516-
)
517-
518-
status_condition = or_(
519-
BiocommonsUser.id.in_(platform_status_query),
520-
BiocommonsUser.id.in_(group_status_query),
521-
)
522-
523-
query = (
524-
params.get_base_query()
525-
.where(status_condition, *base_conditions)
526-
.distinct()
527-
)
528-
count_statement = select(func.count()).select_from(query.subquery())
529-
return db_session.exec(count_statement).one()
530-
531-
532-
def _get_users_with_membership_status(
533-
*,
534-
db_session: Session,
535-
admin_roles: list[str],
536-
query_params: UserQueryParams,
537-
status: ApprovalStatusEnum,
538-
pagination: PaginationParams,
539-
) -> list[BiocommonsUser]:
540-
"""
541-
Return users who have either a platform or group membership with the given status,
542-
limited to resources the admin can manage.
543-
"""
544-
params = UserQueryParams(
545-
**{**query_params.model_dump(), "platform_approval_status": None, "group_approval_status": None}
546-
)
547-
allowed_platforms_subquery, allowed_groups_subquery = _get_allowed_resource_subqueries(admin_roles)
548-
549-
base_conditions = [
550-
params.get_admin_permissions_query(admin_roles),
551-
*params.get_query_conditions(),
552-
]
553-
554-
platform_status_query = select(PlatformMembership.user_id).where(
555-
PlatformMembership.platform_id.in_(allowed_platforms_subquery),
556-
PlatformMembership.approval_status == status,
557-
)
558-
group_status_query = select(GroupMembership.user_id).where(
559-
GroupMembership.group_id.in_(allowed_groups_subquery),
560-
GroupMembership.approval_status == status,
561-
)
562-
563-
status_condition = or_(
564-
BiocommonsUser.id.in_(platform_status_query),
565-
BiocommonsUser.id.in_(group_status_query),
566-
)
567-
568-
query = (
569-
params.get_base_query()
570-
.where(status_condition, *base_conditions)
571-
.distinct()
572-
.offset(pagination.start_index)
573-
.limit(pagination.per_page)
574-
)
575-
return db_session.exec(query).all()
576-
577-
578536
def _get_user_count(
579537
*,
580538
db_session: Session,
@@ -584,11 +542,12 @@ def _get_user_count(
584542
"""
585543
Count distinct users matching the provided query parameters and admin permissions.
586544
"""
545+
query_params._set_allowed_resource_subqueries(admin_roles)
587546
query = (
588547
query_params.get_base_query()
589548
.where(
590549
query_params.get_admin_permissions_query(admin_roles),
591-
*query_params.get_query_conditions(),
550+
*query_params.get_query_conditions(admin_roles),
592551
)
593552
.distinct()
594553
)
@@ -641,18 +600,16 @@ def count_with(overrides: dict[str, object] | None = None) -> int:
641600

642601
return UserCountsResponse(
643602
all=count_with(),
644-
pending=_count_users_with_membership_status(
645-
db_session=db_session,
646-
admin_roles=admin_roles,
647-
base_params=base_params,
648-
status=ApprovalStatusEnum.PENDING,
649-
),
650-
revoked=_count_users_with_membership_status(
651-
db_session=db_session,
652-
admin_roles=admin_roles,
653-
base_params=base_params,
654-
status=ApprovalStatusEnum.REVOKED,
655-
),
603+
pending=count_with({
604+
"approval_status": ApprovalStatusEnum.PENDING,
605+
"platform_approval_status": None,
606+
"group_approval_status": None,
607+
}),
608+
revoked=count_with({
609+
"approval_status": ApprovalStatusEnum.REVOKED,
610+
"platform_approval_status": None,
611+
"group_approval_status": None,
612+
}),
656613
unverified=count_with({"email_verified": False}),
657614
)
658615

@@ -678,14 +635,16 @@ def get_approved_users(db_session: Annotated[Session, Depends(get_db_session)],
678635
def get_pending_users(db_session: Annotated[Session, Depends(get_db_session)],
679636
admin_user: Annotated[SessionUser, Depends(get_session_user)],
680637
pagination: Annotated[PaginationParams, Depends(get_pagination_params)]):
681-
user_query_params = UserQueryParams()
682-
users = _get_users_with_membership_status(
683-
db_session=db_session,
684-
admin_roles=admin_user.access_token.biocommons_roles,
685-
query_params=user_query_params,
686-
status=ApprovalStatusEnum.PENDING,
638+
user_query = get_filtered_user_query(
639+
admin_user=admin_user,
640+
user_query=UserQueryParams(
641+
approval_status=ApprovalStatusEnum.PENDING,
642+
platform_approval_status=None,
643+
group_approval_status=None,
644+
),
687645
pagination=pagination,
688646
)
647+
users = db_session.exec(user_query).all()
689648
return [BiocommonsUserResponse.from_db_user(user) for user in users]
690649

691650

@@ -694,14 +653,16 @@ def get_pending_users(db_session: Annotated[Session, Depends(get_db_session)],
694653
def get_revoked_users(db_session: Annotated[Session, Depends(get_db_session)],
695654
admin_user: Annotated[SessionUser, Depends(get_session_user)],
696655
pagination: Annotated[PaginationParams, Depends(get_pagination_params)]):
697-
user_query_params = UserQueryParams()
698-
users = _get_users_with_membership_status(
699-
db_session=db_session,
700-
admin_roles=admin_user.access_token.biocommons_roles,
701-
query_params=user_query_params,
702-
status=ApprovalStatusEnum.REVOKED,
656+
user_query = get_filtered_user_query(
657+
admin_user=admin_user,
658+
user_query=UserQueryParams(
659+
approval_status=ApprovalStatusEnum.REVOKED,
660+
platform_approval_status=None,
661+
group_approval_status=None,
662+
),
703663
pagination=pagination,
704664
)
665+
users = db_session.exec(user_query).all()
705666
return [BiocommonsUserResponse.from_db_user(user) for user in users]
706667

707668

0 commit comments

Comments
 (0)