Skip to content

Commit 739c9fc

Browse files
ilkilicjames-isbistereleftherioszisis
authored
Add MEModel support to stage_simulation (#138)
--------- Co-authored-by: James Isbister <isbisterjb@gmail.com> Co-authored-by: Eleftherios Zisis <eleftherios.zisis@openbraininstitute.org>
1 parent 1cf1fc2 commit 739c9fc

File tree

13 files changed

+1674
-35
lines changed

13 files changed

+1674
-35
lines changed

src/entitysdk/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
SlicingDirectionType,
2020
)
2121
from entitysdk.models.emodel import EModel
22+
from entitysdk.models.etype import ETypeClass
2223
from entitysdk.models.ion_channel import IonChannel
2324
from entitysdk.models.ion_channel_model import IonChannelModel, NeuronBlock, UseIon
2425
from entitysdk.models.ion_channel_recording import IonChannelRecording
@@ -59,6 +60,7 @@
5960
"EMCellMeshType",
6061
"EMDenseReconstructionDataset",
6162
"EModel",
63+
"ETypeClass",
6264
"ETypeClassification",
6365
"IonChannel",
6466
"IonChannelModel",

src/entitysdk/models/etype.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ class ETypeClass(Identifiable):
2525
alt_label: Annotated[
2626
str | None,
2727
Field(description="The alternative label of th etype class."),
28-
]
28+
] = None

src/entitysdk/staging/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""Staging constants."""
2+
3+
DEFAULT_NODE_POPULATION_NAME = "All"
4+
DEFAULT_NODE_SET_NAME = "All"

src/entitysdk/staging/memodel.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Staging functions for Single-Cell."""
22

3-
import json
43
import logging
54
import shutil
65
import tempfile
@@ -12,6 +11,10 @@
1211
from entitysdk.downloaders.memodel import DownloadedMEModel, download_memodel
1312
from entitysdk.exception import StagingError
1413
from entitysdk.models.memodel import MEModel
14+
from entitysdk.staging.constants import (
15+
DEFAULT_NODE_POPULATION_NAME,
16+
DEFAULT_NODE_SET_NAME,
17+
)
1518
from entitysdk.utils.filesystem import create_dir
1619
from entitysdk.utils.io import write_json
1720

