Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 133 additions & 1 deletion src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@

import warnings
import numpy as np
from collections import namedtuple

from .sortinganalyzer import AnalyzerExtension, register_result_extension
from .sortinganalyzer import SortingAnalyzer, AnalyzerExtension, register_result_extension
from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator
from .recording_tools import get_noise_levels
from .template import Templates
from .sorting_tools import random_spikes_selection
from .job_tools import fix_job_kwargs, split_job_kwargs


class ComputeRandomSpikes(AnalyzerExtension):
Expand Down Expand Up @@ -806,3 +808,133 @@ def _handle_backward_compatibility_on_load(self):

register_result_extension(ComputeNoiseLevels)
compute_noise_levels = ComputeNoiseLevels.function_factory()


class BaseSpikeVectorExtension(AnalyzerExtension):
"""
Base class for spikevector-based extension, where the data is a numpy array with the same
length as the spike vector.
"""

extension_name = None # to be defined in subclass
need_recording = True
use_nodepipeline = True
need_job_kwargs = True
need_backward_compatibility_on_load = False
nodepipeline_variables = [] # to be defined in subclass

def _set_params(self, **kwargs):
params = kwargs.copy()
return params

def _run(self, verbose=False, **job_kwargs):
from spikeinterface.core.node_pipeline import run_node_pipeline

job_kwargs = fix_job_kwargs(job_kwargs)
nodes = self.get_pipeline_nodes()
data = run_node_pipeline(
self.sorting_analyzer.recording,
nodes,
job_kwargs=job_kwargs,
job_name=self.extension_name,
gather_mode="memory",
verbose=False,
)
if isinstance(data, tuple):
# this logic enables extensions to optionally compute additional data based on params
assert len(data) <= len(self.nodepipeline_variables), "Pipeline produced more outputs than expected"
else:
data = (data,)
if len(self.nodepipeline_variables) > len(data):
data_names = self.nodepipeline_variables[: len(data)]
else:
data_names = self.nodepipeline_variables
for d, name in zip(data, data_names):
self.data[name] = d

def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None):
"""
Return extension data. If the extension computes more than one `nodepipeline_variables`,
the `return_data_name` is used to specify which one to return.

Parameters
----------
outputs : "numpy" | "by_unit", default: "numpy"
How to return the data, by default "numpy"
concatenated : bool, default: False
Whether to concatenate the data across segments.
return_data_name : str | None, default: None
The name of the data to return. If None and multiple `nodepipeline_variables` are computed,
the first one is returned.

Returns
-------
numpy.ndarray | dict
The
"""
from spikeinterface.core.sorting_tools import spike_vector_to_indices

if len(self.nodepipeline_variables) == 1:
return_data_name = self.nodepipeline_variables[0]
else:
if return_data_name is None:
return_data_name = self.nodepipeline_variables[0]
else:
assert (
return_data_name in self.nodepipeline_variables
), f"return_data_name {return_data_name} not in nodepipeline_variables {self.nodepipeline_variables}"

all_data = self.data[return_data_name]
if outputs == "numpy":
return all_data
elif outputs == "by_unit":
unit_ids = self.sorting_analyzer.unit_ids
spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False)
spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True)
data_by_units = {}
for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()):
data_by_units[segment_index] = {}
for unit_id in unit_ids:
inds = spike_indices[segment_index][unit_id]
data_by_units[segment_index][unit_id] = all_data[inds]

if concatenated:
data_by_units_concatenated = {
unit_id: np.concatenate([data_in_segment[unit_id] for data_in_segment in data_by_units.values()])
for unit_id in unit_ids
}
return data_by_units_concatenated

return data_by_units
else:
raise ValueError(f"Wrong .get_data(outputs={outputs}); possibilities are `numpy` or `by_unit`")

def _select_extension_data(self, unit_ids):
keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids))

spikes = self.sorting_analyzer.sorting.to_spike_vector()
keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices)

new_data = dict()
for data_name in self.nodepipeline_variables:
if self.data.get(data_name) is not None:
new_data[data_name] = self.data[data_name][keep_spike_mask]

return new_data

def _merge_extension_data(
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
):
new_data = dict()
for data_name in self.nodepipeline_variables:
if self.data.get(data_name) is not None:
if keep_mask is None:
new_data[data_name] = self.data[data_name].copy()
else:
new_data[data_name] = self.data[data_name][keep_mask]

return new_data

def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs):
# splitting only changes random spikes assignments
return self.data.copy()
95 changes: 8 additions & 87 deletions src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,14 @@
import numpy as np

from spikeinterface.core import ChannelSparsity
from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, ensure_n_jobs, fix_job_kwargs
from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array, _get_nbefore
from spikeinterface.core.sortinganalyzer import register_result_extension
from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension

from spikeinterface.core.template_tools import get_template_extremum_channel
from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, find_parent_of_type

