Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
66ae344
Reducing memory footprint
yger Sep 29, 2025
11958d7
WIP
yger Sep 29, 2025
76fd5d1
WIP
yger Sep 29, 2025
2e1098a
WIP
yger Sep 29, 2025
a37b8f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2025
40b1f6c
Fixing tests
yger Sep 29, 2025
98ed633
Fixing tests
yger Sep 29, 2025
d7c2e89
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2025
b844f3e
WIP
yger Sep 30, 2025
f8e3ba9
WIP
yger Sep 30, 2025
794102a
Merge branch 'memory_template_similarity' of github.com:yger/spikeint…
yger Sep 30, 2025
0aa76a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2025
9858fc6
WIP
yger Sep 30, 2025
b51432e
Merge branch 'memory_template_similarity' of github.com:yger/spikeint…
yger Sep 30, 2025
341d980
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2025
ecfee8d
Merge branch 'main' of https://github.com/SpikeInterface/spikeinterface
yger Oct 3, 2025
4f6a7f1
WIP
yger Oct 3, 2025
76b9a7b
WIP
yger Oct 3, 2025
6a29e3f
Reducing memory footprint for large number of templates/channels
yger Oct 3, 2025
bb3421d
Merge branch 'memory_template_similarity'
yger Oct 3, 2025
5f0e02b
improve iterative_isosplit and remove warnings
samuelgarcia Oct 6, 2025
82223a9
"n_pca_features" 6 > 3
samuelgarcia Oct 6, 2025
50b4143
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Oct 8, 2025
8ccee0e
Merge branch 'main' of github.com:spikeinterface/spikeinterface
yger Oct 9, 2025
426b61d
Merge branch 'main' of github.com:yger/spikeinterface
yger Oct 10, 2025
b76552a
iterative isosplit params
samuelgarcia Oct 13, 2025
22aa5cd
oups
samuelgarcia Oct 15, 2025
61a570e
wip
samuelgarcia Oct 15, 2025
d671acc
Merge branch 'SpikeInterface:main' into main
yger Oct 16, 2025
4ec8408
various try on iterative_isosplit
samuelgarcia Oct 21, 2025
19e77fa
fix git bug
samuelgarcia Oct 22, 2025
d4a2124
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Oct 24, 2025
b6df1cb
Merge branch 'main' of github.com:spikeinterface/spikeinterface
yger Oct 26, 2025
8330277
improve isocut and tdc2
samuelgarcia Oct 27, 2025
9f9bddb
WIP
yger Oct 27, 2025
84aeb92
tdc2 improvement
samuelgarcia Oct 28, 2025
41b5d6b
WIP
yger Oct 28, 2025
ab470f5
WIP
yger Oct 28, 2025
936c31b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
1b5ef48
Alignment during merging
yger Oct 28, 2025
3524264
Merge branch 'circus2_paper' of github.com:yger/spikeinterface into c…
yger Oct 28, 2025
8bff173
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
f5869d7
WIP
yger Oct 28, 2025
17691f8
Merge branch 'circus2_paper' of github.com:yger/spikeinterface into c…
yger Oct 28, 2025
27eb077
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
4026a07
WIP
yger Oct 28, 2025
903e85e
Merge branch 'circus2_paper' of github.com:yger/spikeinterface into c…
yger Oct 28, 2025
c6f4708
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
356508d
fix nan in plot perf vs snr
samuelgarcia Oct 29, 2025
c53a887
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Oct 29, 2025
3f22619
WIP
yger Oct 29, 2025
911ea27
WIP
yger Oct 29, 2025
04f4c09
WIP
yger Oct 29, 2025
15d64dc
Merge branch 'circus2_paper' of github.com:yger/spikeinterface into m…
samuelgarcia Oct 30, 2025
88e3081
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Oct 30, 2025
9ddc138
rename spitting_tools to tersplit_tools to avoid double file with sam…
samuelgarcia Oct 30, 2025
dce3b96
compute_similarity_with_templates_array returan lags always
samuelgarcia Oct 30, 2025
055176e
tdc2 params ajustement
samuelgarcia Oct 30, 2025
9f7aa02
start lupin
samuelgarcia Oct 30, 2025
60c9782
lupin wip
samuelgarcia Oct 31, 2025
85eabb3
tdc sc versions
samuelgarcia Oct 31, 2025
cab04d7
Merge branch 'more_isosplit' of github.com:samuelgarcia/spikeinterfac…
samuelgarcia Oct 31, 2025
f546475
Clean lupin parameters
samuelgarcia Oct 31, 2025
10c064b
fix conflicts
samuelgarcia Nov 3, 2025
9388808
lupin whitten before motion correction
samuelgarcia Nov 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
335 changes: 335 additions & 0 deletions src/spikeinterface/sorters/internal/lupin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,335 @@
from __future__ import annotations

