Skip to content

Commit

Permalink
Merge 08124ad into 4464e32
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniBodor authored Sep 15, 2023
2 parents 4464e32 + 08124ad commit f0a89b1
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 19 deletions.
88 changes: 69 additions & 19 deletions eitprocessing/binreader/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
This file contains methods related to parts of electrical impedance tomographs
as they are read.
"""

from __future__ import annotations
import bisect
import copy
import functools
Expand Down Expand Up @@ -53,6 +53,8 @@ class Sequence:
Args:
path (Path | str | List[Path | str]): path(s) to data file.
vendor (Vendor | str): vendor indicating the device used.
label (str): description of object for human interpretation.
Defaults to "Sequence_<unique_id>".
time (NDArray[float]): list of time label for each data point (can be
true time or relative time)
max_frames (int): number of frames in sequence
Expand All @@ -70,6 +72,7 @@ class Sequence:

path: Path | str | List[Path | str] = None
vendor: Vendor = None
label: str = None
time: NDArray = None
nframes: int = None
framerate: int = None
Expand All @@ -79,6 +82,9 @@ class Sequence:
phases: List[PhaseIndicator] = field(default_factory=list, repr=False)

def __post_init__(self):
if self.label is None:
self.label = f'Sequence_{id(self)}'

self._set_vendor_class()

def __len__(self) -> int:
Expand Down Expand Up @@ -121,7 +127,7 @@ def _set_vendor_class(self):
raise NotImplementedError(f"vendor {self.vendor} is not implemented")

@staticmethod
def check_equivalence(a: "Sequence", b: "Sequence"):
def check_equivalence(a: Sequence, b: Sequence):
"""Checks whether content of two Sequence objects is equivalent."""

if (a_ := a.vendor) != (b_ := b.vendor):
Expand All @@ -134,12 +140,18 @@ def check_equivalence(a: "Sequence", b: "Sequence"):
)
return True

def __add__(self, other):
def __add__(self, other: Sequence) -> Sequence:
return self.merge(self, other)

@classmethod
def merge(cls, a: "Sequence", b: "Sequence") -> "Sequence":
"""Merge two Sequence objects together."""
def merge(
cls,
a: Sequence,
b: Sequence,
label: str = None,
) -> Sequence:
"""Create a merge of two Sequence objects."""

try:
Sequence.check_equivalence(a, b)
except Exception as e:
Expand All @@ -163,9 +175,12 @@ def merge_attribute(attr: str) -> list:
item.time = time[item.index]
return a_items + b_items

label = f'Merge of ({a.label}) and <{b.label}>' if label is None else label

return cls(
path=path,
vendor=a.vendor,
label=label,
time=time,
nframes=nframes,
framerate=a.framerate,
Expand All @@ -179,16 +194,19 @@ def merge_attribute(attr: str) -> list:
def from_path( # pylint: disable=too-many-arguments, unused-argument
cls,
path: Path | str | List[Path | str],
vendor: Vendor | str,
vendor: Vendor | str = None,
label: str = None,
framerate: int = None,
first_frame: int = 0,
max_frames: int | None = None,
) -> "Sequence":
) -> Sequence:
"""Load sequence from path(s)
Args:
path (Path | str | List[Path | str]): path(s) to data file
vendor (Vendor | str): vendor indicating the device used
path (Path | str | List[Path | str]): path(s) to data file.
vendor (Vendor | str): vendor indicating the device used.
label (str): description of object for human interpretation.
Defaults to "Sequence_<unique_id>".
framerate (int, optional): framerate at which the data was recorded.
Default for Draeger: 20
Default for Timpel: 50
Expand Down Expand Up @@ -223,10 +241,11 @@ def _load_file( # pylint: disable=too-many-arguments
cls,
path: Path | str,
vendor: Vendor | str,
label: str = None,
framerate: int = None,
first_frame: int = 0,
max_frames: int | None = None,
) -> "Sequence":
) -> Sequence:
"""Method used by `from_path` that initiates the object and calls
child method for loading the data.
Expand All @@ -249,6 +268,7 @@ def _load_file( # pylint: disable=too-many-arguments
path=Path(path),
vendor=vendor,
nframes=max_frames,
label=label,
)
obj._set_vendor_class()
if framerate:
Expand All @@ -274,7 +294,7 @@ def _load_data(self, first_frame: int | None):
"""Needs to be implemented in child class."""
raise NotImplementedError(f"Data loading for {self.vendor} is not implemented")

def __getitem__(self, indices):
def select_by_index(self, indices: slice, label: str | None = None):
if not isinstance(indices, slice):
raise NotImplementedError(
"Slicing only implemented using a slice object"
Expand All @@ -288,27 +308,33 @@ def __getitem__(self, indices):
if indices.stop is None:
indices = slice(indices.start, self.nframes, indices.step)

obj = self.deepcopy()
obj = self.deepcopy() #TODO: consider to make this more efficient for large data
obj.time = self.time[indices]
obj.nframes = len(obj.time)

obj.framesets = {k: v[indices] for k, v in self.framesets.items()}
obj.label = f'Slice ({indices.start}-{indices.stop}) of <{self.label}>' if label is None else label

r = range(indices.start, indices.stop)
range_ = range(indices.start, indices.stop)
for attr in ["events", "timing_errors", "phases"]:
setattr(obj, attr, [x for x in getattr(obj, attr) if x.index in r])
setattr(obj, attr, [x for x in getattr(obj, attr) if x.index in range_])
for x in getattr(obj, attr):
x.index -= indices.start

return obj

def select_by_time(

def __getitem__(self, indices: slice):
return self.select_by_index(indices)


def select_by_time( #pylint: disable=too-many-arguments
self,
start: float | int | None = None,
end: float | int | None = None,
start_inclusive: bool = True,
end_inclusive: bool = False,
) -> "Sequence":
label: str = None,
) -> Sequence:
"""Select subset of sequence by the `Sequence.time` information (i.e.
based on the time stamp).
Expand Down Expand Up @@ -354,9 +380,33 @@ def select_by_time(
else:
end_index = bisect.bisect_left(self.time, end) - 1

return self[start_index:end_index]
return self.select_by_index(slice(start_index,end_index), label = label)

deepcopy = copy.deepcopy

def deepcopy(
self,
label: str = None,
relabel: bool = True,
) -> Sequence:
"""Create a deep copy of `Sequence` object.
Args:
label (str): Create a new `label` for the copy.
Defaults to None, which will trigger behavior described for relabel (below)
relabel (bool): If `True` (default), the label of self is re-used for the copy,
otherwise the following label is assigned f"Deepcopy of {self.label}".
Note that this setting is ignored if a label is given.
Returns:
Sequence: a deep copy of self
"""

obj = copy.deepcopy(self)
if label:
obj.label = label
elif relabel:
obj.label = f'Deepcopy of {self.label}'
return obj


@dataclass(eq=False)
Expand Down
63 changes: 63 additions & 0 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def test_load_partial(
assert Sequence.merge(draeger_first_part, draeger_second_part) == draeger_data2
assert Sequence.merge(draeger_second_part, draeger_first_part) != draeger_data2


def test_illegal_first():
for ff in [0.5, -1, 'fdw']:
with pytest.raises((TypeError, ValueError)):
Expand Down Expand Up @@ -339,3 +340,65 @@ def test_select_by_time(
start_inclusive=False,
end_inclusive=False)
assert len(sliced) == end_slicing[2]-start_slicing[2]


def test_label(
draeger_data1: DraegerSequence,
draeger_data2: DraegerSequence,
):

assert isinstance(draeger_data1.label, str), 'default label is not a string'
assert draeger_data1.label == f'Sequence_{id(draeger_data1)}', 'unexpected default label'

assert draeger_data1.label != draeger_data2.label, 'different data has identical label'

timpel_1 = Sequence.from_path(timpel_file, vendor = 'timpel')
timpel_2 = Sequence.from_path(timpel_file, vendor = 'timpel')
assert timpel_1.label != timpel_2.label, 'reloaded data has identical label'

test_label = 'test_label'
timpel_3 = Sequence.from_path(timpel_file, vendor = 'timpel', label = test_label)
timpel_4 = Sequence.from_path(timpel_file, vendor = 'timpel', label = test_label)
assert timpel_3.label == test_label, 'label attribute does not match given label'
assert timpel_3.label == timpel_4.label, 're-used test label not recognized as identical'

timpel_copy = timpel_1.deepcopy()
assert timpel_1.label != timpel_copy.label, 'deepcopied data has identical label'
timpel_copy_relabel = timpel_1.deepcopy(label = test_label)
assert timpel_1.label != timpel_copy_relabel.label, 'deepcopied data with new label has identical label'
timpel_copy_relabel = timpel_1.deepcopy(relabel = False)
assert timpel_1.label == timpel_copy_relabel.label, 'deepcopied data did not keep old label'
timpel_copy_relabel = timpel_1.deepcopy(label = test_label, relabel = False)
assert timpel_1.label != timpel_copy_relabel.label, 'combo of label and relabel not working as intended'


def test_relabeling(
timpel_data: TimpelSequence,
draeger_data2: DraegerSequence,

):
test_label = 'test label'

#merging
merged_timpel = Sequence.merge(timpel_data, timpel_data)
assert merged_timpel.label != timpel_data.label, 'merging does not assign new label by default'
assert merged_timpel.label == f'Merge of ({timpel_data.label}) and <{timpel_data.label}>', 'merging generates unexpected default label'
merged_timpel_2 = Sequence.merge(timpel_data, timpel_data, label = test_label)
assert merged_timpel_2.label == test_label, 'incorrect label assigned when merging data with new label'

# slicing
indices = slice(0,10)
sliced_timpel = timpel_data[indices]
assert sliced_timpel.label != timpel_data.label, 'slicing does not assign new label by default'
assert sliced_timpel.label == f'Slice ({indices.start}-{indices.stop}) of <{timpel_data.label}>', 'slicing generates unexpected default label'
sliced_timpel_2 = timpel_data.select_by_index(indices=indices, label=test_label)
assert sliced_timpel_2.label == test_label, 'incorrect label assigned when slicing data with new label'

# select_by_time
t22 = 55825.268
t52 = 55826.768
time_sliced = draeger_data2.select_by_time(t22, t52+0.001)
assert time_sliced.label != draeger_data2.label, 'time slicing does not assign new label by default'
assert time_sliced.label == f'Slice (22-52) of <{draeger_data2.label}>', 'slicing generates unexpected default label'
time_sliced_2 = draeger_data2.select_by_time(t22, t52, label=test_label)
assert time_sliced_2.label == test_label, 'incorrect label assigned when time slicing data with new label'

0 comments on commit f0a89b1

Please sign in to comment.