Skip to content

Commit

Permalink
allow relabeling when slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniBodor committed Sep 15, 2023
1 parent 85dc2f6 commit 08124ad
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
19 changes: 12 additions & 7 deletions eitprocessing/binreader/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def _load_data(self, first_frame: int | None):
"""Needs to be implemented in child class."""
raise NotImplementedError(f"Data loading for {self.vendor} is not implemented")

def __getitem__(self, indices):
def select_by_index(self, indices: slice, label: str | None = None):
if not isinstance(indices, slice):
raise NotImplementedError(
"Slicing only implemented using a slice object"
Expand All @@ -308,27 +308,32 @@ def __getitem__(self, indices):
if indices.stop is None:
indices = slice(indices.start, self.nframes, indices.step)

obj = self.deepcopy()
obj = self.deepcopy() #TODO: consider to make this more efficient for large data
obj.time = self.time[indices]
obj.nframes = len(obj.time)

obj.framesets = {k: v[indices] for k, v in self.framesets.items()}
obj.label = f'Slice ({indices.start}-{indices.stop}) of <{self.label}>' if label is None else label

r = range(indices.start, indices.stop)
range_ = range(indices.start, indices.stop)
for attr in ["events", "timing_errors", "phases"]:
setattr(obj, attr, [x for x in getattr(obj, attr) if x.index in r])
setattr(obj, attr, [x for x in getattr(obj, attr) if x.index in range_])
for x in getattr(obj, attr):
x.index -= indices.start

return obj


def select_by_time(
def __getitem__(self, indices: slice):
return self.select_by_index(indices)


def select_by_time( #pylint: disable=too-many-arguments
self,
start: float | int | None = None,
end: float | int | None = None,
start_inclusive: bool = True,
end_inclusive: bool = False,
label: str = None,
) -> Sequence:
"""Select subset of sequence by the `Sequence.time` information (i.e.
based on the time stamp).
Expand Down Expand Up @@ -375,7 +380,7 @@ def select_by_time(
else:
end_index = bisect.bisect_left(self.time, end) - 1

return self[start_index:end_index]
return self.select_by_index(slice(start_index,end_index), label = label)


def deepcopy(
Expand Down
17 changes: 16 additions & 1 deletion tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,4 +386,19 @@ def test_relabeling(
merged_timpel_2 = Sequence.merge(timpel_data, timpel_data, label = test_label)
assert merged_timpel_2.label == test_label, 'incorrect label assigned when merging data with new label'


# slicing
indices = slice(0,10)
sliced_timpel = timpel_data[indices]
assert sliced_timpel.label != timpel_data.label, 'slicing does not assign new label by default'
assert sliced_timpel.label == f'Slice ({indices.start}-{indices.stop}) of <{timpel_data.label}>', 'slicing generates unexpected default label'
sliced_timpel_2 = timpel_data.select_by_index(indices=indices, label=test_label)
assert sliced_timpel_2.label == test_label, 'incorrect label assigned when slicing data with new label'

# select_by_time
t22 = 55825.268
t52 = 55826.768
time_sliced = draeger_data2.select_by_time(t22, t52+0.001)
assert time_sliced.label != draeger_data2.label, 'time slicing does not assign new label by default'
assert time_sliced.label == f'Slice (22-52) of <{draeger_data2.label}>', 'slicing generates unexpected default label'
time_sliced_2 = draeger_data2.select_by_time(t22, t52, label=test_label)
assert time_sliced_2.label == test_label, 'incorrect label assigned when time slicing data with new label'

0 comments on commit 08124ad

Please sign in to comment.