from .si_based import ComponentsBasedSorter

from copy import deepcopy

from spikeinterface.core import (
get_noise_levels,
NumpySorting,
estimate_templates_with_accumulator,
Templates,
compute_sparsity,
)

from spikeinterface.core.job_tools import fix_job_kwargs

from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore, whiten
from spikeinterface.core.basesorting import minimum_spike_dtype

from spikeinterface.sortingcomponents.tools import cache_preprocessing, clean_cache_preprocessing


import numpy as np


class LupinSorter(ComponentsBasedSorter):
"""
Gentleman thief spike sorter.

This sorter is composed by pieces of code and ideas stolen everywhere : yass, tridesclous, spkyking-circus, kilosort.
It should be the best sorter we can build using spikeinterface.sortingcomponents
"""
sorter_name = "lupin"

_default_params = {
"apply_preprocessing": True,
"apply_motion_correction": False,
"motion_correction_preset" : "dredge_fast",
"clustering_ms_before": 0.3,
"clustering_ms_after": 1.3,
"radius_um": 120.,
"freq_min": 150.0,
"freq_max": 6000.0,
"cache_preprocessing_mode" : "auto",
"peak_sign": "neg",
"detect_threshold": 5,
"n_peaks_per_channel": 5000,
"n_svd_components": 10,
"clustering_recursive_depth": 3,
"ms_before": 2.0,
"ms_after": 3.0,
"sparsity_threshold": 1.5,
"template_min_snr": 2.5,
"gather_mode": "memory",
"job_kwargs": {},
"seed": None,
"save_array": False,
"debug": False,
}

_params_description = {
"apply_preprocessing": "Apply internal preprocessing or not",
"apply_motion_correction": "Apply motion correction or not",
"motion_correction_preset": "Motion correction preset",
"clustering_ms_before": "Milliseconds before the spike peak for clustering",
"clustering_ms_after": "Milliseconds after the spike peak for clustering",
"radius_um": "Radius for sparsity",
"freq_min": "Low frequency",
"freq_max": "High frequency",
"peak_sign": "Sign of peaks neg/pos/both",
"detect_threshold": "Treshold for peak detection",
"n_peaks_per_channel": "Number of spike per channel for clustering",
"n_svd_components": "Number of SVD components for clustering",
"clustering_recursive_depth": "Clustering recussivity",
"ms_before": "Milliseconds before the spike peak for template matching",
"ms_after": "Milliseconds after the spike peak for template matching",
"sparsity_threshold": "Threshold to sparsify templates before template matching",
"template_min_snr": "Threshold to remove templates before template matching",
"gather_mode": "How to accumalte spike in matching : memory/npy",
"job_kwargs": "The famous and fabulous job_kwargs",
"seed": "Seed for random number",
"save_array": "Save or not intermediate arrays in the folder",
"debug": "Save debug files",
}

handle_multi_segment = True

@classmethod
def get_sorter_version(cls):
return "2025.11"

@classmethod
def _run_from_folder(cls, sorter_output_folder, params, verbose):

from spikeinterface.sortingcomponents.tools import get_prototype_and_waveforms_from_recording
from spikeinterface.sortingcomponents.matching import find_spikes_from_templates
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_selection import select_peaks
from spikeinterface.sortingcomponents.clustering.main import find_clusters_from_peaks, clustering_methods
from spikeinterface.sortingcomponents.tools import remove_empty_templates
from spikeinterface.preprocessing import correct_motion
from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording
from spikeinterface.sortingcomponents.tools import clean_templates

job_kwargs = params["job_kwargs"].copy()
job_kwargs = fix_job_kwargs(job_kwargs)
job_kwargs["progress_bar"] = verbose

seed = params["seed"]
radius_um = params["radius_um"]

recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False)

num_chans = recording_raw.get_num_channels()
sampling_frequency = recording_raw.get_sampling_frequency()

apply_cmr = num_chans >= 32

# preprocessing
if params["apply_preprocessing"]:
if params["apply_motion_correction"]:
rec_for_motion = recording_raw
if params["apply_preprocessing"]:
rec_for_motion = bandpass_filter(
rec_for_motion, freq_min=300.0, freq_max=6000.0, ftype="bessel", dtype="float32"
)
if apply_cmr:
rec_for_motion = common_reference(rec_for_motion)
if verbose:
print("Start correct_motion()")
_, motion_info = correct_motion(
rec_for_motion,
folder=sorter_output_folder / "motion",
output_motion_info=True,
preset=params["motion_correction_preset"],
)
if verbose:
print("Done correct_motion()")

