Skip to content

Add splitting functionality to curation and SortingAnalyzer #3817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
96863de
Add pydantic curation model and improve curation format and merging r…
alejoe91 Mar 11, 2025
1ce611c
Update src/spikeinterface/curation/curation_model.py
alejoe91 Mar 11, 2025
567b2b7
Merge branch 'curation-pydantic' of github.com:alejoe91/spikeinterfac…
alejoe91 Mar 11, 2025
3464987
Move pydantic to core
alejoe91 Mar 11, 2025
69c4854
Merge branch 'main' into curation-pydantic
alejoe91 Mar 24, 2025
228722c
Merge branch 'main' into curation-pydantic
alejoe91 Mar 25, 2025
677f90c
wip
alejoe91 Mar 25, 2025
86a3ab4
Enhance CurationModel: Add split_units validation
anoushkajain Mar 26, 2025
d4e0f84
Add splitting sorting to curation format
alejoe91 Mar 26, 2025
c7316bb
(wip) Add split_units to SortingAnalyzer
alejoe91 Mar 26, 2025
2e21923
Add split_units to sorting analyzer
alejoe91 Mar 26, 2025
4afdb80
Propagate SortingAnalyzer.split_units to apply_curation
alejoe91 Mar 26, 2025
62bfb7f
Extend CurationModel tests
alejoe91 Mar 27, 2025
40fe01b
Add analyzer split to curation tests
alejoe91 Mar 27, 2025
ca6f2e0
wip: add split tests in postprocessing
alejoe91 Mar 27, 2025
58b62fb
wip - modify model
alejoe91 Mar 27, 2025
dbfa315
Refactor curation model to include merges and splits
alejoe91 Mar 27, 2025
82526b0
Add merge list to tests
alejoe91 Mar 27, 2025
482f0be
Simplify and centralize conversion and checks
alejoe91 Mar 27, 2025
f122db7
Fix sortingview tests
alejoe91 Mar 27, 2025
4f14e90
Fix sortingview conversion
alejoe91 Mar 27, 2025
d7633bf
Merge branch 'main' into curation-pydantic
alejoe91 Mar 27, 2025
09a379e
Fix tests and cleanup apply_curation
alejoe91 Mar 28, 2025
d2f220a
Fix test-multi-extensions
alejoe91 Mar 28, 2025
317f87c
merge_new_unit_ids -> merge_new_unit_id
alejoe91 Mar 28, 2025
a9ed838
Merge branch 'curation-pydantic' of github.com:alejoe91/spikeinterfac…
alejoe91 Mar 28, 2025
64a3b83
conflicts
alejoe91 Mar 28, 2025
d4fa8bf
Deal with multi-segment
alejoe91 Mar 28, 2025
659ecff
Extend splitting-tests to multi-segment and mask labels
alejoe91 Mar 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"neo>=0.14.0",
"probeinterface>=0.2.23",
"packaging",
"pydantic",
]

[build-system]
Expand Down
7 changes: 6 additions & 1 deletion src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,12 @@
get_chunk_with_margin,
order_channels_by_depth,
)
from .sorting_tools import spike_vector_to_spike_trains, random_spikes_selection, apply_merges_to_sorting
from .sorting_tools import (
spike_vector_to_spike_trains,
random_spikes_selection,
apply_merges_to_sorting,
apply_splits_to_sorting,
)

from .waveform_tools import extract_waveforms_to_buffers, estimate_templates, estimate_templates_with_accumulator
from .snippets_tools import snippets_from_sorting
Expand Down
52 changes: 50 additions & 2 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ def _merge_extension_data(
new_data["random_spikes_indices"] = np.flatnonzero(selected_mask[keep_mask])
return new_data

def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs):
new_data = dict()
new_data["random_spikes_indices"] = self.data["random_spikes_indices"].copy()
return new_data

def _get_data(self):
return self.data["random_spikes_indices"]

Expand Down Expand Up @@ -245,8 +250,6 @@ def _select_extension_data(self, unit_ids):
def _merge_extension_data(
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
):
new_data = dict()

waveforms = self.data["waveforms"]
some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes()
if keep_mask is not None:
Expand Down Expand Up @@ -277,6 +280,11 @@ def _merge_extension_data(

return dict(waveforms=waveforms)

def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs):
# splitting only affects random spikes, not waveforms
new_data = dict(waveforms=self.data["waveforms"].copy())
return new_data

