Skip to content

Various improvement in sorting components clustering + matching #3923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
dd8f443
low level auto merge using template similarity for sorting components.
samuelgarcia May 14, 2025
3fcd9d7
improve auto merge in tdc_lustering and cicurs_clustering
samuelgarcia May 15, 2025
90e0b14
improve drift aware clustering tdc
samuelgarcia May 16, 2025
3653c81
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia May 16, 2025
221e780
add with_template=False in BenchmarkClustering.compute_result
samuelgarcia Jun 4, 2025
0edfc77
oups
samuelgarcia Jun 13, 2025
bed91f2
update tdc
samuelgarcia Jun 16, 2025
ea7abd8
oups
samuelgarcia Jun 16, 2025
3ae7e6e
oups
samuelgarcia Jun 16, 2025
5c65137
fix
samuelgarcia Jun 16, 2025
811c0f1
tests
samuelgarcia Jun 16, 2025
bfa6e23
Merge branch 'main' into components_merge_templates
samuelgarcia Jun 17, 2025
6a24e8b
merge main and fixes
samuelgarcia Jun 17, 2025
b1ec837
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Jun 17, 2025
0102ad9
clean
samuelgarcia Jun 17, 2025
5c9b641
Fix MatchingStudy.plot_collisions
samuelgarcia Jun 17, 2025
d1ba03d
small fixes in circus-clustering
samuelgarcia Jun 17, 2025
aa1a6f3
speedup the collision comparison and benchmarkmatching
samuelgarcia Jun 18, 2025
861f73a
wip
samuelgarcia Jun 23, 2025
bef679b
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Jun 25, 2025
6129df5
Better sparsity for analyzer in benchmarks
samuelgarcia Jun 25, 2025
1b88a97
Fix some etra_outputs in clustering methods.
samuelgarcia Jun 25, 2025
2462ae5
merge conflict
samuelgarcia Jun 25, 2025
2d3d654
more fix in comparison
samuelgarcia Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/spikeinterface/benchmark/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time


from spikeinterface.core import SortingAnalyzer, ChannelSparsity, NumpySorting
from spikeinterface.core import SortingAnalyzer, NumpySorting
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
from spikeinterface import load, create_sorting_analyzer, load_sorting_analyzer
Expand Down Expand Up @@ -120,8 +121,29 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None):
rec, gt_sorting = data

if gt_sorting is not None:
if "gt_unit_locations" in gt_sorting.get_property_keys():
# if real units locations is present then use it for a better sparsity
# then the real max channel is used
radius_um = 100.
channel_ids = rec.channel_ids
unit_ids = gt_sorting.unit_ids
gt_unit_locations = gt_sorting.get_property("gt_unit_locations")
channel_locations = rec.get_channel_locations()
max_channel_indices = np.argmin(np.linalg.norm(gt_unit_locations[:, np.newaxis, :2] - channel_locations[np.newaxis, :], axis=2), axis=1)
mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool")
distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2)
for unit_ind, unit_id in enumerate(unit_ids):
chan_ind = max_channel_indices[unit_ind]
(chan_inds,) = np.nonzero(distances[chan_ind, :] <= radius_um)
mask[unit_ind, chan_inds] = True
sparsity = ChannelSparsity(mask, unit_ids, channel_ids)
sparse =False
else:
sparse = True
sparsity = None

analyzer = create_sorting_analyzer(
gt_sorting, rec, sparse=True, format="binary_folder", folder=local_analyzer_folder
gt_sorting, rec, sparse=sparse, sparsity=sparsity, format="binary_folder", folder=local_analyzer_folder
)
analyzer.compute("random_spikes")
analyzer.compute("templates")
Expand All @@ -135,6 +157,7 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None):
analyzer = create_sorting_analyzer(
gt_sorting, rec, sparse=False, format="binary_folder", folder=local_analyzer_folder
)

else:
# new case : analzyer
assert isinstance(data, SortingAnalyzer)
Expand Down Expand Up @@ -254,6 +277,8 @@ def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs):

for key in job_keys:
benchmark = self.create_benchmark(key)
if verbose:
print("### Run benchmark", key, "###")
t0 = time.perf_counter()
benchmark.run(**job_kwargs)
t1 = time.perf_counter()
Expand Down Expand Up @@ -383,6 +408,8 @@ def compute_results(self, case_keys=None, verbose=False, **result_params):

