Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 149 additions & 13 deletions routers/admin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import logging
from datetime import datetime, timezone
from typing import Annotated, Any
Expand Down Expand Up @@ -86,6 +87,14 @@ def from_db_user(cls, user: BiocommonsUser) -> "BiocommonsUserResponse":
)


class UserCountsResponse(BaseModel):
"""Aggregate counts for different user categories."""
all: int
pending: int
revoked: int
unverified: int


class PaginationParams(BaseModel):
"""
Query parameters for paginated endpoints. Page starts at 1.
Expand Down Expand Up @@ -273,6 +282,10 @@ class UserQueryParams(BaseModel):

Each field listed here must have a {field}_query method defined.
"""
approval_status: ApprovalStatusEnum | None = Field(
None,
description="Filter by approval status across platforms and groups",
)
platform: PlatformEnum | None = Field(None, description="Filter by platform")
platform_approval_status: ApprovalStatusEnum | None = Field(None, description="Filter by platform approval status")
group: GroupEnum | None = Field(None, description="Filter by group")
Expand All @@ -283,6 +296,8 @@ class UserQueryParams(BaseModel):
description="Filter users by group ('tsi',) or platform ('galaxy', 'bpa_data_portal')"
)
search: str | None = Field(None, description="Search users by username or email")
_allowed_platforms_subquery: SelectOfScalar[Platform] | None = None
_allowed_groups_subquery: SelectOfScalar[BiocommonsGroup] | None = None

