Skip to content

Add support for compartment reports (soma + named section) as well as spike reports #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions bluecellulab/circuit/config/sonata_simulation_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
from __future__ import annotations
from functools import lru_cache
import json
import logging
from pathlib import Path
from typing import Optional

Expand All @@ -21,6 +23,8 @@

from bluepysnap import Simulation as SnapSimulation

logger = logging.getLogger(__name__)


class SonataSimulationConfig:
"""Sonata implementation of SimulationConfig protocol."""
Expand Down Expand Up @@ -74,9 +78,42 @@
result.append(ConnectionOverrides.from_sonata(conn_entry))
return result

@lru_cache(maxsize=1)
def get_compartment_sets(self) -> dict[str, dict]:
filepath = self.impl.config.get("compartment_sets_file")
if not filepath:
raise ValueError("No 'compartment_sets_file' entry found in SONATA config.")

Check warning on line 85 in bluecellulab/circuit/config/sonata_simulation_config.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit/config/sonata_simulation_config.py#L85

Added line #L85 was not covered by tests
with open(filepath, 'r') as f:
return json.load(f)

@lru_cache(maxsize=1)
def get_node_sets(self) -> dict[str, dict]:
filepath = self.impl.circuit.config.get("node_sets_file")
if not filepath:
raise ValueError("No 'node_sets_file' entry found in SONATA config.")

Check warning on line 93 in bluecellulab/circuit/config/sonata_simulation_config.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit/config/sonata_simulation_config.py#L93

Added line #L93 was not covered by tests
with open(filepath, 'r') as f:
return json.load(f)

@lru_cache(maxsize=1)
def get_report_entries(self) -> dict[str, dict]:
"""Returns the 'reports' dictionary from the SONATA simulation config.

Each key is a report name, and the value is its configuration.
"""
reports = self.impl.config.get("reports", {})
if not isinstance(reports, dict):
raise ValueError("Invalid format for 'reports' in SONATA config.")

Check warning on line 105 in bluecellulab/circuit/config/sonata_simulation_config.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit/config/sonata_simulation_config.py#L105

Added line #L105 was not covered by tests
return reports

def connection_entries(self) -> list[ConnectionOverrides]:
return self._connection_entries() + self._connection_overrides

def report_file_path(self, report_cfg: dict, report_key: str) -> Path:
"""Resolve the full path for the report output file."""
output_dir = Path(self.output_root_path)
file_name = report_cfg.get("file_name", f"{report_key}.h5")
return output_dir / file_name

Check warning on line 115 in bluecellulab/circuit/config/sonata_simulation_config.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit/config/sonata_simulation_config.py#L113-L115

Added lines #L113 - L115 were not covered by tests

@property
def base_seed(self) -> int:
return self.impl.run.random_seed
Expand Down Expand Up @@ -137,6 +174,12 @@
def output_root_path(self) -> str:
return self.impl.config["output"]["output_dir"]

@property
def spikes_file_path(self) -> Path:
output_dir = Path(self.output_root_path)
spikes_file = self.impl.config.get("output", {}).get("spikes_file", "spikes.h5")
return output_dir / spikes_file

Check warning on line 181 in bluecellulab/circuit/config/sonata_simulation_config.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit/config/sonata_simulation_config.py#L179-L181

Added lines #L179 - L181 were not covered by tests

@property
def extracellular_calcium(self) -> Optional[float]:
return self.condition_parameters().extracellular_calcium
Expand Down
91 changes: 90 additions & 1 deletion bluecellulab/circuit/iotools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
from __future__ import annotations
from pathlib import Path
import logging
from typing import List

import bluepy
import numpy as np
import h5py

from bluecellulab.tools import resolve_segments
from bluecellulab.cell.cell_dict import CellDict
from bluecellulab.circuit.node_id import CellId

logger = logging.getLogger(__name__)
Expand All @@ -28,6 +31,7 @@
def parse_outdat(path: str | Path) -> dict[CellId, np.ndarray]:
"""Parse the replay spiketrains in a out.dat formatted file pointed to by
path."""
import bluepy

Check warning on line 34 in bluecellulab/circuit/iotools.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit/iotools.py#L34

Added line #L34 was not covered by tests
spikes = bluepy.impl.spike_report.SpikeReport.load(path).get()
# convert Series to DataFrame with 2 columns for `groupby` operation
spike_df = spikes.to_frame().reset_index()
Expand All @@ -40,3 +44,88 @@
# convert outdat's index from int to CellId
outdat.index = [CellId("", gid) for gid in outdat.index]
return outdat.to_dict()


def write_compartment_report(
output_path: str,
cells: CellDict,
report_cfg: dict,
source_sets: dict,
source_type: str,
):
"""Write a SONATA-compatible compartment report to an HDF5 file."""
source_name = report_cfg.get("cells") if source_type == "node_set" else report_cfg.get("compartments")
source = source_sets.get(source_name)
if not source:
raise ValueError(f"{source_type} '{source_name}' not found.")

