Skip to content
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,6 @@ cscope.out

# Merge conflict artifacts
*.orig

# local test files
local_tests/*
2 changes: 1 addition & 1 deletion pynest/nest/lib/hl_api_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def Create(model, n=1, params=None, positions=None):

if params is not None and iterable_or_parameter_in_params:
try:
SetStatus(node_ids, params)
node_ids.set(params)
except Exception:
warnings.warn(
"SetStatus() call failed, but nodes have already been "
Expand Down
21 changes: 15 additions & 6 deletions pynest/nest/lib/hl_api_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Classes defining the different PyNEST types
"""

import itertools
import json
import numbers
from math import floor, log
Expand Down Expand Up @@ -468,6 +469,13 @@ def set(self, params=None, **kwargs):

local_nodes = [self.local] if len(self) == 1 else self.local

all_nodes_ids = [self.get("global_id")] if len(self) == 1 else self.get("global_id")
only_local_nodes_ids = list(itertools.compress(all_nodes_ids, local_nodes))
selected_local_nodes = NodeCollection(only_local_nodes_ids)

if len(selected_local_nodes) == 0:
# No local nodes, nothing to do
return
if isinstance(params, dict) and "compartments" in params:
if isinstance(params["compartments"], Compartments):
params["compartments"] = params["compartments"].get_tuple()
Expand All @@ -482,15 +490,15 @@ def set(self, params=None, **kwargs):
# Adding receptors has been handled by the += operator, so we can remove the entry.
params.pop("receptors")

if isinstance(params, dict) and all(local_nodes):
node_params = self[0].get()
if isinstance(params, dict):
node_params = selected_local_nodes[0].get()
contains_list = [
is_iterable(vals) and key in node_params and not is_iterable(node_params[key])
for key, vals in params.items()
]

if any(contains_list):
temp_param = [{} for _ in range(self.__len__())]
temp_param = [{} for _ in range(len(selected_local_nodes))]

for key, vals in params.items():
if not is_iterable(vals):
Expand All @@ -501,10 +509,11 @@ def set(self, params=None, **kwargs):
temp_dict[key] = vals[i]
params = temp_param

if isinstance(params, (list, tuple)) and self.__len__() != len(params):
raise TypeError("status dict must be a dict, or a list of dicts of length {} ".format(self.__len__()))
if isinstance(params, (list, tuple)) and len(selected_local_nodes) != len(params):
n = len(selected_local_nodes)
raise TypeError("status dict must be a dict, or a list of dicts of length {} ".format(n))

sli_func("SetStatus", self._datum, params)
sli_func("SetStatus", selected_local_nodes._datum, params)

def tolist(self):
"""
Expand Down
Loading
Loading