Skip to content

Add support for compartment reports (soma + named section) as well as spike reports #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions bluecellulab/cell/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,10 @@
nc.record(spike_vec)
self.recordings[f"spike_detector_{location}_{threshold}"] = spike_vec

def is_recording_spikes(self, location: str, threshold: float) -> bool:
key = f"spike_detector_{location}_{threshold}"
return key in self.recordings

def get_recorded_spikes(self, location: str, threshold: float = -30) -> list[float]:
"""Get recorded spikes in the current cell.

Expand Down Expand Up @@ -756,6 +760,18 @@
"""Get a vector of AIS voltage."""
return self.get_recording('self.axonal[1](0.5)._ref_v')

def add_variable_recording(self, variable: str, section, segx):
if variable == "v":
self.add_voltage_recording(section, segx)
else:
raise ValueError(f"Unsupported variable for recording: {variable}")

def get_variable_recording(self, variable: str, section, segx) -> np.ndarray:
if variable == "v":
return self.get_voltage_recording(section=section, segx=segx)
else:
raise ValueError(f"Unsupported variable '{variable}'")

Check warning on line 773 in bluecellulab/cell/core.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/cell/core.py#L773

Added line #L773 was not covered by tests

@property
def n_segments(self) -> int:
"""Get the number of segments in the cell."""
Expand Down
45 changes: 44 additions & 1 deletion bluecellulab/circuit/config/sonata_simulation_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
from __future__ import annotations
from functools import lru_cache
import json
import logging
from pathlib import Path
from typing import Optional

Expand All @@ -21,6 +23,8 @@

from bluepysnap import Simulation as SnapSimulation

logger = logging.getLogger(__name__)


class SonataSimulationConfig:
"""Sonata implementation of SimulationConfig protocol."""
Expand Down Expand Up @@ -74,9 +78,42 @@
result.append(ConnectionOverrides.from_sonata(conn_entry))
return result

@lru_cache(maxsize=1)
def get_compartment_sets(self) -> dict[str, dict]:
filepath = self.impl.config.get("compartment_sets_file")
if not filepath:
raise ValueError("No 'compartment_sets_file' entry found in SONATA config.")

Check warning on line 85 in bluecellulab/circuit/config/sonata_simulation_config.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit/config/sonata_simulation_config.py#L85

Added line #L85 was not covered by tests
with open(filepath, 'r') as f:
return json.load(f)

@lru_cache(maxsize=1)
def get_node_sets(self) -> dict[str, dict]:
filepath = self.impl.circuit.config.get("node_sets_file")
if not filepath:
raise ValueError("No 'node_sets_file' entry found in SONATA config.")

Check warning on line 93 in bluecellulab/circuit/config/sonata_simulation_config.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit/config/sonata_simulation_config.py#L93

Added line #L93 was not covered by tests
with open(filepath, 'r') as f:
return json.load(f)

@lru_cache(maxsize=1)
def get_report_entries(self) -> dict[str, dict]:
"""Returns the 'reports' dictionary from the SONATA simulation config.