recording = bandpass_filter(recording_raw, freq_min=params["freq_min"], freq_max=params["freq_max"],
ftype="bessel", filter_order=2, margin_ms=20., dtype="float32")

if apply_cmr:
recording = common_reference(recording)

recording = whiten(recording, dtype="float32", mode="local", radius_um=radius_um)

if params["apply_motion_correction"]:
interpolate_motion_kwargs = dict(
border_mode="force_extrapolate",
spatial_interpolation_method="kriging",
sigma_um=20.0,
p=2,
)

recording = InterpolateMotionRecording(
recording,
motion_info["motion"],
**interpolate_motion_kwargs,
)

# used only if "folder" or "zarr"
cache_folder = sorter_output_folder / "cache_preprocessing"
recording, cache_info = cache_preprocessing(
recording, mode=params["cache_preprocessing_mode"], folder=cache_folder, job_kwargs=job_kwargs,
)

noise_levels = get_noise_levels(recording, return_in_uV=False)
else:
recording = recording_raw
noise_levels = get_noise_levels(recording, return_in_uV=False)
cache_info = None

# detection
ms_before = params["ms_before"]
ms_after = params["ms_after"]
prototype, few_waveforms, few_peaks = get_prototype_and_waveforms_from_recording(
recording,
n_peaks=10_000,
ms_before=ms_before,
ms_after=ms_after,
seed=seed,
noise_levels=noise_levels,
job_kwargs=job_kwargs,
)
detection_params = dict(
peak_sign=params["peak_sign"],
detect_threshold=params["detect_threshold"],
exclude_sweep_ms=1.5,
radius_um=radius_um/2., # half the svd radius is enough for detection
prototype=prototype,
ms_before=ms_before,
)
all_peaks = detect_peaks(
recording, method="matched_filtering", method_kwargs=detection_params, job_kwargs=job_kwargs
)

if verbose:
print(f"detect_peaks(): {len(all_peaks)} peaks found")

# selection
n_peaks = max(params["n_peaks_per_channel"] * num_chans, 20_000)
peaks = select_peaks(all_peaks, method="uniform", n_peaks=n_peaks)
if verbose:
print(f"select_peaks(): {len(peaks)} peaks kept for clustering")

# Clustering
clustering_kwargs = deepcopy(clustering_methods["iterative-isosplit"]._default_params)
clustering_kwargs["peaks_svd"]["ms_before"] = params["clustering_ms_before"]
clustering_kwargs["peaks_svd"]["ms_after"] = params["clustering_ms_after"]
clustering_kwargs["peaks_svd"]["radius_um"] = params["radius_um"]
clustering_kwargs["peaks_svd"]["n_components"] = params["n_svd_components"]
clustering_kwargs["split"]["recursive_depth"] = params["clustering_recursive_depth"]
if params["debug"]:
clustering_kwargs["debug_folder"] = sorter_output_folder
unit_ids, clustering_label, more_outs = find_clusters_from_peaks(
recording,
peaks,
method="iterative-isosplit",
method_kwargs=clustering_kwargs,
extra_outputs=True,
job_kwargs=job_kwargs,
)
new_peaks = peaks

mask = clustering_label >= 0
sorting_pre_peeler = NumpySorting.from_samples_and_labels(
new_peaks["sample_index"][mask],
clustering_label[mask],
sampling_frequency,
unit_ids=unit_ids,
)
if verbose:
print(f"find_clusters_from_peaks(): {sorting_pre_peeler.unit_ids.size} cluster found")

# Template

nbefore = int(ms_before * sampling_frequency / 1000.0)
nafter = int(ms_after * sampling_frequency / 1000.0)
templates_array = estimate_templates_with_accumulator(
recording,
sorting_pre_peeler.to_spike_vector(),
sorting_pre_peeler.unit_ids,
nbefore,
nafter,
return_in_uV=False,
**job_kwargs,
)
templates_dense = Templates(
templates_array=templates_array,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
channel_ids=recording.channel_ids,
unit_ids=sorting_pre_peeler.unit_ids,
sparsity_mask=None,
probe=recording.get_probe(),
is_in_uV=False,
)