job_keys = []
for key in case_keys:
if verbose:
print("### Compute result", key, "###")
benchmark = self.benchmarks[key]
assert benchmark is not None
benchmark.compute_result(**result_params)
Expand Down
29 changes: 15 additions & 14 deletions src/spikeinterface/benchmark/benchmark_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def run(self, **job_kwargs):
)
self.result["peak_labels"] = peak_labels

def compute_result(self, **result_params):
result_params, job_kwargs = split_job_kwargs(result_params)
def compute_result(self, with_template=False, **job_kwargs):
# result_params, job_kwargs = split_job_kwargs(result_params)
job_kwargs = fix_job_kwargs(job_kwargs)
self.noise = self.result["peak_labels"] < 0
spikes = self.gt_sorting.to_spike_vector()
Expand Down Expand Up @@ -68,19 +68,20 @@ def compute_result(self, **result_params):
self.result["sliced_gt_sorting"], self.result["clustering"], exhaustive_gt=self.exhaustive_gt
)

sorting_analyzer = create_sorting_analyzer(
self.result["sliced_gt_sorting"], self.recording, format="memory", sparse=False, **job_kwargs
)
sorting_analyzer.compute("random_spikes")
ext = sorting_analyzer.compute("templates", **job_kwargs)
self.result["sliced_gt_templates"] = ext.get_data(outputs="Templates")
if with_template:
sorting_analyzer = create_sorting_analyzer(
self.result["sliced_gt_sorting"], self.recording, format="memory", sparse=False, **job_kwargs
)
sorting_analyzer.compute("random_spikes")
ext = sorting_analyzer.compute("templates", **job_kwargs)
self.result["sliced_gt_templates"] = ext.get_data(outputs="Templates")

sorting_analyzer = create_sorting_analyzer(
self.result["clustering"], self.recording, format="memory", sparse=False, **job_kwargs
)
sorting_analyzer.compute("random_spikes")
ext = sorting_analyzer.compute("templates", **job_kwargs)
self.result["clustering_templates"] = ext.get_data(outputs="Templates")
sorting_analyzer = create_sorting_analyzer(
self.result["clustering"], self.recording, format="memory", sparse=False, **job_kwargs
)
sorting_analyzer.compute("random_spikes")
ext = sorting_analyzer.compute("templates", **job_kwargs)
self.result["clustering_templates"] = ext.get_data(outputs="Templates")

_run_key_saved = [("peak_labels", "npy")]

Expand Down
26 changes: 19 additions & 7 deletions src/spikeinterface/benchmark/benchmark_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,23 +91,35 @@ def plot_performances_ordered(self, *args, **kwargs):

return plot_performances_ordered(self, *args, **kwargs)

def plot_collisions(self, case_keys=None, figsize=None):
def plot_collisions(self, case_keys=None, metric="l2", mode="lines", show_legend=True, axs=None, figsize=None):
if case_keys is None:
case_keys = list(self.cases.keys())
import matplotlib.pyplot as plt

fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
if axs is None:
fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
axs = axs[0, :]
else:
fig = axs[0].figure


for count, key in enumerate(case_keys):
templates_array = self.get_result(key)["templates"].templates_array
label = self.cases[key]["label"]
templates_array = self.get_sorting_analyzer(key).get_extension("templates").get_templates(outputs="numpy")
ax = axs[count]
plot_comparison_collision_by_similarity(
self.get_result(key)["gt_collision"],
templates_array,
ax=axs[0, count],
show_legend=True,
mode="lines",
good_only=False,
metric=metric,
ax=ax,
show_legend=show_legend,
mode=mode,
# good_only=False,
# good_only=False,
good_only=True,
)

ax.set_title(label)

return fig

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_benchmark_clustering(create_cache_folder):
peaks[dataset] = spikes

cases = {}
for method in ["random_projections", "circus", "tdc_clustering"]:
for method in ["random_projections", "circus", "tdc-clustering"]:
cases[method] = {
"label": f"{method} on toy",
"dataset": "toy",
Expand Down
Loading
Loading