@@ -12,7 +12,7 @@ class TestSessionDisplacementGenerator:
12
12
"""
13
13
This class tests the `generate_session_displacement_recordings` that
14
14
returns a recordings / sorting in which the units are shifted
15
- across sessions. This is acheived by shifting the unit locations
15
+ across sessions. This is achieved by shifting the unit locations
16
16
in both (x, y) on the generated templates that are used in
17
17
`InjectTemplatesRecording()`.
18
18
"""
@@ -136,7 +136,7 @@ def test_recordings_length(self, options):
136
136
for rec , expected_rec_length in zip (output_recordings , options ["kwargs" ]["recording_durations" ]):
137
137
assert rec .get_total_duration () == expected_rec_length
138
138
139
- def test_spike_times_across_recordings (self , options ):
139
+ def test_spike_times_and_firing_rates_across_recordings (self , options ):
140
140
"""
141
141
Check the randomisation of spike times across recordings.
142
142
When a seed is set, this is passed to `generate_sorting`
@@ -146,14 +146,17 @@ def test_spike_times_across_recordings(self, options):
146
146
"""
147
147
options ["kwargs" ]["recording_durations" ] = (10 ,) * options ["num_recs" ]
148
148
149
- output_sortings_same = generate_session_displacement_recordings (** options ["kwargs" ])[1 ]
149
+ output_sortings_same , extra_outputs_same = generate_session_displacement_recordings (** options ["kwargs" ])[1 : 3 ]
150
150
151
151
options ["kwargs" ]["seed" ] = None
152
- output_sortings_different = generate_session_displacement_recordings (** options ["kwargs" ])[1 ]
152
+ output_sortings_different , extra_outputs_different = generate_session_displacement_recordings (
153
+ ** options ["kwargs" ]
154
+ )[1 :3 ]
153
155
154
156
for unit_idx in range (options ["kwargs" ]["num_units" ]):
155
157
for rec_idx in range (1 , options ["num_recs" ]):
156
158
159
+ # Exact spike times are not preserved when seed is None
157
160
assert np .array_equal (
158
161
output_sortings_same [0 ].get_unit_spike_train (unit_idx ),
159
162
output_sortings_same [rec_idx ].get_unit_spike_train (unit_idx ),
@@ -162,6 +165,15 @@ def test_spike_times_across_recordings(self, options):
162
165
output_sortings_different [0 ].get_unit_spike_train (unit_idx ),
163
166
output_sortings_different [rec_idx ].get_unit_spike_train (unit_idx ),
164
167
)
168
+ # Firing rates should always be preserved.
169
+ assert np .array_equal (
170
+ extra_outputs_same ["firing_rates" ][0 ][unit_idx ],
171
+ extra_outputs_same ["firing_rates" ][rec_idx ][unit_idx ],
172
+ )
173
+ assert np .array_equal (
174
+ extra_outputs_different ["firing_rates" ][0 ][unit_idx ],
175
+ extra_outputs_different ["firing_rates" ][rec_idx ][unit_idx ],
176
+ )
165
177
166
178
@pytest .mark .parametrize ("dim_idx" , [0 , 1 ])
167
179
def test_x_y_shift_non_rigid (self , options , dim_idx ):
@@ -271,32 +283,70 @@ def test_displacement_with_peak_detection(self, options):
271
283
assert np .isclose (new_pos , first_pos + y_shift , rtol = 0 , atol = options ["y_bin_um" ])
272
284
273
285
def test_amplitude_scalings (self , options ):
274
-
286
+ """
287
+ Test that the templates are scaled by the passed scaling factors
288
+ in the specified order. The order can be in the passed order,
289
+ in the order of highest-to-lowest firing unit, or in the order
290
+ of (amplitude * firing_rate) (highest to lowest unit).
291
+ """
292
+ # Setup arguments to create an unshifted set of recordings
293
+ # where the templates are to be scaled with `true_scalings`
275
294
options ["kwargs" ]["recording_durations" ] = (10 , 10 )
276
295
options ["kwargs" ]["recording_shifts" ] = ((0 , 0 ), (0 , 0 ))
277
296
options ["kwargs" ]["num_units" ] == 5 ,
278
297
298
+ true_scalings = np .array ([0.1 , 0.2 , 0.3 , 0.4 , 0.5 ])
299
+
279
300
recording_amplitude_scalings = {
280
301
"method" : "by_passed_order" ,
281
- "scalings" : (np .ones (5 ), np . array ([ 0.1 , 0.2 , 0.3 , 0.4 , 0.5 ]) ),
302
+ "scalings" : (np .ones (5 ), true_scalings ),
282
303
}
283
304
284
305
_ , output_sortings , extra_outputs = generate_session_displacement_recordings (
285
306
** options ["kwargs" ],
286
307
recording_amplitude_scalings = recording_amplitude_scalings ,
287
308
)
288
- breakpoint ()
289
- first , second = extra_outputs ["templates_array_moved" ] # TODO: own function
290
- first_min = np .min (np .min (first , axis = 2 ), axis = 1 )
291
- second_min = np .min (np .min (second , axis = 2 ), axis = 1 )
292
- scales = second_min / first_min
293
309
294
- assert np .allclose (scales , shifts )
310
+ # Check that the unit templates are scaled in the order
311
+ # the scalings were passed.
312
+ test_scalings = self ._calculate_scalings_from_output (extra_outputs )
313
+ assert np .allclose (test_scalings , true_scalings )
314
+
315
+ # Now run, again applying the scalings in the order of
316
+ # unit firing rates (highest to lowest).
317
+ firing_rates = np .array ([5 , 4 , 3 , 2 , 1 ])
318
+ generate_sorting_kwargs = dict (firing_rates = firing_rates , refractory_period_ms = 4.0 )
319
+ recording_amplitude_scalings ["method" ] = "by_firing_rate"
320
+ _ , output_sortings , extra_outputs = generate_session_displacement_recordings (
321
+ ** options ["kwargs" ],
322
+ recording_amplitude_scalings = recording_amplitude_scalings ,
323
+ generate_sorting_kwargs = generate_sorting_kwargs ,
324
+ )
325
+
326
+ test_scalings = self ._calculate_scalings_from_output (extra_outputs )
327
+ assert np .allclose (test_scalings , true_scalings [np .argsort (firing_rates )])
295
328
296
- # TODO: scale based on recording output
297
- # check scaled by amplitude.
329
+ # Finally, run again applying the scalings in the order of
330
+ # unit amplitude * firing_rate
331
+ recording_amplitude_scalings ["method" ] = "by_amplitude_and_firing_rate" # TODO: method -> order
332
+ amplitudes = np .min (np .min (extra_outputs ["templates_array_moved" ][0 ], axis = 2 ), axis = 1 )
333
+ firing_rate_by_amplitude = np .argsort (amplitudes * firing_rates )
298
334
299
- breakpoint ()
335
+ _ , output_sortings , extra_outputs = generate_session_displacement_recordings (
336
+ ** options ["kwargs" ],
337
+ recording_amplitude_scalings = recording_amplitude_scalings ,
338
+ generate_sorting_kwargs = generate_sorting_kwargs ,
339
+ )
340
+
341
+ test_scalings = self ._calculate_scalings_from_output (extra_outputs )
342
+ assert np .allclose (test_scalings , true_scalings [firing_rate_by_amplitude ])
343
+
344
+ def _calculate_scalings_from_output (self , extra_outputs ):
345
+ first , second = extra_outputs ["templates_array_moved" ]
346
+ first_min = np .min (np .min (first , axis = 2 ), axis = 1 )
347
+ second_min = np .min (np .min (second , axis = 2 ), axis = 1 )
348
+ test_scalings = second_min / first_min
349
+ return test_scalings
300
350
301
351
def test_metadata (self , options ):
302
352
"""
@@ -339,7 +389,7 @@ def test_same_as_generate_ground_truth_recording(self):
339
389
generate_probe_kwargs = None
340
390
generate_unit_locations_kwargs = dict ()
341
391
generate_templates_kwargs = dict (ms_before = 1.5 , ms_after = 3 )
342
- generate_sorting_kwargs = dict ()
392
+ generate_sorting_kwargs = dict (firing_rates = 1 )
343
393
generate_noise_kwargs = dict ()
344
394
seed = 42
345
395
0 commit comments