Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions app/db/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,67 @@
from sqlalchemy.orm import Query

from app.db.model import Entity
from app.db.utils import get_declaring_class
from app.schemas.auth import UserContext


def constrain_to_accessible_entities[Q: Query | Select](
query: Q,
project_id: UUID4 | None,
user_context: UserContext | None,
db_model_class: Any = Entity,
) -> Q:
"""Ensure a query is filtered to rows that are viewable by the user."""
query = query.where(
if not user_context: # admin or global resource
return query

# if model or alias has an authorized_project_id use it as is
if hasattr(db_model_class, "authorized_project_id"):
id_model_class = db_model_class
# otherwise look up the hierarchy to check if there is one defined there
else:
id_model_class = get_declaring_class(db_model_class, "authorized_project_id")
# global resource without authorized_project_id, always accessible
if not id_model_class:
return query

# if user passes a specific project_id, use it to constrain resources
if user_context.project_id:
return query.where(
or_(
id_model_class.authorized_public == true(),
id_model_class.authorized_project_id == user_context.project_id,
)
)

# otherwise use user_project_ids from token to check if user has access
return query.where(
or_(
db_model_class.authorized_public == true(),
db_model_class.authorized_project_id == project_id if project_id else false(),
id_model_class.authorized_public == true(),
id_model_class.authorized_project_id.in_(user_context.user_project_ids),
)
)

return query


def constrain_to_private_entities[Q: Query | Select](
query: Q,
user_context: UserContext,
db_model_class: Any = Entity,
) -> Q:
"""Ensure a query is filtered to private rows that are viewable by the user."""
# if user passes a specific project_id, use it to constrain resources
if user_context.project_id:
return query.where(
and_(
db_model_class.authorized_public == false(),
db_model_class.authorized_project_id == user_context.project_id,
)
)

# otherwise use project_ids from token to check if user has access
return query.where(
and_(
db_model_class.authorized_public == false(),
db_model_class.authorized_project_id.in_(user_context.user_project_ids)
if user_context.user_project_ids
else false(),
db_model_class.authorized_project_id.in_(user_context.user_project_ids),
)
)

Expand Down
37 changes: 11 additions & 26 deletions app/queries/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def router_read_one[T: BaseModel, I: Identifiable](
id_: uuid.UUID,
db: Session,
db_model_class: type[I],
authorized_project_id: uuid.UUID | None,
user_context: UserContext | None,
response_schema_class: type[T],
apply_operations: ApplyOperations[I] | None,
) -> T:
Expand All @@ -49,20 +49,15 @@ def router_read_one[T: BaseModel, I: Identifiable](
id_: id of the entity to read.
db: database session.
db_model_class: database model class.
authorized_project_id: id of the authorized project.
user_context: the user's context
response_schema_class: Pydantic schema class for the returned data.
apply_operations: transformer function that modifies the select query.

Returns:
the model data as a Pydantic model.
"""
query = sa.select(db_model_class).where(db_model_class.id == id_)
if authorized_project_id and (
id_model_class := get_declaring_class(db_model_class, "authorized_project_id")
):
query = constrain_to_accessible_entities(
query, authorized_project_id, db_model_class=id_model_class
)
query = constrain_to_accessible_entities(query, user_context, db_model_class)
if apply_operations:
query = apply_operations(query)
with ensure_result(error_message=f"{db_model_class.__name__} not found"):
Expand Down Expand Up @@ -235,7 +230,7 @@ def router_read_many[T: BaseModel, I: Identifiable]( # noqa: PLR0913
*,
db: Session,
db_model_class: type[I],
authorized_project_id: uuid.UUID | None,
user_context: UserContext | None,
with_search: Search[I] | None,
with_in_brain_region: InBrainRegionQuery | None,
facets: WithFacets | None,
Expand All @@ -254,7 +249,7 @@ def router_read_many[T: BaseModel, I: Identifiable]( # noqa: PLR0913
Args:
db: database session.
db_model_class: database model class.
authorized_project_id: project id for filtering the resources.
user_context: the user's context
with_search: search query (str).
with_in_brain_region: enable family queries based on BrainRegion
facets: facet query (bool).
Expand All @@ -274,12 +269,7 @@ def router_read_many[T: BaseModel, I: Identifiable]( # noqa: PLR0913
the list of model data, pagination, and facets as a Pydantic model.
"""
filter_query = sa.select(db_model_class)
if id_model_class := get_declaring_class(db_model_class, "authorized_project_id"):
filter_query = constrain_to_accessible_entities(
filter_query,
project_id=authorized_project_id,
db_model_class=id_model_class,
)
filter_query = constrain_to_accessible_entities(filter_query, user_context, db_model_class)

if apply_filter_query_operations:
filter_query = apply_filter_query_operations(filter_query)
Expand Down Expand Up @@ -387,23 +377,18 @@ def router_delete_one[T: BaseModel, I: Identifiable](
id_: uuid.UUID,
db: Session,
db_model_class: type[I],
authorized_project_id: uuid.UUID | None,
user_context: UserContext | None,
) -> dict:
"""Delete a model from the database.

Args:
id_: id of the entity to read.
db: database session.
db_model_class: database model class.
authorized_project_id: project id for filtering the resources.
user_context: the user's context
"""
query = sa.select(db_model_class).where(db_model_class.id == id_)
if authorized_project_id and (
id_model_class := get_declaring_class(db_model_class, "authorized_project_id")
):
query = constrain_to_accessible_entities(
query, authorized_project_id, db_model_class=id_model_class
)
query = constrain_to_accessible_entities(query, user_context, db_model_class)

with ensure_result(error_message=f"{db_model_class.__name__} not found"):
obj = db.execute(query).scalars().one()
Expand All @@ -427,15 +412,15 @@ def router_update_activity_one[T: BaseModel, I: Activity](
id_: uuid.UUID,
db: Session,
db_model_class: type[I],
user_context: UserContext | UserContextWithProjectId,
user_context: UserContext,
json_model: ActivityUpdate,
response_schema_class: type[T],
apply_operations: ApplyOperations | None = None,
) -> T:
query = sa.select(db_model_class).where(db_model_class.id == id_)
if id_model_class := get_declaring_class(db_model_class, "authorized_project_id"):
query = constrain_to_accessible_entities(
query, user_context.project_id, db_model_class=id_model_class
query, user_context=user_context, db_model_class=id_model_class
)
if apply_operations:
query = apply_operations(query)
Expand Down
7 changes: 4 additions & 3 deletions app/queries/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,29 @@
from app.db.auth import constrain_entity_query_to_project, constrain_to_accessible_entities
from app.db.model import Entity
from app.errors import ensure_result
from app.schemas.auth import UserContext


def get_readable_entity[T: Entity](
db: Session,
db_model_class: type[T],
entity_id: uuid.UUID,
project_id: uuid.UUID | None,
user_context: UserContext | None,
) -> T:
"""Return a specific entity by type and id, readable by the given project.

Args:
db: db session.
db_model_class: Entity subclass.
entity_id: id of the entity.
project_id: optional project id owning the entity.
user_context: optional user context

Returns:
the selected entity if it's public or owned by project_id,
or raises NoResultFound if the entity doesn't exist, or it's forbidden.
"""
query = sa.select(db_model_class).where(db_model_class.id == entity_id)
query = constrain_to_accessible_entities(query, project_id=project_id)
query = constrain_to_accessible_entities(query, user_context=user_context)
with ensure_result(f"Entity {db_model_class.__name__} {entity_id} not found or forbidden"):
return db.execute(query).scalar_one()

Expand Down
2 changes: 1 addition & 1 deletion app/routers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def delete_one(
id_=id_,
db=db,
db_model_class=RESOURCE_TYPE_TO_CLASS[resource_type],
authorized_project_id=None,
user_context=None,
)


Expand Down
11 changes: 9 additions & 2 deletions app/routers/ion_channel_recording.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from fastapi import APIRouter

import app.service.ion_channel_recording
from app.routers.admin import router as admin_router

ROUTE = "ion-channel-recording"

router = APIRouter(
prefix="/ion-channel-recording",
tags=["ion-channel-recording"],
prefix=f"/{ROUTE}",
tags=[ROUTE],
)

read_many = router.get("")(app.service.ion_channel_recording.read_many)
read_one = router.get("/{id_}")(app.service.ion_channel_recording.read_one)
create_one = router.post("")(app.service.ion_channel_recording.create_one)
update_one = router.patch("/{id_}")(app.service.ion_channel_recording.update_one)

admin_read_one = admin_router.get(f"/{ROUTE}/{{id_}}")(
app.service.ion_channel_recording.admin_read_one
)
2 changes: 1 addition & 1 deletion app/service/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_entity_assets(
return router_read_many(
db=repos.db,
db_model_class=db_model_class,
authorized_project_id=None,
user_context=None,
with_search=None,
with_in_brain_region=None,
facets=None,
Expand Down
2 changes: 1 addition & 1 deletion app/service/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_entity_assets(
return router_read_many(
db=repos.db,
db_model_class=db_model_class,
authorized_project_id=user_context.project_id,
user_context=user_context,
with_search=None,
with_in_brain_region=None,
facets=None,
Expand Down
10 changes: 5 additions & 5 deletions app/service/brain_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def read_many(
return app.queries.common.router_read_many(
db=db,
db_model_class=BrainAtlas,
authorized_project_id=user_context.project_id,
user_context=user_context,
with_search=None,
with_in_brain_region=None,
facets=None,
Expand All @@ -52,7 +52,7 @@ def read_one(user_context: UserContextDep, atlas_id: uuid.UUID, db: SessionDep)
id_=atlas_id,
db=db,
db_model_class=BrainAtlas,
authorized_project_id=user_context.project_id,
user_context=user_context,
response_schema_class=BrainAtlasRead,
apply_operations=_load_brain_atlas,
)
Expand All @@ -63,7 +63,7 @@ def admin_read_one(db: SessionDep, atlas_id: uuid.UUID) -> BrainAtlasRead:
id_=atlas_id,
db=db,
db_model_class=BrainAtlas,
authorized_project_id=None,
user_context=None,
response_schema_class=BrainAtlasRead,
apply_operations=_load_brain_atlas,
)
Expand All @@ -79,7 +79,7 @@ def read_many_region(
return app.queries.common.router_read_many(
db=db,
db_model_class=BrainAtlasRegion,
authorized_project_id=user_context.project_id,
user_context=user_context,
with_search=None,
with_in_brain_region=None,
facets=None,
Expand All @@ -102,7 +102,7 @@ def read_one_region(
id_=atlas_region_id,
db=db,
db_model_class=BrainAtlasRegion,
authorized_project_id=user_context.project_id,
user_context=user_context,
response_schema_class=BrainAtlasRegionRead,
apply_operations=lambda select: select.filter(
BrainAtlasRegion.brain_atlas_id == atlas_id
Expand Down
2 changes: 1 addition & 1 deletion app/service/brain_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def read_many(
return app.queries.common.router_read_many(
db=db,
db_model_class=BrainRegion,
authorized_project_id=None,
user_context=None,
with_search=None,
with_in_brain_region=None,
facets=None,
Expand Down
4 changes: 2 additions & 2 deletions app/service/brain_region_hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def read_many(
return app.queries.common.router_read_many(
db=db,
db_model_class=BrainRegionHierarchy,
authorized_project_id=None,
user_context=None,
with_search=None,
with_in_brain_region=None,
facets=None,
Expand All @@ -50,7 +50,7 @@ def read_one(id_: uuid.UUID, db: SessionDep) -> BrainRegionHierarchyRead:
id_=id_,
db=db,
db_model_class=BrainRegionHierarchy,
authorized_project_id=None,
user_context=None,
response_schema_class=BrainRegionHierarchyRead,
apply_operations=_load,
)
Expand Down
Loading
Loading