Skip to content

Commit

Permalink
Merge d5cdf76 into 42c6cc1
Browse files Browse the repository at this point in the history
  • Loading branch information
psomhorst authored Jun 4, 2024
2 parents 42c6cc1 + d5cdf76 commit 34e16ab
Show file tree
Hide file tree
Showing 19 changed files with 747 additions and 222 deletions.
9 changes: 9 additions & 0 deletions eitprocessing/datahandling/breath.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import NamedTuple


class Breath(NamedTuple):
"""Represents a breath with a start, middle and end index."""

start_time: float
middle_time: float
end_time: float
26 changes: 16 additions & 10 deletions eitprocessing/datahandling/continuousdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@

@dataclass(eq=False)
class ContinuousData(Equivalence, SelectByTime):
"""Data class for (non-EIT) data with a continuous time axis.
"""Container for data with a continuous time axis.
Continuous data is data that was continuously measured/created at a predictable rate. Therefore, continuous data is
assumed to have a predictable delta between time points. A fixed delta is not enforced for two reasons: a) some
devices have slightly varying sampling rate (but fixed around a set rate) and b) floating point arithmetic.
Continuous data is assumed to be sequential (i.e. a single data point at each time point, sorted by time).
Args:
label: Computer readable naming of the instance.
name: Human readable naming of the instance.
unit: Unit for the data.
unit: Unit of the data, if applicable.
category: Category the data falls into, e.g. 'airway pressure'.
description: Human readable extended description of the data.
parameters: Parameters used to derive this data.
Expand All @@ -32,10 +38,10 @@ class ContinuousData(Equivalence, SelectByTime):
""" # TODO: update docstring

label: str = field(compare=False)
name: str = field(compare=False)
unit: str = field(metadata={"check_equivalence": True})
category: str = field(metadata={"check_equivalence": True})
description: str = field(default="", compare=False)
name: str = field(compare=False, repr=False)
unit: str = field(metadata={"check_equivalence": True}, repr=False)
category: str = field(metadata={"check_equivalence": True}, repr=False)
description: str = field(default="", compare=False, repr=False)
parameters: dict[str, Any] = field(default_factory=dict, repr=False, metadata={"check_equivalence": True})
derived_from: Any | list[Any] = field(default_factory=list, repr=False, compare=False)
time: np.ndarray = field(kw_only=True, repr=False)
Expand Down Expand Up @@ -89,7 +95,7 @@ def copy(
return obj

def __add__(self: T, other: T) -> T:
return self.concatenate(self, other)
return self.concatenate(other)

def concatenate(self: T, other: T, newlabel: str | None = None) -> T: # noqa: D102, will be removed soon
# TODO: compare both concatenate methods and check what is needed from both and merge into one
Expand All @@ -100,7 +106,7 @@ def concatenate(self: T, other: T, newlabel: str | None = None) -> T: # noqa: D
raise ValueError(msg)

cls = self.__class__
newlabel = newlabel or f"Merge of <{self.label}> and <{other.label}>"
newlabel = newlabel or self.label

return cls(
name=self.name,
Expand Down Expand Up @@ -161,7 +167,7 @@ def lock(self, *attr: str) -> None:
"""
if not len(attr):
# default values are not allowed when using *attr, so set a default here if none is supplied
attr = ["values"]
attr = ("values",)
for attr_ in attr:
getattr(self, attr_).flags["WRITEABLE"] = False