sparsity_threshold = params["sparsity_threshold"]
radius_um = params["radius_um"]
sparsity = compute_sparsity(templates_dense, method="radius", radius_um=radius_um)
sparsity_snr = compute_sparsity(templates_dense, method="snr", amplitude_mode="peak_to_peak",
noise_levels=noise_levels, threshold=sparsity_threshold)
sparsity.mask = sparsity.mask & sparsity_snr.mask
templates = templates_dense.to_sparse(sparsity)

templates = clean_templates(
templates,
sparsify_threshold=None,
noise_levels=noise_levels,
min_snr=params["template_min_snr"],
max_jitter_ms=None,
remove_empty=True,
)

# Template matching
gather_mode = params["gather_mode"]
pipeline_kwargs = dict(gather_mode=gather_mode)
if gather_mode == "npy":
pipeline_kwargs["folder"] = sorter_output_folder / "matching"

spikes = find_spikes_from_templates(
recording,
templates,
method="wobble",
method_kwargs={},
pipeline_kwargs=pipeline_kwargs,
job_kwargs=job_kwargs,
)

final_spikes = np.zeros(spikes.size, dtype=minimum_spike_dtype)
final_spikes["sample_index"] = spikes["sample_index"]
final_spikes["unit_index"] = spikes["cluster_index"]
final_spikes["segment_index"] = spikes["segment_index"]
sorting = NumpySorting(final_spikes, sampling_frequency, templates.unit_ids)

auto_merge = True
analyzer_final = None
if auto_merge:
# TODO expose some of theses parameters
from spikeinterface.sorters.internal.spyking_circus2 import final_cleaning_circus

analyzer_final = final_cleaning_circus(
recording,
sorting,
templates,
similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": 0.1},
sparsity_overlap=0.5,
censor_ms=3.0,
max_distance_um=50,
template_diff_thresh=np.arange(0.05, 0.4, 0.05),
debug_folder=None,
job_kwargs=job_kwargs,
)
sorting = NumpySorting.from_sorting(analyzer_final.sorting)

if params["save_array"]:
sorting_pre_peeler = sorting_pre_peeler.save(folder=sorter_output_folder / "sorting_pre_peeler")
np.save(sorter_output_folder / "noise_levels.npy", noise_levels)
np.save(sorter_output_folder / "all_peaks.npy", all_peaks)
np.save(sorter_output_folder / "peaks.npy", peaks)
np.save(sorter_output_folder / "clustering_label.npy", clustering_label)
np.save(sorter_output_folder / "spikes.npy", spikes)
templates.to_zarr(sorter_output_folder / "templates.zarr")
if analyzer_final is not None:
analyzer_final.save_as(format="binary_folder", folder=sorter_output_folder / "analyzer")

sorting = sorting.save(folder=sorter_output_folder / "sorting")


del recording
clean_cache_preprocessing(cache_info)

return sorting
13 changes: 3 additions & 10 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion
from spikeinterface.sortingcomponents.tools import (
cache_preprocessing,
clean_cache_preprocessing,
get_shuffled_recording_slices,
_set_optimal_chunk_size,
)
Expand Down Expand Up @@ -189,7 +190,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
elif recording_w.check_serializability("pickle"):
recording_w.dump(sorter_output_folder / "preprocessed_recording.pickle", relative_to=None)

recording_w = cache_preprocessing(recording_w, **job_kwargs, **params["cache_preprocessing"])
recording_w, cache_info = cache_preprocessing(recording_w, job_kwargs=job_kwargs, **params["cache_preprocessing"])

## Then, we are detecting peaks with a locally_exclusive method
detection_method = params["detection"].get("method", "matched_filtering")
Expand Down Expand Up @@ -451,16 +452,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if verbose:
print(f"Kept {len(sorting.unit_ids)} units after final merging")

folder_to_delete = None
cache_mode = params["cache_preprocessing"].get("mode", "memory")
delete_cache = params["cache_preprocessing"].get("delete_cache", True)

if cache_mode in ["folder", "zarr"] and delete_cache:
folder_to_delete = recording_w._kwargs["folder_path"]

del recording_w
if folder_to_delete is not None:
shutil.rmtree(folder_to_delete)
clean_cache_preprocessing(cache_info)

sorting = sorting.save(folder=sorting_folder)

Expand Down
18 changes: 18 additions & 0 deletions src/spikeinterface/sorters/internal/tests/test_lupin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import unittest

from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite

from spikeinterface.sorters import LupinSorter, run_sorter

from pathlib import Path


class LupinSorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase):
SorterClass = LupinSorter


if __name__ == "__main__":
test = LupinSorterCommonTestSuite()
test.cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "sorters"
test.setUp()
test.test_with_run()
Loading