Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dummy segmentation extractor #176

Merged
merged 1 commit into from
Jul 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_frames(self, frame_idxs=None, channel: Optional[int] = 0) -> np.ndarray:

frames = self._video.take(indices=frame_idxs, axis=0)
if channel is not None:
frames = frames[:, :, :, channel].squeeze()
frames = frames[..., channel].squeeze()

return frames

Expand Down Expand Up @@ -178,7 +178,7 @@ def __init__(
list of ROI ids that are rejected manually or via automated rejection
channel_names: list
list of strings representing channel names
movie_dims: list
movie_dims: tuple
height x width of the movie
"""
SegmentationExtractor.__init__(self)
Expand Down Expand Up @@ -222,7 +222,7 @@ def __init__(
raise ValueError("'timeeseries' is does not exist")
elif isinstance(image_masks, np.ndarray):
NoneType = type(None)
assert isinstance(raw, np.ndarray)
assert isinstance(raw, (np.ndarray, NoneType))
assert isinstance(dff, (np.ndarray, NoneType))
assert isinstance(neuropil, (np.ndarray, NoneType))
assert isinstance(deconvolved, (np.ndarray, NoneType))
Expand Down
4 changes: 2 additions & 2 deletions src/roiextractors/segmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def get_roi_ids(self) -> list:

Returns
-------
channel_ids: list
Channel list.
roi_ids: list
List of roi ids.
"""
pass

Expand Down
100 changes: 98 additions & 2 deletions src/roiextractors/testing.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from collections.abc import Iterable
from typing import Tuple
from typing import Tuple, Optional

import numpy as np
from numpy.testing import assert_array_equal, assert_array_almost_equal

from .segmentationextractor import SegmentationExtractor
from .imagingextractor import ImagingExtractor

from roiextractors import NumpyImagingExtractor
from roiextractors import NumpyImagingExtractor, NumpySegmentationExtractor

from roiextractors.extraction_tools import DtypeType

NoneType = type(None)
Expand Down Expand Up @@ -50,6 +51,101 @@ def generate_dummy_imaging_extractor(
return imaging_extractor


def generate_dummy_segmentation_extractor(
num_rois: int = 10,
num_frames: int = 30,
num_rows: int = 25,
num_columns: int = 25,
sampling_frequency: float = 30.0,
has_raw_signal: bool = True,
has_dff_signal: bool = True,
has_deconvolved_signal: bool = True,
has_neuropil_signal: bool = True,
rejected_list: Optional[list] = None,
) -> SegmentationExtractor:

"""
A dummy segmentation extractor for testing. The segmentation extractor is built by feeding random data into the
`NumpySegmentationExtractor`.

Note that this dummy example is meant to be a mock object with the right shape, structure and objects but does not
contain meaningful content. That is, the image masks matrices are not plausible image mask for a roi, the raw signal
is not a meaningful biological signal and is not related appropriately to the deconvolved signal , etc.

Parameters
----------
num_rois : int, optional
number of regions of interest, by default 10.
num_frames : int, optional
_description_, by default 30
num_rows : number of frames used in the hypotethical video from which the data was extracted, optional
number of rows in the hypotethical video from which the data was extracted, by default 25.
num_columns : int, optional
numbe rof columns in the hypotethical video from which the data was extracted, by default 25.
sampling_frequency : float, optional
sampling frequency of the hypotethical video form which the data was extracted, by default 30.0.
has_raw_signal : bool, optional
whether a raw fluoresence signal is desired in the object, by default True.
has_dff_signal : bool, optional
whether a relative (df/f) fluoresence signal is desired in the object, by default True.
has_deconvolved_signal : bool, optional
whether a deconvolved signal is desired in the object, by default True.
has_neuropil_signal : bool, optional
whether a neuropil signal is desiredi n the object, by default True.
rejected_list: list, optional
A list of rejected rois, None by default.

Returns
-------
SegmentationExtractor
A segmentation extractor with random data fed into `NumpySegmentationExtractor`
"""

# Create dummy image masks
image_masks = np.random.rand(num_rows, num_columns, num_rois)
movie_dims = (num_rows, num_columns)

# Create signals
raw = np.random.rand(num_rois, num_frames) if has_raw_signal else None
dff = np.random.rand(num_rois, num_frames) if has_dff_signal else None
deconvolved = np.random.rand(num_rois, num_frames) if has_deconvolved_signal else None
neuropil = np.random.rand(num_rois, num_frames) if has_neuropil_signal else None

# Summary images
mean_image = np.random.rand(num_rows, num_columns)
correlation_image = np.random.rand(num_rows, num_columns)

# Rois
roi_ids = [id for id in range(num_rois)]
roi_locations_rows = np.random.randint(low=0, high=num_rows, size=num_rois)
roi_locations_columns = np.random.randint(low=0, high=num_columns, size=num_rois)
roi_locations = np.vstack((roi_locations_rows, roi_locations_columns))

rejected_list = rejected_list if rejected_list else None

accepeted_list = roi_ids
if rejected_list is not None:
accepeted_list = list(set(accepeted_list).difference(rejected_list))

dummy_segmentation_extractor = NumpySegmentationExtractor(
sampling_frequency=sampling_frequency,
image_masks=image_masks,
raw=raw,
dff=dff,
deconvolved=deconvolved,
neuropil=neuropil,
mean_image=mean_image,
correlation_image=correlation_image,
roi_ids=roi_ids,
roi_locations=roi_locations,
accepted_lst=accepeted_list,
rejected_list=rejected_list,
movie_dims=movie_dims,
)

return dummy_segmentation_extractor


def _assert_iterable_shape(iterable, shape):
ar = iterable if isinstance(iterable, np.ndarray) else np.array(iterable)
for ar_shape, given_shape in zip(ar.shape, shape):
Expand Down
72 changes: 72 additions & 0 deletions tests/test_internals/test_testing_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import unittest

from roiextractors.testing import generate_dummy_segmentation_extractor


class TestDummySegmentationExtractor(unittest.TestCase):
def setUp(self) -> None:
self.num_rois = 10
self.num_frames = 30
self.num_rows = 25
self.num_columns = 25
self.sampling_frequency = 30.0

self.raw = True
self.dff = True
self.deconvolved = True
self.neuropil = True

def test_default_values(self):
segmentation_extractor = generate_dummy_segmentation_extractor()

# Test basic shape
assert segmentation_extractor.get_num_rois() == self.num_rois
assert segmentation_extractor.get_num_frames() == self.num_frames
assert segmentation_extractor.get_image_size() == (self.num_rows, self.num_columns)
assert segmentation_extractor.get_sampling_frequency() == self.sampling_frequency
assert segmentation_extractor.get_roi_ids() == list(range(self.num_rois))
assert segmentation_extractor.get_accepted_list() == segmentation_extractor.get_roi_ids()
assert segmentation_extractor.get_rejected_list() == []
assert segmentation_extractor.get_roi_locations().shape == (2, self.num_rois)

# Test image masks
assert segmentation_extractor.get_roi_image_masks().shape == (self.num_rows, self.num_columns, self.num_rois)
# TO-DO Missing testing of pixel masks

# Test summary images
assert segmentation_extractor.get_image(name="mean").shape == (self.num_rows, self.num_columns)
assert segmentation_extractor.get_image(name="correlation").shape == (self.num_rows, self.num_columns)

# Test signals
assert segmentation_extractor.get_traces(name="raw").shape == (self.num_rois, self.num_frames)
assert segmentation_extractor.get_traces(name="dff").shape == (self.num_rois, self.num_frames)
assert segmentation_extractor.get_traces(name="deconvolved").shape == (self.num_rois, self.num_frames)
assert segmentation_extractor.get_traces(name="neuropil").shape == (self.num_rois, self.num_frames)

def test_passing_parameters(self):

segmentation_extractor = generate_dummy_segmentation_extractor()

# Test basic shape
assert segmentation_extractor.get_num_rois() == self.num_rois
assert segmentation_extractor.get_num_frames() == self.num_frames
assert segmentation_extractor.get_image_size() == (self.num_rows, self.num_columns)
assert segmentation_extractor.get_sampling_frequency() == self.sampling_frequency
assert segmentation_extractor.get_roi_ids() == list(range(self.num_rois))
assert segmentation_extractor.get_accepted_list() == segmentation_extractor.get_roi_ids()
assert segmentation_extractor.get_rejected_list() == []
assert segmentation_extractor.get_roi_locations().shape == (2, self.num_rois)

# Test image masks
assert segmentation_extractor.get_roi_image_masks().shape == (self.num_rows, self.num_columns, self.num_rois)
# TO-DO Missing testing of pixel masks

# Test summary images
assert segmentation_extractor.get_image(name="mean").shape == (self.num_rows, self.num_columns)
assert segmentation_extractor.get_image(name="correlation").shape == (self.num_rows, self.num_columns)

# Test signals
assert segmentation_extractor.get_traces(name="raw").shape == (self.num_rois, self.num_frames)
assert segmentation_extractor.get_traces(name="dff").shape == (self.num_rois, self.num_frames)
assert segmentation_extractor.get_traces(name="deconvolved").shape == (self.num_rois, self.num_frames)
assert segmentation_extractor.get_traces(name="neuropil").shape == (self.num_rois, self.num_frames)