Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ExtractSegmentationExtractor improvements #210

Merged
merged 15 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ extractor depending on the version of the file. [PR #170](https://github.com/cat
* Improved the `MultiImagingExtractor.get_video()` to no longer rely on `get_frames`. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195)
* Added `dtype` consistency check across `MultiImaging` components as well as a direct override method. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195)
* Added the `FrameSliceSegmentationExtractor` class and corresponding `Segmentation.frame_slice(...)` method. [PR #201](https://github.com/catalystneuro/neuroconv/pull/201)
* Changed the `output_struct_name` argument to optional in `ExtractSegmentationExtractor`.
to allow more flexible usage for the user and better error message when it cannot be found in the file.
For consistency, `output_struct_name` argument has been also added to the legacy extractor.
The orientation of segmentation images are transposed for consistency in image orientation (height x width). [PR #210](https://github.com/catalystneuro/roiextractors/pull/210)

### Fixes
* Fixed the reference to the proper `mov_field` in `Hdf5ImagingExtractor`. [PR #195](https://github.com/catalystneuro/neuroconv/pull/195)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Extractor for reading the segmentation data that results from calls to EXTRACT."""
from abc import ABC
from pathlib import Path
from typing import Optional

import numpy as np
from lazy_ops import DatasetView
Expand Down Expand Up @@ -34,7 +35,7 @@ def __new__(
cls,
file_path: PathType,
sampling_frequency: float,
output_struct_name: str = "output",
output_struct_name: Optional[str] = None,
):
"""Abstract class that defines which extractor class to use for a given file.
For newer versions of the EXTRACT algorithm, the extractor class redirects to
Expand All @@ -45,48 +46,81 @@ def __new__(
----------
file_path: str
The location of the folder containing the .mat file.
output_struct_name: str
The name of output struct in the .mat file, default is "output".
output_struct_name: str, optional
The name of output struct in the .mat file.
When unspecified, we check if any of the default values can be found in the file.
For newer version of extract, the default name is assumed to be "output".
For older versions the default is "extractAnalysisOutput". If none of them
can be found, it must be supplied.
sampling_frequency: float
The sampling frequency in units of Hz.
"""
self = super().__new__(cls)
self.file_path = file_path
self.output_struct_name = output_struct_name
# Check if the file is a .mat file
cls._assert_file_is_mat(self)

if output_struct_name is None:
self.output_struct_name = cls._get_default_output_struct_name_from_file(self)
else:
# Check that user-given 'output_struct_name' is in the file
self.output_struct_name = output_struct_name
cls._assert_output_struct_name_is_in_file(self)

# Check the version of the .mat file
if cls._check_extract_file_version(self):
# For newer versions of the .mat file, use the newer extractor
return NewExtractSegmentationExtractor(
file_path=file_path,
sampling_frequency=sampling_frequency,
output_struct_name=output_struct_name,
output_struct_name=self.output_struct_name,
)

# For older versions of the .mat file, use the legacy extractor
return LegacyExtractSegmentationExtractor(file_path=file_path)
return LegacyExtractSegmentationExtractor(
file_path=file_path,
output_struct_name=self.output_struct_name,
)

def _assert_file_is_mat(self):
"""Check that the file exists and is a .mat file."""
file_path = Path(self.file_path)
assert file_path.exists(), f"File {file_path} does not exist."
assert file_path.suffix == ".mat", f"File {file_path} must be a .mat file."

def _get_default_output_struct_name_from_file(self):
"""Return the default value for 'output_struct_name' when it is unspecified.
For newer version of extract, the default name is assumed to be "output".
For older versions the default is "extractAnalysisOutput".
If none of them is found, raise an error that 'output_struct_name' must be supplied."""
newer_default_output_struct_name = "output"
legacy_default_output_struct_name = "extractAnalysisOutput"
with h5py.File(name=self.file_path, mode="r") as mat_file:
if newer_default_output_struct_name in mat_file.keys():
return newer_default_output_struct_name
elif legacy_default_output_struct_name in mat_file.keys():
return legacy_default_output_struct_name
else:
raise AssertionError("The 'output_struct_name' must be supplied.")

def _assert_output_struct_name_is_in_file(self):
"""Check that 'output_struct_name' is in the file, raises an error if not."""
with h5py.File(name=self.file_path, mode="r") as mat_file:
assert (
self.output_struct_name in mat_file
), f"Output struct name '{self.output_struct_name}' not found in file."

def _check_extract_file_version(self) -> bool:
"""Check the version of the extract file.
If the file was created with a newer version of the EXTRACT algorithm, the
function will return True, otherwise it will return False."""
with h5py.File(name=self.file_path, mode="r") as mat_file:
if self.output_struct_name not in mat_file:
return False
dataset_version = mat_file[self.output_struct_name]["info"]["version"][:]
dataset_version = np.ravel(dataset_version)
# dataset_version is an HDF5 dataset of encoded characters
version_name = _decode_h5py_array(dataset_version)

return version.Version(version_name) >= version.Version("1.1.0")
return version.Version(version_name) >= version.Version("1.0.0")


class NewExtractSegmentationExtractor(SegmentationExtractor):
Expand Down Expand Up @@ -180,7 +214,7 @@ def _image_mask_extractor_read(self) -> DatasetView:
return DatasetView(self._output_struct["spatial_weights"]).lazy_transpose()

def _trace_extractor_read(self) -> DatasetView:
"""Returns the traces with a shape of number of ROIs and number of frames."""
"""Returns the traces with a shape of number of frames and number of ROIs."""
return DatasetView(self._output_struct["temporal_weights"]).lazy_transpose()

def get_accepted_list(self) -> list:
Expand All @@ -193,7 +227,7 @@ def get_accepted_list(self) -> list:
accepted_list: list
List of accepted ROIs
"""
return [roi for roi in self.get_roi_ids() if np.any(self._image_masks[roi])]
return [roi for roi in self.get_roi_ids() if np.any(self._image_masks[..., roi])]

def get_rejected_list(self) -> list:
"""
Expand Down Expand Up @@ -234,17 +268,17 @@ def get_image_size(self) -> ArrayType:
def get_images_dict(self):
"""
Returns a dictionary with key, values representing different types of Images
used in segmentation.
used in segmentation. The shape of images is height and width.
Returns
-------
images_dict: dict
dictionary with key, values representing different types of Images
"""
images_dict = super().get_images_dict()
images_dict.update(
summary_image=self._info_struct["summary_image"][:],
f_per_pixel=self._info_struct["F_per_pixel"][:],
max_image=self._info_struct["max_image"][:],
summary_image=self._info_struct["summary_image"][:].T,
f_per_pixel=self._info_struct["F_per_pixel"][:].T,
max_image=self._info_struct["max_image"][:].T,
)

return images_dict
Expand All @@ -263,16 +297,24 @@ class LegacyExtractSegmentationExtractor(SegmentationExtractor):
mode = "file"
installation_mesg = "To use extract install h5py: \n\n pip install h5py \n\n" # error message when not installed

def __init__(self, file_path: PathType):
def __init__(
self,
file_path: PathType,
output_struct_name: str = "extractAnalysisOutput",
):
"""
Parameters
----------
file_path: str
The location of the folder containing dataset.mat file.
output_struct_name: str, optional
The user has control over the names of the variables that return from `extraction(images, config)`.
When unspecified, the default is 'extractAnalysisOutput'.
"""
super().__init__()
self.file_path = file_path
self._dataset_file, self._group0 = self._file_extractor_read()
self._dataset_file = self._file_extractor_read()
self.output_struct_name = output_struct_name
self._image_masks = self._image_mask_extractor_read()
self._roi_response_raw = self._trace_extractor_read()
self._raw_movie_file_location = self._raw_datafile_read()
Expand All @@ -283,27 +325,24 @@ def __del__(self):
self._dataset_file.close()

def _file_extractor_read(self):
f = h5py.File(self.file_path, "r")
_group0_temp = list(f.keys())
_group0 = [a for a in _group0_temp if "#" not in a]
return f, _group0
return h5py.File(self.file_path, "r")

def _image_mask_extractor_read(self):
return self._dataset_file[self._group0[0]]["filters"][:].transpose([1, 2, 0])
return self._dataset_file[self.output_struct_name]["filters"][:].transpose([1, 2, 0])

def _trace_extractor_read(self):
return self._dataset_file[self._group0[0]]["traces"]
return self._dataset_file[self.output_struct_name]["traces"]

def _tot_exptime_extractor_read(self):
return self._dataset_file[self._group0[0]]["time"]["totalTime"][0][0]
return self._dataset_file[self.output_struct_name]["time"]["totalTime"][0][0]

def _summary_image_read(self):
summary_image = self._dataset_file[self._group0[0]]["info"]["summary_image"]
summary_image = self._dataset_file[self.output_struct_name]["info"]["summary_image"]
return np.array(summary_image)

def _raw_datafile_read(self):
if self._dataset_file[self._group0[0]].get("file"):
charlist = [chr(i) for i in np.squeeze(self._dataset_file[self._group0[0]]["file"][:])]
if self._dataset_file[self.output_struct_name].get("file"):
charlist = [chr(i) for i in np.squeeze(self._dataset_file[self.output_struct_name]["file"][:])]
return "".join(charlist)

def get_accepted_list(self):
Expand Down
34 changes: 16 additions & 18 deletions tests/test_extractsegmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,40 +40,38 @@ def test_extract_segmentation_extractor_file_path_is_not_a_mat_file(self):
sampling_frequency=self.sampling_frequency,
)

def test_extract_segmentation_extractor_with_default_output_struct_name(self):
"""Test that the extractor returns the NewExtractSegmentationExtractor
when the default "output" struct name is used."""
extractor = ExtractSegmentationExtractor(
file_path=self.ophys_data_path / "extract_public_output.mat",
sampling_frequency=self.sampling_frequency,
)

self.assertIsInstance(extractor, NewExtractSegmentationExtractor)
def test_extract_segmentation_extractor_user_given_output_struct_name_not_in_file(self):
"""Test that the extractor returns the expected error when a user given output
struct name is not in the file."""
file_path = self.ophys_data_path / "2014_04_01_p203_m19_check01_extractAnalysis.mat"
with self.assertRaisesWith(AssertionError, "Output struct name 'not_output' not found in file."):
ExtractSegmentationExtractor(
file_path=file_path,
sampling_frequency=self.sampling_frequency,
output_struct_name="not_output",
)

param_list = [
param(
file_path=ophys_data_path / "2014_04_01_p203_m19_check01_extractAnalysis.mat",
output_struct_name="extractAnalysisOutput",
extractor_class=LegacyExtractSegmentationExtractor,
),
param(
file_path=ophys_data_path / "extract_public_output.mat",
output_struct_name="output",
extractor_class=NewExtractSegmentationExtractor,
),
]

@parameterized.expand(
param_list,
)
def test_extract_segmentation_extractor_redirects(self, file_path, output_struct_name, extractor_class):
def test_extract_segmentation_extractor_redirects(self, file_path, extractor_class):
"""
Test that the extractor class redirects to the correct class
given the version of the .mat file.
"""
extractor = ExtractSegmentationExtractor(
file_path=file_path,
output_struct_name=output_struct_name,
sampling_frequency=self.sampling_frequency,
)

Expand Down Expand Up @@ -170,8 +168,8 @@ def test_extractor_config(self):
def test_extractor_accepted_list(self, accepted_list):
"""Test that the extractor class returns the list of accepted and rejected ROIs
correctly given the list of non-zero ROIs."""
dummy_image_mask = np.zeros((20, 50, 50))
dummy_image_mask[accepted_list, ...] = 1
dummy_image_mask = np.zeros((50, 50, 20))
dummy_image_mask[..., accepted_list] = 1

self.extractor._image_masks = dummy_image_mask

Expand All @@ -186,13 +184,13 @@ def test_extractor_get_images_dict(self):
with h5py.File(self.file_path, "r") as segmentation_file:
summary_image = DatasetView(
segmentation_file[self.output_struct_name]["info"]["summary_image"],
)[:]
)[:].T
max_image = DatasetView(
segmentation_file[self.output_struct_name]["info"]["max_image"],
)[:]
)[:].T
f_per_pixel = DatasetView(
segmentation_file[self.output_struct_name]["info"]["F_per_pixel"],
)[:]
)[:].T

images_dict = self.extractor.get_images_dict()
self.assertEqual(len(images_dict), 5)
Expand Down