Skip to content

Commit 7fbcc5b

Browse files
authored
Generic entity endpoint (#290)
1 parent c7f19fa commit 7fbcc5b

File tree

9 files changed

+135
-11
lines changed

9 files changed

+135
-11
lines changed

app/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class Settings(BaseSettings):
3939
LOG_CATCH: bool = True
4040
LOG_STANDARD_LOGGER: dict[str, str] = {"root": "INFO"}
4141

42-
KEYCLOAK_URL: str = "https://example.openbraininstitute.org/auth/realms/SBO"
42+
KEYCLOAK_URL: str = "https://staging.openbraininstitute.org/auth/realms/SBO"
4343
AUTH_CACHE_MAXSIZE: int = 128 # items
4444
AUTH_CACHE_MAX_TTL: int = 300 # seconds
4545
AUTH_CACHE_INFO: bool = False

app/dependencies/auth.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,14 @@ def _check_user_info(
127127
)
128128

129129
user_info_response = deserialize_response(response, model_class=UserInfoResponse)
130+
130131
is_authorized = user_info_response.is_authorized_for(
131132
virtual_lab_id=project_context.virtual_lab_id,
132133
project_id=project_context.project_id,
133134
)
135+
134136
is_service_admin = user_info_response.is_service_admin(settings.APP_NAME)
137+
135138
user_context = UserContext(
136139
profile=UserProfile.from_user_info(user_info_response),
137140
expiration=decoded.exp if decoded else None,
@@ -140,6 +143,7 @@ def _check_user_info(
140143
virtual_lab_id=project_context.virtual_lab_id,
141144
project_id=project_context.project_id,
142145
auth_error_reason=AuthErrorReason.NOT_AUTHORIZED_PROJECT if not is_authorized else None,
146+
user_project_ids=user_info_response.user_project_ids(),
143147
)
144148

145149
if not user_context.is_authorized:

app/routers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252

5353
router = APIRouter()
5454
router.include_router(root.router)
55+
56+
5557
authenticated_routers = [
5658
asset.router,
5759
brain_atlas.router,

app/routers/entity.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""Entity router."""
22

33
from typing import Annotated
4+
from uuid import UUID
45

56
from fastapi import APIRouter, Query
67

78
from app.db.utils import EntityTypeWithBrainRegion
89
from app.dependencies.auth import UserContextDep
910
from app.dependencies.common import InBrainRegionDep
1011
from app.dependencies.db import SessionDep
11-
from app.schemas.entity import EntityCountRead
12+
from app.schemas.entity import EntityCountRead, EntityRead
1213
from app.service import entity as entity_service
1314

1415
router = APIRouter(
@@ -36,3 +37,12 @@ def count_entities_by_type(
3637
entity_types=types,
3738
in_brain_region=in_brain_region,
3839
)
40+
41+
42+
@router.get("/{id_}")
43+
def read_one(
44+
id_: UUID,
45+
user_context: UserContextDep,
46+
db: SessionDep,
47+
) -> EntityRead:
48+
return entity_service.read_one(id_, db, user_context)

app/schemas/auth.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from typing import Self
23
from uuid import UUID
34

@@ -105,6 +106,19 @@ def is_authorized_for(self, virtual_lab_id: UUID | None, project_id: UUID | None
105106
]
106107
)
107108

109+
def user_project_ids(self) -> list[UUID]:
110+
"""Return the the list if project_ids the user is authorized for."""
111+
pattern = r"/proj/[0-9a-fA-F-]+/([0-9a-fA-F-]+)/(admin|member)"
112+
113+
project_ids: set[UUID] = set()
114+
115+
for s in self.groups:
116+
match = re.match(pattern, s)
117+
if match:
118+
project_ids.add(UUID(match.group(1)))
119+
120+
return list(project_ids)
121+
108122

109123
class UserProfile(BaseModel):
110124
"""User profile representing a keycloak user."""
@@ -145,6 +159,7 @@ class UserContext(UserContextBase):
145159

146160
virtual_lab_id: UUID | None = None
147161
project_id: UUID | None = None
162+
user_project_ids: list[UUID] = []
148163

149164

150165
class UserContextWithProjectId(UserContextBase):

app/service/entity.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import uuid
22

33
import sqlalchemy as sa
4+
from sqlalchemy.orm import Session
45

56
import app.queries.entity
67
from app.db.auth import constrain_to_accessible_entities
@@ -10,13 +11,12 @@
1011
ENTITY_TYPE_TO_CLASS,
1112
EntityTypeWithBrainRegion,
1213
)
13-
from app.dependencies.auth import UserContextDep
1414
from app.dependencies.common import InBrainRegionDep
15-
from app.dependencies.db import SessionDep
15+
from app.errors import ensure_result
1616
from app.filters.brain_region import get_family_query
1717
from app.repository.group import RepositoryGroup
1818
from app.schemas.auth import UserContext, UserContextWithProjectId
19-
from app.schemas.entity import EntityCountRead
19+
from app.schemas.entity import EntityCountRead, EntityRead
2020

2121

2222
def get_readable_entity(
@@ -55,8 +55,8 @@ def get_writable_entity(
5555

5656
def count_entities_by_type(
5757
*,
58-
user_context: UserContextDep,
59-
db: SessionDep,
58+
user_context: UserContext,
59+
db: Session,
6060
entity_types: list[EntityTypeWithBrainRegion],
6161
in_brain_region: InBrainRegionDep,
6262
) -> EntityCountRead:
@@ -114,3 +114,20 @@ def count_entities_by_type(
114114
results = EntityCountRead.model_validate(data)
115115

116116
return results
117+
118+
119+
def read_one(
120+
id_: uuid.UUID,
121+
db: Session,
122+
user_context: UserContext,
123+
) -> EntityRead:
124+
with ensure_result(f"Entity {id_} not found or forbidden"):
125+
query = sa.select(Entity).where(
126+
Entity.id == id_,
127+
sa.or_(
128+
Entity.authorized_public.is_(True),
129+
Entity.authorized_project_id.in_(user_context.user_project_ids),
130+
),
131+
)
132+
row = db.execute(query).unique().scalar_one()
133+
return EntityRead.model_validate(row)

tests/conftest.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,9 @@ def user_context_user_1():
118118
expiration=None,
119119
is_authorized=True,
120120
is_service_admin=False,
121-
virtual_lab_id=VIRTUAL_LAB_ID,
122-
project_id=PROJECT_ID,
121+
virtual_lab_id=UUID(VIRTUAL_LAB_ID),
122+
project_id=UUID(PROJECT_ID),
123+
user_project_ids=[UUID(PROJECT_ID)],
123124
)
124125

125126

@@ -134,8 +135,8 @@ def user_context_user_2():
134135
expiration=None,
135136
is_authorized=True,
136137
is_service_admin=False,
137-
virtual_lab_id=UNRELATED_VIRTUAL_LAB_ID,
138-
project_id=UNRELATED_PROJECT_ID,
138+
virtual_lab_id=UUID(UNRELATED_VIRTUAL_LAB_ID),
139+
project_id=UUID(UNRELATED_PROJECT_ID),
139140
)
140141

141142

tests/test_auth.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def test_user_verified_ok(httpx_mock, request_mock, is_admin, is_token_jwt, proj
9696
is_service_admin=is_admin,
9797
virtual_lab_id=project_context.virtual_lab_id,
9898
project_id=project_context.project_id,
99+
user_project_ids=[uuid.UUID(PROJECT_ID)],
99100
)
100101

101102

tests/test_entity.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from app.db.model import IonChannelModel
2+
from app.schemas.morphology import ReconstructionMorphologyRead
23

34
from .utils import (
45
PROJECT_ID,
@@ -11,6 +12,79 @@
1112
ROUTE = "/entity"
1213

1314

15+
def test_get_entity(client, brain_region_id, species_id, strain_id, license_id):
16+
morph = assert_request(
17+
client.post,
18+
url="/reconstruction-morphology",
19+
json={
20+
"brain_region_id": str(brain_region_id),
21+
"species_id": str(species_id),
22+
"strain_id": str(strain_id),
23+
"description": "Test morph",
24+
"name": "Test morph",
25+
"location": {"x": 10, "y": 20, "z": 30},
26+
"legacy_id": ["Test Legacy ID"],
27+
"license_id": str(license_id),
28+
},
29+
).json()
30+
31+
data = assert_request(client.get, url=f"{ROUTE}/{morph['id']}").json()
32+
33+
assert data["type"] == "reconstruction_morphology"
34+
35+
36+
def test_get_entity_no_auth(
37+
client, client_user_2, brain_region_id, species_id, strain_id, license_id
38+
):
39+
morph = assert_request(
40+
client_user_2.post,
41+
url="/reconstruction-morphology",
42+
json={
43+
"brain_region_id": str(brain_region_id),
44+
"species_id": str(species_id),
45+
"strain_id": str(strain_id),
46+
"description": "Test morph",
47+
"name": "Test morph",
48+
"location": {"x": 10, "y": 20, "z": 30},
49+
"legacy_id": ["Test Legacy ID"],
50+
"license_id": str(license_id),
51+
},
52+
).json()
53+
54+
res = client.get(url=f"{ROUTE}/{morph['id']}")
55+
56+
assert res.status_code == 404
57+
58+
59+
def test_public_unrelated_project_accessible(
60+
client, client_user_2, brain_region_id, species_id, strain_id, license_id
61+
):
62+
morph = assert_request(
63+
client_user_2.post,
64+
url="/reconstruction-morphology",
65+
json={
66+
"authorized_public": True,
67+
"brain_region_id": str(brain_region_id),
68+
"species_id": str(species_id),
69+
"strain_id": str(strain_id),
70+
"description": "Test morph",
71+
"name": "Test morph",
72+
"location": {"x": 10, "y": 20, "z": 30},
73+
"legacy_id": ["Test Legacy ID"],
74+
"license_id": str(license_id),
75+
},
76+
).json()
77+
78+
data = assert_request(client.get, url=f"{ROUTE}/{morph['id']}").json()
79+
assert data["type"] == "reconstruction_morphology"
80+
81+
morph_detail = assert_request(
82+
client.get, url=f"/reconstruction-morphology/{morph['id']}"
83+
).json()
84+
85+
assert ReconstructionMorphologyRead.model_validate(morph_detail)
86+
87+
1488
def test_count_entities_validation_errors(client):
1589
"""Test validation errors for the count endpoint."""
1690
response = client.get(f"{ROUTE}/counts")

0 commit comments

Comments
 (0)