Skip to content

Commit 4dc065f

Browse files
committed
Finalise and tidy up tests.
1 parent 5f012a2 commit 4dc065f

File tree

3 files changed

+116
-28
lines changed

3 files changed

+116
-28
lines changed

debugging/debugging_session_displacement.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@
5050
"method": "by_amplitude_and_firing_rate",
5151
"scalings": scale_,
5252
},
53-
generate_sorting_kwargs=dict(firing_rates=(149, 150), refractory_period_ms=4.0),
53+
generate_sorting_kwargs=dict(firing_rates=(0, 200), refractory_period_ms=4.0),
5454
generate_templates_kwargs=dict(unit_params=default_unit_params_range, ms_before=1.5, ms_after=3),
55-
seed=44,
55+
seed=None,
5656
generate_unit_locations_kwargs=dict(
5757
margin_um=0.0, # if this is say 20, then units go off the edge of the probe and are such low amplitude they are not picked up.
5858
minimum_z=5.0,

src/spikeinterface/generation/session_displacement_generator.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,15 @@
1212
)
1313
import numpy as np
1414
from spikeinterface.generation.noise_tools import generate_noise
15-
from spikeinterface.core.generate import setup_inject_templates_recording
15+
from spikeinterface.core.generate import setup_inject_templates_recording, _ensure_firing_rates
1616
from spikeinterface.core import InjectTemplatesRecording
1717

1818

