Skip to content

Commit 538503b

Browse files
Merge pull request #138 from AustralianBioCommons/AAI-530-calculate-users
feat: add user calculation endpoint and updating pending/revoked logic
2 parents bdd5e86 + b29b517 commit 538503b

File tree

3 files changed

+265
-18
lines changed

3 files changed

+265
-18
lines changed

routers/admin.py

Lines changed: 149 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import logging
23
from datetime import datetime, timezone
34
from typing import Annotated, Any
@@ -86,6 +87,14 @@ def from_db_user(cls, user: BiocommonsUser) -> "BiocommonsUserResponse":
8687
)
8788

8889

90+
class UserCountsResponse(BaseModel):
91+
"""Aggregate counts for different user categories."""
92+
all: int
93+
pending: int
94+
revoked: int
95+
unverified: int
96+
97+
8998
class PaginationParams(BaseModel):
9099
"""
91100
Query parameters for paginated endpoints. Page starts at 1.
@@ -273,6 +282,10 @@ class UserQueryParams(BaseModel):
273282
274283
Each field listed here must have a {field}_query method defined.
275284
"""
285+
approval_status: ApprovalStatusEnum | None = Field(
286+
None,
287+
description="Filter by approval status across platforms and groups",
288+
)
276289
platform: PlatformEnum | None = Field(None, description="Filter by platform")
277290
platform_approval_status: ApprovalStatusEnum | None = Field(None, description="Filter by platform approval status")
278291
group: GroupEnum | None = Field(None, description="Filter by group")
@@ -283,6 +296,8 @@ class UserQueryParams(BaseModel):
283296
description="Filter users by group ('tsi',) or platform ('galaxy', 'bpa_data_portal')"
284297
)
285298
search: str | None = Field(None, description="Search users by username or email")
299+
_allowed_platforms_subquery: SelectOfScalar[Platform] | None = None
300+
_allowed_groups_subquery: SelectOfScalar[BiocommonsGroup] | None = None
286301

