Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add output typehints #332

Merged
merged 10 commits into from
May 21, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def get_rejected_list(self):
return rejected

@staticmethod
def write_segmentation(segmentation_object, save_path, overwrite=True):
def write_segmentation(segmentation_object: SegmentationExtractor, save_path: PathType, overwrite: bool = True):
"""Write a segmentation object to a *.hdf5 or *.h5 file specified by save_path.

Parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_rejected_list(self):
return [a for a in range(self.get_num_rois()) if a not in ac_set]

@staticmethod
def write_segmentation(segmentation_object: SegmentationExtractor, save_path, overwrite=True):
def write_segmentation(segmentation_object: SegmentationExtractor, save_path: PathType, overwrite: bool = True):
"""Write a segmentation object to a .mat file.

Parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Suite2pSegmentationExtractor(SegmentationExtractor):
installation_mesg = "" # error message when not installed

@classmethod
def get_available_channels(cls, folder_path: PathType):
def get_available_channels(cls, folder_path: PathType) -> list[str]:
"""Get the available channel names from the folder paths produced by Suite2p.

Parameters
Expand All @@ -52,7 +52,7 @@ def get_available_channels(cls, folder_path: PathType):
return channel_names

@classmethod
def get_available_planes(cls, folder_path: PathType):
def get_available_planes(cls, folder_path: PathType) -> list[str]:
"""Get the available plane names from the folder produced by Suite2p.

Parameters
Expand Down Expand Up @@ -220,13 +220,13 @@ def _load_npy(self, file_name: str, mmap_mode=None, transpose: bool = False, req
def get_num_frames(self) -> int:
return self._num_frames

def get_accepted_list(self):
def get_accepted_list(self) -> list[int]:
return list(np.where(self.iscell[:, 0] == 1)[0])

def get_rejected_list(self):
def get_rejected_list(self) -> list[int]:
return list(np.where(self.iscell[:, 0] == 0)[0])

def _correlation_image_read(self):
def _correlation_image_read(self) -> Optional[np.ndarray]:
"""Read correlation image from ops (settings) dict.

Returns
Expand All @@ -250,11 +250,11 @@ def _correlation_image_read(self):
return img

@property
def roi_locations(self):
def roi_locations(self) -> np.ndarray:
"""Returns the center locations (x, y) of each ROI."""
return np.array([j["med"] for j in self.stat]).T.astype(int)

def get_roi_pixel_masks(self, roi_ids=None):
def get_roi_pixel_masks(self, roi_ids=None) -> list[np.ndarray]:
pixel_mask = []
for i in range(self.get_num_rois()):
pixel_mask.append(
Expand All @@ -274,7 +274,7 @@ def get_roi_pixel_masks(self, roi_ids=None):
roi_idx_ = [j[0] for i, j in enumerate(roi_idx) if i not in ele]
return [pixel_mask[i] for i in roi_idx_]

def get_image_size(self):
def get_image_size(self) -> tuple[int, int]:
return self._image_size

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def check_frame_inputs(self, frame) -> None:
if frame < 0:
raise ValueError(f"Frame index ({frame}) must be greater than or equal to 0.")

def frame_to_raw_index(self, frame):
def frame_to_raw_index(self, frame: int) -> int:
"""Convert a frame index to the raw index in the TIFF file.

Parameters
Expand Down
16 changes: 8 additions & 8 deletions src/roiextractors/multisegmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(self, segmentatation_extractors_list, plane_names=None): # TODO: H
_ = [self._channel_names.extend(self._segmentations[i].get_channel_names()) for i in range(self._no_planes)]

@property
def no_planes(self):
def no_planes(self) -> int:
"""Number of planes in the recording.

Returns
Expand All @@ -115,7 +115,7 @@ def no_planes(self):
return self._no_planes

@property
def segmentations(self):
def segmentations(self) -> list[SegmentationExtractor]:
"""List of segmentation extractors (one for each plane).

Returns
Expand All @@ -128,7 +128,7 @@ def segmentations(self):
def get_num_channels(self):
return np.sum([self._segmentations[i].get_num_channels() for i in range(self._no_planes)])

def get_num_rois(self):
def get_num_rois(self) -> int:
return len(self._all_roi_ids)

def get_images(self, name="correlation_plane0"): # TODO: add get_images to base SegmentationExtractor class
Expand All @@ -147,21 +147,21 @@ def get_images(self, name="correlation_plane0"): # TODO: add get_images to base
plane_no = int(name[-1])
return self._segmentations[plane_no].get_images(name=name.split("_")[0])

def get_images_dict(self):
def get_images_dict(self) -> dict:
return_dict = dict()
for i in range(self._no_planes):
for image_name, image in self._segmentations[i].get_images_dict().items():
return_dict.update({f"{image_name}_Plane{i}": image})
return return_dict

def get_traces_dict(self):
def get_traces_dict(self) -> dict:
return_dict = dict()
for i in range(self._no_planes):
for trace_name, trace in self._segmentations[i].get_traces_dict().items():
return_dict.update({f"{trace_name}_Plane{i}": trace})
return return_dict

def get_image_size(self):
def get_image_size(self) -> tuple[int, int]:
return self._segmentations[0].get_image_size()

@concatenate_output
Expand All @@ -183,14 +183,14 @@ def get_roi_locations(self, roi_ids=None):
def get_num_frames(self):
return np.sum([self._segmentations[i].get_num_frames() for i in range(self._no_planes)])

def get_accepted_list(self):
def get_accepted_list(self) -> list[int]:
accepted_list_all = []
for i in range(self._no_planes):
ids_loop = self._segmentations[i].get_accepted_list()
accepted_list_all.extend([j for j in self._all_roi_ids if self._roi_map[j]["roi_id"] in ids_loop])
return accepted_list_all

def get_rejected_list(self):
def get_rejected_list(self) -> list[int]:
rejected_list_all = []
for i in range(self._no_planes):
ids_loop = self._segmentations[i].get_rejected_list()
Expand Down
14 changes: 7 additions & 7 deletions src/roiextractors/segmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def get_traces(
start_frame: Optional[int] = None,
end_frame: Optional[int] = None,
name: str = "raw",
):
) -> ArrayType:
"""Get the traces of each ROI specified by roi_ids.

Parameters
Expand Down Expand Up @@ -268,7 +268,7 @@ def get_traces(
idxs = slice(None) if roi_ids is None else roi_idxs
return np.array(traces[start_frame:end_frame, :])[:, idxs] # numpy fancy indexing is quickest

def get_traces_dict(self):
def get_traces_dict(self) -> dict:
"""Get traces as a dictionary with key as the name of the ROiResponseSeries.

