Skip to content

Allow run_sorter to accept dicts #4005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions doc/how_to/process_by_channel_group.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ to any preprocessing function.
referenced_recording = spre.common_reference(filtered_recording)
good_channels_recording = spre.detect_and_remove_bad_channels(filtered_recording)

We can then aggregate the recordings back together using the ``aggregate_channels`` function
We can then aggregate the recordings back together using the ``aggregate_channels`` function.
Note that we do not need to do this to sort the data (see :ref:`sorting a recording by channel group`)

.. code-block:: python

Expand Down Expand Up @@ -141,16 +142,38 @@ Sorting a Recording by Channel Group
We can also sort a recording for each channel group separately. It is not necessary to preprocess
a recording by channel group in order to sort by channel group.

There are two ways to sort a recording by channel group. First, we can split the preprocessed
recording (or, if it was already split during preprocessing as above, skip the :py:func:`~aggregate_channels` step
directly use the :py:func:`~split_recording_dict`).
There are two ways to sort a recording by channel group. First, we can simply pass the output from
our preprocessing-by-group method above. Second, for more control, we can loop over the recordings
ourselves.

**Option 1: Manual splitting**
**Option 1 : Automatic splitting**

In this example, similar to above we loop over all preprocessed recordings that
Simply pass the split recording to the `run_sorter` function, as if it was a non-split recording.
This will return a dict of sortings, with the keys corresponding to the groups.

.. code-block:: python

split_recording = raw_recording.split_by("group")

# do preprocessing if needed
pp_recording = spre.bandpass_filter(split_recording)

dict_of_sortings = run_sorter(
sorter_name='kilosort2',
recording=pp_recording,
working_folder='working_path'
)


**Option 2: Manual splitting**

In this example, we loop over all preprocessed recordings that
are grouped by channel, and apply the sorting separately. We store the
sorting objects in a dictionary for later use.

You might do this if you want extra control e.g. to apply bespoke steps
to different groups.

.. code-block:: python

split_preprocessed_recording = preprocessed_recording.split_by("group")
Expand All @@ -163,16 +186,3 @@ sorting objects in a dictionary for later use.
output_folder=f"folder_KS2_group{group}"
)
sortings[group] = sorting

**Option 2 : Automatic splitting**

Alternatively, SpikeInterface provides a convenience function to sort the recording by property:

.. code-block:: python

aggregate_sorting = run_sorter_by_property(
sorter_name='kilosort2',
recording=preprocessed_recording,
grouping_property='group',
working_folder='working_path'
)
27 changes: 15 additions & 12 deletions doc/modules/sorters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ Running spike sorting by group is indeed a very common need.
A :py:class:`~spikeinterface.core.BaseRecording` object has the ability to split itself into a dictionary of
sub-recordings given a certain property (see :py:meth:`~spikeinterface.core.BaseRecording.split_by`).
So it is easy to loop over this dictionary and sequentially run spike sorting on these sub-recordings.
SpikeInterface also provides a high-level function to automate the process of splitting the
recording and then aggregating the results with the :py:func:`~spikeinterface.sorters.run_sorter_by_property` function.
The :py:func:`~spikeinterface.sorters.run_sorter` method can also accept the dictionary which is returned
by :py:meth:`~spikeinterface.core.BaseRecording.split_by` and will return a dictionary of sortings.

In this example, we create a 16-channel recording with 4 tetrodes:

Expand Down Expand Up @@ -368,7 +368,19 @@ In this example, we create a 16-channel recording with 4 tetrodes:
# >>> [0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3]


**Option 1: Manual splitting**
**Option 1 : Automatic splitting**

.. code-block:: python

# here the result is a dict of sortings
dict_of_sortings = run_sorter(
sorter_name='kilosort2',
recording=recording_4_tetrodes,
working_folder='working_path'
)


**Option 2: Manual splitting**

.. code-block:: python

Expand All @@ -383,15 +395,6 @@ In this example, we create a 16-channel recording with 4 tetrodes:
sorting = run_sorter(sorter_name='kilosort2', recording=recording, output_folder=f"folder_KS2_group{group}")
sortings[group] = sorting

**Option 2 : Automatic splitting**

.. code-block:: python

# here the result is one sorting that aggregates all sub sorting objects
aggregate_sorting = run_sorter_by_property(sorter_name='kilosort2', recording=recording_4_tetrodes,
grouping_property='group',
working_folder='working_path')


Handling multi-segment recordings
---------------------------------
Expand Down
82 changes: 75 additions & 7 deletions src/spikeinterface/sorters/runsorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
----------
sorter_name : str
The sorter name
recording : RecordingExtractor
recording : RecordingExtractor | dict of RecordingExtractor
The recording extractor to be spike sorted
folder : str or Path
Path to output folder
Expand Down Expand Up @@ -100,16 +100,12 @@
**sorter_params : keyword args
Spike sorter specific arguments (they can be retrieved with `get_default_sorter_params(sorter_name_or_class)`)