287302
def _fields(self):
288303
return (name for name in self.__pydantic_fields__.keys()
@@ -295,6 +310,11 @@ def model_post_init(self, context: Any) -> None:
295310
for field_name in self._fields():
296311
if not hasattr(self, f"{field_name}_query"):
297312
raise NotImplementedError(f"Missing query method for field '{field_name}'")
313+
if self.approval_status and (self.platform_approval_status or self.group_approval_status):
314+
raise HTTPException(
315+
status_code=400,
316+
detail="approval_status cannot be used with platform_approval_status or group_approval_status",
317+
)
298318

299319
def get_base_query(self):
300320
"""
@@ -304,26 +324,41 @@ def get_base_query(self):
304324
select(BiocommonsUser)
305325
)
306326

307-
def get_admin_permissions_query(self, admin_roles: list[str]):
327+
def _set_allowed_resource_subqueries(self, admin_roles: list[str]) -> None:
308328
"""
309-
Get the query for only returning users the admin has permission to view/manage,
310-
based on group/platform roles
329+
Cache allowed platform/group subqueries for reuse within this instance.
330+
"""
331+
allowed_platforms_subquery, allowed_groups_subquery = self.get_allowed_resource_subqueries(admin_roles)
332+
self._allowed_platforms_subquery = allowed_platforms_subquery
333+
self._allowed_groups_subquery = allowed_groups_subquery
334+
335+
def get_allowed_resource_subqueries(self, admin_roles: list[str]):
336+
"""
337+
Return subqueries for platform/group IDs the admin has access to.
311338
"""
312339
allowed_platforms_subquery = (
313340
select(Platform.id)
314341
.join(Platform.admin_roles)
315342
.where(Auth0Role.name.in_(admin_roles))
316343
)
317-
platform_access_condition = BiocommonsUser.id.in_(
318-
select(PlatformMembership.user_id).where(
319-
PlatformMembership.platform_id.in_(allowed_platforms_subquery)
320-
)
321-
)
322344
allowed_groups_subquery = (
323345
select(BiocommonsGroup.group_id)
324346
.join(BiocommonsGroup.admin_roles)
325347
.where(Auth0Role.name.in_(admin_roles))
326348
)
349+
return allowed_platforms_subquery, allowed_groups_subquery
350+
351+
def get_admin_permissions_query(self, admin_roles: list[str]):
352+
"""
353+
Get the query for only returning users the admin has permission to view/manage,
354+
based on group/platform roles
355+
"""
356+
allowed_platforms_subquery, allowed_groups_subquery = self.get_allowed_resource_subqueries(admin_roles)
357+
platform_access_condition = BiocommonsUser.id.in_(
358+
select(PlatformMembership.user_id).where(
359+
PlatformMembership.platform_id.in_(allowed_platforms_subquery)
360+
)
361+
)
327362
group_access_condition = BiocommonsUser.id.in_(
328363
select(GroupMembership.user_id).where(
329364
GroupMembership.group_id.in_(allowed_groups_subquery)
@@ -335,17 +370,18 @@ def get_complete_query(self, admin_roles: list[str], pagination: PaginationParam
335370
"""
336371
Return a full user query, with permissions from admin roles applied
337372
"""
373+
self._set_allowed_resource_subqueries(admin_roles)
338374
return (
339375
self.get_base_query()
340376
.where(
341377
self.get_admin_permissions_query(admin_roles),
342-
*self.get_query_conditions())
378+
*self.get_query_conditions(admin_roles))
343379
.distinct()
344380
.offset(pagination.start_index)
345381
.limit(pagination.per_page)
346382
)
347383

348-
def get_query_conditions(self):
384+
def get_query_conditions(self, admin_roles: list[str] | None = None):
349385
"""
350386
Returns a list of SQLAlchemy queries for the filters that have been set.
351387
The queries can be passed to where().
@@ -358,7 +394,11 @@ def get_query_conditions(self):
358394
if field_value is not None:
359395
method_name = f"{field_name}_query"
360396
query_method = getattr(self, method_name)
361-
condition = query_method()
397+
params = inspect.signature(query_method).parameters
398+
if admin_roles is not None and "admin_roles" in params:
399+
condition = query_method(admin_roles)
400+
else:
401+
condition = query_method()
362402
# conditions may be None for interacting queries like platform
363403
# and platform_approval_status
364404
if condition is not None:
@@ -406,6 +446,10 @@ def platform_approval_status_query(self):
406446
platform_status_query = select(PlatformMembership.user_id).where(
407447
PlatformMembership.approval_status == self.platform_approval_status
408448
)
449+
if self._allowed_platforms_subquery is not None:
450+
platform_status_query = platform_status_query.where(
451+
PlatformMembership.platform_id.in_(self._allowed_platforms_subquery)
452+
)
409453
return BiocommonsUser.id.in_(platform_status_query)
410454

411455
def group_query(self):
@@ -415,11 +459,54 @@ def group_query(self):
415459
return BiocommonsUser.id.in_(group_query)
416460

417461
def group_approval_status_query(self):
462+
"""
463+
Filter by group approval status. This intentionally does not scope the
464+
subquery to the admin's group permissions because get_admin_permissions_query
465+
already enforces visibility. That allows platform admins (who may not be
466+
group admins) to still see group-status results for the users they manage.
467+
"""
418468
group_status_query = select(GroupMembership.user_id).where(
419469
GroupMembership.approval_status == self.group_approval_status
420470
)
421471
return BiocommonsUser.id.in_(group_status_query)
422472

473+
def approval_status_query(self, admin_roles: list[str] | None = None):
474+
"""
475+
Filter by approval status across platforms and groups.
476+
"""
477+
if self._allowed_platforms_subquery is None or self._allowed_groups_subquery is None:
478+
if admin_roles is None:
479+
raise ValueError("Allowed resource subqueries must be set before calling approval_status_query")
480+
self._set_allowed_resource_subqueries(admin_roles)
481+
482+
platform_status_query = select(PlatformMembership.user_id).where(
483+
PlatformMembership.platform_id.in_(self._allowed_platforms_subquery),
484+
PlatformMembership.approval_status == self.approval_status,
485+
)
486+
group_status_query = select(GroupMembership.user_id).where(
487+
GroupMembership.approval_status == self.approval_status,
488+
)
489+
return or_(
490+
BiocommonsUser.id.in_(platform_status_query),
491+
BiocommonsUser.id.in_(group_status_query),
492+
)
493+
494+
def get_count(self, db_session: Session, admin_roles: list[str]) -> int:
495+
"""
496+
Count distinct users matching the current filters and admin permissions.
497+
"""
498+
self._set_allowed_resource_subqueries(admin_roles)
499+
query = (
500+
self.get_base_query()
501+
.where(
502+
self.get_admin_permissions_query(admin_roles),
503+
*self.get_query_conditions(admin_roles),
504+
)
505+
.distinct()
506+
)
507+
count_statement = select(func.count()).select_from(query.subquery())
508+
return db_session.exec(count_statement).one()
509+
423510
def email_verified_query(self):
424511
return BiocommonsUser.email_verified.is_(self.email_verified)
425512

@@ -487,6 +574,47 @@ def get_users(db_session: Annotated[Session, Depends(get_db_session)],
487574
return [BiocommonsUserResponse.from_db_user(user) for user in users]
488575

489576

577+
@router.get(
578+
"/users/counts",
579+
response_model=UserCountsResponse,
580+
)
581+
def get_user_counts(
582+
db_session: Annotated[Session, Depends(get_db_session)],
583+
admin_user: Annotated[SessionUser, Depends(get_session_user)],
584+
query_params: Annotated[UserQueryParams, Depends()],
585+
):
586+
"""
587+
Get aggregate counts for all, pending, revoked, and unverified users.
588+
Applies the same filtering and permission checks as the /users endpoint.
589+
"""
590+
query_params.check_missing_ids(db_session)
591+
admin_roles = admin_user.access_token.biocommons_roles
592+
base_params = query_params.model_dump()
593+
594+
def count_with(overrides: dict[str, object] | None = None) -> int:
595+
params_data = {**base_params, **(overrides or {})}
596+
params = UserQueryParams(**params_data)
597+
return params.get_count(
598+
db_session=db_session,
599+
admin_roles=admin_roles,
600+
)
601+
602+
return UserCountsResponse(
603+
all=count_with(),
604+
pending=count_with({
605+
"approval_status": ApprovalStatusEnum.PENDING,
606+
"platform_approval_status": None,
607+
"group_approval_status": None,
608+
}),
609+
revoked=count_with({
610+
"approval_status": ApprovalStatusEnum.REVOKED,
611+
"platform_approval_status": None,
612+
"group_approval_status": None,
613+
}),
614+
unverified=count_with({"email_verified": False}),
615+
)
616+
617+
490618
# NOTE: This must appear before /users/{user_id} so it takes precedence
491619
@router.get(
492620
"/users/approved",
@@ -510,7 +638,11 @@ def get_pending_users(db_session: Annotated[Session, Depends(get_db_session)],
510638
pagination: Annotated[PaginationParams, Depends(get_pagination_params)]):
511639
user_query = get_filtered_user_query(
512640
admin_user=admin_user,
513-
user_query=UserQueryParams(platform_approval_status=ApprovalStatusEnum.PENDING),
641+
user_query=UserQueryParams(
642+
approval_status=ApprovalStatusEnum.PENDING,
643+
platform_approval_status=None,
644+
group_approval_status=None,
645+
),
514646
pagination=pagination,
515647
)
516648
users = db_session.exec(user_query).all()
@@ -524,7 +656,11 @@ def get_revoked_users(db_session: Annotated[Session, Depends(get_db_session)],
524656
pagination: Annotated[PaginationParams, Depends(get_pagination_params)]):
525657
user_query = get_filtered_user_query(
526658
admin_user=admin_user,
527-
user_query=UserQueryParams(platform_approval_status=ApprovalStatusEnum.REVOKED),
659+
user_query=UserQueryParams(
660+
approval_status=ApprovalStatusEnum.REVOKED,
661+
platform_approval_status=None,
662+
group_approval_status=None,
663+
),
528664
pagination=pagination,
529665
)
530666
users = db_session.exec(user_query).all()

tests/db/datagen.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def is_deleted(cls) -> bool:
7979
def _create_user_with_platform_membership(db_session: Session, platform_id: PlatformEnum,
8080
approval_status=ApprovalStatusEnum.APPROVED,
8181
commit=True, **kwargs):
82+
if "email_verified" not in kwargs:
83+
kwargs["email_verified"] = True
8284
user = BiocommonsUserFactory.build(**kwargs)
8385
membership = PlatformMembershipFactory.create_sync(
8486
platform_id=platform_id,
@@ -105,6 +107,8 @@ def _users_with_platform_membership(n: int, db_session: Session, platform_id: Pl
105107
def _create_user_with_group_membership(db_session: Session, group_id: str,
106108
approval_status=ApprovalStatusEnum.APPROVED,
107109
commit=True, **kwargs):
110+
if "email_verified" not in kwargs:
111+
kwargs["email_verified"] = True
108112
user = BiocommonsUserFactory.build(**kwargs)
109113
membership = GroupMembershipFactory.create_sync(
110114
group_id=group_id,

0 commit comments

Comments
 (0)