from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension

from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type

from spikeinterface.core.template_tools import get_dense_templates_array, _get_nbefore


class ComputeAmplitudeScalings(AnalyzerExtension):
class ComputeAmplitudeScalings(BaseSpikeVectorExtension):
"""
Computes the amplitude scalings from a SortingAnalyzer.

Expand Down Expand Up @@ -55,31 +51,11 @@ class ComputeAmplitudeScalings(AnalyzerExtension):
multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently.
delta_collision_ms: float, default: 2
The maximum time difference in ms before and after a spike to gather colliding spikes.
load_if_exists : bool, default: False
Whether to load precomputed spike amplitudes, if they already exist.
outputs: "concatenated" | "by_unit", default: "concatenated"
How the output should be returned
{}

Returns
-------
amplitude_scalings: np.array or list of dict
The amplitude scalings.
- If "concatenated" all amplitudes for all spikes and all units are concatenated
- If "by_unit", amplitudes are returned as a list (for segments) of dictionaries (for units)
"""

extension_name = "amplitude_scalings"
depend_on = ["templates"]
need_recording = True
use_nodepipeline = True
nodepipeline_variables = ["amplitude_scalings", "collision_mask"]
need_job_kwargs = True

def __init__(self, sorting_analyzer):
AnalyzerExtension.__init__(self, sorting_analyzer)

self.collisions = None

def _set_params(
self,
Expand All @@ -90,46 +66,14 @@ def _set_params(
handle_collisions=True,
delta_collision_ms=2,
):
params = dict(
return super()._set_params(
sparsity=sparsity,
max_dense_channels=max_dense_channels,
ms_before=ms_before,
ms_after=ms_after,
handle_collisions=handle_collisions,
delta_collision_ms=delta_collision_ms,
)
return params

def _select_extension_data(self, unit_ids):
keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids))

spikes = self.sorting_analyzer.sorting.to_spike_vector()
keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices)

new_data = dict()
new_data["amplitude_scalings"] = self.data["amplitude_scalings"][keep_spike_mask]
if self.params["handle_collisions"]:
new_data["collision_mask"] = self.data["collision_mask"][keep_spike_mask]
return new_data

def _merge_extension_data(
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
):
new_data = dict()

if keep_mask is None:
new_data["amplitude_scalings"] = self.data["amplitude_scalings"].copy()
if self.params["handle_collisions"]:
new_data["collision_mask"] = self.data["collision_mask"].copy()
else:
new_data["amplitude_scalings"] = self.data["amplitude_scalings"][keep_mask]
if self.params["handle_collisions"]:
new_data["collision_mask"] = self.data["collision_mask"][keep_mask]

return new_data

def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs):
return self.data.copy()

def _get_pipeline_nodes(self):

Expand All @@ -141,6 +85,7 @@ def _get_pipeline_nodes(self):
all_templates = get_dense_templates_array(self.sorting_analyzer, return_in_uV=return_in_uV)
nbefore = _get_nbefore(self.sorting_analyzer)
nafter = all_templates.shape[1] - nbefore
templates_ext = self.sorting_analyzer.get_extension("templates")

# if ms_before / ms_after are set in params then the original templates are shorten
if self.params["ms_before"] is not None:
Expand All @@ -155,7 +100,7 @@ def _get_pipeline_nodes(self):
cut_out_after = int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0)
assert (
cut_out_after <= nafter
), f"`ms_after` must be smaller than `ms_after` used in WaveformExractor: {we._params['ms_after']}"
), f"`ms_after` must be smaller than `ms_after` used in templates: {templates_ext.params['ms_after']}"
else:
cut_out_after = nafter

Expand Down Expand Up @@ -210,30 +155,6 @@ def _get_pipeline_nodes(self):
nodes = [spike_retriever_node, amplitude_scalings_node]
return nodes

def _run(self, verbose=False, **job_kwargs):
job_kwargs = fix_job_kwargs(job_kwargs)
nodes = self.get_pipeline_nodes()
amp_scalings, collision_mask = run_node_pipeline(
self.sorting_analyzer.recording,
nodes,
job_kwargs=job_kwargs,
job_name="amplitude_scalings",
gather_mode="memory",
verbose=verbose,
)
self.data["amplitude_scalings"] = amp_scalings
if self.params["handle_collisions"]:
self.data["collision_mask"] = collision_mask
# TODO: make collisions "global"
# for collision in collisions:
# collisions_dict.update(collision)
# self.collisions = collisions_dict
# # Note: collisions are note in _extension_data because they are not pickable. We only store the indices
# self._extension_data["collisions"] = np.array(list(collisions_dict.keys()))

def _get_data(self):
return self.data[f"amplitude_scalings"]


register_result_extension(ComputeAmplitudeScalings)
compute_amplitude_scalings = ComputeAmplitudeScalings.function_factory()
Expand Down
Loading