Skip to content

Method to bypass matching, and assign a labels to all peaks given templates and SVD representation #3856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
8e33d08
WIP
yger Apr 8, 2025
e499302
Example of how to use SVD to estimate templates in SC2
yger Apr 8, 2025
5be94da
Patching to get a working example
yger Apr 8, 2025
2be228b
WIP
yger Apr 8, 2025
bd7c7be
WIP
yger Apr 8, 2025
1cd89f3
WIP
yger Apr 8, 2025
c28a7b6
WIP
yger Apr 8, 2025
8e455f9
WIP
yger Apr 8, 2025
6de8310
WIP
yger Apr 8, 2025
3fb5fa6
Cosmetic
yger Apr 8, 2025
5eec5e3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2025
74f52bd
Patch
yger Apr 8, 2025
1388666
Merge branch 'returned_svd' of github.com:yger/spikeinterface into re…
yger Apr 8, 2025
ff10442
WIP
yger Apr 9, 2025
d0333dd
WIP
yger Apr 9, 2025
220fe03
Fix
yger Apr 9, 2025
97e8c67
WIP
yger Apr 9, 2025
83dceec
Merge branch 'main' into returned_svd
yger Apr 9, 2025
2cb39eb
WIP
yger Apr 9, 2025
c502f4d
WIP
yger Apr 9, 2025
cc4fa8f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2025
4b99ff1
Merge branch 'returned_svd' into no_matching
yger Apr 9, 2025
830d8a2
Start a full clustering pipeline
yger Apr 9, 2025
f6c0342
Option to bypass matchin
yger Apr 10, 2025
0c77a48
WIP
yger Apr 10, 2025
88bcd29
Fix
yger Apr 10, 2025
7b39765
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2025
461819b
WIP
yger Apr 10, 2025
a69904b
Merge branch 'no_matching' of github.com:yger/spikeinterface into no_…
yger Apr 10, 2025
2f1674b
Make delete_mixtures optional
yger Apr 10, 2025
8728273
Better logs
yger Apr 10, 2025
3c854bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2025
d7a4b1c
Merge branch 'returned_svd' of github.com:yger/spikeinterface into re…
yger Apr 10, 2025
6a5531e
Merge branch 'returned_svd' into no_matching
yger Apr 10, 2025
c813216
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2025
c0ab25e
Merging
yger Apr 10, 2025
eb5dd7e
Merge branch 'main' into no_matching
yger Apr 10, 2025
b0bb5a9
Sync with main
yger Apr 16, 2025
5dcb6f7
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Apr 22, 2025
39eb4ce
Merge branch 'main' into no_matching
yger May 6, 2025
e3cc000
WIP
yger May 6, 2025
6a20fd5
Merge branch 'no_matching' of github.com:yger/spikeinterface into no_…
yger May 6, 2025
a978989
Merge branch 'main' into no_matching
yger May 6, 2025
79b4f15
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 6, 2025
27882d1
WIP
yger May 6, 2025
2445aa2
Cleaning args
yger May 7, 2025
94e36d3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
num_channels = recording.get_num_channels()
ms_before = params["general"].get("ms_before", 2)
ms_after = params["general"].get("ms_after", 2)
radius_um = params["general"].get("radius_um", 75)
radius_um = params["general"].get("radius_um", 100)
peak_sign = params["detection"].get("peak_sign", "neg")
templates_from_svd = params["templates_from_svd"]
debug = params["debug"]
Expand Down Expand Up @@ -170,7 +170,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
## Then, we are detecting peaks with a locally_exclusive method
detection_method = params["detection"].get("method", "matched_filtering")
detection_params = params["detection"].get("method_kwargs", dict())
detection_params["radius_um"] = radius_um
detection_params["radius_um"] = radius_um / 2
detection_params["exclude_sweep_ms"] = exclude_sweep_ms
detection_params["noise_levels"] = noise_levels

Expand Down Expand Up @@ -219,6 +219,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
)
detection_method = "locally_exclusive"

matching_method = params["matching"].get("method", "circus-omp-svd")
if matching_method is None:
# We want all peaks if we are planning to assign them to templates afterwards
detection_params["skip_after_n_peaks"] = None

peaks = detect_peaks(recording_w, detection_method, **detection_params, **job_kwargs)
order = np.lexsort((peaks["sample_index"], peaks["segment_index"]))
peaks = peaks[order]
Expand Down Expand Up @@ -353,11 +358,24 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
else:
## we should have a case to deal with clustering all peaks without matching
## for small density channel counts
from spikeinterface.sortingcomponents.matching.tools import assign_templates_to_peaks

peak_labels = assign_templates_to_peaks(
recording_w,
peaks,
templates=templates,
svd_model=svd_model,
sparse_mask=sparsity_mask,
**job_kwargs,
)

