7
7
from spikeinterface .sortingcomponents .peak_detection import detect_peaks
8
8
from spikeinterface .sortingcomponents .peak_localization import localize_peaks
9
9
10
- # TODO: test templates_array_moved are the same with
11
- # no shift, with both seed and no seed
12
-
13
- # rescale units per session
14
-
15
10
16
11
class TestSessionDisplacementGenerator :
17
12
"""
@@ -97,14 +92,14 @@ def test_x_y_rigid_shifts_are_properly_set(self, options):
97
92
for unit_idx in range (num_units ):
98
93
99
94
start_pos = self ._get_peak_chan_loc_in_um (
100
- extra_outputs ["template_array_moved " ][0 ][unit_idx ],
95
+ extra_outputs ["templates_array_moved " ][0 ][unit_idx ],
101
96
options ["y_bin_um" ],
102
97
)
103
98
104
99
for rec_idx in range (1 , options ["num_recs" ]):
105
100
106
101
new_pos = self ._get_peak_chan_loc_in_um (
107
- extra_outputs ["template_array_moved " ][rec_idx ][unit_idx ], options ["y_bin_um" ]
102
+ extra_outputs ["templates_array_moved " ][rec_idx ][unit_idx ], options ["y_bin_um" ]
108
103
)
109
104
110
105
y_shift = recording_shifts [rec_idx ][1 ]
@@ -120,7 +115,7 @@ def test_x_y_rigid_shifts_are_properly_set(self, options):
120
115
for rec_idx in range (options ["num_recs" ]):
121
116
assert np .array_equal (
122
117
output_recordings [rec_idx ].templates ,
123
- extra_outputs ["template_array_moved " ][rec_idx ],
118
+ extra_outputs ["templates_array_moved " ][rec_idx ],
124
119
)
125
120
126
121
def _get_peak_chan_loc_in_um (self , template_array , y_bin_um ):
@@ -275,6 +270,56 @@ def test_displacement_with_peak_detection(self, options):
275
270
276
271
assert np .isclose (new_pos , first_pos + y_shift , rtol = 0 , atol = options ["y_bin_um" ])
277
272
273
+ def test_amplitude_scalings (self , options ):
274
+
275
+ options ["kwargs" ]["recording_durations" ] = (10 , 10 )
276
+ options ["kwargs" ]["recording_shifts" ] = ((0 , 0 ), (0 , 0 ))
277
+ options ["kwargs" ]["num_units" ] == 5 ,
278
+
279
+ recording_amplitude_scalings = {
280
+ "method" : "by_passed_order" ,
281
+ "scalings" : (np .ones (5 ), np .array ([0.1 , 0.2 , 0.3 , 0.4 , 0.5 ])),
282
+ }
283
+
284
+ _ , output_sortings , extra_outputs = generate_session_displacement_recordings (
285
+ ** options ["kwargs" ],
286
+ recording_amplitude_scalings = recording_amplitude_scalings ,
287
+ )
288
+ breakpoint ()
289
+ first , second = extra_outputs ["templates_array_moved" ] # TODO: own function
290
+ first_min = np .min (np .min (first , axis = 2 ), axis = 1 )
291
+ second_min = np .min (np .min (second , axis = 2 ), axis = 1 )
292
+ scales = second_min / first_min
293
+
294
+ assert np .allclose (scales , shifts )
295
+
296
+ # TODO: scale based on recording output
297
+ # check scaled by amplitude.
298
+
299
+ breakpoint ()
300
+
301
+ def test_metadata (self , options ):
302
+ """
303
+ Check that metadata required to be set of generated recordings is present
304
+ on all output recordings.
305
+ """
306
+ output_recordings , output_sortings , extra_outputs = generate_session_displacement_recordings (
307
+ ** options ["kwargs" ], generate_noise_kwargs = dict (noise_levels = (1.0 , 2.0 ), spatial_decay = 1.0 )
308
+ )
309
+ num_chans = output_recordings [0 ].get_num_channels ()
310
+
311
+ for i in range (len (output_recordings )):
312
+ assert output_recordings [i ].name == "InterSessionDisplacementRecording"
313
+ assert output_recordings [i ]._annotations ["is_filtered" ] is True
314
+ assert output_recordings [i ].has_probe ()
315
+ assert np .array_equal (output_recordings [i ].get_channel_gains (), np .ones (num_chans ))
316
+ assert np .array_equal (output_recordings [i ].get_channel_offsets (), np .zeros (num_chans ))
317
+
318
+ assert np .array_equal (
319
+ output_sortings [i ].get_property ("gt_unit_locations" ), extra_outputs ["unit_locations" ][i ]
320
+ )
321
+ assert output_sortings [i ].name == "InterSessionDisplacementSorting"
322
+
278
323
def test_same_as_generate_ground_truth_recording (self ):
279
324
"""
280
325
It is expected that inter-session displacement randomly
@@ -302,7 +347,7 @@ def test_same_as_generate_ground_truth_recording(self):
302
347
no_shift_recording , _ = generate_session_displacement_recordings (
303
348
num_units = num_units ,
304
349
recording_durations = [duration ],
305
- recording_shifts = ((0 , 0 )),
350
+ recording_shifts = ((0 , 0 ), ),
306
351
sampling_frequency = sampling_frequency ,
307
352
probe_name = probe_name ,
308
353
generate_probe_kwargs = generate_probe_kwargs ,
0 commit comments