Expand All @@ -188,7 +194,7 @@ def unlock(self, *attr: str) -> None:
"""
if not len(attr):
# default values are not allowed when using *attr, so set a default here if none is supplied
attr = ["values"]
attr = ("values",)
for attr_ in attr:
getattr(self, attr_).flags["WRITEABLE"] = True

Expand Down
51 changes: 36 additions & 15 deletions eitprocessing/datahandling/datacollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,35 @@

from eitprocessing.datahandling.continuousdata import ContinuousData
from eitprocessing.datahandling.eitdata import EITData
from eitprocessing.datahandling.intervaldata import IntervalData
from eitprocessing.datahandling.mixins.equality import Equivalence
from eitprocessing.datahandling.mixins.slicing import HasTimeIndexer
from eitprocessing.datahandling.sparsedata import SparseData

if TYPE_CHECKING:
from typing_extensions import Self

V = TypeVar("V", EITData, ContinuousData, SparseData)
V_classes = (EITData, ContinuousData, SparseData, IntervalData)
V = TypeVar("V", *V_classes)


class DataCollection(Equivalence, UserDict, HasTimeIndexer, Generic[V]):
"""A collection of a single type of data with unique labels.
This collection functions as a dictionary in most part. When initializing, a data type has to be passed. EITData,
ContinuousData or SparseData is expected as the data type. Other types are allowed, but not supported. The objects
added to the collection need to have a `label` attribute and a `concatenate()` method.
When adding an item to the collection, the type of the value has to match the data type of the collection.
Furthermore, the key has to match the attribute 'label' attached to the value.
A DataCollection functions largely as a dictionary, but requires a data_type argument, which must be one of the data
containers existing in this package. When adding an item to the collection, the type of the value must match the
data_type of the collection. Furthermore, the key has to match the attribute 'label' attached to the value.
The convenience method `add()` adds an item by setting the key to `value.label`.
Args:
data_type: the type of data stored in this collection. Expected to be one of EITData, ContinuousData or
SparseData.
data_type: the data container stored in this collection.
"""

data_type: type

def __init__(self, data_type: type[V], *args, **kwargs):
if not any(issubclass(data_type, cls) for cls in V.__constraints__):
if not any(issubclass(data_type, cls) for cls in V_classes):
msg = f"Type {data_type} not expected to be stored in a DataCollection."
raise ValueError(msg)
self.data_type = data_type
Expand All @@ -46,7 +44,10 @@ def __setitem__(self, __key: str, __value: V) -> None:
return super().__setitem__(__key, __value)

def add(self, *item: V, overwrite: bool = False) -> None:
"""Add one or multiple item(s) to the collection."""
"""Add one or multiple item(s) to the collection.
The item is added to the collection using the item label as the key.
"""
for item_ in item:
self._check_item(item_, overwrite=overwrite)
super().__setitem__(item_.label, item_)
Expand Down Expand Up @@ -101,7 +102,7 @@ def get_derived_data(self) -> dict[str, V]:
"""Return all data that was derived from any source."""
return {k: v for k, v in self.items() if v.derived_from}

def concatenate(self: Self[V], other: Self[V]) -> Self[V]:
def concatenate(self: Self, other: Self) -> Self:
"""Concatenate this collection with an equivalent collection.
Each item of self of concatenated with the item of other with the same key.
Expand All @@ -110,7 +111,7 @@ def concatenate(self: Self[V], other: Self[V]) -> Self[V]:

concatenated = self.__class__(self.data_type)
for key in self:
concatenated[key] = self[key].concatenate(other[key])
concatenated.add(self[key].concatenate(other[key]))

return concatenated

Expand All @@ -120,9 +121,29 @@ def select_by_time(
end_time: float | None,
start_inclusive: bool = True,
end_inclusive: bool = False,
) -> Self:
) -> DataCollection:
"""Return a DataCollection containing sliced copies of the items."""
if self.data_type is IntervalData:
return DataCollection(
self.data_type,
**{
k: v.select_by_time(
start_time=start_time,
end_time=end_time,
)
for k, v in self.items()
},
)

return DataCollection(
self.data_type,
**{k: v.select_by_time(start_time, end_time, start_inclusive, end_inclusive) for k, v in self.items()},
**{
k: v.select_by_time(
start_time=start_time,
end_time=end_time,
start_inclusive=start_inclusive,
end_inclusive=end_inclusive,
)
for k, v in self.items()
},
)
83 changes: 27 additions & 56 deletions eitprocessing/datahandling/eitdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass, field
from enum import auto
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar
from typing import TypeVar

import numpy as np
from strenum import LowercaseStrEnum
Expand All @@ -12,15 +12,12 @@
from eitprocessing.datahandling.mixins.equality import Equivalence
from eitprocessing.datahandling.mixins.slicing import SelectByTime

if TYPE_CHECKING:
from eitprocessing.datahandling.eitdata import Vendor

T = TypeVar("T", bound="EITData")