if verbose:
print("Found %d spikes" % len(peaks))

sorting = np.zeros(selected_peaks.size, dtype=minimum_spike_dtype)
sorting["sample_index"] = selected_peaks["sample_index"]
sorting = np.zeros(peaks.size, dtype=minimum_spike_dtype)
sorting["sample_index"] = peaks["sample_index"]
sorting["unit_index"] = peak_labels
sorting["segment_index"] = selected_peaks["segment_index"]
sorting["segment_index"] = peaks["segment_index"]
sorting = NumpySorting(sorting, sampling_frequency, templates.unit_ids)

merging_params = params["merging"].copy()
Expand Down
138 changes: 138 additions & 0 deletions src/spikeinterface/sortingcomponents/matching/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from spikeinterface.core.node_pipeline import (
run_node_pipeline,
ExtractSparseWaveforms,
ExtractDenseWaveforms,
PeakRetriever,
PipelineNode,
)
from spikeinterface.sortingcomponents.waveforms.temporal_pca import (
TemporalPCAProjection,
)
from spikeinterface.core.job_tools import fix_job_kwargs
import numpy as np
from scipy.spatial.distance import cdist


class FindNearestTemplate(PipelineNode):
def __init__(
self,
recording,
pca_model,
sparsity_mask,
templates,
name="nn_templates",
return_output=True,
parents=None,
):
PipelineNode.__init__(self, recording, return_output=return_output, parents=parents)
templates_array = templates.get_dense_templates()
n_templates = templates_array.shape[0]
num_channels = recording.get_num_channels()
self.svd_templates = np.zeros((n_templates, pca_model.n_components, num_channels), "float32")
for i in range(n_templates):
self.svd_templates[i] = pca_model.transform(templates_array[i].T).T
self.sparsity_mask = sparsity_mask
self._dtype = recording.get_dtype()
self._kwargs.update(
dict(
sparsity_mask=self.sparsity_mask,
svd_templates=self.svd_templates,
)
)

def get_dtype(self):
return self._dtype

def compute(self, traces, peaks, waveforms):
peak_labels = np.empty(len(peaks), dtype="int64")
for main_chan in np.unique(peaks["channel_index"]):
(idx,) = np.nonzero(peaks["channel_index"] == main_chan)
(chan_inds,) = np.nonzero(self.sparsity_mask[main_chan])
local_svds = waveforms[idx][:, :, : len(chan_inds)]
XA = local_svds.reshape(local_svds.shape[0], -1)
XB = self.svd_templates[:, :, chan_inds].reshape(self.svd_templates.shape[0], -1)
distances = cdist(XA, XB, metric="euclidean")
peak_labels[idx] = np.argmin(distances, axis=1)
return peak_labels


def assign_templates_to_peaks(
recording, peaks, svd_model, sparse_mask, templates, gather_mode="memory", **job_kwargs
) -> np.ndarray | tuple[np.ndarray, dict]:
"""
Assigns templates to peaks using a pipeline of nodes.

Parameters
----------
recording : RecordingExtractor
The recording extractor.
peaks : np.ndarray
Peaks that should be assigned to templates.
templates : Templates
The templates used for matching.
svd_model : SVDModel
The SVD model used for PCA projection.
sparse_mask : np.ndarray
The sparsity mask used to extract waveforms.
gather_mode : str
The mode for gathering results. Can be 'memory' or 'file'.
job_kwargs : dict
Additional keyword arguments for joblib.

Returns
-------
peak_labels: np.ndarray
The labels assigned to each peak.
"""

job_kwargs = fix_job_kwargs(job_kwargs)

node0 = PeakRetriever(recording, peaks)
ms_before = templates.ms_before
ms_after = templates.ms_after

if templates.are_templates_sparse():
node1 = ExtractSparseWaveforms(
recording,
parents=[node0],
return_output=False,
ms_before=ms_before,
ms_after=ms_after,
sparsity_mask=sparse_mask,
)
else:
node1 = ExtractDenseWaveforms(
recording,
parents=[node0],
return_output=False,
ms_before=ms_before,
ms_after=ms_after,
)

node2 = TemporalPCAProjection(
recording,
parents=[node0, node1],
return_output=False,
pca_model=svd_model,
)

node3 = FindNearestTemplate(
recording,
parents=[node0, node2],
return_output=True,
pca_model=svd_model,
templates=templates,
sparsity_mask=sparse_mask,
)

pipeline_nodes = [node0, node1, node2, node3]

peak_labels = run_node_pipeline(
recording,
pipeline_nodes,
job_kwargs,
job_name=f"assign labels",
gather_mode=gather_mode,
squeeze_output=True,
)
return peak_labels