From 08124ad83ed942d4ac286540f635caf1aee9735f Mon Sep 17 00:00:00 2001 From: Dani Bodor Date: Fri, 15 Sep 2023 14:02:25 +0200 Subject: [PATCH] allow relabeling when slicing --- eitprocessing/binreader/sequence.py | 19 ++++++++++++------- tests/test_sequence.py | 17 ++++++++++++++++- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/eitprocessing/binreader/sequence.py b/eitprocessing/binreader/sequence.py index 578cfa8cf..f84efa825 100644 --- a/eitprocessing/binreader/sequence.py +++ b/eitprocessing/binreader/sequence.py @@ -294,7 +294,7 @@ def _load_data(self, first_frame: int | None): """Needs to be implemented in child class.""" raise NotImplementedError(f"Data loading for {self.vendor} is not implemented") - def __getitem__(self, indices): + def select_by_index(self, indices: slice, label: str | None = None): if not isinstance(indices, slice): raise NotImplementedError( "Slicing only implemented using a slice object" @@ -308,27 +308,32 @@ def __getitem__(self, indices): if indices.stop is None: indices = slice(indices.start, self.nframes, indices.step) - obj = self.deepcopy() + obj = self.deepcopy() #TODO: consider to make this more efficient for large data obj.time = self.time[indices] obj.nframes = len(obj.time) - obj.framesets = {k: v[indices] for k, v in self.framesets.items()} + obj.label = f'Slice ({indices.start}-{indices.stop}) of <{self.label}>' if label is None else label - r = range(indices.start, indices.stop) + range_ = range(indices.start, indices.stop) for attr in ["events", "timing_errors", "phases"]: - setattr(obj, attr, [x for x in getattr(obj, attr) if x.index in r]) + setattr(obj, attr, [x for x in getattr(obj, attr) if x.index in range_]) for x in getattr(obj, attr): x.index -= indices.start return obj - def select_by_time( + def __getitem__(self, indices: slice): + return self.select_by_index(indices) + + + def select_by_time( #pylint: disable=too-many-arguments self, start: float | int | None = None, end: float | int | None = None, start_inclusive: bool = True, end_inclusive: bool = False, + label: str = None, ) -> Sequence: """Select subset of sequence by the `Sequence.time` information (i.e. based on the time stamp). @@ -375,7 +380,7 @@ def select_by_time( else: end_index = bisect.bisect_left(self.time, end) - 1 - return self[start_index:end_index] + return self.select_by_index(slice(start_index,end_index), label = label) def deepcopy( diff --git a/tests/test_sequence.py b/tests/test_sequence.py index fd3e28dfa..42a588c92 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -386,4 +386,19 @@ def test_relabeling( merged_timpel_2 = Sequence.merge(timpel_data, timpel_data, label = test_label) assert merged_timpel_2.label == test_label, 'incorrect label assigned when merging data with new label' - + # slicing + indices = slice(0,10) + sliced_timpel = timpel_data[indices] + assert sliced_timpel.label != timpel_data.label, 'slicing does not assign new label by default' + assert sliced_timpel.label == f'Slice ({indices.start}-{indices.stop}) of <{timpel_data.label}>', 'slicing generates unexpected default label' + sliced_timpel_2 = timpel_data.select_by_index(indices=indices, label=test_label) + assert sliced_timpel_2.label == test_label, 'incorrect label assigned when slicing data with new label' + + # select_by_time + t22 = 55825.268 + t52 = 55826.768 + time_sliced = draeger_data2.select_by_time(t22, t52+0.001) + assert time_sliced.label != draeger_data2.label, 'time slicing does not assign new label by default' + assert time_sliced.label == f'Slice (22-52) of <{draeger_data2.label}>', 'slicing generates unexpected default label' + time_sliced_2 = draeger_data2.select_by_time(t22, t52, label=test_label) + assert time_sliced_2.label == test_label, 'incorrect label assigned when time slicing data with new label'