Skip to content

Commit 713892d

Browse files
Fixes for SimulationCampaign (#360)
* Add entity_id filter in SimulationCampaign * Fix SimulationCampaign simulation prefix * Add circuit filter
1 parent 4d975f6 commit 713892d

File tree

5 files changed

+103
-24
lines changed

5 files changed

+103
-24
lines changed

app/filters/simulation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@
1313
class SimulationFilterBase(NameFilterMixin, IdFilterMixin, CustomFilter):
1414
entity_id: uuid.UUID | None = None
1515
entity_id__in: list[uuid.UUID] | None = None
16-
circuit: Annotated[NestedCircuitFilter | None, NestedCircuitFilterDep] = None
1716

1817

1918
class NestedSimulationFilter(SimulationFilterBase):
19+
circuit: Annotated[
20+
NestedCircuitFilter | None,
21+
FilterDepends(with_prefix("simulation__circuit", NestedCircuitFilter)),
22+
]
23+
2024
class Constants(CustomFilter.Constants):
2125
model = Simulation
2226

@@ -25,6 +29,8 @@ class SimulationFilter(EntityFilterMixin, SimulationFilterBase):
2529
simulation_campaign_id: uuid.UUID | None = None
2630
simulation_campaign_id__in: list[uuid.UUID] | None = None
2731

32+
circuit: Annotated[NestedCircuitFilter | None, NestedCircuitFilterDep] = None
33+
2834
order_by: list[str] = ["-creation_date"] # noqa: RUF012
2935

3036
class Constants(CustomFilter.Constants):

app/filters/simulation_campaign.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,29 @@
1+
import uuid
12
from typing import Annotated
23

4+
from fastapi_filter import with_prefix
5+
36
from app.db.model import SimulationCampaign
47
from app.dependencies.filter import FilterDepends
58
from app.filters.base import CustomFilter
9+
from app.filters.circuit import NestedCircuitFilter, NestedCircuitFilterDep
610
from app.filters.common import EntityFilterMixin, NameFilterMixin
7-
from app.filters.simulation import NestedSimulationFilter, NestedSimulationFilterDep
11+
from app.filters.simulation import NestedSimulationFilter
812

913

1014
class SimulationCampaignFilter(CustomFilter, EntityFilterMixin, NameFilterMixin):
11-
simulation: Annotated[NestedSimulationFilter | None, NestedSimulationFilterDep] = None
15+
entity_id: uuid.UUID | None = None
16+
entity_id__in: list[uuid.UUID] | None = None
17+
18+
circuit: Annotated[
19+
NestedCircuitFilter | None,
20+
NestedCircuitFilterDep,
21+
] = None
22+
23+
simulation: Annotated[
24+
NestedSimulationFilter | None,
25+
FilterDepends(with_prefix("simulation", NestedSimulationFilter)),
26+
] = None
1227

1328
order_by: list[str] = ["-creation_date"] # noqa: RUF012
1429

app/queries/factory.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def _get_alias[T: type[Identifiable]](db_cls: T, name: str | None = None) -> T:
129129
"post_region": {"id": post_region_alias.id, "label": post_region_alias.name},
130130
"simulation": {"id": simulation_alias.id, "label": simulation_alias.name},
131131
"simulation.circuit": {"id": circuit_alias.id, "label": circuit_alias.name},
132+
"circuit": {"id": circuit_alias.id, "label": circuit_alias.name},
132133
"ion_channel": {"id": ion_channel_alias.id, "label": ion_channel_alias.label},
133134
"em_dense_reconstruction_dataset": {
134135
"id": em_dense_reconstruction_dataset_alias.id,
@@ -219,9 +220,10 @@ def _get_alias[T: type[Identifiable]](db_cls: T, name: str | None = None) -> T:
219220
"simulation": lambda q: q.outerjoin(
220221
simulation_alias, db_model_class.id == simulation_alias.simulation_campaign_id
221222
),
222-
"simulation.circuit": lambda q: q.join(
223+
"simulation.circuit": lambda q: q.outerjoin(
223224
circuit_alias, simulation_alias.entity_id == circuit_alias.id
224225
),
226+
"circuit": lambda q: q.join(circuit_alias, db_model_class.entity_id == circuit_alias.id),
225227
"used": lambda q: q.outerjoin(
226228
Usage, db_model_class.id == Usage.usage_activity_id
227229
).outerjoin(used_alias, Usage.usage_entity_id == used_alias.id),

app/service/simulation_campaign.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def read_many(
155155
"created_by",
156156
"updated_by",
157157
"contribution",
158+
"circuit",
158159
"simulation",
159160
"simulation.circuit",
160161
]

tests/test_simulation_campaign.py

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,7 @@ def multiple_circuits(db, brain_atlas_id, subject_id, brain_region_id, license_i
226226

227227

228228
@pytest.fixture
229-
def campaigns_with_different_circuits(
230-
db, json_data, person_id, simulation_json_data, multiple_circuits
231-
):
229+
def campaigns_with_different_circuits(db, json_data, person_id, multiple_circuits):
232230
campaigns = []
233231

234232
for i, circuit in enumerate(multiple_circuits):
@@ -238,6 +236,7 @@ def campaigns_with_different_circuits(
238236
**(
239237
json_data
240238
| {
239+
"entity_id": str(circuit.id),
241240
"name": f"campaign-circuit-{i}",
242241
"description": f"Campaign for circuit {i}",
243242
"created_by_id": person_id,
@@ -252,15 +251,14 @@ def campaigns_with_different_circuits(
252251
add_db(
253252
db,
254253
Simulation(
255-
**simulation_json_data
256-
| {
257-
"name": f"simulation-circuit-{i}",
258-
"simulation_campaign_id": campaign.id,
259-
"entity_id": circuit.id,
260-
"created_by_id": person_id,
261-
"updated_by_id": person_id,
262-
"authorized_project_id": PROJECT_ID,
263-
}
254+
name=f"simulation-{i}",
255+
description=f"simulation-description-{i}",
256+
entity_id=circuit.id,
257+
simulation_campaign_id=campaign.id,
258+
created_by_id=person_id,
259+
updated_by_id=person_id,
260+
authorized_project_id=PROJECT_ID,
261+
scan_parameters={"foo1": "bar1", "foo2": "bar2"},
264262
),
265263
)
266264

@@ -269,45 +267,64 @@ def campaigns_with_different_circuits(
269267

270268
def test_filter_by_circuit_id(client, campaigns_with_different_circuits, multiple_circuits): # noqa: ARG001
271269
first_circuit_id = str(multiple_circuits[0].id)
272-
data = assert_request(client.get, url=ROUTE, params={"circuit__id": first_circuit_id}).json()[
270+
271+
data = assert_request(client.get, url=ROUTE, params={"entity_id": first_circuit_id}).json()[
273272
"data"
274273
]
274+
assert len(data) == 1
275+
assert data[0]["name"] == "campaign-circuit-0"
276+
277+
data = assert_request(
278+
client.get, url=ROUTE, params={"simulation__circuit__id": first_circuit_id}
279+
).json()["data"]
275280

276281
assert len(data) == 1
277282
assert data[0]["name"] == "campaign-circuit-0"
278283

279284
second_circuit_id = str(multiple_circuits[1].id)
280-
data = assert_request(client.get, url=ROUTE, params={"circuit__id": second_circuit_id}).json()[
281-
"data"
282-
]
285+
data = assert_request(
286+
client.get, url=ROUTE, params={"simulation__circuit__id": second_circuit_id}
287+
).json()["data"]
283288

284289
assert len(data) == 1
285290
assert data[0]["name"] == "campaign-circuit-1"
286291

287292

288293
def test_filter_by_circuit_name(client, campaigns_with_different_circuits, multiple_circuits): # noqa: ARG001
289294
data = assert_request(
290-
client.get, url=ROUTE, params={"circuit__name": "micro-circuit-1"}
295+
client.get, url=ROUTE, params={"simulation__circuit__name": "micro-circuit-1"}
291296
).json()["data"]
292297

293298
assert len(data) == 1
294299
assert data[0]["name"] == "campaign-circuit-0"
295300

296301
data = assert_request(
297-
client.get, url=ROUTE, params={"circuit__name__in": "micro-circuit-2"}
302+
client.get, url=ROUTE, params={"simulation__circuit__name__in": "micro-circuit-2"}
298303
).json()["data"]
299304

300305
assert len(data) == 1
301306
assert data[0]["name"] == "campaign-circuit-1"
302307

303308

304309
def test_filter_by_circuit_scale(client, campaigns_with_different_circuits, multiple_circuits): # noqa: ARG001
310+
data = assert_request(
311+
client.get, url=ROUTE, params={"simulation__circuit__scale": CircuitScale.microcircuit}
312+
).json()["data"]
313+
314+
assert len(data) == 2
315+
305316
data = assert_request(
306317
client.get, url=ROUTE, params={"circuit__scale": CircuitScale.microcircuit}
307318
).json()["data"]
308319

309320
assert len(data) == 2
310321

322+
data = assert_request(
323+
client.get, url=ROUTE, params={"simulation__circuit__scale": CircuitScale.pair}
324+
).json()["data"]
325+
326+
assert len(data) == 1
327+
311328
data = assert_request(
312329
client.get, url=ROUTE, params={"circuit__scale": CircuitScale.pair}
313330
).json()["data"]
@@ -320,6 +337,12 @@ def test_filter_by_circuit_scale_empty(
320337
campaigns_with_different_circuits, # noqa: ARG001
321338
multiple_circuits, # noqa: ARG001
322339
):
340+
data = assert_request(
341+
client.get, url=ROUTE, params={"simulation__circuit__scale": CircuitScale.small}
342+
).json()["data"]
343+
344+
assert len(data) == 0
345+
323346
data = assert_request(
324347
client.get, url=ROUTE, params={"circuit__scale": CircuitScale.small}
325348
).json()["data"]
@@ -331,7 +354,7 @@ def test_filter_by_circuit_scale_in(client, campaigns_with_different_circuits, m
331354
data = assert_request(
332355
client.get,
333356
url=ROUTE,
334-
params={"circuit__scale__in": [CircuitScale.microcircuit, CircuitScale.pair]},
357+
params={"simulation__circuit__scale__in": [CircuitScale.microcircuit, CircuitScale.pair]},
335358
).json()["data"]
336359

337360
assert len(data) == 3
@@ -344,6 +367,16 @@ def test_filter_by_circuit_build_category(
344367
campaigns_with_different_circuits, # noqa: ARG001
345368
multiple_circuits, # noqa: ARG001
346369
):
370+
data = assert_request(
371+
client.get,
372+
url=ROUTE,
373+
params={"simulation__circuit__build_category": CircuitBuildCategory.computational_model},
374+
).json()["data"]
375+
376+
assert len(data) == 2
377+
campaign_names = {campaign["name"] for campaign in data}
378+
assert campaign_names == {"campaign-circuit-0", "campaign-circuit-2"}
379+
347380
data = assert_request(
348381
client.get,
349382
url=ROUTE,
@@ -354,6 +387,15 @@ def test_filter_by_circuit_build_category(
354387
campaign_names = {campaign["name"] for campaign in data}
355388
assert campaign_names == {"campaign-circuit-0", "campaign-circuit-2"}
356389

390+
data = assert_request(
391+
client.get,
392+
url=ROUTE,
393+
params={"simulation__circuit__build_category": CircuitBuildCategory.em_reconstruction},
394+
).json()["data"]
395+
396+
assert len(data) == 1
397+
assert data[0]["name"] == "campaign-circuit-1"
398+
357399
data = assert_request(
358400
client.get,
359401
url=ROUTE,
@@ -373,7 +415,7 @@ def test_filter_by_circuit_build_category_in(
373415
client.get,
374416
url=ROUTE,
375417
params={
376-
"circuit__build_category__in": [
418+
"simulation__circuit__build_category__in": [
377419
CircuitBuildCategory.computational_model,
378420
CircuitBuildCategory.em_reconstruction,
379421
],
@@ -383,3 +425,16 @@ def test_filter_by_circuit_build_category_in(
383425
assert len(data) == 3
384426
campaign_names = {campaign["name"] for campaign in data}
385427
assert campaign_names == {"campaign-circuit-0", "campaign-circuit-1", "campaign-circuit-2"}
428+
429+
data = assert_request(
430+
client.get,
431+
url=ROUTE,
432+
params={
433+
"circuit__build_category__in": [
434+
CircuitBuildCategory.computational_model,
435+
CircuitBuildCategory.em_reconstruction,
436+
],
437+
},
438+
).json()["data"]
439+
440+
assert len(data) == 3

0 commit comments

Comments
 (0)