Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/entitysdk/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
SlicingDirectionType,
)
from entitysdk.models.emodel import EModel
from entitysdk.models.etype import ETypeClass
from entitysdk.models.ion_channel import IonChannel
from entitysdk.models.ion_channel_model import IonChannelModel, NeuronBlock, UseIon
from entitysdk.models.ion_channel_recording import IonChannelRecording
Expand Down Expand Up @@ -59,6 +60,7 @@
"EMCellMeshType",
"EMDenseReconstructionDataset",
"EModel",
"ETypeClass",
"ETypeClassification",
"IonChannel",
"IonChannelModel",
Expand Down
2 changes: 1 addition & 1 deletion src/entitysdk/models/etype.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ class ETypeClass(Identifiable):
alt_label: Annotated[
str | None,
Field(description="The alternative label of th etype class."),
]
] = None
4 changes: 4 additions & 0 deletions src/entitysdk/staging/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Staging constants."""

DEFAULT_NODE_POPULATION_NAME = "All"
DEFAULT_NODE_SET_NAME = "All"
33 changes: 19 additions & 14 deletions src/entitysdk/staging/memodel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Staging functions for Single-Cell."""

import json
import logging
import shutil
import tempfile
Expand All @@ -12,6 +11,10 @@
from entitysdk.downloaders.memodel import DownloadedMEModel, download_memodel
from entitysdk.exception import StagingError
from entitysdk.models.memodel import MEModel
from entitysdk.staging.constants import (
DEFAULT_NODE_POPULATION_NAME,
DEFAULT_NODE_SET_NAME,
)
from entitysdk.utils.filesystem import create_dir
from entitysdk.utils.io import write_json

