Skip to content

Commit eea210b

Browse files
Add MEModel-to-SONATA staging function (#118)
* stage memodel as sonata single-cell circuit * add h5py dependency * add unit-tests * mechanism_files added to DownloadedMEModel * Make generate_sonata_files_from_memodel protected * _generate_sonata_files_from_memodel additional changes * staging optional dependency added for tests * remove hoc_files --------- Co-authored-by: James Isbister <isbisterjb@gmail.com>
1 parent 04289c6 commit eea210b

File tree

10 files changed

+1582
-4
lines changed

10 files changed

+1582
-4
lines changed

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ classifiers = [
2929
]
3030
dynamic = ["version"]
3131

32+
[project.optional-dependencies]
33+
staging = [
34+
"h5py",
35+
]
36+
3237
[project.urls]
3338
documentation = "https://entitysdk.readthedocs.io/en/stable"
3439
repository = "https://github.com/openbraininstitute/entitysdk"

src/entitysdk/downloaders/memodel.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from entitysdk.downloaders.cell_morphology import download_morphology
88
from entitysdk.downloaders.emodel import download_hoc
99
from entitysdk.downloaders.ion_channel_model import download_ion_channel_mechanism
10-
from entitysdk.exception import IteratorResultError
10+
from entitysdk.exception import IteratorResultError, StagingError
1111
from entitysdk.models.emodel import EModel
1212
from entitysdk.models.memodel import MEModel
1313
from entitysdk.schemas.memodel import DownloadedMEModel
@@ -29,6 +29,9 @@ def download_memodel(client: Client, memodel: MEModel, output_dir=".") -> Downlo
2929
)
3030

3131
hoc_path = download_hoc(client, emodel, Path(output_dir) / "hoc")
32+
if not hoc_path.exists():
33+
raise StagingError(f"HOC does not exist: {hoc_path}")
34+
3235
# only take .asc format for now.
3336
# Will take specific format when morphology_format is integrated into MEModel
3437
try:
@@ -40,9 +43,14 @@ def download_memodel(client: Client, memodel: MEModel, output_dir=".") -> Downlo
4043
client, memodel.morphology, Path(output_dir) / "morphology", "swc"
4144
)
4245
mechanisms_dir = create_dir(Path(output_dir) / "mechanisms")
46+
mechanism_files = []
4347
for ic in emodel.ion_channel_models or []:
44-
download_ion_channel_mechanism(client, ic, mechanisms_dir)
48+
ion_channel_path = download_ion_channel_mechanism(client, ic, mechanisms_dir)
49+
mechanism_files.append(ion_channel_path.name)
4550

4651
return DownloadedMEModel(
47-
hoc_path=hoc_path, mechanisms_dir=mechanisms_dir, morphology_path=morphology_path
52+
hoc_path=hoc_path,
53+
mechanisms_dir=mechanisms_dir,
54+
mechanism_files=mechanism_files,
55+
morphology_path=morphology_path,
4856
)

src/entitysdk/schemas/memodel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ class DownloadedMEModel(Schema):
1010

1111
hoc_path: Path
1212
mechanisms_dir: Path
13+
mechanism_files: list[str]
1314
morphology_path: Path

src/entitysdk/staging/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
"""Staging functions."""
22

33
from entitysdk.staging.circuit import stage_circuit
4+
from entitysdk.staging.memodel import stage_sonata_from_memodel
45
from entitysdk.staging.simulation import stage_simulation
56
from entitysdk.staging.simulation_result import stage_simulation_result
67

7-
__all__ = ["stage_circuit", "stage_simulation", "stage_simulation_result"]
8+
__all__ = [
9+
"stage_circuit",
10+
"stage_sonata_from_memodel",
11+
"stage_simulation",
12+
"stage_simulation_result",
13+
]

