Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions src/spikeinterface/core/tests/test_waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,36 @@ def test_estimate_templates():

job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")

# mask with differents sparsity
sparsity_mask = np.ones((sorting.unit_ids.size, recording.channel_ids.size), dtype=bool)
sparsity_mask[:4, : recording.channel_ids.size // 2 - 1] = False
sparsity_mask[4:, recording.channel_ids.size // 2 :] = False

for operator in ("average", "median"):
templates = estimate_templates(
templates_array = estimate_templates(
recording, spikes, sorting.unit_ids, nbefore, nafter, operator=operator, return_in_uV=True, **job_kwargs
)
# print(templates.shape)
assert templates.shape[0] == sorting.unit_ids.size
assert templates.shape[1] == nbefore + nafter
assert templates.shape[2] == recording.get_num_channels()

assert np.any(templates != 0)
assert templates_array.shape[0] == sorting.unit_ids.size
assert templates_array.shape[1] == nbefore + nafter
assert templates_array.shape[2] == recording.get_num_channels()

assert np.any(templates_array != 0)

sparse_templates_array = estimate_templates(
recording,
spikes,
sorting.unit_ids,
nbefore,
nafter,
operator=operator,
return_in_uV=True,
sparsity_mask=sparsity_mask,
**job_kwargs,
)
n_chan = np.max(np.sum(sparsity_mask, axis=1))
assert n_chan == sparse_templates_array.shape[2]
assert np.any(sparse_templates_array == 0)

# import matplotlib.pyplot as plt
# fig, ax = plt.subplots()
Expand All @@ -247,7 +267,7 @@ def test_estimate_templates():


if __name__ == "__main__":
cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "core"
test_waveform_tools(cache_folder)
test_estimate_templates_with_accumulator()
# cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "core"
# test_waveform_tools(cache_folder)
# test_estimate_templates_with_accumulator()
test_estimate_templates()
43 changes: 37 additions & 6 deletions src/spikeinterface/core/waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,12 +744,13 @@ def estimate_templates(
operator: str = "average",
return_scaled=None,
return_in_uV=True,
sparsity_mask=None,
job_name=None,
**job_kwargs,
):
"""
Estimate dense templates with "average" or "median".
If "average" internally estimate_templates_with_accumulator() is used to saved memory/
If "average" internally estimate_templates_with_accumulator() is used to saved memory.

Parameters
----------
Expand All @@ -770,6 +771,8 @@ def estimate_templates(
return_in_uV : bool, default: True
If True and the recording has scaling (gain_to_uV and offset_to_uV properties),
traces are scaled to uV
sparsity_mask: None or array of bool, default: None
If not None shape must be must be (len(unit_ids), len(channel_ids))

Returns
-------
Expand All @@ -791,7 +794,15 @@ def estimate_templates(

if operator == "average":
templates_array = estimate_templates_with_accumulator(
recording, spikes, unit_ids, nbefore, nafter, return_in_uV=return_in_uV, job_name=job_name, **job_kwargs
recording,
spikes,
unit_ids,
nbefore,
nafter,
return_in_uV=return_in_uV,
sparsity_mask=sparsity_mask,
job_name=job_name,
**job_kwargs,
)
elif operator == "median":
all_waveforms, wf_array_info = extract_waveforms_to_single_buffer(
Expand All @@ -802,6 +813,7 @@ def estimate_templates(
nafter,
mode="shared_memory",
return_in_uV=return_in_uV,
sparsity_mask=sparsity_mask,
copy=False,
**job_kwargs,
)
Expand All @@ -828,6 +840,7 @@ def estimate_templates_with_accumulator(
nafter: int,
return_scaled=None,
return_in_uV=True,
sparsity_mask=None,
job_name=None,
return_std: bool = False,
verbose: bool = False,
Expand Down Expand Up @@ -859,6 +872,8 @@ def estimate_templates_with_accumulator(
return_in_uV : bool, default: True
If True and the recording has scaling (gain_to_uV and offset_to_uV properties),
traces are scaled to uV
sparsity_mask: None or array of bool, default: None
If not None shape must be must be (len(unit_ids), len(channel_ids))
return_std: bool, default: False
If True, the standard deviation is also computed.

Expand All @@ -882,10 +897,14 @@ def estimate_templates_with_accumulator(
job_kwargs = fix_job_kwargs(job_kwargs)
num_worker = job_kwargs["n_jobs"]

num_chans = recording.get_num_channels()
if sparsity_mask is None:
num_chans = int(recording.get_num_channels())
else:
num_chans = int(max(np.sum(sparsity_mask, axis=1))) # This is a numpy scalar, so we cast to int
num_units = len(unit_ids)

shape = (num_worker, num_units, nbefore + nafter, num_chans)

dtype = np.dtype("float32")
waveform_accumulator_per_worker, shm = make_shared_array(shape, dtype)
shm_name = shm.name
Expand All @@ -909,6 +928,7 @@ def estimate_templates_with_accumulator(
nbefore,
nafter,
return_in_uV,
sparsity_mask,
)

if job_name is None:
Expand Down Expand Up @@ -965,13 +985,15 @@ def _init_worker_estimate_templates(
nbefore,
nafter,
return_in_uV,
sparsity_mask,
):
worker_dict = {}
worker_dict["recording"] = recording
worker_dict["spikes"] = spikes
worker_dict["nbefore"] = nbefore
worker_dict["nafter"] = nafter
worker_dict["return_in_uV"] = return_in_uV
worker_dict["sparsity_mask"] = sparsity_mask

from multiprocessing.shared_memory import SharedMemory
import multiprocessing
Expand Down Expand Up @@ -1009,6 +1031,7 @@ def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_dic
waveform_squared_accumulator_per_worker = worker_dict.get("waveform_squared_accumulator_per_worker", None)
worker_index = worker_dict["worker_index"]
return_in_uV = worker_dict["return_in_uV"]
sparsity_mask = worker_dict["sparsity_mask"]

seg_size = recording.get_num_samples(segment_index=segment_index)

Expand Down Expand Up @@ -1040,6 +1063,14 @@ def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_dic
unit_index = spikes[spike_index]["unit_index"]
wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :]

waveform_accumulator_per_worker[worker_index, unit_index, :, :] += wf
if waveform_squared_accumulator_per_worker is not None:
waveform_squared_accumulator_per_worker[worker_index, unit_index, :, :] += wf**2
if sparsity_mask is None:
waveform_accumulator_per_worker[worker_index, unit_index, :, :] += wf
if waveform_squared_accumulator_per_worker is not None:
waveform_squared_accumulator_per_worker[worker_index, unit_index, :, :] += wf**2

else:
mask = sparsity_mask[unit_index, :]
wf = wf[:, mask]
waveform_accumulator_per_worker[worker_index, unit_index, :, : wf.shape[1]] += wf
if waveform_squared_accumulator_per_worker is not None:
waveform_squared_accumulator_per_worker[worker_index, unit_index, :, : wf.shape[1]] += wf**2