44from copy import deepcopy
55from pathlib import Path
66
7+ from entitysdk ._server_schemas import EntityType as EntityType
78from entitysdk .client import Client
89from entitysdk .downloaders .simulation import (
910 download_node_sets_file ,
1011 download_simulation_config_content ,
1112 download_spike_replay_files ,
1213)
1314from entitysdk .exception import StagingError
14- from entitysdk .models import Circuit , Simulation
15+ from entitysdk .models import Circuit , MEModel , Simulation
16+ from entitysdk .models .entity import Entity
1517from entitysdk .staging .circuit import stage_circuit
18+ from entitysdk .staging .constants import (
19+ DEFAULT_NODE_POPULATION_NAME ,
20+ DEFAULT_NODE_SET_NAME ,
21+ )
22+ from entitysdk .staging .memodel import stage_sonata_from_memodel
1623from entitysdk .types import StrOrPath
1724from entitysdk .utils .filesystem import create_dir
1825from entitysdk .utils .io import write_json
@@ -46,13 +53,7 @@ def stage_simulation(
4653 The path to the staged simulation config file.
4754 """
4855 output_dir = create_dir (output_dir ).resolve ()
49-
5056 simulation_config : dict = download_simulation_config_content (client , model = model )
51- node_sets_file : Path = download_node_sets_file (
52- client ,
53- model = model ,
54- output_path = output_dir / DEFAULT_NODE_SETS_FILENAME ,
55- )
5657 spike_paths : list [Path ] = download_spike_replay_files (
5758 client ,
5859 model = model ,
@@ -63,13 +64,48 @@ def stage_simulation(
6364 "Circuit config path was not provided. Circuit is going to be staged from metadata. "
6465 "Circuit id to be staged: %s"
6566 )
66- circuit_config_path = stage_circuit (
67+ base_entity = client .get_entity (entity_id = model .entity_id , entity_type = Entity )
68+ match base_entity .type :
69+ case EntityType .memodel :
70+ memodel = client .get_entity (entity_id = model .entity_id , entity_type = MEModel )
71+ L .info (
72+ "Staging single-cell SONATA circuit from MEModel %s" ,
73+ memodel .id ,
74+ )
75+ node_sets_file = _stage_single_cell_node_sets_file (
76+ node_set_name = simulation_config .get ("node_set" , DEFAULT_NODE_SET_NAME ),
77+ output_path = output_dir / DEFAULT_NODE_SETS_FILENAME ,
78+ )
79+ circuit_config_path = stage_sonata_from_memodel (
80+ client ,
81+ memodel = memodel ,
82+ output_dir = create_dir (output_dir / DEFAULT_CIRCUIT_DIR ),
83+ )
84+ case EntityType .circuit :
85+ circuit = client .get_entity (entity_id = model .entity_id , entity_type = Circuit )
86+ L .info (
87+ "Staging SONATA circuit from Circuit %s" ,
88+ circuit .id ,
89+ )
90+ node_sets_file = download_node_sets_file (
91+ client ,
92+ model = model ,
93+ output_path = output_dir / DEFAULT_NODE_SETS_FILENAME ,
94+ )
95+ circuit_config_path = stage_circuit (
96+ client ,
97+ model = circuit ,
98+ output_dir = create_dir (output_dir / DEFAULT_CIRCUIT_DIR ),
99+ )
100+ case _:
101+ raise StagingError (
102+ f"Simulation { model .id } references unsupported type { base_entity .type } "
103+ )
104+ else :
105+ node_sets_file = download_node_sets_file (
67106 client ,
68- model = client .get_entity (
69- entity_id = model .entity_id ,
70- entity_type = Circuit ,
71- ),
72- output_dir = create_dir (output_dir / DEFAULT_CIRCUIT_DIR ),
107+ model = model ,
108+ output_path = output_dir / DEFAULT_NODE_SETS_FILENAME ,
73109 )
74110
75111 transformed_simulation_config : dict = _transform_simulation_config (
@@ -82,17 +118,28 @@ def stage_simulation(
82118 )
83119
84120 output_simulation_config_file = output_dir / DEFAULT_SIMULATION_CONFIG_FILENAME
85-
86- write_json (
87- data = transformed_simulation_config ,
88- path = output_simulation_config_file ,
89- )
121+ write_json (data = transformed_simulation_config , path = output_simulation_config_file )
90122
91123 L .info ("Staged Simulation %s at %s" , model .id , output_dir )
92-
93124 return output_simulation_config_file
94125
95126
127+ def _stage_single_cell_node_sets_file (
128+ node_set_name : str ,
129+ output_path : Path ,
130+ ) -> Path :
131+ write_json (
132+ {
133+ node_set_name : {
134+ "population" : DEFAULT_NODE_POPULATION_NAME ,
135+ "node_id" : [0 ],
136+ }
137+ },
138+ output_path ,
139+ )
140+ return output_path
141+
142+
96143def _transform_simulation_config (
97144 simulation_config : dict ,
98145 circuit_config_path : Path ,
0 commit comments