Skip to content

Commit b8b5b86

Browse files
Example admin update endpoint
1 parent e3c47c9 commit b8b5b86

File tree

6 files changed

+120
-13
lines changed

6 files changed

+120
-13
lines changed

app/queries/common.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,15 +334,17 @@ def router_update_one[T: BaseModel, I: Identifiable](
334334
id_: uuid.UUID,
335335
db: Session,
336336
db_model_class: type[I],
337-
user_context: UserContext,
337+
user_context: UserContext | None,
338338
json_model: BaseModel,
339339
response_schema_class: type[T],
340340
apply_operations: ApplyOperations | None = None,
341341
):
342342
query = (
343343
sa.select(db_model_class).where(db_model_class.id == id_).with_for_update(of=db_model_class)
344344
)
345-
if id_model_class := get_declaring_class(db_model_class, "authorized_project_id"):
345+
if user_context and (
346+
id_model_class := get_declaring_class(db_model_class, "authorized_project_id")
347+
):
346348
query = constrain_to_private_entities(query, user_context, db_model_class=id_model_class)
347349
if apply_operations:
348350
query = apply_operations(query)
Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
from fastapi import APIRouter
22

33
import app.service.electrical_cell_recording
4+
from app.routers.admin import router as admin_router
5+
6+
ROUTE = "electrical-cell-recording"
47

58
router = APIRouter(
6-
prefix="/electrical-cell-recording",
7-
tags=["electrical-cell-recording"],
9+
prefix=f"/{ROUTE}",
10+
tags=[ROUTE],
811
)
912

1013
read_many = router.get("")(app.service.electrical_cell_recording.read_many)
1114
read_one = router.get("/{id_}")(app.service.electrical_cell_recording.read_one)
1215
create_one = router.post("")(app.service.electrical_cell_recording.create_one)
1316
update_one = router.patch("/{id_}")(app.service.electrical_cell_recording.update_one)
17+
18+
admin_update_one = admin_router.patch(f"/{ROUTE}/{{id_}}")(
19+
app.service.electrical_cell_recording.admin_update_one
20+
)