Returns
Expand All @@ -285,7 +285,7 @@ def get_traces_dict(self):
denoised=self._roi_response_denoised,
)

def get_images_dict(self):
def get_images_dict(self) -> dict:
"""Get images as a dictionary with key as the name of the ROIResponseSeries.

Returns
Expand Down Expand Up @@ -368,7 +368,7 @@ def get_num_channels(self) -> int:
"""
return len(self._channel_names)

def get_num_planes(self):
def get_num_planes(self) -> int:
"""Get the default number of planes of imaging for the segmentation extractor.

Notes
Expand Down Expand Up @@ -513,7 +513,7 @@ def get_traces(
name=name,
)

def get_traces_dict(self):
def get_traces_dict(self) -> dict:
return {
trace_name: self._parent_segmentation.get_traces(
start_frame=self._start_frame, end_frame=self._end_frame, name=trace_name
Expand All @@ -527,7 +527,7 @@ def get_image_size(self) -> Tuple[int, int]:
def get_num_frames(self) -> int:
return self._num_frames

def get_num_rois(self):
def get_num_rois(self) -> int:
return self._parent_segmentation.get_num_rois()

def get_images_dict(self) -> dict:
Expand All @@ -545,7 +545,7 @@ def get_channel_names(self) -> list:
def get_num_channels(self) -> int:
return self._parent_segmentation.get_num_channels()

def get_num_planes(self):
def get_num_planes(self) -> int:
return self._parent_segmentation.get_num_planes()

def get_roi_pixel_masks(self, roi_ids: Optional[ArrayLike] = None) -> List[np.ndarray]:
Expand Down
Loading