def _fields(self):
return (name for name in self.__pydantic_fields__.keys()
Expand All @@ -295,6 +310,11 @@ def model_post_init(self, context: Any) -> None:
for field_name in self._fields():
if not hasattr(self, f"{field_name}_query"):
raise NotImplementedError(f"Missing query method for field '{field_name}'")
if self.approval_status and (self.platform_approval_status or self.group_approval_status):
raise HTTPException(
status_code=400,
detail="approval_status cannot be used with platform_approval_status or group_approval_status",
)

def get_base_query(self):
"""
Expand All @@ -304,26 +324,41 @@ def get_base_query(self):
select(BiocommonsUser)
)

def get_admin_permissions_query(self, admin_roles: list[str]):
def _set_allowed_resource_subqueries(self, admin_roles: list[str]) -> None:
"""
Get the query for only returning users the admin has permission to view/manage,
based on group/platform roles
Cache allowed platform/group subqueries for reuse within this instance.
"""
allowed_platforms_subquery, allowed_groups_subquery = self.get_allowed_resource_subqueries(admin_roles)
self._allowed_platforms_subquery = allowed_platforms_subquery
self._allowed_groups_subquery = allowed_groups_subquery

def get_allowed_resource_subqueries(self, admin_roles: list[str]):
"""
Return subqueries for platform/group IDs the admin has access to.
"""
allowed_platforms_subquery = (
select(Platform.id)
.join(Platform.admin_roles)
.where(Auth0Role.name.in_(admin_roles))
)
platform_access_condition = BiocommonsUser.id.in_(
select(PlatformMembership.user_id).where(
PlatformMembership.platform_id.in_(allowed_platforms_subquery)
)
)
allowed_groups_subquery = (
select(BiocommonsGroup.group_id)
.join(BiocommonsGroup.admin_roles)
.where(Auth0Role.name.in_(admin_roles))
)
return allowed_platforms_subquery, allowed_groups_subquery

def get_admin_permissions_query(self, admin_roles: list[str]):
"""
Get the query for only returning users the admin has permission to view/manage,
based on group/platform roles
"""
allowed_platforms_subquery, allowed_groups_subquery = self.get_allowed_resource_subqueries(admin_roles)
platform_access_condition = BiocommonsUser.id.in_(
select(PlatformMembership.user_id).where(
PlatformMembership.platform_id.in_(allowed_platforms_subquery)
)
)
group_access_condition = BiocommonsUser.id.in_(
select(GroupMembership.user_id).where(
GroupMembership.group_id.in_(allowed_groups_subquery)
Expand All @@ -335,17 +370,18 @@ def get_complete_query(self, admin_roles: list[str], pagination: PaginationParam
"""
Return a full user query, with permissions from admin roles applied
"""
self._set_allowed_resource_subqueries(admin_roles)
return (
self.get_base_query()
.where(
self.get_admin_permissions_query(admin_roles),
*self.get_query_conditions())
*self.get_query_conditions(admin_roles))
.distinct()
.offset(pagination.start_index)
.limit(pagination.per_page)
)

def get_query_conditions(self):
def get_query_conditions(self, admin_roles: list[str] | None = None):
"""
Returns a list of SQLAlchemy queries for the filters that have been set.
The queries can be passed to where().
Expand All @@ -358,7 +394,11 @@ def get_query_conditions(self):
if field_value is not None:
method_name = f"{field_name}_query"
query_method = getattr(self, method_name)
condition = query_method()
params = inspect.signature(query_method).parameters
if admin_roles is not None and "admin_roles" in params:
condition = query_method(admin_roles)
else:
condition = query_method()
# conditions may be None for interacting queries like platform
# and platform_approval_status
if condition is not None:
Expand Down Expand Up @@ -406,6 +446,10 @@ def platform_approval_status_query(self):
platform_status_query = select(PlatformMembership.user_id).where(
PlatformMembership.approval_status == self.platform_approval_status
)
if self._allowed_platforms_subquery is not None:
platform_status_query = platform_status_query.where(
PlatformMembership.platform_id.in_(self._allowed_platforms_subquery)
)
return BiocommonsUser.id.in_(platform_status_query)

def group_query(self):
Expand All @@ -415,11 +459,54 @@ def group_query(self):
return BiocommonsUser.id.in_(group_query)

def group_approval_status_query(self):
"""
Filter by group approval status. This intentionally does not scope the
subquery to the admin's group permissions because get_admin_permissions_query
already enforces visibility. That allows platform admins (who may not be
group admins) to still see group-status results for the users they manage.
"""
group_status_query = select(GroupMembership.user_id).where(
GroupMembership.approval_status == self.group_approval_status
)
return BiocommonsUser.id.in_(group_status_query)

def approval_status_query(self, admin_roles: list[str] | None = None):
"""
Filter by approval status across platforms and groups.
"""
if self._allowed_platforms_subquery is None or self._allowed_groups_subquery is None:
if admin_roles is None:
raise ValueError("Allowed resource subqueries must be set before calling approval_status_query")
self._set_allowed_resource_subqueries(admin_roles)

platform_status_query = select(PlatformMembership.user_id).where(
PlatformMembership.platform_id.in_(self._allowed_platforms_subquery),
PlatformMembership.approval_status == self.approval_status,
)
group_status_query = select(GroupMembership.user_id).where(
GroupMembership.approval_status == self.approval_status,
)
return or_(
BiocommonsUser.id.in_(platform_status_query),
BiocommonsUser.id.in_(group_status_query),
)

def get_count(self, db_session: Session, admin_roles: list[str]) -> int:
"""
Count distinct users matching the current filters and admin permissions.
"""
self._set_allowed_resource_subqueries(admin_roles)
query = (
self.get_base_query()
.where(
self.get_admin_permissions_query(admin_roles),
*self.get_query_conditions(admin_roles),
)
.distinct()
)
count_statement = select(func.count()).select_from(query.subquery())
return db_session.exec(count_statement).one()

def email_verified_query(self):
return BiocommonsUser.email_verified.is_(self.email_verified)

Expand Down Expand Up @@ -487,6 +574,47 @@ def get_users(db_session: Annotated[Session, Depends(get_db_session)],
return [BiocommonsUserResponse.from_db_user(user) for user in users]


@router.get(
"/users/counts",
response_model=UserCountsResponse,
)
def get_user_counts(
db_session: Annotated[Session, Depends(get_db_session)],
admin_user: Annotated[SessionUser, Depends(get_session_user)],
query_params: Annotated[UserQueryParams, Depends()],
):
"""
Get aggregate counts for all, pending, revoked, and unverified users.
Applies the same filtering and permission checks as the /users endpoint.
"""
query_params.check_missing_ids(db_session)
admin_roles = admin_user.access_token.biocommons_roles
base_params = query_params.model_dump()

def count_with(overrides: dict[str, object] | None = None) -> int:
params_data = {**base_params, **(overrides or {})}
params = UserQueryParams(**params_data)
return params.get_count(
db_session=db_session,
admin_roles=admin_roles,
)

return UserCountsResponse(
all=count_with(),
pending=count_with({
"approval_status": ApprovalStatusEnum.PENDING,
"platform_approval_status": None,
"group_approval_status": None,
}),
revoked=count_with({
"approval_status": ApprovalStatusEnum.REVOKED,
"platform_approval_status": None,
"group_approval_status": None,
}),
unverified=count_with({"email_verified": False}),
)


# NOTE: This must appear before /users/{user_id} so it takes precedence
@router.get(
"/users/approved",
Expand All @@ -510,7 +638,11 @@ def get_pending_users(db_session: Annotated[Session, Depends(get_db_session)],
pagination: Annotated[PaginationParams, Depends(get_pagination_params)]):
user_query = get_filtered_user_query(
admin_user=admin_user,
user_query=UserQueryParams(platform_approval_status=ApprovalStatusEnum.PENDING),
user_query=UserQueryParams(
approval_status=ApprovalStatusEnum.PENDING,
platform_approval_status=None,
group_approval_status=None,
),
pagination=pagination,
)
users = db_session.exec(user_query).all()
Expand All @@ -524,7 +656,11 @@ def get_revoked_users(db_session: Annotated[Session, Depends(get_db_session)],
pagination: Annotated[PaginationParams, Depends(get_pagination_params)]):
user_query = get_filtered_user_query(
admin_user=admin_user,
user_query=UserQueryParams(platform_approval_status=ApprovalStatusEnum.REVOKED),
user_query=UserQueryParams(
approval_status=ApprovalStatusEnum.REVOKED,
platform_approval_status=None,
group_approval_status=None,
),
pagination=pagination,
)
users = db_session.exec(user_query).all()
Expand Down
4 changes: 4 additions & 0 deletions tests/db/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def is_deleted(cls) -> bool:
def _create_user_with_platform_membership(db_session: Session, platform_id: PlatformEnum,
approval_status=ApprovalStatusEnum.APPROVED,
commit=True, **kwargs):
if "email_verified" not in kwargs:
kwargs["email_verified"] = True
user = BiocommonsUserFactory.build(**kwargs)
membership = PlatformMembershipFactory.create_sync(
platform_id=platform_id,
Expand All @@ -105,6 +107,8 @@ def _users_with_platform_membership(n: int, db_session: Session, platform_id: Pl
def _create_user_with_group_membership(db_session: Session, group_id: str,
approval_status=ApprovalStatusEnum.APPROVED,
commit=True, **kwargs):
if "email_verified" not in kwargs:
kwargs["email_verified"] = True
user = BiocommonsUserFactory.build(**kwargs)
membership = GroupMembershipFactory.create_sync(
group_id=group_id,
Expand Down
Loading