@@ -15,7 +15,10 @@ class TestTimeHandling:
15
15
is generated on the fly. Both time representations are tested here.
16
16
"""
17
17
18
- # Fixtures #####
18
+ # #########################################################################
19
+ # Fixtures
20
+ # #########################################################################
21
+
19
22
@pytest .fixture (scope = "session" )
20
23
def time_vector_recording (self ):
21
24
"""
@@ -95,7 +98,10 @@ def _get_fixture_data(self, request, fixture_name):
95
98
raw_recording , times_recording , all_times = time_recording_fixture
96
99
return (raw_recording , times_recording , all_times )
97
100
98
- # Tests #####
101
+ # #########################################################################
102
+ # Tests
103
+ # #########################################################################
104
+
99
105
def test_has_time_vector (self , time_vector_recording ):
100
106
"""
101
107
Test the `has_time_vector` function returns `False` before
@@ -305,7 +311,87 @@ def test_sorting_analyzer_get_durations_no_recording(self, time_vector_recording
305
311
306
312
assert np .array_equal (sorting_analyzer .get_total_duration (), raw_recording .get_total_duration ())
307
313
308
- # Helpers ####
314
+ @pytest .mark .parametrize ("fixture_name" , ["time_vector_recording" , "t_start_recording" ])
315
+ @pytest .mark .parametrize ("shift" , [- 123.456 , 123.456 ])
316
+ def test_shift_time_all_segments (self , request , fixture_name , shift ):
317
+ """
318
+ Shift the times in every segment using the `None` default, then
319
+ check that every segment of the recording is shifted as expected.
320
+ """
321
+ _ , times_recording , all_times = self ._get_fixture_data (request , fixture_name )
322
+
323
+ num_segments , orig_seg_data = self ._store_all_times (times_recording )
324
+
325
+ times_recording .shift_times (shift ) # use default `segment_index=None`
326
+
327
+ for idx in range (num_segments ):
328
+ assert np .allclose (
329
+ orig_seg_data [idx ], times_recording .get_times (segment_index = idx ) - shift , rtol = 0 , atol = 1e-8
330
+ )
331
+
332
+ @pytest .mark .parametrize ("fixture_name" , ["time_vector_recording" , "t_start_recording" ])
333
+ @pytest .mark .parametrize ("shift" , [- 123.456 , 123.456 ])
334
+ def test_shift_times_different_segments (self , request , fixture_name , shift ):
335
+ """
336
+ Shift each segment separately, and check the shifted segment only
337
+ is shifted as expected.
338
+ """
339
+ _ , times_recording , all_times = self ._get_fixture_data (request , fixture_name )
340
+
341
+ num_segments , orig_seg_data = self ._store_all_times (times_recording )
342
+
343
+ # For each segment, shift the segment only and check the
344
+ # times are updated as expected.
345
+ for idx in range (num_segments ):
346
+
347
+ scaler = idx + 2
348
+ times_recording .shift_times (shift * scaler , segment_index = idx )
349
+
350
+ assert np .allclose (
351
+ orig_seg_data [idx ], times_recording .get_times (segment_index = idx ) - shift * scaler , rtol = 0 , atol = 1e-8
352
+ )
353
+
354
+ # Just do a little check that we are not
355
+ # accidentally changing some other segments,
356
+ # which should remain unchanged at this point in the loop.
357
+ if idx != num_segments - 1 :
358
+ assert np .array_equal (orig_seg_data [idx + 1 ], times_recording .get_times (segment_index = idx + 1 ))
359
+
360
+ @pytest .mark .parametrize ("fixture_name" , ["time_vector_recording" , "t_start_recording" ])
361
+ def test_save_and_load_time_shift (self , request , fixture_name , tmp_path ):
362
+ """
363
+ Save the shifted data and check the shift is propagated correctly.
364
+ """
365
+ _ , times_recording , all_times = self ._get_fixture_data (request , fixture_name )
366
+
367
+ shift = 100
368
+ times_recording .shift_times (shift = shift )
369
+
370
+ times_recording .save (folder = tmp_path / "my_file" )
371
+
372
+ loaded_recording = si .load_extractor (tmp_path / "my_file" )
373
+
374
+ for idx in range (times_recording .get_num_segments ()):
375
+ assert np .array_equal (
376
+ times_recording .get_times (segment_index = idx ), loaded_recording .get_times (segment_index = idx )
377
+ )
378
+
379
+ def _store_all_times (self , recording ):
380
+ """
381
+ Convenience function to store original times of all segments to a dict.
382
+ """
383
+ num_segments = recording .get_num_segments ()
384
+ seg_data = {}
385
+
386
+ for idx in range (num_segments ):
387
+ seg_data [idx ] = copy .deepcopy (recording .get_times (segment_index = idx ))
388
+
389
+ return num_segments , seg_data
390
+
391
+ # #########################################################################
392
+ # Helpers
393
+ # #########################################################################
394
+
309
395
def _check_times_match (self , recording , all_times ):
310
396
"""
311
397
For every segment in a recording, check the `get_times()`
0 commit comments