Skip to content

Allow for SortingAnalyzer or BaseSorter in plot_* #3941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions src/spikeinterface/widgets/isi_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import numpy as np
from warnings import warn

from spikeinterface.core import SortingAnalyzer, BaseSorting

from .base import BaseWidget, to_attr
from .utils import get_unit_colors


class ISIDistributionWidget(BaseWidget):
Expand All @@ -13,18 +14,37 @@ class ISIDistributionWidget(BaseWidget):

Parameters
----------
sorting : SortingExtractor
The sorting extractor object
unit_ids : list
List of unit ids
bins_ms : int
Bin size in ms
window_ms : float
sorting_analyzer_or_sorting : SortingAnalyzer | BaseSorting | None, default: None
The object containing the sorting information for the isi distribution plot
unit_ids : list | None, default: None
List of unit ids. If None, uses all unit ids.
window_ms : float, default: 100.0
Window size in ms

bins_ms : int, default: 1.0
Bin size in ms
sorting : SortingExtractor | None, default: None
A sorting object. Deprecated.
"""

def __init__(self, sorting, unit_ids=None, window_ms=100.0, bin_ms=1.0, backend=None, **backend_kwargs):
def __init__(
self,
sorting_analyzer_or_sorting: SortingAnalyzer | BaseSorting | None = None,
unit_ids: list | None = None,
window_ms: float = 100.0,
bin_ms: float = 1.0,
backend: str | None = None,
sorting: BaseSorting | None = None,
**backend_kwargs,
):

if sorting is not None:
# When removed, make `sorting_analyzer_or_sorting` a required argument rather than None.
deprecation_msg = "`sorting` argument is deprecated and will be removed in version 0.105.0. Please use `sorting_analyzer_or_sorting` instead"
warn(deprecation_msg, category=DeprecationWarning, stacklevel=2)
sorting_analyzer_or_sorting = sorting

sorting = self.ensure_sorting(sorting_analyzer_or_sorting)

if unit_ids is None:
unit_ids = sorting.get_unit_ids()

Expand Down
58 changes: 33 additions & 25 deletions src/spikeinterface/widgets/rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
from warnings import warn

from .base import BaseWidget, to_attr, default_backend_kwargs
from spikeinterface.core import SortingAnalyzer, BaseSorting
from .base import BaseWidget, to_attr
from .utils import get_some_colors


Expand Down Expand Up @@ -278,39 +279,46 @@ class RasterWidget(BaseRasterWidget):

Parameters
----------
sorting : SortingExtractor | None, default: None
A sorting object
sorting_analyzer : SortingAnalyzer | None, default: None
A sorting analyzer object
segment_index : None or int
The segment index.
unit_ids : list
List of unit ids
time_range : list
sorting_analyzer_or_sorting : SortingAnalyzer | BaseSorting | None, default: None
The object containing the sorting information for the raster plot
segment_index : None | int, default: None
The segment index. If None, uses first segment.
unit_ids : list | None, default: None
List of unit ids. If None, uses all unit ids.
time_range : list | None, default: None
List with start time and end time
color : matplotlib color
color : matplotlib color, default: "k"
The color to be used
sorting : SortingExtractor | None, default: None
A sorting object. Deprecated.
sorting_analyzer : SortingAnalyzer | None, default: None
A sorting analyzer object. Deprecated.
"""

def __init__(
self,
sorting=None,
sorting_analyzer=None,
segment_index=None,
unit_ids=None,
time_range=None,
sorting_analyzer_or_sorting: SortingAnalyzer | BaseSorting | None = None,
segment_index: int | None = None,
unit_ids: list | None = None,
time_range: list | None = None,
color="k",
backend=None,
backend: str | None = None,
sorting: BaseSorting | None = None,
sorting_analyzer: SortingAnalyzer | None = None,
**backend_kwargs,
):
if sorting is None and sorting_analyzer is None:
raise Exception("Must supply either a sorting or a sorting_analyzer")
elif sorting is not None and sorting_analyzer is not None:
raise Exception("Should supply either a sorting or a sorting_analyzer, not both")
elif sorting_analyzer is not None:
sorting = sorting_analyzer.sorting

sorting = self.ensure_sorting(sorting)

if sorting is not None:
# When removed, make `sorting_analyzer_or_sorting` a required argument rather than None.
deprecation_msg = "`sorting` argument is deprecated and will be removed in version 0.105.0. Please use `sorting_analyzer_or_sorting` instead"
warn(deprecation_msg, category=DeprecationWarning, stacklevel=2)
sorting_analyzer_or_sorting = sorting
if sorting_analyzer is not None:
deprecation_msg = "`sorting_analyzer` argument is deprecated and will be removed in version 0.105.0. Please use `sorting_analyzer_or_sorting` instead"
warn(deprecation_msg, category=DeprecationWarning, stacklevel=2)
sorting_analyzer_or_sorting = sorting_analyzer

sorting = self.ensure_sorting(sorting_analyzer_or_sorting)

if sorting.get_num_segments() > 1:
if segment_index is None:
Expand Down
34 changes: 23 additions & 11 deletions src/spikeinterface/widgets/unit_presence.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import numpy as np
from warnings import warn

from spikeinterface.core import SortingAnalyzer, BaseSorting
from .base import BaseWidget, to_attr


Expand All @@ -11,29 +13,39 @@ class UnitPresenceWidget(BaseWidget):

Parameters
----------
sorting : SortingExtractor
The sorting extractor object
segment_index : None or int
The segment index.
sorting_analyzer_or_sorting : SortingAnalyzer | BaseSorting | None, default: None
The object containing the sorting information for the raster plot
segment_index : None or int, default: None
The segment index. If None, uses first segment.
time_range : list or None, default: None
List with start time and end time
bin_duration_s : float, default: 0.5
Bin size (in seconds) for the heat map time axis
smooth_sigma : float, default: 4.5
Sigma for the Gaussian kernel (in number of bins)
sorting : SortingExtractor | None, default: None
A sorting object. Deprecated.
"""

def __init__(
self,
sorting,
segment_index=None,
time_range=None,
bin_duration_s=0.05,
smooth_sigma=4.5,
backend=None,
sorting_analyzer_or_sorting: SortingAnalyzer | BaseSorting | None = None,
segment_index: int | None = None,
time_range: list | None = None,
bin_duration_s: float = 0.05,
smooth_sigma: float = 4.5,
backend: str | None = None,
sorting: BaseSorting | None = None,
**backend_kwargs,
):
sorting = self.ensure_sorting(sorting)

if sorting is not None:
# When removed, make `sorting_analyzer_or_sorting` a required argument rather than None.
deprecation_msg = "`sorting` argument is deprecated and will be removed in version 0.105.0. Please use `sorting_analyzer_or_sorting` instead"
warn(deprecation_msg, category=DeprecationWarning, stacklevel=2)
sorting_analyzer_or_sorting = sorting

sorting = self.ensure_sorting(sorting_analyzer_or_sorting)

if segment_index is None:
nseg = sorting.get_num_segments()
Expand Down
Loading