Check warning on line 60 in bluecellulab/circuit/iotools.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit/iotools.py#L60

Added line #L60 was not covered by tests

population = source["population"]

if source_type == "compartment_set":
compartment_nodes = source.get("compartment_set", [])
node_ids = [n[0] for n in compartment_nodes]

Check warning on line 66 in bluecellulab/circuit/iotools.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit/iotools.py#L65-L66

Added lines #L65 - L66 were not covered by tests
else:
if "node_id" in source:
node_ids = source["node_id"]
else:
# Fallback: get all node IDs for the population from cells
node_ids = [node_id for (pop, node_id) in cells.keys() if pop == population]

Check warning on line 72 in bluecellulab/circuit/iotools.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit/iotools.py#L72

Added line #L72 was not covered by tests
compartment_nodes = None # Not used for node_set

data_matrix: List[np.ndarray] = []
recorded_node_ids: List[int] = []
index_pointers: List[int] = [0]
element_ids: List[int] = []

for node_id in node_ids:
try:
cell = cells[(population, node_id)]
except KeyError:
continue

Check warning on line 84 in bluecellulab/circuit/iotools.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit/iotools.py#L83-L84

Added lines #L83 - L84 were not covered by tests
if not cell:
continue

Check warning on line 86 in bluecellulab/circuit/iotools.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit/iotools.py#L86

Added line #L86 was not covered by tests

targets = resolve_segments(cell, report_cfg, node_id, compartment_nodes, source_type)
for sec, sec_name, seg in targets:
try:
trace = cell.get_voltage_recording(section=sec, segx=seg)
data_matrix.append(trace)
recorded_node_ids.append(node_id)
element_ids.append(len(element_ids))
index_pointers.append(index_pointers[-1] + 1)
except Exception as e:
logger.warning(f"Failed recording: GID {node_id} sec {sec_name} seg {seg}: {e}")

Check warning on line 97 in bluecellulab/circuit/iotools.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit/iotools.py#L96-L97

Added lines #L96 - L97 were not covered by tests

if not data_matrix:
logger.warning(f"No data recorded for report '{source_name}'. Skipping write.")
return

Check warning on line 101 in bluecellulab/circuit/iotools.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit/iotools.py#L100-L101

Added lines #L100 - L101 were not covered by tests

write_sonata_report_file(
output_path, population, data_matrix, recorded_node_ids, index_pointers, element_ids, report_cfg
)


def write_sonata_report_file(
output_path, population, data_matrix, recorded_node_ids, index_pointers, element_ids, report_cfg
):
data_array = np.stack(data_matrix, axis=1)
node_ids_arr = np.array(recorded_node_ids, dtype=np.uint64)
index_ptr_arr = np.array(index_pointers, dtype=np.uint64)
element_ids_arr = np.array(element_ids, dtype=np.uint32)
time_array = np.array([
report_cfg.get("start_time", 0.0),
report_cfg.get("end_time", 0.0),
report_cfg.get("dt", 0.1)
], dtype=np.float64)

with h5py.File(output_path, "w") as f:
grp = f.require_group(f"/report/{population}")
data_ds = grp.create_dataset("data", data=data_array.astype(np.float32))
data_ds.attrs["units"] = "mV"

mapping = grp.require_group("mapping")
mapping.create_dataset("node_ids", data=node_ids_arr)
mapping.create_dataset("index_pointers", data=index_ptr_arr)
mapping.create_dataset("element_ids", data=element_ids_arr)
time_ds = mapping.create_dataset("time", data=time_array)
time_ds.attrs["units"] = "ms"
48 changes: 48 additions & 0 deletions bluecellulab/circuit_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
from bluecellulab.circuit.node_id import create_cell_id, create_cell_ids
from bluecellulab.circuit.simulation_access import BluepySimulationAccess, SimulationAccess, SonataSimulationAccess, _sample_array
from bluecellulab.importer import load_mod_files
from bluecellulab.circuit.iotools import write_compartment_report
from bluecellulab.rngsettings import RNGSettings
from bluecellulab.simulation.neuron_globals import NeuronGlobals
from bluecellulab.simulation.report import configure_all_reports
from bluecellulab.stimulus.circuit_stimulus_definitions import Noise, OrnsteinUhlenbeck, RelativeOrnsteinUhlenbeck, RelativeShotNoise, ShotNoise
import bluecellulab.stimulus.circuit_stimulus_definitions as circuit_stimulus_definitions
from bluecellulab.exceptions import BluecellulabError
Expand Down Expand Up @@ -301,6 +303,11 @@
add_linear_stimuli=add_linear_stimuli
)

