Skip to content

Commit d019d64

Browse files
ygerpre-commit-ci[bot]samuelgarciaalejoe91
authored
Implementing sparsity for ComputeTemplates (#4212)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Garcia Samuel <sam.garcia.die@gmail.com> Co-authored-by: Alessio Buccino <alejoe9187@gmail.com>
1 parent 584164b commit d019d64

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

src/spikeinterface/core/analyzer_extension_core.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ def _run(self, verbose=False, **job_kwargs):
437437
return_in_uV = self.sorting_analyzer.return_in_uV
438438

439439
return_std = "std" in self.params["operators"]
440+
sparsity_mask = None if self.sparsity is None else self.sparsity.mask
440441
output = estimate_templates_with_accumulator(
441442
recording,
442443
some_spikes,
@@ -445,17 +446,24 @@ def _run(self, verbose=False, **job_kwargs):
445446
self.nafter,
446447
return_in_uV=return_in_uV,
447448
return_std=return_std,
449+
sparsity_mask=sparsity_mask,
448450
verbose=verbose,
449451
**job_kwargs,
450452
)
451453

452-
# Output of estimate_templates_with_accumulator is either (templates,) or (templates, stds)
453454
if return_std:
454455
templates, stds = output
455-
self.data["average"] = templates
456-
self.data["std"] = stds
456+
data = dict(average=templates, std=stds)
457457
else:
458-
self.data["average"] = output
458+
templates = output
459+
data = dict(average=templates)
460+
461+
if self.sparsity is not None:
462+
# make average and std dense again
463+
for k, arr in data.items():
464+
dense_arr = self.sparsity.densify_templates(arr)
465+
data[k] = dense_arr
466+
self.data.update(data)
459467

460468
def _compute_and_append_from_waveforms(self, operators):
461469
if not self.sorting_analyzer.has_extension("waveforms"):

src/spikeinterface/core/sparsity.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,18 @@ def sparsify_templates(self, templates_array: np.ndarray) -> np.ndarray:
247247

248248
return sparse_templates
249249

250+
def densify_templates(self, templates_array: np.ndarray) -> np.ndarray:
251+
assert templates_array.shape[0] == self.num_units
252+
253+
densified_shape = (self.num_units, templates_array.shape[1], self.num_channels)
254+
dense_templates = np.zeros(shape=densified_shape, dtype=templates_array.dtype)
255+
for unit_index, unit_id in enumerate(self.unit_ids):
256+
sparse_template = templates_array[unit_index, ...]
257+
dense_template = self.densify_waveforms(waveforms=sparse_template[np.newaxis, :, :], unit_id=unit_id)
258+
dense_templates[unit_index, :, :] = dense_template
259+
260+
return dense_templates
261+
250262
@classmethod
251263
def from_unit_id_to_channel_ids(cls, unit_id_to_channel_ids, unit_ids, channel_ids):
252264
"""

0 commit comments

Comments
 (0)