Skip to content

Commit 9aadfe9

Browse files
authored
Merge pull request #2976 from nicolossus/port_test_sinusoidal_gamma_generator
Port sinusoidal generator tests (SLI-2-Py)
2 parents d0532fc + e574991 commit 9aadfe9

File tree

3 files changed

+189
-570
lines changed

3 files changed

+189
-570
lines changed
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# test_sinusoidal_generators.py
4+
#
5+
# This file is part of NEST.
6+
#
7+
# Copyright (C) 2004 The NEST Initiative
8+
#
9+
# NEST is free software: you can redistribute it and/or modify
10+
# it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 2 of the License, or
12+
# (at your option) any later version.
13+
#
14+
# NEST is distributed in the hope that it will be useful,
15+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17+
# GNU General Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License
20+
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
21+
22+
"""
23+
Test basic properties of sinusoidal generators.
24+
"""
25+
26+
import nest
27+
import numpy as np
28+
import numpy.testing as nptest
29+
import pytest
30+
31+
# List of sinusoidal generator models
32+
gen_models = ["sinusoidal_poisson_generator", "sinusoidal_gamma_generator"]
33+
34+
35+
@pytest.fixture(autouse=True)
36+
def reset():
37+
nest.ResetKernel()
38+
39+
40+
@pytest.mark.parametrize("gen_model", gen_models)
41+
def test_individual_spike_trains_true_by_default(gen_model):
42+
"""
43+
Test that ``individual_spike_trains`` is true by default.
44+
"""
45+
46+
gen = nest.Create(gen_model)
47+
assert gen.individual_spike_trains
48+
49+
50+
@pytest.mark.parametrize("gen_model", gen_models)
51+
def test_set_individual_spike_trains_on_set_defaults(gen_model):
52+
"""
53+
Test whether ``individual_spike_trains`` can be set on ``SetDefaults``.
54+
"""
55+
56+
nest.SetDefaults(gen_model, {"individual_spike_trains": False})
57+
gen = nest.Create(gen_model)
58+
assert not gen.individual_spike_trains
59+
60+
61+
@pytest.mark.parametrize("gen_model", gen_models)
62+
def test_set_individual_spike_trains_on_creation(gen_model):
63+
"""
64+
Test whether ``individual_spike_trains`` can be set on model creation.
65+
"""
66+
67+
gen = nest.Create(gen_model, params={"individual_spike_trains": False})
68+
assert not gen.individual_spike_trains
69+
70+
71+
@pytest.mark.parametrize("gen_model", gen_models)
72+
def test_set_individual_spike_trains_on_copy_model(gen_model):
73+
"""
74+
Test whether the set ``individual_spike_trains`` is inherited on ``CopyModel``.
75+
"""
76+
77+
nest.CopyModel(
78+
gen_model,
79+
"sinusoidal_generator_copy",
80+
params={"individual_spike_trains": False},
81+
)
82+
gen = nest.Create("sinusoidal_generator_copy")
83+
assert not gen.individual_spike_trains
84+
85+
86+
@pytest.mark.parametrize("gen_model", gen_models)
87+
def test_set_individual_spike_trains_on_instance(gen_model):
88+
"""
89+
Test that ``individual_spike_trains`` cannot be set on an instance.
90+
"""
91+
92+
gen = nest.Create(gen_model)
93+
94+
with pytest.raises(nest.kernel.NESTErrors.BadProperty):
95+
gen.individual_spike_trains = False
96+
97+
98+
@pytest.mark.skipif_missing_threads()
99+
@pytest.mark.parametrize("gen_model", gen_models)
100+
@pytest.mark.parametrize("individual_spike_trains", [False, True])
101+
@pytest.mark.parametrize("num_threads", [1, 2])
102+
def test_sinusoidal_generator_with_spike_recorder(gen_model, num_threads, individual_spike_trains):
103+
"""
104+
Test spike recording with both true and false ``individual_spike_trains``.
105+
106+
The test builds a network with ``num_threads x 4`` parrot neurons that
107+
receives spikes from the specified sinusoidal generator. A ``spike_recorder``
108+
is connected to each parrot neuron. The test ensures that different targets
109+
(on the same or different threads) receives identical spike trains if
110+
``individual_spike_trains`` is false and different spike trains otherwise.
111+
"""
112+
113+
nest.local_num_threads = num_threads
114+
nrns_per_thread = 4
115+
total_num_nrns = num_threads * nrns_per_thread
116+
117+
nest.SetDefaults(
118+
gen_model,
119+
{
120+
"rate": 100,
121+
"amplitude": 50.0,
122+
"frequency": 10.0,
123+
"individual_spike_trains": individual_spike_trains,
124+
},
125+
)
126+
127+
parrots = nest.Create("parrot_neuron", total_num_nrns)
128+
gen = nest.Create(gen_model)
129+
srecs = nest.Create("spike_recorder", total_num_nrns)
130+
131+
nest.Connect(gen, parrots)
132+
nest.Connect(parrots, srecs, "one_to_one")
133+
134+
nest.Simulate(500.0)
135+
136+
# Nested list of recorded spike times from each sender
137+
spikes_all_nrns = srecs.get("events", "times")
138+
139+
# Check that we actually obtained a spike times array for each neuron
140+
assert len(spikes_all_nrns) == total_num_nrns
141+
142+
if individual_spike_trains:
143+
# all trains must be pairwise different
144+
assert all(
145+
not np.array_equal(left, right)
146+
for idx, left in enumerate(spikes_all_nrns[:-1])
147+
for right in spikes_all_nrns[(idx + 1) :]
148+
)
149+
else:
150+
# all trains should be equal
151+
assert all(np.array_equal(spikes_all_nrns[0], right) for right in spikes_all_nrns[1:])
152+
153+
154+
@pytest.mark.parametrize("gen_model", gen_models)
155+
def test_sinusoidal_generator_rate_profile(gen_model):
156+
"""
157+
Test recorded rate of provided sinusoidal generator against expectation.
158+
159+
The test checks that the recorded rate profile with ``multimeter`` is the
160+
same as the analytical expectation.
161+
"""
162+
163+
dc = 1.0
164+
ac = 0.5
165+
freq = 10.0
166+
phi = 2.0
167+
168+
nest.SetDefaults(
169+
gen_model,
170+
{"rate": dc, "amplitude": ac, "frequency": freq, "phase": phi / np.pi * 180},
171+
)
172+
173+
parrots = nest.Create("parrot_neuron")
174+
sspg = nest.Create(gen_model)
175+
mm = nest.Create("multimeter", {"record_from": ["rate"]})
176+
177+
nest.Connect(sspg, parrots)
178+
nest.Connect(mm, sspg)
179+
180+
nest.Simulate(100.0)
181+
182+
times = mm.events["times"]
183+
scaled_times = times * 2 * np.pi * freq / 1000
184+
shifted_times = scaled_times + phi
185+
expected_rates = np.sin(shifted_times) * ac + dc
186+
187+
actual_rates = mm.events["rate"]
188+
189+
nptest.assert_allclose(actual_rates, expected_rates)

0 commit comments

Comments
 (0)