Each key is a report name, and the value is its configuration.
"""
reports = self.impl.config.get("reports", {})
if not isinstance(reports, dict):
raise ValueError("Invalid format for 'reports' in SONATA config.")

Check warning on line 105 in bluecellulab/circuit/config/sonata_simulation_config.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit/config/sonata_simulation_config.py#L105

Added line #L105 was not covered by tests
return reports

def connection_entries(self) -> list[ConnectionOverrides]:
return self._connection_entries() + self._connection_overrides

def report_file_path(self, report_cfg: dict, report_key: str) -> Path:
"""Resolve the full path for the report output file."""
output_dir = Path(self.output_root_path)
file_name = report_cfg.get("file_name", f"{report_key}.h5")
return output_dir / file_name

@property
def base_seed(self) -> int:
return self.impl.run.random_seed
Expand Down Expand Up @@ -135,7 +172,13 @@

@property
def output_root_path(self) -> str:
return self.impl.config["output"]["output_dir"]
return self.impl.config.get("output", {}).get("output_dir", "output")

@property
def spikes_file_path(self) -> Path:
output_dir = Path(self.output_root_path)
spikes_file = self.impl.config.get("output", {}).get("spikes_file", "spikes.h5")
return output_dir / spikes_file

@property
def extracellular_calcium(self) -> Optional[float]:
Expand Down
2 changes: 1 addition & 1 deletion bluecellulab/circuit/iotools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from pathlib import Path
import logging

import bluepy
import numpy as np

from bluecellulab.circuit.node_id import CellId
Expand All @@ -28,6 +27,7 @@
def parse_outdat(path: str | Path) -> dict[CellId, np.ndarray]:
"""Parse the replay spiketrains in a out.dat formatted file pointed to by
path."""
import bluepy

Check warning on line 30 in bluecellulab/circuit/iotools.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit/iotools.py#L30

Added line #L30 was not covered by tests
spikes = bluepy.impl.spike_report.SpikeReport.load(path).get()
# convert Series to DataFrame with 2 columns for `groupby` operation
spike_df = spikes.to_frame().reset_index()
Expand Down
111 changes: 106 additions & 5 deletions bluecellulab/circuit_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

from __future__ import annotations
from collections.abc import Iterable
import os
from pathlib import Path
from typing import Optional
import logging

from collections import defaultdict
import neuron
import numpy as np
import pandas as pd
Expand All @@ -45,6 +47,7 @@
from bluecellulab.importer import load_mod_files
from bluecellulab.rngsettings import RNGSettings
from bluecellulab.simulation.neuron_globals import NeuronGlobals
from bluecellulab.simulation.report import configure_all_reports, write_compartment_report, write_sonata_spikes
from bluecellulab.stimulus.circuit_stimulus_definitions import Noise, OrnsteinUhlenbeck, RelativeOrnsteinUhlenbeck, RelativeShotNoise, ShotNoise
import bluecellulab.stimulus.circuit_stimulus_definitions as circuit_stimulus_definitions
from bluecellulab.exceptions import BluecellulabError
Expand Down Expand Up @@ -301,6 +304,16 @@
add_linear_stimuli=add_linear_stimuli
)

configure_all_reports(
cells=self.cells,
simulation_config=self.circuit_access.config
)

# add spike recordings
for cell in self.cells.values():
if not cell.is_recording_spikes("soma", threshold=self.spike_threshold):
cell.start_recording_spikes(None, location="soma", threshold=self.spike_threshold)

def _add_stimuli(self, add_noise_stimuli=False,
add_hyperpolarizing_stimuli=False,
add_relativelinear_stimuli=False,
Expand Down Expand Up @@ -458,13 +471,26 @@
@staticmethod
def merge_pre_spike_trains(*train_dicts) -> dict[CellId, np.ndarray]:
"""Merge presynaptic spike train dicts."""
filtered_dicts = [d for d in train_dicts if d not in [None, {}, [], ()]]
filtered_dicts = [d for d in train_dicts if isinstance(d, dict) and d]

if not filtered_dicts:
logger.warning("merge_pre_spike_trains: No presynaptic spike trains found.")
return {}

all_keys = set().union(*[d.keys() for d in filtered_dicts])
return {
k: np.sort(np.concatenate([d[k] for d in filtered_dicts if k in d]))
for k in all_keys
}
result = {}

for k in all_keys:
valid_arrays = []
for d in filtered_dicts:
if k in d:
val = d[k]
if isinstance(val, (np.ndarray, list)) and len(val) > 0:
valid_arrays.append(np.asarray(val))
if valid_arrays:
result[k] = np.sort(np.concatenate(valid_arrays))

return result

def _add_connections(
self,
Expand Down Expand Up @@ -646,6 +672,8 @@
forward_skip_value=forward_skip_value,
show_progress=show_progress)

self.write_reports()

def get_mainsim_voltage_trace(
self, cell_id: int | tuple[str, int], t_start=None, t_stop=None, t_step=None
) -> np.ndarray:
Expand Down Expand Up @@ -779,3 +807,76 @@
record_dt=cell_kwargs['record_dt'],
template_format=cell_kwargs['template_format'],
emodel_properties=cell_kwargs['emodel_properties'])

def write_reports(self):
"""Write all reports defined in the simulation config."""
report_entries = self.circuit_access.config.get_report_entries()

for report_name, report_cfg in report_entries.items():
report_type = report_cfg.get("type", "compartment")
section = report_cfg.get("sections")

if report_type != "compartment":
raise NotImplementedError(f"Report type '{report_type}' is not supported.")

output_path = self.circuit_access.config.report_file_path(report_cfg, report_name)
if section == "compartment_set":
if report_cfg.get("cells") is not None:
raise ValueError(

Check warning on line 825 in bluecellulab/circuit_simulation.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit_simulation.py#L824-L825

Added lines #L824 - L825 were not covered by tests
"Report config error: 'cells' must not be set when using 'compartment_set' sections."
)
compartment_sets = self.circuit_access.config.get_compartment_sets()
write_compartment_report(

Check warning on line 829 in bluecellulab/circuit_simulation.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit_simulation.py#L828-L829

Added lines #L828 - L829 were not covered by tests
report_name=report_name,
output_path=output_path,
cells=self.cells,
report_cfg=report_cfg,
source_sets=compartment_sets,
source_type="compartment_set"
)

else:
node_sets = self.circuit_access.config.get_node_sets()
if report_cfg.get("compartments") not in ("center", "all"):
raise ValueError(
f"Unsupported 'compartments' value '{report_cfg.get('compartments')}' "
"for node-based section recording (must be 'center' or 'all')."
)
write_compartment_report(
report_name=report_name,
output_path=output_path,
cells=self.cells,
report_cfg=report_cfg,
source_sets=node_sets,
source_type="node_set"
)

self.write_spike_report()

def write_spike_report(self):
"""Collect and write in-memory recorded spike times to a SONATA HDF5
file, grouped by population as required by the SONATA specification."""
output_path = self.circuit_access.config.spikes_file_path

if os.path.exists(output_path):
os.remove(output_path)

# Group spikes per population
spikes_by_population = defaultdict(dict)
for gid, cell in self.cells.items():
pop = getattr(gid, 'population_name', None)
if pop is None:
continue

Check warning on line 869 in bluecellulab/circuit_simulation.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit_simulation.py#L869

Added line #L869 was not covered by tests
try:
cell_spikes = cell.get_recorded_spikes(location="soma", threshold=self.spike_threshold)
if cell_spikes is not None:
spikes_by_population[pop][gid.id] = list(cell_spikes)
except AttributeError:
continue

Check warning on line 875 in bluecellulab/circuit_simulation.py

View check run for this annotation

Codecov / codecov/patch

bluecellulab/circuit_simulation.py#L874-L875

Added lines #L874 - L875 were not covered by tests

# Ensure we at least create empty groups for all known populations
all_populations = set(getattr(gid, 'population_name', None) for gid in self.cells.keys())

for pop in all_populations:
spikes = spikes_by_population.get(pop, {}) # May be empty
write_sonata_spikes(output_path, spikes, pop)
Loading