app/schemas/electrical_cell_recording.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,15 @@ class ElectricalCellRecordingCreate(ElectricalCellRecordingBase, ScientificArtif
7272
pass
7373

7474

75-
ElectricalCellRecordingUpdate = make_update_schema(
76-
ElectricalCellRecordingCreate, "ElectricalCellRecordingUpdate"
75+
ElectricalCellRecordingUserUpdate = make_update_schema(
76+
ElectricalCellRecordingCreate,
77+
"ElectricalCellRecordingUserUpdate",
78+
) # pyright : ignore [reportInvalidTypeForm]
79+
80+
ElectricalCellRecordingAdminUpdate = make_update_schema(
81+
ElectricalCellRecordingCreate,
82+
"ElectricalCellRecordingAdminUpdate",
83+
excluded_fields=set(),
7784
) # pyright : ignore [reportInvalidTypeForm]
7885

7986

app/schemas/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
}
99

1010

11-
def make_update_schema(schema: type[BaseModel], new_schema_name: str):
11+
def make_update_schema(
12+
schema: type[BaseModel],
13+
new_schema_name: str | None = None,
14+
excluded_fields: set = EXCLUDED_FIELDS,
15+
):
1216
"""Create a new pydantic schema from current schema where all fields are optional.
1317
1418
In order to differentiate between the user providing a None value and an actual not set by the
@@ -27,6 +31,6 @@ def make_optional(field):
2731
fields = {
2832
name: make_optional(field)
2933
for name, field in schema.model_fields.items()
30-
if name not in EXCLUDED_FIELDS
34+
if name not in excluded_fields
3135
}
3236
return create_model(new_schema_name, **fields) # pyright: ignore reportArgumentType

app/service/electrical_cell_recording.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@
2727
)
2828
from app.queries.factory import query_params_factory
2929
from app.schemas.electrical_cell_recording import (
30+
ElectricalCellRecordingAdminUpdate,
3031
ElectricalCellRecordingCreate,
3132
ElectricalCellRecordingRead,
32-
ElectricalCellRecordingUpdate,
33+
ElectricalCellRecordingUserUpdate,
3334
)
3435
from app.schemas.types import ListResponse
3536

@@ -153,7 +154,7 @@ def update_one(
153154
user_context: UserContextDep,
154155
db: SessionDep,
155156
id_: uuid.UUID,
156-
json_model: ElectricalCellRecordingUpdate, # pyright: ignore [reportInvalidTypeForm]
157+
json_model: ElectricalCellRecordingUserUpdate, # pyright: ignore [reportInvalidTypeForm]
157158
) -> ElectricalCellRecordingRead:
158159
return router_update_one(
159160
id_=id_,
@@ -164,3 +165,19 @@ def update_one(
164165
response_schema_class=ElectricalCellRecordingRead,
165166
apply_operations=_load,
166167
)
168+
169+
170+
def admin_update_one(
171+
db: SessionDep,
172+
id_: uuid.UUID,
173+
json_model: ElectricalCellRecordingAdminUpdate, # pyright: ignore [reportInvalidTypeForm]
174+
) -> ElectricalCellRecordingRead:
175+
return router_update_one(
176+
id_=id_,
177+
db=db,
178+
db_model_class=ElectricalCellRecording,
179+
user_context=None,
180+
json_model=json_model,
181+
response_schema_class=ElectricalCellRecordingRead,
182+
apply_operations=_load,
183+
)

tests/test_electrical_cell_recording.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_create_one(
5656
assert data["etypes"] == []
5757

5858

59-
def test_update_one(client, trace_id_with_assets):
59+
def test_update_one(client, client_admin, trace_id_with_assets):
6060
new_name = "my_new_name"
6161
new_description = "my_new_description"
6262

@@ -92,8 +92,64 @@ def test_update_one(client, trace_id_with_assets):
9292
).json()
9393
assert data["temperature"] is None
9494

95+
# only admin client can hit admin endpoint
96+
data = assert_request(
97+
client.patch,
98+
url=f"{ADMIN_ROUTE}/{trace_id_with_assets}",
99+
json={
100+
"name": new_name,
101+
"description": new_description,
102+
},
103+
expected_status_code=403,
104+
).json()
105+
assert data["error_code"] == "NOT_AUTHORIZED"
106+
assert data["message"] == "Service admin role required"
107+
108+
data = assert_request(
109+
client_admin.patch,
110+
url=f"{ADMIN_ROUTE}/{trace_id_with_assets}",
111+
json={
112+
"name": new_name,
113+
"description": new_description,
114+
},
115+
).json()
95116

96-
def test_update_one__public(client, electrical_cell_recording_json_data):
117+
assert data["name"] == new_name
118+
assert data["description"] == new_description
119+
120+
# set temperature
121+
data = assert_request(
122+
client_admin.patch,
123+
url=f"{ADMIN_ROUTE}/{trace_id_with_assets}",
124+
json={
125+
"temperature": 10.0,
126+
},
127+
).json()
128+
assert data["temperature"] == 10.0
129+
130+
# unset temperature
131+
data = assert_request(
132+
client_admin.patch,
133+
url=f"{ADMIN_ROUTE}/{trace_id_with_assets}",
134+
json={
135+
"temperature": None,
136+
},
137+
).json()
138+
assert data["temperature"] is None
139+
140+
# admin is treated as regular user for regular route (no authorized project ids)
141+
data = assert_request(
142+
client_admin.patch,
143+
url=f"{ROUTE}/{trace_id_with_assets}",
144+
json={
145+
"temperature": None,
146+
},
147+
expected_status_code=404,
148+
).json()
149+
assert data["error_code"] == "ENTITY_NOT_FOUND"
150+
151+
152+
def test_user_update_one__public(client, client_admin, electrical_cell_recording_json_data):
97153
# make private entity public
98154
data = assert_request(
99155
client.post,
@@ -104,15 +160,29 @@ def test_update_one__public(client, electrical_cell_recording_json_data):
104160
},
105161
).json()
106162

163+
entity_id = data["id"]
164+
107165
# should not be allowed to update it once public
108166
data = assert_request(
109167
client.patch,
110-
url=f"{ROUTE}/{data['id']}",
168+
url=f"{ROUTE}/{entity_id}",
111169
json={"name": "foo"},
112170
expected_status_code=404,
113171
).json()
114172
assert data["error_code"] == "ENTITY_NOT_FOUND"
115173

174+
# admin has no such restrictions
175+
data = assert_request(
176+
client_admin.patch,
177+
url=f"{ADMIN_ROUTE}/{entity_id}",
178+
json={
179+
"authorized_public": False,
180+
"name": "foo",
181+
},
182+
).json()
183+
assert data["authorized_public"] is False
184+
assert data["name"] == "foo"
185+
116186

117187
def test_read_one(client, subject_id, license_id, brain_region_id, trace_id_with_assets):
118188
data = assert_request(

0 commit comments

Comments
 (0)