def get_waveforms_one_unit(self, unit_id, force_dense: bool = False):
"""
Returns the waveforms of a unit id.
Expand Down Expand Up @@ -556,6 +564,42 @@ def _merge_extension_data(

return new_data

def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs):
new_data = dict()
for operator, arr in self.data.items():
# we first copy the unsplit units
new_array = np.zeros((len(new_sorting_analyzer.unit_ids), arr.shape[1], arr.shape[2]), dtype=arr.dtype)
new_analyzer_unit_ids = list(new_sorting_analyzer.unit_ids)
unsplit_unit_ids = [unit_id for unit_id in self.sorting_analyzer.unit_ids if unit_id not in split_units]
new_indices = np.array([new_analyzer_unit_ids.index(unit_id) for unit_id in unsplit_unit_ids])
old_indices = self.sorting_analyzer.sorting.ids_to_indices(unsplit_unit_ids)
new_array[new_indices, ...] = arr[old_indices, ...]

for split_unit_id, new_splits in zip(split_units, new_unit_ids):
if new_sorting_analyzer.has_extension("waveforms"):
for new_unit_id in new_splits:
split_unit_index = new_sorting_analyzer.sorting.id_to_index(new_unit_id)
wfs = new_sorting_analyzer.get_extension("waveforms").get_waveforms_one_unit(
new_unit_id, force_dense=True
)

if operator == "average":
arr = np.average(wfs, axis=0)
elif operator == "std":
arr = np.std(wfs, axis=0)
elif operator == "median":
arr = np.median(wfs, axis=0)
elif "percentile" in operator:
_, percentile = operator.splot("_")
arr = np.percentile(wfs, float(percentile), axis=0)
new_array[split_unit_index, ...] = arr
else:
old_template = arr[self.sorting_analyzer.sorting.ids_to_indices([split_unit_id])[0], ...]
new_indices = np.array([new_unit_ids.index(unit_id) for unit_id in new_splits])
new_array[new_indices, ...] = np.tile(old_template, (len(new_splits), 1, 1))
Comment on lines +597 to +599
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@samuelgarcia this needs to be discussed. What should we do if the waveforms extension is not there? The current behavior is copying, but we mught want to force recompute here

new_data[operator] = new_array
return new_data

def _get_data(self, operator="average", percentile=None, outputs="numpy"):
if operator != "percentile":
key = operator
Expand Down Expand Up @@ -729,6 +773,10 @@ def _merge_extension_data(
# this does not depend on units
return self.data.copy()

def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs):
# this does not depend on units
return self.data.copy()

def _run(self, verbose=False):
self.data["noise_levels"] = get_noise_levels(
self.sorting_analyzer.recording, return_scaled=self.sorting_analyzer.return_scaled, **self.params
Expand Down
187 changes: 183 additions & 4 deletions src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,15 @@ def random_spikes_selection(
return random_spikes_indices


### MERGING ZONE ###
def apply_merges_to_sorting(
sorting, merge_unit_groups, new_unit_ids=None, censor_ms=None, return_extra=False, new_id_strategy="append"
):
sorting: BaseSorting,
merge_unit_groups: list[list[int | str]] | list[tuple[int | str]],
new_unit_ids: list[int | str] | None = None,
censor_ms: float | None = None,
return_extra: bool = False,
new_id_strategy: str = "append",
) -> NumpySorting | tuple[NumpySorting, np.ndarray, list[int | str]]:
"""
Apply a resolved representation of the merges to a sorting object.

Expand All @@ -245,9 +251,9 @@ def apply_merges_to_sorting(

Parameters
----------
sorting : Sorting
sorting : BaseSorting
The Sorting object to apply merges.
merge_unit_groups : list/tuple of lists/tuples
merge_unit_groups : list of lists/tuples
A list of lists for every merge group. Each element needs to have at least two elements (two units to merge),
but it can also have more (merge multiple units at once).
new_unit_ids : list | None, default: None
Expand Down Expand Up @@ -440,3 +446,176 @@ def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_
raise ValueError("wrong new_id_strategy")

return new_unit_ids


### SPLITTING ZONE ###
def apply_splits_to_sorting(sorting, unit_splits, new_unit_ids=None, return_extra=False, new_id_strategy="append"):
spikes = sorting.to_spike_vector().copy()

# here we assume that unit_splits split_indices are already full.
# this is true when running via apply_curation

new_unit_ids = generate_unit_ids_for_split(
sorting.unit_ids, unit_splits, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy
)
all_unit_ids = _get_ids_after_splitting(sorting.unit_ids, unit_splits, new_unit_ids)
all_unit_ids = list(all_unit_ids)

num_seg = sorting.get_num_segments()
seg_lims = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2))
segment_slices = [(seg_lims[i], seg_lims[i + 1]) for i in range(num_seg)]

# using this function vaoid to use the mask approach and simplify a lot the algo
spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices]
spike_indices = spike_vector_to_indices(spike_vector_list, sorting.unit_ids, absolute_index=True)

for unit_id in sorting.unit_ids:
if unit_id in unit_splits:
split_indices = unit_splits[unit_id]
new_split_ids = new_unit_ids[list(unit_splits.keys()).index(unit_id)]

for split, new_unit_id in zip(split_indices, new_split_ids):
new_unit_index = all_unit_ids.index(new_unit_id)
# split_indices are a concatenation across segments with absolute indices
# so we need to concatenate the spike indices across segments
spike_indices_unit = np.concatenate(
[spike_indices[segment_index][unit_id] for segment_index in range(num_seg)]
)
spikes["unit_index"][spike_indices_unit[split]] = new_unit_index
else:
new_unit_index = all_unit_ids.index(unit_id)
for segment_index in range(num_seg):
spike_inds = spike_indices[segment_index][unit_id]
spikes["unit_index"][spike_inds] = new_unit_index
sorting = NumpySorting(spikes, sorting.sampling_frequency, all_unit_ids)