@@ -143,7 +146,7 @@ def create_nodes_file(
143146
output_file.parent.mkdir(parents=True, exist_ok=True)
144147
with h5py.File(output_file, "w") as f:
145148
nodes = f.create_group("nodes")
146-
population = nodes.create_group("All")
149+
population = nodes.create_group(DEFAULT_NODE_POPULATION_NAME)
147150
population.create_dataset("node_type_id", (1,), dtype="int64")[0] = -1
148151
group_0 = population.create_group("0")
149152

@@ -192,12 +195,15 @@ def create_nodes_file(
192195
L.debug(f"Successfully created file at {output_file}")
193196

194197

195-
def create_circuit_config(output_path: Path, node_population_name: str = "All"):
198+
def create_circuit_config(
199+
output_path: Path,
200+
node_population_name: str = DEFAULT_NODE_POPULATION_NAME,
201+
):
196202
"""Create a SONATA circuit_config.json for a single cell.
197203
198204
Args:
199-
output_path (str): Directory where circuit_config.json will be written.
200-
node_population_name (str): Name of the node population (default: 'All').
205+
output_path: Directory where circuit_config.json will be written.
206+
node_population_name: Name of the node population.
201207
"""
202208
config = {
203209
"manifest": {"$BASE_DIR": "."},
@@ -219,25 +225,24 @@ def create_circuit_config(output_path: Path, node_population_name: str = "All"):
219225
"edges": [],
220226
},
221227
}
222-
config_path = output_path / "circuit_config.json"
223-
with open(config_path, "w") as f:
224-
json.dump(config, f, indent=2)
228+
config_path = output_path / DEFAULT_CIRCUIT_CONFIG_FILENAME
229+
write_json(data=config, path=config_path, indent=2)
225230
L.debug(f"Successfully created circuit_config.json at {config_path}")
226231

227232

228233
def create_node_sets_file(
229234
output_file: Path,
230-
node_population_name: str = "All",
231-
node_set_name: str = "All",
235+
node_population_name: str = DEFAULT_NODE_POPULATION_NAME,
236+
node_set_name: str = DEFAULT_NODE_SET_NAME,
232237
node_id: int = 0,
233238
):
234239
"""Create a node_sets.json file for a single cell.
235240
236241
Args:
237-
output_file (Path): Output file path for node_sets.json.
238-
node_population_name (str): Name of the node population (default: 'All').
239-
node_set_name (str): Name of the node set (default: 'All').
240-
node_id (int): Node ID to include (default: 0).
242+
output_file: Output file path for node_sets.json.
243+
node_population_name: Name of the node population.
244+
node_set_name: Name of the node set (default: MEMODEL_CIRCUIT_STAGING_NODE_SET_NAME).
245+
node_id: Node ID to include (default: 0).
241246
"""
242247
node_sets = {node_set_name: {"population": node_population_name, "node_id": [node_id]}}
243248
write_json(node_sets, output_file)

src/entitysdk/staging/simulation.py

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,22 @@
44
from copy import deepcopy
55
from pathlib import Path
66

7+
from entitysdk._server_schemas import EntityType as EntityType
78
from entitysdk.client import Client
89
from entitysdk.downloaders.simulation import (
910
download_node_sets_file,
1011
download_simulation_config_content,
1112
download_spike_replay_files,
1213
)
1314
from entitysdk.exception import StagingError
14-
from entitysdk.models import Circuit, Simulation
15+
from entitysdk.models import Circuit, MEModel, Simulation
16+
from entitysdk.models.entity import Entity
1517
from entitysdk.staging.circuit import stage_circuit
18+
from entitysdk.staging.constants import (
19+
DEFAULT_NODE_POPULATION_NAME,
20+
DEFAULT_NODE_SET_NAME,
21+
)
22+
from entitysdk.staging.memodel import stage_sonata_from_memodel
1623
from entitysdk.types import StrOrPath
1724
from entitysdk.utils.filesystem import create_dir
1825
from entitysdk.utils.io import write_json
@@ -46,13 +53,7 @@ def stage_simulation(
4653
The path to the staged simulation config file.
4754
"""
4855
output_dir = create_dir(output_dir).resolve()
49-
5056
simulation_config: dict = download_simulation_config_content(client, model=model)
51-
node_sets_file: Path = download_node_sets_file(
52-
client,
53-
model=model,
54-
output_path=output_dir / DEFAULT_NODE_SETS_FILENAME,
55-
)
5657
spike_paths: list[Path] = download_spike_replay_files(
5758
client,
5859
model=model,
@@ -63,13 +64,48 @@ def stage_simulation(
6364
"Circuit config path was not provided. Circuit is going to be staged from metadata. "
6465
"Circuit id to be staged: %s"
6566
)
66-
circuit_config_path = stage_circuit(
67+
base_entity = client.get_entity(entity_id=model.entity_id, entity_type=Entity)
68+
match base_entity.type:
69+
case EntityType.memodel:
70+
memodel = client.get_entity(entity_id=model.entity_id, entity_type=MEModel)
71+
L.info(
72+
"Staging single-cell SONATA circuit from MEModel %s",
73+
memodel.id,
74+
)
75+
node_sets_file = _stage_single_cell_node_sets_file(
76+
node_set_name=simulation_config.get("node_set", DEFAULT_NODE_SET_NAME),
77+
output_path=output_dir / DEFAULT_NODE_SETS_FILENAME,
78+
)
79+
circuit_config_path = stage_sonata_from_memodel(
80+
client,
81+
memodel=memodel,
82+
output_dir=create_dir(output_dir / DEFAULT_CIRCUIT_DIR),
83+
)
84+
case EntityType.circuit:
85+
circuit = client.get_entity(entity_id=model.entity_id, entity_type=Circuit)
86+
L.info(
87+
"Staging SONATA circuit from Circuit %s",
88+
circuit.id,
89+
)
90+
node_sets_file = download_node_sets_file(
91+
client,
92+
model=model,
93+
output_path=output_dir / DEFAULT_NODE_SETS_FILENAME,
94+
)
95+
circuit_config_path = stage_circuit(
96+
client,
97+
model=circuit,
98+
output_dir=create_dir(output_dir / DEFAULT_CIRCUIT_DIR),
99+
)
100+
case _:
101+
raise StagingError(
102+
f"Simulation {model.id} references unsupported type {base_entity.type}"
103+
)
104+
else:
105+
node_sets_file = download_node_sets_file(
67106
client,
68-
model=client.get_entity(
69-
entity_id=model.entity_id,
70-
entity_type=Circuit,
71-
),
72-
output_dir=create_dir(output_dir / DEFAULT_CIRCUIT_DIR),
107+
model=model,
108+
output_path=output_dir / DEFAULT_NODE_SETS_FILENAME,
73109
)
74110

75111
transformed_simulation_config: dict = _transform_simulation_config(
@@ -82,17 +118,28 @@ def stage_simulation(
82118
)
83119

84120
output_simulation_config_file = output_dir / DEFAULT_SIMULATION_CONFIG_FILENAME
85-
86-
write_json(
87-
data=transformed_simulation_config,
88-
path=output_simulation_config_file,
89-
)
121+
write_json(data=transformed_simulation_config, path=output_simulation_config_file)
90122

91123
L.info("Staged Simulation %s at %s", model.id, output_dir)
92-
93124
return output_simulation_config_file
94125

95126

127+
def _stage_single_cell_node_sets_file(
128+
node_set_name: str,
129+
output_path: Path,
130+
) -> Path:
131+
write_json(
132+
{
133+
node_set_name: {
134+
"population": DEFAULT_NODE_POPULATION_NAME,
135+
"node_id": [0],
136+
}
137+
},
138+
output_path,
139+
)
140+
return output_path
141+
142+
96143
def _transform_simulation_config(
97144
simulation_config: dict,
98145
circuit_config_path: Path,

0 commit comments

Comments
 (0)