Skip to content

Commit 5f012a2

Browse files
committed
Start finalising tests.
1 parent 9798c75 commit 5f012a2

File tree

2 files changed

+62
-22
lines changed

2 files changed

+62
-22
lines changed

src/spikeinterface/generation/session_displacement_generator.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@
1616
from spikeinterface.core import InjectTemplatesRecording
1717

1818

19-
# TODO: test metadata
20-
# TOOD: test new amplitude scalings
21-
# TODO: test correct unit_locations are on the sortings (part of metadata)
22-
23-
2419
def generate_session_displacement_recordings(
2520
num_units=250,
2621
recording_durations=(10, 10, 10),
@@ -131,7 +126,7 @@ def generate_session_displacement_recordings(
131126
# and scaled as required
132127
extra_outputs_dict = {
133128
"unit_locations": [],
134-
"template_array_moved": [],
129+
"templates_array_moved": [],
135130
}
136131
output_recordings = []
137132
output_sortings = []
@@ -170,7 +165,7 @@ def generate_session_displacement_recordings(
170165
)
171166

172167
# Generate the (possibly shifted, scaled) unit templates
173-
template_array_moved = generate_templates(
168+
templates_array_moved = generate_templates(
174169
channel_locations,
175170
unit_locations_moved,
176171
sampling_frequency=sampling_frequency,
@@ -179,9 +174,8 @@ def generate_session_displacement_recordings(
179174
)
180175

181176
if recording_amplitude_scalings is not None:
182-
183-
template_array_moved = _amplitude_scale_templates_in_place(
184-
template_array_moved, recording_amplitude_scalings, sorting_extra_outputs, rec_idx
177+
_amplitude_scale_templates_in_place(
178+
templates_array_moved, recording_amplitude_scalings, sorting_extra_outputs, rec_idx
185179
)
186180

187181
# Bring it all together in a `InjectTemplatesRecording` and
@@ -191,7 +185,7 @@ def generate_session_displacement_recordings(
191185

192186
recording = InjectTemplatesRecording(
193187
sorting=sorting,
194-
templates=template_array_moved,
188+
templates=templates_array_moved,
195189
nbefore=nbefore,
196190
amplitude_factor=None,
197191
parent_recording=noise,
@@ -208,7 +202,7 @@ def generate_session_displacement_recordings(
208202
output_recordings.append(recording)
209203
output_sortings.append(sorting)
210204
extra_outputs_dict["unit_locations"].append(unit_locations_moved)
211-
extra_outputs_dict["template_array_moved"].append(template_array_moved)
205+
extra_outputs_dict["templates_array_moved"].append(templates_array_moved)
212206

213207
if extra_outputs:
214208
return output_recordings, output_sortings, extra_outputs_dict
@@ -344,12 +338,13 @@ def _check_generate_session_displacement_arguments(
344338
if not "method" in keys or not "scalings" in keys:
345339
raise ValueError("`recording_amplitude_scalings` must be a dict " "with keys `method` and `scalings`.")
346340

347-
allowed_methods = ["by_passed_value", "by_amplitude_and_firing_rate", "by_firing_rate"]
341+
allowed_methods = ["by_passed_order", "by_amplitude_and_firing_rate", "by_firing_rate"]
348342
if not recording_amplitude_scalings["method"] in allowed_methods:
349343
raise ValueError(f"`recording_amplitude_scalings` must be one of {allowed_methods}")
350344

351345
rec_scalings = recording_amplitude_scalings["scalings"]
352346
if not len(rec_scalings) == expected_num_recs:
347+
breakpoint()
353348
raise ValueError("`recording_amplitude_scalings` 'scalings' " "must have one array per recording.")
354349

355350
if not all([len(scale) == num_units for scale in rec_scalings]):

src/spikeinterface/generation/tests/test_session_displacement_generator.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@
77
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
88
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
99

10-
# TODO: test templates_array_moved are the same with
11-
# no shift, with both seed and no seed
12-
13-
# rescale units per session
14-
1510

1611
class TestSessionDisplacementGenerator:
1712
"""
@@ -97,14 +92,14 @@ def test_x_y_rigid_shifts_are_properly_set(self, options):
9792
for unit_idx in range(num_units):
9893

9994
start_pos = self._get_peak_chan_loc_in_um(
100-
extra_outputs["template_array_moved"][0][unit_idx],
95+
extra_outputs["templates_array_moved"][0][unit_idx],
10196
options["y_bin_um"],
10297
)
10398

10499
for rec_idx in range(1, options["num_recs"]):
105100

106101
new_pos = self._get_peak_chan_loc_in_um(
107-
extra_outputs["template_array_moved"][rec_idx][unit_idx], options["y_bin_um"]
102+
extra_outputs["templates_array_moved"][rec_idx][unit_idx], options["y_bin_um"]
108103
)
109104

110105
y_shift = recording_shifts[rec_idx][1]
@@ -120,7 +115,7 @@ def test_x_y_rigid_shifts_are_properly_set(self, options):
120115
for rec_idx in range(options["num_recs"]):
121116
assert np.array_equal(
122117
output_recordings[rec_idx].templates,
123-
extra_outputs["template_array_moved"][rec_idx],
118+
extra_outputs["templates_array_moved"][rec_idx],
124119
)
125120

126121
def _get_peak_chan_loc_in_um(self, template_array, y_bin_um):
@@ -275,6 +270,56 @@ def test_displacement_with_peak_detection(self, options):
275270

276271
assert np.isclose(new_pos, first_pos + y_shift, rtol=0, atol=options["y_bin_um"])
277272

273+
def test_amplitude_scalings(self, options):
274+
275+
options["kwargs"]["recording_durations"] = (10, 10)
276+
options["kwargs"]["recording_shifts"] = ((0, 0), (0, 0))
277+
options["kwargs"]["num_units"] == 5,
278+
279+
recording_amplitude_scalings = {
280+
"method": "by_passed_order",
281+
"scalings": (np.ones(5), np.array([0.1, 0.2, 0.3, 0.4, 0.5])),
282+
}
283+
284+
_, output_sortings, extra_outputs = generate_session_displacement_recordings(
285+
**options["kwargs"],
286+
recording_amplitude_scalings=recording_amplitude_scalings,
287+
)
288+
breakpoint()
289+
first, second = extra_outputs["templates_array_moved"] # TODO: own function
290+
first_min = np.min(np.min(first, axis=2), axis=1)
291+
second_min = np.min(np.min(second, axis=2), axis=1)
292+
scales = second_min / first_min
293+
294+
assert np.allclose(scales, shifts)
295+
296+
# TODO: scale based on recording output
297+
# check scaled by amplitude.
298+
299+
breakpoint()
300+
301+
def test_metadata(self, options):
302+
"""
303+
Check that metadata required to be set of generated recordings is present
304+
on all output recordings.
305+
"""
306+
output_recordings, output_sortings, extra_outputs = generate_session_displacement_recordings(
307+
**options["kwargs"], generate_noise_kwargs=dict(noise_levels=(1.0, 2.0), spatial_decay=1.0)
308+
)
309+
num_chans = output_recordings[0].get_num_channels()
310+
311+
for i in range(len(output_recordings)):
312+
assert output_recordings[i].name == "InterSessionDisplacementRecording"
313+
assert output_recordings[i]._annotations["is_filtered"] is True
314+
assert output_recordings[i].has_probe()
315+
assert np.array_equal(output_recordings[i].get_channel_gains(), np.ones(num_chans))
316+
assert np.array_equal(output_recordings[i].get_channel_offsets(), np.zeros(num_chans))
317+
318+
assert np.array_equal(
319+
output_sortings[i].get_property("gt_unit_locations"), extra_outputs["unit_locations"][i]
320+
)
321+
assert output_sortings[i].name == "InterSessionDisplacementSorting"
322+
278323
def test_same_as_generate_ground_truth_recording(self):
279324
"""
280325
It is expected that inter-session displacement randomly
@@ -302,7 +347,7 @@ def test_same_as_generate_ground_truth_recording(self):
302347
no_shift_recording, _ = generate_session_displacement_recordings(
303348
num_units=num_units,
304349
recording_durations=[duration],
305-
recording_shifts=((0, 0)),
350+
recording_shifts=((0, 0),),
306351
sampling_frequency=sampling_frequency,
307352
probe_name=probe_name,
308353
generate_probe_kwargs=generate_probe_kwargs,

0 commit comments

Comments
 (0)