configure_all_reports(
cells=self.cells,
simulation_config=self.circuit_access.config
)

def _add_stimuli(self, add_noise_stimuli=False,
add_hyperpolarizing_stimuli=False,
add_relativelinear_stimuli=False,
Expand Down Expand Up @@ -779,3 +786,44 @@
record_dt=cell_kwargs['record_dt'],
template_format=cell_kwargs['template_format'],
emodel_properties=cell_kwargs['emodel_properties'])

def write_reports(self):
"""Write all reports defined in the simulation config."""
report_entries = self.circuit_access.config.get_report_entries()

for report_name, report_cfg in report_entries.items():
report_type = report_cfg.get("type", "compartment")
section = report_cfg.get("sections")

if report_type != "compartment":
raise NotImplementedError(f"Report type '{report_type}' is not supported.")

output_path = self.circuit_access.config.report_file_path(report_cfg, report_name)
if section == "compartment_set":
if report_cfg.get("cells") is not None:
raise ValueError(

Check warning on line 804 in bluecellulab/circuit_simulation.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit_simulation.py#L803-L804

Added lines #L803 - L804 were not covered by tests
"Report config error: 'cells' must not be set when using 'compartment_set' sections."
)
compartment_sets = self.circuit_access.config.get_compartment_sets()
write_compartment_report(

Check warning on line 808 in bluecellulab/circuit_simulation.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit_simulation.py#L807-L808

Added lines #L807 - L808 were not covered by tests
output_path=output_path,
cells=self.cells,
report_cfg=report_cfg,
source_sets=compartment_sets,
source_type="compartment_set"
)

else:
node_sets = self.circuit_access.config.get_node_sets()
if report_cfg.get("compartments") not in ("center", "all"):
raise ValueError(
f"Unsupported 'compartments' value '{report_cfg.get('compartments')}' "
"for node-based section recording (must be 'center' or 'all')."
)
write_compartment_report(
output_path=output_path,
cells=self.cells,
report_cfg=report_cfg,
source_sets=node_sets,
source_type="node_set"
)
81 changes: 81 additions & 0 deletions bluecellulab/simulation/report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2025 Open Brain Institute

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Report class of bluecellulab."""

import logging

from bluecellulab.tools import resolve_segments

logger = logging.getLogger(__name__)


def _configure_recording(cell, report_cfg, source, source_type, report_name):
variable = report_cfg.get("variable_name", "v")
if variable != "v":
logger.warning(f"Unsupported variable '{variable}' for report '{report_name}'")
return

node_id = cell.cell_id
compartment_nodes = source.get("compartment_set") if source_type == "compartment_set" else None

targets = resolve_segments(cell, report_cfg, node_id, compartment_nodes, source_type)
for sec, sec_name, seg in targets:
try:
cell.add_voltage_recording(section=sec, segx=seg)
except Exception as e:
logger.warning(

Check warning on line 37 in bluecellulab/simulation/report.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/simulation/report.py#L36-L37

Added lines #L36 - L37 were not covered by tests
f"Failed to record voltage at {sec_name}({seg}) on GID {node_id} for report '{report_name}': {e}"
)


def configure_all_reports(cells, simulation_config):
report_entries = simulation_config.get_report_entries()

for report_name, report_cfg in report_entries.items():
report_type = report_cfg.get("type", "compartment")
section = report_cfg.get("sections", "soma")

if report_type != "compartment":
raise NotImplementedError(f"Report type '{report_type}' is not supported.")

if section == "compartment_set":
source_type = "compartment_set"
source_sets = simulation_config.get_compartment_sets()
source_name = report_cfg.get("compartments")

Check warning on line 55 in bluecellulab/simulation/report.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/simulation/report.py#L53-L55

Added lines #L53 - L55 were not covered by tests
else:
source_type = "node_set"
source_sets = simulation_config.get_node_sets()
source_name = report_cfg.get("cells")

source = source_sets.get(source_name)
if not source:
logger.warning(f"{source_type.title()} '{source_name}' not found for report '{report_name}'")
continue

Check warning on line 64 in bluecellulab/simulation/report.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/simulation/report.py#L63-L64

Added lines #L63 - L64 were not covered by tests

population = source["population"]

if source_type == "compartment_set":
node_ids = [entry[0] for entry in source.get("compartment_set", [])]

Check warning on line 69 in bluecellulab/simulation/report.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/simulation/report.py#L69

Added line #L69 was not covered by tests
else: # node_set
if "node_id" in source:
node_ids = source["node_id"]
else:
# Fallback: use all available node IDs from this population
node_ids = [node_id for (pop, node_id) in cells.keys() if pop == population]

for node_id in node_ids:
cell = cells.get((population, node_id))
if not cell:
continue
_configure_recording(cell, report_cfg, source, source_type, report_name)
Loading
Loading