src/entitysdk/staging/memodel.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
"""Staging functions for Single-Cell."""
2+
3+
import json
4+
import logging
5+
import shutil
6+
import tempfile
7+
from pathlib import Path
8+
9+
import h5py
10+
11+
from entitysdk.client import Client
12+
from entitysdk.downloaders.memodel import DownloadedMEModel, download_memodel
13+
from entitysdk.exception import StagingError
14+
from entitysdk.models.memodel import MEModel
15+
from entitysdk.utils.filesystem import create_dir
16+
from entitysdk.utils.io import write_json
17+
18+
L = logging.getLogger(__name__)
19+
20+
DEFAULT_CIRCUIT_CONFIG_FILENAME = "circuit_config.json"
21+
22+
23+
def stage_sonata_from_memodel(
24+
client: Client,
25+
memodel: MEModel,
26+
output_dir: Path = Path("."),
27+
) -> Path:
28+
"""Stages a SONATA single-cell circuit from an MEModel entity.
29+
30+
Downloads the MEModel and converts it into SONATA circuit format.
31+
32+
Returns:
33+
Path to generated circuit_config.json (inside SONATA folder).
34+
"""
35+
with tempfile.TemporaryDirectory() as tmp_dir:
36+
downloaded_me_model = download_memodel(client, memodel=memodel, output_dir=tmp_dir)
37+
38+
if memodel.mtypes is None:
39+
raise StagingError(f"MEModel {memodel.id} has no mtypes defined.")
40+
41+
mtype = memodel.mtypes[0].pref_label
42+
43+
if memodel.calibration_result is None:
44+
raise StagingError(f"MEModel {memodel.id} has no calibration result.")
45+
46+
threshold_current = memodel.calibration_result.threshold_current
47+
holding_current = memodel.calibration_result.holding_current
48+
49+
_generate_sonata_files_from_memodel(
50+
downloaded_memodel=downloaded_me_model,
51+
output_path=output_dir,
52+
mtype=mtype,
53+
threshold_current=threshold_current,
54+
holding_current=holding_current,
55+
)
56+
57+
config_path = output_dir / DEFAULT_CIRCUIT_CONFIG_FILENAME
58+
59+
L.info("Single-Cell %s staged at %s", memodel.id, config_path)
60+
61+
return config_path
62+
63+
64+
def _generate_sonata_files_from_memodel(
65+
downloaded_memodel: DownloadedMEModel,
66+
output_path: Path,
67+
mtype: str,
68+
threshold_current: float,
69+
holding_current: float,
70+
):
71+
"""Generate SONATA single cell circuit structure from a downloaded MEModel folder.
72+
73+
Args:
74+
downloaded_memodel (DownloadedMEModel): The downloaded MEModel object.
75+
output_path (str or Path): Path to the output 'sonata' folder.
76+
mtype (str): Cell mtype.
77+
threshold_current (float): Threshold current.
78+
holding_current (float): Holding current.
79+
"""
80+
subdirs = {
81+
"hocs": output_path / "hocs",
82+
"mechanisms": output_path / "mechanisms",
83+
"morphologies": output_path / "morphologies",
84+
"network": output_path / "network",
85+
}
86+
for path in subdirs.values():
87+
create_dir(path)
88+
89+
# Copy hoc file
90+
hoc_file = downloaded_memodel.hoc_path
91+
if not downloaded_memodel.hoc_path.exists():
92+
raise FileNotFoundError(f"No HOC file found {downloaded_memodel.hoc_path}")
93+
hoc_dst = subdirs["hocs"] / hoc_file.name
94+
shutil.copy(hoc_file, hoc_dst)
95+
96+
# Copy morphology file
97+
if not downloaded_memodel.morphology_path.exists():
98+
raise FileNotFoundError(f"No morphology file found {downloaded_memodel.morphology_path}")
99+
morph_dst = subdirs["morphologies"] / downloaded_memodel.morphology_path.name
100+
shutil.copy(downloaded_memodel.morphology_path, morph_dst)
101+
102+
# Copy mechanisms
103+
for file in downloaded_memodel.mechanism_files:
104+
src_path = downloaded_memodel.mechanisms_dir / file
105+
if Path(src_path).exists():
106+
target = subdirs["mechanisms"] / file
107+
shutil.copy(src_path, target)
108+
109+
create_nodes_file(
110+
hoc_file=str(hoc_dst),
111+
morph_file=str(morph_dst),
112+
output_file=Path(str(subdirs["network"])) / "nodes.h5",
113+
mtype=mtype,
114+
threshold_current=threshold_current,
115+
holding_current=holding_current,
116+
)
117+
118+
create_circuit_config(output_path=output_path)
119+
create_node_sets_file(output_file=output_path / "node_sets.json")
120+
121+
L.debug(f"SONATA single cell circuit created at {output_path}")
122+
123+
124+
def create_nodes_file(
125+
hoc_file: str,
126+
morph_file: str,
127+
output_file: Path,
128+
mtype: str,
129+
threshold_current: float,
130+
holding_current: float,
131+
):
132+
"""Create a SONATA nodes.h5 file for a single cell population.
133+
134+
Args:
135+
hoc_file (str): Path to the hoc file.
136+
morph_file (str): Path to the morphology file.
137+
output_file (Path): Output file path for nodes.h5.
138+
mtype (str): Cell mtype.
139+
threshold_current (float): Threshold current value.
140+
holding_current (float): Holding current value.
141+
"""
142+
output_file = Path(output_file) # ensure Path type
143+
output_file.parent.mkdir(parents=True, exist_ok=True)
144+
with h5py.File(output_file, "w") as f:
145+
nodes = f.create_group("nodes")
146+
population = nodes.create_group("All")
147+
population.create_dataset("node_type_id", (1,), dtype="int64")[0] = -1
148+
group_0 = population.create_group("0")
149+
150+
# Add dynamics_params fields
151+
dynamics = group_0.create_group("dynamics_params")
152+
dynamics.create_dataset("holding_current", (1,), dtype="float32")[0] = holding_current
153+
dynamics.create_dataset("threshold_current", (1,), dtype="float32")[0] = threshold_current
154+
155+
# Standard string properties
156+
group_0.create_dataset("model_template", (1,), dtype=h5py.string_dtype())[0] = (
157+
f"hoc:{Path(hoc_file).stem}"
158+
)
159+
group_0.create_dataset("model_type", (1,), dtype="int32")[0] = 0
160+
group_0.create_dataset("morph_class", (1,), dtype="int32")[0] = 0
161+
group_0.create_dataset("morphology", (1,), dtype=h5py.string_dtype())[0] = (
162+
f"morphologies/{Path(morph_file).stem}"
163+
)
164+
group_0.create_dataset("mtype", (1,), dtype=h5py.string_dtype())[0] = mtype
165+
166+
# Coordinates and rotation
167+
for name in [
168+
"x",
169+
"y",
170+
"z",
171+
"rotation_angle_xaxis",
172+
"rotation_angle_yaxis",
173+
"rotation_angle_zaxis",
174+
]:
175+
group_0.create_dataset(name, (1,), dtype="float32")[0] = 0.0
176+
177+
# Quaternion orientation
178+
orientation = {
179+
"orientation_w": 1.0,
180+
"orientation_x": 0.0,
181+
"orientation_y": 0.0,
182+
"orientation_z": 0.0,
183+
}
184+
for name, value in orientation.items():
185+
group_0.create_dataset(name, (1,), dtype="float64")[0] = value
186+
187+
# Optional fields
188+
group_0.create_dataset("morphology_producer", (1,), dtype=h5py.string_dtype())[0] = (
189+
"biologic"
190+
)
191+
192+
L.debug(f"Successfully created file at {output_file}")
193+
194+
195+
def create_circuit_config(output_path: Path, node_population_name: str = "All"):
196+
"""Create a SONATA circuit_config.json for a single cell.
197+
198+
Args:
199+
output_path (str): Directory where circuit_config.json will be written.
200+
node_population_name (str): Name of the node population (default: 'All').
201+
"""
202+
config = {
203+
"manifest": {"$BASE_DIR": "."},
204+
"node_sets_file": "$BASE_DIR/node_sets.json",
205+
"networks": {
206+
"nodes": [
207+
{
208+
"nodes_file": "$BASE_DIR/network/nodes.h5",
209+
"populations": {
210+
node_population_name: {
211+
"type": "biophysical",
212+
"morphologies_dir": "$BASE_DIR/morphologies",
213+
"biophysical_neuron_models_dir": "$BASE_DIR/hocs",
214+
"alternate_morphologies": {"neurolucida-asc": "$BASE_DIR/"},
215+
}
216+
},
217+
}
218+
],
219+
"edges": [],
220+
},
221+
}
222+
config_path = output_path / "circuit_config.json"
223+
with open(config_path, "w") as f:
224+
json.dump(config, f, indent=2)
225+
L.debug(f"Successfully created circuit_config.json at {config_path}")
226+
227+
228+
def create_node_sets_file(
229+
output_file: Path,
230+
node_population_name: str = "All",
231+
node_set_name: str = "All",
232+
node_id: int = 0,
233+
):
234+
"""Create a node_sets.json file for a single cell.
235+
236+
Args:
237+
output_file (Path): Output file path for node_sets.json.
238+
node_population_name (str): Name of the node population (default: 'All').
239+
node_set_name (str): Name of the node set (default: 'All').
240+
node_id (int): Node ID to include (default: 0).
241+
"""
242+
node_sets = {node_set_name: {"population": node_population_name, "node_id": [node_id]}}
243+
write_json(node_sets, output_file)
244+
L.debug(f"Successfully created node_sets.json at {output_file}")