if return_extra:
return sorting, new_unit_ids
else:
return sorting


def generate_unit_ids_for_split(old_unit_ids, unit_splits, new_unit_ids=None, new_id_strategy="append"):
"""
Function to generate new units ids during a merging procedure. If new_units_ids
are provided, it will return these unit ids, checking that they have the the same
length as `merge_unit_groups`.

Parameters
----------
old_unit_ids : np.array
The old unit_ids.
unit_splits : dict

new_unit_ids : list | None, default: None
Optional new unit_ids for merged units. If given, it needs to have the same length as `merge_unit_groups`.
If None, new ids will be generated.
new_id_strategy : "append" | "take_first" | "join", default: "append"
The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids.

* "append" : new_units_ids will be added at the end of max(sorging.unit_ids)
* "split" : new_unit_ids will join unit_ids of groups with a "-".
Only works if unit_ids are str otherwise switch to "append"

Returns
-------
new_unit_ids : list of lists
The new units_ids associated with the merges.
"""
assert new_id_strategy in ["append", "split"], "new_id_strategy should be 'append' or 'split'"
old_unit_ids = np.asarray(old_unit_ids)

if new_unit_ids is not None:
for split_unit, new_split_ids in zip(unit_splits.values(), new_unit_ids):
# then only doing a consistency check
assert len(split_unit) == len(new_split_ids), "new_unit_ids should have the same len as unit_splits.values"
# new_unit_ids can also be part of old_unit_ids only inside the same group:
assert all(
new_split_id not in old_unit_ids for new_split_id in new_split_ids
), "new_unit_ids already exists but outside the split groups"
else:
dtype = old_unit_ids.dtype
new_unit_ids = []
for unit_to_split, split_indices in unit_splits.items():
num_splits = len(split_indices)
# select new_unit_ids greater that the max id, event greater than the numerical str ids
if new_id_strategy == "append":
if np.issubdtype(dtype, np.character):
# dtype str
if all(p.isdigit() for p in old_unit_ids):
# All str are digit : we can generate a max
m = max(int(p) for p in old_unit_ids) + 1
new_unit_ids.append([str(m + i) for i in range(num_splits)])
else:
# we cannot automatically find new names
new_unit_ids.append([f"split{i}" for i in range(num_splits)])
else:
# dtype int
new_unit_ids.append(list(max(old_unit_ids) + 1 + np.arange(num_splits, dtype=dtype)))
old_unit_ids = np.concatenate([old_unit_ids, new_unit_ids[-1]])
elif new_id_strategy == "split":
if np.issubdtype(dtype, np.character):
new_unit_ids.append([f"{unit_to_split}-{i}" for i in np.arange(len(split_indices))])
else:
# dtype int
new_unit_ids.append(list(max(old_unit_ids) + 1 + np.arange(num_splits, dtype=dtype)))
old_unit_ids = np.concatenate([old_unit_ids, new_unit_ids[-1]])

return new_unit_ids


def _get_full_unit_splits(unit_splits, sorting):
# take care of single-list splits
full_unit_splits = {}
num_spikes = sorting.count_num_spikes_per_unit()
for unit_id, split_indices in unit_splits.items():
if not isinstance(split_indices[0], (list, np.ndarray)):
split_2 = np.arange(num_spikes[unit_id])
split_2 = split_2[~np.isin(split_2, split_indices)]
new_split_indices = [split_indices, split_2]
else:
new_split_indices = split_indices
full_unit_splits[unit_id] = new_split_indices
return full_unit_splits


def _get_ids_after_splitting(old_unit_ids, split_units, new_unit_ids):
"""
Function to get the list of unique unit_ids after some splits, with given new_units_ids would
be provided.

Every new unit_id will be added at the end if not already present.

Parameters
----------
old_unit_ids : np.array
The old unit_ids.
split_units : dict
A dict of split units. Each element needs to have at least two elements (two units to split).
new_unit_ids : list | None
A new unit_ids for split units. If given, it needs to have the same length as `split_units` values.

Returns
-------

all_unit_ids : The unit ids in the split sorting
The units_ids that will be present after splits

"""
old_unit_ids = np.asarray(old_unit_ids)
dtype = old_unit_ids.dtype
if dtype.kind == "U":
# the new dtype can be longer
dtype = "U"

assert len(new_unit_ids) == len(split_units), "new_unit_ids should have the same len as merge_unit_groups"
for new_unit_in_split, unit_to_split in zip(new_unit_ids, split_units.keys()):
assert len(new_unit_in_split) == len(
split_units[unit_to_split]
), "new_unit_ids should have the same len as split_units values"

all_unit_ids = list(old_unit_ids.copy())
for split_unit, split_new_units in zip(split_units, new_unit_ids):
all_unit_ids.remove(split_unit)
all_unit_ids.extend(split_new_units)
return np.array(all_unit_ids, dtype=dtype)
Loading