@dataclass(eq=False)
class EITData(SelectByTime, Equivalence):
"""Container for EIT data.
"""Container for EIT impedance data.
This class holds the pixel impedance from an EIT measurement, as well as metadata describing the measurement. The
class is meant to hold data from (part of) a singular continuous measurement.
Expand All @@ -29,31 +26,39 @@ class is meant to hold data from (part of) a singular continuous measurement.
disk.
Args:
path: the path of list of paths of the source from which data was derived.
nframes: number of frames
time:
Several convenience methods are supplied for calculating global impedance, calculating or removing baselines, etc.
path: The path of list of paths of the source from which data was derived.
nframes: Number of frames.
time: The time of each frame (since start measurement).
framerate: The (average) rate at which the frames are collection, in Hz.
vendor: The vendor of the device the data was collected with.
label: Computer readable label identifying this dataset.
name: Human readable name for the data.
pixel_impedance: Impedance values for each pixel at each frame.
""" # TODO: fix docstring

path: Path | list[Path] = field(compare=False)
nframes: int
path: str | Path | list[Path | str] = field(compare=False, repr=False)
nframes: int = field(repr=False)
time: np.ndarray = field(repr=False)
framerate: float = field(metadata={"check_equivalence": True})
vendor: Vendor = field(metadata={"check_equivalence": True})
phases: list = field(default_factory=list, repr=False)
events: list = field(default_factory=list, repr=False)
framerate: float = field(metadata={"check_equivalence": True}, repr=False)
vendor: Vendor = field(metadata={"check_equivalence": True}, repr=False)
label: str | None = field(default=None, compare=False, metadata={"check_equivalence": True})
name: str | None = field(default=None, compare=False)
name: str | None = field(default=None, compare=False, repr=False)
pixel_impedance: np.ndarray = field(repr=False, kw_only=True)

def __post_init__(self):
if not self.label:
self.label = f"{self.__class__.__name__}_{id(self)}"

self.path = self.ensure_path_list(self.path)
if len(self.path) == 1:
self.path = self.path[0]

self.name = self.name or self.label

@staticmethod
def ensure_path_list(path: str | Path | list[str | Path]) -> list[Path]:
def ensure_path_list(
path: str | Path | list[str | Path],
) -> list[Path]:
"""Return the path or paths as a list.
The path of any EITData object can be a single str/Path or a list of str/Path objects. This method returns a
Expand All @@ -64,7 +69,7 @@ def ensure_path_list(path: str | Path | list[str | Path]) -> list[Path]:
return [Path(path)]

def __add__(self: T, other: T) -> T:
return self.concatenate(self, other)
return self.concatenate(other)

def concatenate(self: T, other: T, newlabel: str | None = None) -> T: # noqa: D102, will be moved to mixin in future
# Check that data can be concatenated
Expand All @@ -79,14 +84,12 @@ def concatenate(self: T, other: T, newlabel: str | None = None) -> T: # noqa: D

return self.__class__(
vendor=self.vendor,
path=self_path + other_path,
path=[*self_path, *other_path],
label=self.label, # TODO: using newlabel leads to errors
framerate=self.framerate,
nframes=self.nframes + other.nframes,
time=np.concatenate((self.time, other.time)),
pixel_impedance=np.concatenate((self.pixel_impedance, other.pixel_impedance), axis=0),
phases=self.phases + other.phases,
events=self.events + other.events,
)

def _sliced_copy(
Expand All @@ -99,9 +102,6 @@ def _sliced_copy(
time = self.time[start_index:end_index]
nframes = len(time)

phases = list(filter(lambda p: start_index <= p.index < end_index, self.phases))
events = list(filter(lambda e: start_index <= e.index < end_index, self.events))

pixel_impedance = self.pixel_impedance[start_index:end_index, :, :]

return cls(
Expand All @@ -110,44 +110,15 @@ def _sliced_copy(
vendor=self.vendor,
time=time,
framerate=self.framerate,
phases=phases,
events=events,
label=self.label, # newlabel gives errors
pixel_impedance=pixel_impedance,
)

def __len__(self):
return self.pixel_impedance.shape[0]

@property
def global_baseline(self) -> np.ndarray:
"""Return the global baseline, i.e. the minimum pixel impedance across all pixels."""
return np.nanmin(self.pixel_impedance)

@property
def pixel_impedance_global_offset(self) -> np.ndarray:
"""Return the pixel impedance with the global baseline removed.
In the resulting array the minimum impedance across all pixels is set to 0.
"""
return self.pixel_impedance - self.global_baseline

@property
def pixel_baseline(self) -> np.ndarray:
"""Return the lowest value in each individual pixel over time."""
return np.nanmin(self.pixel_impedance, axis=0)

@property
def pixel_impedance_individual_offset(self) -> np.ndarray:
"""Return the pixel impedance with the baseline of each individual pixel removed.
Each pixel in the resulting array has a minimum value of 0.
"""
return self.pixel_impedance - self.pixel_baseline

@property
def global_impedance(self) -> np.ndarray:
"""Return the global impedance, i.e. the sum of all pixels at each frame."""
def calculate_global_impedance(self) -> np.ndarray:
"""Return the global impedance, i.e. the sum of all included pixels at each frame."""
return np.nansum(self.pixel_impedance, axis=(1, 2))


Expand Down
2 changes: 0 additions & 2 deletions eitprocessing/datahandling/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,5 @@
class Event:
"""Single time point event registered during an EIT measurement."""

index: int
time: float
marker: int
text: str
Loading

0 comments on commit 34e16ab

Please sign in to comment.