Returns
-------
BaseSorting | None
The spike sorted data (it `with_output` is True) or None (if `with_output` is False)
"""


def run_sorter(
sorter_name: str,
recording: BaseRecording,
recording: BaseRecording | dict,
folder: Optional[str] = None,
remove_existing_folder: bool = False,
delete_output_folder: bool = False,
Expand All @@ -124,8 +120,11 @@ def run_sorter(
):
"""
Generic function to run a sorter via function approach.

{}
Returns
-------
BaseSorting | dict of BaseSorting | None
The spike sorted data (it `with_output` is True) or None (if `with_output` is False)

Examples
--------
Expand All @@ -151,6 +150,20 @@ def run_sorter(
**sorter_params,
)

if isinstance(recording, dict):

all_kwargs = common_kwargs
all_kwargs.update(
dict(
docker_image=docker_image,
singularity_image=singularity_image,
delete_container_files=delete_container_files,
)
)

dict_of_sorters = _run_sorter_by_dict(recording, **all_kwargs)
return dict_of_sorters

if docker_image or singularity_image:
common_kwargs.update(dict(delete_container_files=delete_container_files))
if docker_image:
Expand Down Expand Up @@ -201,6 +214,61 @@ def run_sorter(
run_sorter.__doc__ = run_sorter.__doc__.format(_common_param_doc)


def _run_sorter_by_dict(dict_of_recordings: dict, folder: str | Path | None = None, **run_sorter_params):
"""
Applies `run_sorter` to each recording in a dict of recordings and saves
the results.
{}
Returns
-------
dict
Dictionary of `BaseSorting`s, with the same keys as the input dict of `BaseRecording`s.
"""

sorter_name = run_sorter_params["sorter_name"]
remove_existing_folder = run_sorter_params["remove_existing_folder"]

if folder is None:
folder = Path(sorter_name + "_output")

folder = Path(folder)
folder.mkdir(exist_ok=remove_existing_folder)

# If we know how the recording was split, save this in the info file
first_recording = next(iter(dict_of_recordings.values()))
split_by_property = first_recording.get_annotation("split_by_property")
if split_by_property is None:
split_by_property = "Unknown"

dict_keys = dict_of_recordings.keys()
dict_key_types = [type(key).__name__ for key in dict_keys]

info_file = folder / "spikeinterface_info.json"
info = dict(
version=spikeinterface.__version__,
dev_mode=spikeinterface.DEV_MODE,
object="dict of Sorting",
dict_keys=list(dict_of_recordings.keys()),
dict_key_types=dict_key_types,
split_by_property=split_by_property,
)
with open(info_file, mode="w") as f:
json.dump(check_json(info), f, indent=4)

sorter_dict = {}
for group_key, recording in dict_of_recordings.items():

if "recording" in run_sorter_params:
run_sorter_params.pop("recording")

sorter_dict[group_key] = run_sorter(recording=recording, folder=folder / f"{group_key}", **run_sorter_params)

return sorter_dict


_run_sorter_by_dict.__doc__ = _run_sorter_by_dict.__doc__.format(_common_param_doc)


def run_sorter_local(
sorter_name,
recording,
Expand Down
43 changes: 43 additions & 0 deletions src/spikeinterface/sorters/tests/test_runsorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
import shutil
from packaging.version import parse
import json

from spikeinterface import generate_ground_truth_recording
from spikeinterface.sorters import run_sorter
Expand Down Expand Up @@ -45,6 +46,48 @@ def test_run_sorter_local(generate_recording, create_cache_folder):
print(sorting)


def test_run_sorter_dict(generate_recording, create_cache_folder):
recording = generate_recording
cache_folder = create_cache_folder

recording.set_property(key="split_property", values=[4, 4, "g", "g", 4, 4, 4, "g"])
dict_of_recordings = recording.split_by("split_property")

sorter_params = {"detection": {"detect_threshold": 4.9}}

output_folder = cache_folder / "sorting_tdc_local_dict"

dict_of_sortings = run_sorter(
"tridesclous2",
dict_of_recordings,
output_folder=output_folder,
remove_existing_folder=True,
delete_output_folder=False,
verbose=True,
raise_error=True,
**sorter_params,
)

assert set(list(dict_of_sortings.keys())) == set(["g", "4"])
assert (output_folder / "g").is_dir()
assert (output_folder / "4").is_dir()

assert dict_of_sortings["g"]._recording.get_num_channels() == 3
assert dict_of_sortings["4"]._recording.get_num_channels() == 5

info_filepath = output_folder / "spikeinterface_info.json"
assert info_filepath.is_file()

with open(info_filepath) as f:
spikeinterface_info = json.load(f)

si_info_keys = spikeinterface_info.keys()
for key in ["version", "dev_mode", "object", "dict_keys", "split_by_property"]:
assert key in si_info_keys

assert spikeinterface_info["split_by_property"] == "split_property"


@pytest.mark.skipif(ON_GITHUB, reason="Docker tests don't run on github: test locally")
def test_run_sorter_docker(generate_recording, create_cache_folder):
recording = generate_recording
Expand Down