Skip to content

Commit 3fd3d97

Browse files
authored
Merge pull request #3509 from JoeZiminski/add_shift_time_function
Add `shift start time` function.
2 parents 681fb01 + 469b3b0 commit 3fd3d97

File tree

2 files changed

+118
-3
lines changed

2 files changed

+118
-3
lines changed

src/spikeinterface/core/baserecording.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,35 @@ def reset_times(self):
509509
rs.t_start = None
510510
rs.sampling_frequency = self.sampling_frequency
511511

512+
def shift_times(self, shift: int | float, segment_index: int | None = None) -> None:
513+
"""
514+
Shift all times by a scalar value.
515+
516+
Parameters
517+
----------
518+
shift : int | float
519+
The shift to apply. If positive, times will be increased by `shift`.
520+
e.g. shifting by 1 will be like the recording started 1 second later.
521+
If negative, the start time will be decreased i.e. as if the recording
522+
started earlier.
523+
524+
segment_index : int | None
525+
The segment on which to shift the times.
526+
If `None`, all segments will be shifted.
527+
"""
528+
if segment_index is None:
529+
segments_to_shift = range(self.get_num_segments())
530+
else:
531+
segments_to_shift = (segment_index,)
532+
533+
for idx in segments_to_shift:
534+
rs = self._recording_segments[idx]
535+
536+
if self.has_time_vector(segment_index=idx):
537+
rs.time_vector += shift
538+
else:
539+
rs.t_start += shift
540+
512541
def sample_index_to_time(self, sample_ind, segment_index=None):
513542
"""
514543
Transform sample index into time in seconds

src/spikeinterface/core/tests/test_time_handling.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ class TestTimeHandling:
1515
is generated on the fly. Both time representations are tested here.
1616
"""
1717

18-
# Fixtures #####
18+
# #########################################################################
19+
# Fixtures
20+
# #########################################################################
21+
1922
@pytest.fixture(scope="session")
2023
def time_vector_recording(self):
2124
"""
@@ -95,7 +98,10 @@ def _get_fixture_data(self, request, fixture_name):
9598
raw_recording, times_recording, all_times = time_recording_fixture
9699
return (raw_recording, times_recording, all_times)
97100

98-
# Tests #####
101+
# #########################################################################
102+
# Tests
103+
# #########################################################################
104+
99105
def test_has_time_vector(self, time_vector_recording):
100106
"""
101107
Test the `has_time_vector` function returns `False` before
@@ -305,7 +311,87 @@ def test_sorting_analyzer_get_durations_no_recording(self, time_vector_recording
305311

306312
assert np.array_equal(sorting_analyzer.get_total_duration(), raw_recording.get_total_duration())
307313

308-
# Helpers ####
314+
@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
315+
@pytest.mark.parametrize("shift", [-123.456, 123.456])
316+
def test_shift_time_all_segments(self, request, fixture_name, shift):
317+
"""
318+
Shift the times in every segment using the `None` default, then
319+
check that every segment of the recording is shifted as expected.
320+
"""
321+
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)
322+
323+
num_segments, orig_seg_data = self._store_all_times(times_recording)
324+
325+
times_recording.shift_times(shift) # use default `segment_index=None`
326+
327+
for idx in range(num_segments):
328+
assert np.allclose(
329+
orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift, rtol=0, atol=1e-8
330+
)
331+
332+
@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
333+
@pytest.mark.parametrize("shift", [-123.456, 123.456])
334+
def test_shift_times_different_segments(self, request, fixture_name, shift):
335+
"""
336+
Shift each segment separately, and check the shifted segment only
337+
is shifted as expected.
338+
"""
339+
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)
340+
341+
num_segments, orig_seg_data = self._store_all_times(times_recording)
342+
343+
# For each segment, shift the segment only and check the
344+
# times are updated as expected.
345+
for idx in range(num_segments):
346+
347+
scaler = idx + 2
348+
times_recording.shift_times(shift * scaler, segment_index=idx)
349+
350+
assert np.allclose(
351+
orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift * scaler, rtol=0, atol=1e-8
352+
)
353+
354+
# Just do a little check that we are not
355+
# accidentally changing some other segments,
356+
# which should remain unchanged at this point in the loop.
357+
if idx != num_segments - 1:
358+
assert np.array_equal(orig_seg_data[idx + 1], times_recording.get_times(segment_index=idx + 1))
359+
360+
@pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"])
361+
def test_save_and_load_time_shift(self, request, fixture_name, tmp_path):
362+
"""
363+
Save the shifted data and check the shift is propagated correctly.
364+
"""
365+
_, times_recording, all_times = self._get_fixture_data(request, fixture_name)
366+
367+
shift = 100
368+
times_recording.shift_times(shift=shift)
369+
370+
times_recording.save(folder=tmp_path / "my_file")
371+
372+
loaded_recording = si.load_extractor(tmp_path / "my_file")
373+
374+
for idx in range(times_recording.get_num_segments()):
375+
assert np.array_equal(
376+
times_recording.get_times(segment_index=idx), loaded_recording.get_times(segment_index=idx)
377+
)
378+
379+
def _store_all_times(self, recording):
380+
"""
381+
Convenience function to store original times of all segments to a dict.
382+
"""
383+
num_segments = recording.get_num_segments()
384+
seg_data = {}
385+
386+
for idx in range(num_segments):
387+
seg_data[idx] = copy.deepcopy(recording.get_times(segment_index=idx))
388+
389+
return num_segments, seg_data
390+
391+
# #########################################################################
392+
# Helpers
393+
# #########################################################################
394+
309395
def _check_times_match(self, recording, all_times):
310396
"""
311397
For every segment in a recording, check the `get_times()`

0 commit comments

Comments
 (0)