Expand Down Expand Up @@ -143,7 +146,7 @@ def create_nodes_file(
output_file.parent.mkdir(parents=True, exist_ok=True)
with h5py.File(output_file, "w") as f:
nodes = f.create_group("nodes")
population = nodes.create_group("All")
population = nodes.create_group(DEFAULT_NODE_POPULATION_NAME)
population.create_dataset("node_type_id", (1,), dtype="int64")[0] = -1
group_0 = population.create_group("0")

Expand Down Expand Up @@ -192,12 +195,15 @@ def create_nodes_file(
L.debug(f"Successfully created file at {output_file}")


def create_circuit_config(output_path: Path, node_population_name: str = "All"):
def create_circuit_config(
output_path: Path,
node_population_name: str = DEFAULT_NODE_POPULATION_NAME,
):
"""Create a SONATA circuit_config.json for a single cell.

Args:
output_path (str): Directory where circuit_config.json will be written.
node_population_name (str): Name of the node population (default: 'All').
output_path: Directory where circuit_config.json will be written.
node_population_name: Name of the node population.
"""
config = {
"manifest": {"$BASE_DIR": "."},
Expand All @@ -219,25 +225,24 @@ def create_circuit_config(output_path: Path, node_population_name: str = "All"):
"edges": [],
},
}
config_path = output_path / "circuit_config.json"
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
config_path = output_path / DEFAULT_CIRCUIT_CONFIG_FILENAME
write_json(data=config, path=config_path, indent=2)
L.debug(f"Successfully created circuit_config.json at {config_path}")


def create_node_sets_file(
output_file: Path,
node_population_name: str = "All",
node_set_name: str = "All",
node_population_name: str = DEFAULT_NODE_POPULATION_NAME,
node_set_name: str = DEFAULT_NODE_SET_NAME,
node_id: int = 0,
):
"""Create a node_sets.json file for a single cell.

Args:
output_file (Path): Output file path for node_sets.json.
node_population_name (str): Name of the node population (default: 'All').
node_set_name (str): Name of the node set (default: 'All').
node_id (int): Node ID to include (default: 0).
output_file: Output file path for node_sets.json.
node_population_name: Name of the node population.
node_set_name: Name of the node set (default: MEMODEL_CIRCUIT_STAGING_NODE_SET_NAME).
node_id: Node ID to include (default: 0).
"""
node_sets = {node_set_name: {"population": node_population_name, "node_id": [node_id]}}
write_json(node_sets, output_file)
Expand Down
85 changes: 66 additions & 19 deletions src/entitysdk/staging/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,22 @@
from copy import deepcopy
from pathlib import Path

from entitysdk._server_schemas import EntityType as EntityType
from entitysdk.client import Client
from entitysdk.downloaders.simulation import (
download_node_sets_file,
download_simulation_config_content,
download_spike_replay_files,
)
from entitysdk.exception import StagingError
from entitysdk.models import Circuit, Simulation
from entitysdk.models import Circuit, MEModel, Simulation
from entitysdk.models.entity import Entity
from entitysdk.staging.circuit import stage_circuit
from entitysdk.staging.constants import (
DEFAULT_NODE_POPULATION_NAME,
DEFAULT_NODE_SET_NAME,
)
from entitysdk.staging.memodel import stage_sonata_from_memodel
from entitysdk.types import StrOrPath
from entitysdk.utils.filesystem import create_dir
from entitysdk.utils.io import write_json
Expand Down Expand Up @@ -46,13 +53,7 @@ def stage_simulation(
The path to the staged simulation config file.
"""
output_dir = create_dir(output_dir).resolve()

simulation_config: dict = download_simulation_config_content(client, model=model)
node_sets_file: Path = download_node_sets_file(
client,
model=model,
output_path=output_dir / DEFAULT_NODE_SETS_FILENAME,
)
spike_paths: list[Path] = download_spike_replay_files(
client,
model=model,
Expand All @@ -63,13 +64,48 @@ def stage_simulation(
"Circuit config path was not provided. Circuit is going to be staged from metadata. "
"Circuit id to be staged: %s"
)
circuit_config_path = stage_circuit(
base_entity = client.get_entity(entity_id=model.entity_id, entity_type=Entity)
match base_entity.type:
case EntityType.memodel:
memodel = client.get_entity(entity_id=model.entity_id, entity_type=MEModel)
L.info(
"Staging single-cell SONATA circuit from MEModel %s",
memodel.id,
)
node_sets_file = _stage_single_cell_node_sets_file(
node_set_name=simulation_config.get("node_set", DEFAULT_NODE_SET_NAME),
output_path=output_dir / DEFAULT_NODE_SETS_FILENAME,
)
circuit_config_path = stage_sonata_from_memodel(
client,
memodel=memodel,
output_dir=create_dir(output_dir / DEFAULT_CIRCUIT_DIR),
)
case EntityType.circuit:
circuit = client.get_entity(entity_id=model.entity_id, entity_type=Circuit)
L.info(
"Staging SONATA circuit from Circuit %s",
circuit.id,
)
node_sets_file = download_node_sets_file(
client,
model=model,
output_path=output_dir / DEFAULT_NODE_SETS_FILENAME,
)
circuit_config_path = stage_circuit(
client,
model=circuit,
output_dir=create_dir(output_dir / DEFAULT_CIRCUIT_DIR),
)
case _:
raise StagingError(
f"Simulation {model.id} references unsupported type {base_entity.type}"
)
else:
node_sets_file = download_node_sets_file(
client,
model=client.get_entity(
entity_id=model.entity_id,
entity_type=Circuit,
),
output_dir=create_dir(output_dir / DEFAULT_CIRCUIT_DIR),
model=model,
output_path=output_dir / DEFAULT_NODE_SETS_FILENAME,
)

transformed_simulation_config: dict = _transform_simulation_config(
Expand All @@ -82,17 +118,28 @@ def stage_simulation(
)

output_simulation_config_file = output_dir / DEFAULT_SIMULATION_CONFIG_FILENAME

write_json(
data=transformed_simulation_config,
path=output_simulation_config_file,
)
write_json(data=transformed_simulation_config, path=output_simulation_config_file)

L.info("Staged Simulation %s at %s", model.id, output_dir)

return output_simulation_config_file


def _stage_single_cell_node_sets_file(
node_set_name: str,
output_path: Path,
) -> Path:
write_json(
{
node_set_name: {
"population": DEFAULT_NODE_POPULATION_NAME,
"node_id": [0],
}
},
output_path,
)
return output_path


def _transform_simulation_config(
simulation_config: dict,
circuit_config_path: Path,
Expand Down
Loading