19+
# TODO: add note on what is fixed / not fixed across sessions
20+
# TODO: tests are failing because of mutable default arguments.
21+
# will need to fix this before proceeding.
22+
23+
1924
def generate_session_displacement_recordings(
2025
num_units=250,
2126
recording_durations=(10, 10, 10),
@@ -87,7 +92,8 @@ def generate_session_displacement_recordings(
8792
"scalings" - a list of numpy arrays, one for each recording, with
8893
each entry an array of length num_units holding the unit scalings.
8994
e.g. for 3 recordings, 2 units: ((1, 1), (1, 1), (0.5, 0.5)).
90-
95+
generate_sorting_kwargs : dict
96+
Only `firing_rates` and `refractory_period_ms` are expected if passed.
9197
All other parameters are used as in from `generate_drifting_recording()`.
9298
9399
Returns
@@ -103,6 +109,19 @@ def generate_session_displacement_recordings(
103109
"templates_array_moved" : list[np.array]
104110
A list (length num records) of (num_units, num_samples, num_channels)
105111
arrays of templates that have been shifted.
112+
113+
114+
Notes
115+
-----
116+
It is important to consider what unit properties are maintained
117+
across the session. Here, all `generate_template_kwargs` are fixed
118+
across sessions, to be sure the unit properties do not change.
119+
The firing rates passed to `generate_sorting` for each unit are
120+
also fixed across sessions. When a seed is set, the exact spike times
121+
will also be fixed across recordings. otherwise, when seed is `None`
122+
the actual spike times will be different across recordings, although
123+
all other unit properties will be maintained (except any location
124+
shifting and template scaling applied).
106125
"""
107126
_check_generate_session_displacement_arguments(
108127
num_units, recording_durations, recording_shifts, recording_amplitude_scalings
@@ -120,13 +139,18 @@ def generate_session_displacement_recordings(
120139
)
121140

122141
# Fix generate template kwargs, so they are the same for every created recording.
142+
# Also fix unit firing rates across recordings.
123143
generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed)
124144

145+
fixed_firing_rates = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed)
146+
generate_sorting_kwargs["firing_rates"] = fixed_firing_rates
147+
125148
# Start looping over parameters, creating recordings shifted
126149
# and scaled as required
127150
extra_outputs_dict = {
128151
"unit_locations": [],
129152
"templates_array_moved": [],
153+
"firing_rates": [],
130154
}
131155
output_recordings = []
132156
output_sortings = []
@@ -173,9 +197,16 @@ def generate_session_displacement_recordings(
173197
**generate_templates_kwargs,
174198
)
175199

200+
# TODO: these first amplitdues don't change per loop, but are usually not
201+
# needed...
176202
if recording_amplitude_scalings is not None:
203+
204+
first_rec_templates = (
205+
templates_array_moved if rec_idx == 0 else extra_outputs_dict["templates_array_moved"][0]
206+
)
207+
177208
_amplitude_scale_templates_in_place(
178-
templates_array_moved, recording_amplitude_scalings, sorting_extra_outputs, rec_idx
209+
first_rec_templates, templates_array_moved, recording_amplitude_scalings, sorting_extra_outputs, rec_idx
179210
)
180211

181212
# Bring it all together in a `InjectTemplatesRecording` and
@@ -203,6 +234,7 @@ def generate_session_displacement_recordings(
203234
output_sortings.append(sorting)
204235
extra_outputs_dict["unit_locations"].append(unit_locations_moved)
205236
extra_outputs_dict["templates_array_moved"].append(templates_array_moved)
237+
extra_outputs_dict["firing_rates"].append(sorting_extra_outputs["firing_rates"][0])
206238

207239
if extra_outputs:
208240
return output_recordings, output_sortings, extra_outputs_dict
@@ -255,7 +287,9 @@ def _get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_
255287
return displacement_vector, displacement_unit_factor
256288

257289

258-
def _amplitude_scale_templates_in_place(templates_array, recording_amplitude_scalings, sorting_extra_outputs, rec_idx):
290+
def _amplitude_scale_templates_in_place(
291+
first_rec_templates, moved_templates, recording_amplitude_scalings, sorting_extra_outputs, rec_idx
292+
):
259293
"""
260294
Scale a set of templates given a set of scaling values. The scaling
261295
values can be applied in the order passed, or instead in order of
@@ -264,9 +298,13 @@ def _amplitude_scale_templates_in_place(templates_array, recording_amplitude_sca
264298
265299
Parameters
266300
----------
267-
templates_array : np.array
268-
A (num_units, num_samples, num_channels) array of
269-
template waveforms for all units.
301+
first_rec_templates : np.array
302+
The (num_units, num_samples, num_channels) templates array from the
303+
first recording. Scaling by amplitude scales based on the amplitudes in
304+
the first session.
305+
moved_templates : np.array
306+
A (num_units, num_samples, num_channels) array moved templates to the
307+
current recording, that will be scaled.
270308
recording_amplitude_scalings : dict
271309
see `generate_session_displacement_recordings()`.
272310
sorting_extra_outputs : dict
@@ -294,12 +332,12 @@ def _amplitude_scale_templates_in_place(templates_array, recording_amplitude_sca
294332
firing_rates_hz = sorting_extra_outputs["firing_rates"][0]
295333

296334
if method == "by_amplitude_and_firing_rate":
297-
neg_ampl = np.min(np.min(templates_array, axis=2), axis=1)
335+
neg_ampl = np.min(np.min(first_rec_templates, axis=2), axis=1)
336+
assert np.all(neg_ampl < 0), "assumes all amplitudes are negative here."
298337
score = firing_rates_hz * neg_ampl
299338
else:
300339
score = firing_rates_hz
301340

302-
assert np.all(score < 0), "assumes all amplitudes are negative here."
303341
order_idx = np.argsort(score)
304342
ordered_rec_scalings = recording_amplitude_scalings["scalings"][rec_idx][order_idx, np.newaxis, np.newaxis]
305343

@@ -310,7 +348,7 @@ def _amplitude_scale_templates_in_place(templates_array, recording_amplitude_sca
310348
else:
311349
raise ValueError("`recording_amplitude_scalings['method']` not recognised.")
312350

313-
templates_array *= ordered_rec_scalings
351+
moved_templates *= ordered_rec_scalings
314352

315353

316354
def _check_generate_session_displacement_arguments(

src/spikeinterface/generation/tests/test_session_displacement_generator.py

Lines changed: 66 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class TestSessionDisplacementGenerator:
1212
"""
1313
This class tests the `generate_session_displacement_recordings` that
1414
returns a recordings / sorting in which the units are shifted
15-
across sessions. This is acheived by shifting the unit locations
15+
across sessions. This is achieved by shifting the unit locations
1616
in both (x, y) on the generated templates that are used in
1717
`InjectTemplatesRecording()`.
1818
"""
@@ -136,7 +136,7 @@ def test_recordings_length(self, options):
136136
for rec, expected_rec_length in zip(output_recordings, options["kwargs"]["recording_durations"]):
137137
assert rec.get_total_duration() == expected_rec_length
138138

139-
def test_spike_times_across_recordings(self, options):
139+
def test_spike_times_and_firing_rates_across_recordings(self, options):
140140
"""
141141
Check the randomisation of spike times across recordings.
142142
When a seed is set, this is passed to `generate_sorting`
@@ -146,14 +146,17 @@ def test_spike_times_across_recordings(self, options):
146146
"""
147147
options["kwargs"]["recording_durations"] = (10,) * options["num_recs"]
148148

149-
output_sortings_same = generate_session_displacement_recordings(**options["kwargs"])[1]
149+
output_sortings_same, extra_outputs_same = generate_session_displacement_recordings(**options["kwargs"])[1:3]
150150

151151
options["kwargs"]["seed"] = None
152-
output_sortings_different = generate_session_displacement_recordings(**options["kwargs"])[1]
152+
output_sortings_different, extra_outputs_different = generate_session_displacement_recordings(
153+
**options["kwargs"]
154+
)[1:3]
153155

154156
for unit_idx in range(options["kwargs"]["num_units"]):
155157
for rec_idx in range(1, options["num_recs"]):
156158

159+
# Exact spike times are not preserved when seed is None
157160
assert np.array_equal(
158161
output_sortings_same[0].get_unit_spike_train(unit_idx),
159162
output_sortings_same[rec_idx].get_unit_spike_train(unit_idx),
@@ -162,6 +165,15 @@ def test_spike_times_across_recordings(self, options):
162165
output_sortings_different[0].get_unit_spike_train(unit_idx),
163166
output_sortings_different[rec_idx].get_unit_spike_train(unit_idx),
164167
)
168+
# Firing rates should always be preserved.
169+
assert np.array_equal(
170+
extra_outputs_same["firing_rates"][0][unit_idx],
171+
extra_outputs_same["firing_rates"][rec_idx][unit_idx],
172+
)
173+
assert np.array_equal(
174+
extra_outputs_different["firing_rates"][0][unit_idx],
175+
extra_outputs_different["firing_rates"][rec_idx][unit_idx],
176+
)
165177

166178
@pytest.mark.parametrize("dim_idx", [0, 1])
167179
def test_x_y_shift_non_rigid(self, options, dim_idx):
@@ -271,32 +283,70 @@ def test_displacement_with_peak_detection(self, options):
271283
assert np.isclose(new_pos, first_pos + y_shift, rtol=0, atol=options["y_bin_um"])
272284

273285
def test_amplitude_scalings(self, options):
274-
286+
"""
287+
Test that the templates are scaled by the passed scaling factors
288+
in the specified order. The order can be in the passed order,
289+
in the order of highest-to-lowest firing unit, or in the order
290+
of (amplitude * firing_rate) (highest to lowest unit).
291+
"""
292+
# Setup arguments to create an unshifted set of recordings
293+
# where the templates are to be scaled with `true_scalings`
275294
options["kwargs"]["recording_durations"] = (10, 10)
276295
options["kwargs"]["recording_shifts"] = ((0, 0), (0, 0))
277296
options["kwargs"]["num_units"] == 5,
278297

298+
true_scalings = np.array([0.1, 0.2, 0.3, 0.4, 0.5])
299+
279300
recording_amplitude_scalings = {
280301
"method": "by_passed_order",
281-
"scalings": (np.ones(5), np.array([0.1, 0.2, 0.3, 0.4, 0.5])),
302+
"scalings": (np.ones(5), true_scalings),
282303
}
283304

284305
_, output_sortings, extra_outputs = generate_session_displacement_recordings(
285306
**options["kwargs"],
286307
recording_amplitude_scalings=recording_amplitude_scalings,
287308
)
288-
breakpoint()
289-
first, second = extra_outputs["templates_array_moved"] # TODO: own function
290-
first_min = np.min(np.min(first, axis=2), axis=1)
291-
second_min = np.min(np.min(second, axis=2), axis=1)
292-
scales = second_min / first_min
293309

294-
assert np.allclose(scales, shifts)
310+
# Check that the unit templates are scaled in the order
311+
# the scalings were passed.
312+
test_scalings = self._calculate_scalings_from_output(extra_outputs)
313+
assert np.allclose(test_scalings, true_scalings)
314+
315+
# Now run, again applying the scalings in the order of
316+
# unit firing rates (highest to lowest).
317+
firing_rates = np.array([5, 4, 3, 2, 1])
318+
generate_sorting_kwargs = dict(firing_rates=firing_rates, refractory_period_ms=4.0)
319+
recording_amplitude_scalings["method"] = "by_firing_rate"
320+
_, output_sortings, extra_outputs = generate_session_displacement_recordings(
321+
**options["kwargs"],
322+
recording_amplitude_scalings=recording_amplitude_scalings,
323+
generate_sorting_kwargs=generate_sorting_kwargs,
324+
)
325+
326+
test_scalings = self._calculate_scalings_from_output(extra_outputs)
327+
assert np.allclose(test_scalings, true_scalings[np.argsort(firing_rates)])
295328

296-
# TODO: scale based on recording output
297-
# check scaled by amplitude.
329+
# Finally, run again applying the scalings in the order of
330+
# unit amplitude * firing_rate
331+
recording_amplitude_scalings["method"] = "by_amplitude_and_firing_rate" # TODO: method -> order
332+
amplitudes = np.min(np.min(extra_outputs["templates_array_moved"][0], axis=2), axis=1)
333+
firing_rate_by_amplitude = np.argsort(amplitudes * firing_rates)
298334

299-
breakpoint()
335+
_, output_sortings, extra_outputs = generate_session_displacement_recordings(
336+
**options["kwargs"],
337+
recording_amplitude_scalings=recording_amplitude_scalings,
338+
generate_sorting_kwargs=generate_sorting_kwargs,
339+
)
340+
341+
test_scalings = self._calculate_scalings_from_output(extra_outputs)
342+
assert np.allclose(test_scalings, true_scalings[firing_rate_by_amplitude])
343+
344+
def _calculate_scalings_from_output(self, extra_outputs):
345+
first, second = extra_outputs["templates_array_moved"]
346+
first_min = np.min(np.min(first, axis=2), axis=1)
347+
second_min = np.min(np.min(second, axis=2), axis=1)
348+
test_scalings = second_min / first_min
349+
return test_scalings
300350

301351
def test_metadata(self, options):
302352
"""
@@ -339,7 +389,7 @@ def test_same_as_generate_ground_truth_recording(self):
339389
generate_probe_kwargs = None
340390
generate_unit_locations_kwargs = dict()
341391
generate_templates_kwargs = dict(ms_before=1.5, ms_after=3)
342-
generate_sorting_kwargs = dict()
392+
generate_sorting_kwargs = dict(firing_rates=1)
343393
generate_noise_kwargs = dict()
344394
seed = 42
345395

0 commit comments

Comments
 (0)