tests/unit/downloaders/test_memodel.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import os
22
import uuid
33

4+
import pytest
5+
46
from entitysdk.downloaders.memodel import download_memodel
7+
from entitysdk.exception import IteratorResultError, StagingError
58
from entitysdk.models.cell_morphology import CellMorphology
69
from entitysdk.models.emodel import EModel
710
from entitysdk.models.memodel import MEModel
@@ -200,3 +203,55 @@ def test_download_memodel(
200203
assert downloaded_memodel.morphology_path.is_file()
201204
assert downloaded_memodel.mechanisms_dir.is_dir()
202205
assert len(os.listdir(downloaded_memodel.mechanisms_dir)) == 1
206+
207+
208+
class DummyClient:
209+
def get_entity(self, entity_id, entity_type):
210+
class DummyEModel:
211+
ion_channel_models = []
212+
id = "dummy_id"
213+
214+
return DummyEModel()
215+
216+
217+
def test_download_memodel_hoc_missing(tmp_path):
218+
class DummyMEModel:
219+
emodel = type("EModel", (), {"id": "dummy_id"})()
220+
morphology = "dummy_morphology"
221+
222+
def dummy_download_hoc(client, emodel, path):
223+
return tmp_path / "nonexistent_hoc_file.hoc"
224+
225+
import entitysdk.downloaders.memodel as memodel_mod
226+
227+
memodel_mod.download_hoc = dummy_download_hoc
228+
with pytest.raises(StagingError) as excinfo:
229+
download_memodel(DummyClient(), DummyMEModel(), tmp_path)
230+
assert "HOC does not exist" in str(excinfo.value)
231+
232+
233+
def test_download_memodel_morphology_asc_fallback_to_swc(tmp_path):
234+
class DummyMEModel:
235+
emodel = type("EModel", (), {"id": "dummy_id"})()
236+
morphology = "dummy_morphology"
237+
238+
def dummy_download_hoc(client, emodel, path):
239+
hoc_file = tmp_path / "dummy.hoc"
240+
hoc_file.write_text("hoc")
241+
return hoc_file
242+
243+
def dummy_download_morphology(client, morphology, path, fmt):
244+
if fmt == "asc":
245+
raise IteratorResultError("asc not available")
246+
elif fmt == "swc":
247+
swc_file = path / "dummy.swc"
248+
swc_file.parent.mkdir(parents=True, exist_ok=True)
249+
swc_file.write_text("swc")
250+
return swc_file
251+
252+
import entitysdk.downloaders.memodel as memodel_mod
253+
254+
memodel_mod.download_hoc = dummy_download_hoc
255+
memodel_mod.download_morphology = dummy_download_morphology
256+
result = download_memodel(DummyClient(), DummyMEModel(), tmp_path)
257+
assert result.morphology_path.name == "dummy.swc"

0 commit comments

Comments
 (0)