Skip to content

Add PreprocessingPipeline #3438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
46 changes: 46 additions & 0 deletions doc/modules/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,52 @@ CMR, and save it to a binary file in the "/path/to/preprocessed" folder. The :co

**NOTE:** all sorters will automatically perform the saving operation internally.

The Preprocessing Pipeline
--------------------------

The module also contains the :code:`PreprocessingPipeline` object which aims to allow users to easily share pipelines across
labs. The input to create the pipeline is a dictionary of preprocessing steps whose keys are the names of the steps
and values are dictionaries of parameters. For example, to construct a pipeline consisting of highpass filtering
with a minimum frequency of 250 Hz followed by whitening with default parameters, we first make the appropriate dictionary

.. code-block:: python

from spikeinterface.preprocessing import apply_pipeline, PreprocessingPipeline

preprocessing_dict = {
'highpass_filter': {'freq_min': 250},
'whiten': {}
}

We can then pass this dictionary to the :code:`apply_pipeline` function to make a preprocessed recording

.. code-block:: python

preprocessed_recording = apply_pipeline(recording, preprocessing_dict)

Alternatively, we can construct a :code:`PreprocessingPipeline`, allowing us to investigate the pipeline before
using it.

.. code-block:: python

preprocessing_pipeline = PreprocessingPipeline(recording, preprocessing_dict)
# to view the pipeline:
preprocessing_pipeline

Once we have the pipeline, we can apply it to a recording in the same way as applying the dictionary

.. code-block:: python

preprocessed_recording_again = apply_pipeline(recording, preprocessing_pipeline)

To share the pipeline you have made with another lab, you can simply share the dictionary. The dictionary
can also be obtained from the pipeline object directly:

.. code-block:: python

dict_used_to_make_pipeline = preprocessing_pipeline.preprocessor_dict


Impact on recording dtype
-------------------------

Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .detect_bad_channels import detect_bad_channels
from .correct_lsb import correct_lsb

from .pipeline import apply_pipeline, PreprocessingPipeline

# for snippets
from .align_snippets import AlignSnippets
264 changes: 264 additions & 0 deletions src/spikeinterface/preprocessing/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
from __future__ import annotations

import json
import inspect
from spikeinterface.core.core_tools import is_dict_extractor
from spikeinterface.core import BaseRecording
from spikeinterface.preprocessing.preprocessinglist import preprocessor_dict, _all_preprocesser_dict

pp_names_to_functions = {preprocessor.__name__: preprocessor for preprocessor in preprocessor_dict.values()}
pp_names_to_classes = {pp_function.__name__: pp_class for pp_class, pp_function in _all_preprocesser_dict.items()}


class PreprocessingPipeline:
"""
A preprocessing pipeline, containing ordered preprocessing steps.

Parameters
----------
preprocessor_dict : dict
Dictionary containing preprocessing steps and their kwargs

Examples
--------
Generate a `PreprocessingPipeline` containing a `bandpass_filter` then a
`common_reference` step. Then apply this to a recording

>>> from spikeinterface.preprocessing import PreprocessingPipeline
>>> preprocessor_dict = {'bandpass_filter': {'freq_max': 3000}, 'common_reference': {}}
>>> my_pipeline = PreprocessingPipeline(preprocessor_dict)
PreprocessingPipeline: Raw Recording → bandpass_filter → common_reference → Preprocessed Recording
>>> my_pipeline.apply(recording)

"""

def __init__(self, preprocessor_dict):
for preprocessor in preprocessor_dict:
if preprocessor not in pp_names_to_functions.keys():
raise TypeError(
f"'{preprocessor}' is not supported by the `PreprocessingPipeline`. \
To see the list of supported steps, run:\n\t>>> from spikeinterface.preprocessing \
import preprocessor_dict\n\t>>> print(preprocessor_dict.keys())"
)

self.preprocessor_dict = preprocessor_dict

def __repr__(self):
txt = "PreprocessingPipeline: \tRaw Recording \u2192 "
for preprocessor in self.preprocessor_dict:
txt += str(preprocessor) + " \u2192 "
txt += "Preprocessed Recording"
return txt

def _repr_html_(self):

all_kwargs = _get_all_kwargs_and_values(self)

html_text = "<div'>"
html_text += "<strong>PreprocessingPipeline</strong>"
html_text += "<div style='border:1px solid #ccc; padding:10px;'><strong>Initial Recording</strong></div>"
html_text += "<div style='margin: auto; text-indent: 30px;'>&#x2193;</div>"

for a, (preprocessor, kwargs) in enumerate(all_kwargs.items()):
html_text += "<details style='border:1px solid #ddd; padding:5px;'>"
html_text += f"<summary><strong>{preprocessor}</strong></summary>"

html_text += "<ul>"
for kwarg, value in kwargs.items():
html_text += f"<li><strong>{kwarg}</strong>: {value}</li>"
html_text += "</ul>"
html_text += "</details>"

html_text += """<div style='margin: auto; text-indent: 30px;'>&#x2193;</div>"""
html_text += "<div style='border:1px solid #ccc; padding:10px;'><strong>Preprocessed Recording</strong></div>"
html_text += "</div>"

return html_text

def apply(self, recording, ignore_precomputed_kwargs=True):
"""
Creates a preprocessed recording by applying the `PreprocessingPipeline` to
`recording`.

Parameters
----------
recording : RecordingExtractor
The initial recording
ignore_precomputed_kwargs : Bool
Some preprocessing steps (e.g. Whitening) contain arguments which are computed
during preprocessing. If True, we ignore these precomputed steps. If False, we
compute when we apply the preprocessors.

Returns
-------
preprocessed_recording : RecordingExtractor
Preprocessed recording

"""

for preprocessor_name, kwargs in self.preprocessor_dict.items():

dont_include_kwargs = ["recording", "parent_recording"]

if ignore_precomputed_kwargs:
preprocessor_class = pp_names_to_classes[preprocessor_name]
precomputable_kwarg_names = preprocessor_class._precomputable_kwarg_names
dont_include_kwargs += precomputable_kwarg_names

non_rec_kwargs = {key: value for key, value in kwargs.items() if key not in dont_include_kwargs}
pp_output = pp_names_to_functions[preprocessor_name](recording, **non_rec_kwargs)
recording = pp_output

return recording


def apply_pipeline(
recording: BaseRecording, pipeline_or_dict: dict | PreprocessingPipeline = {}, ignore_precomputed_kwargs=True
):
"""
Creates a preprocessed recording by applying the preprocessing steps in
`preprocessor_dict` to `recording`.

Parameters
----------
recording : RecordingExtractor
The initial recording
preprocessor_dict : dict | PreprocessingPipeline = {}
Dictionary containing preprocessing steps and their kwargs, or a pipeline object.
If None, the original recording is returned.
ignore_precomputed_kwargs : Bool
Some preprocessing steps (e.g. Whitening) contain arguments which are computed
during preprocessing. If True, we ignore these precomputed steps. If False, we
compute when we apply the preprocessors.

Returns
-------
preprocessed_recording : RecordingExtractor
Preprocessed recording

Examples
--------
Create a preprocessed recording from a generated recording and a preprocessor_dict

>>> from spikeinterface.preprocessing import create_preprocessed
>>> from spikeinterface.generation import generate_recording
>>> recording = generate_recording()
>>> preprocessor_dict = {'bandpass_filter': {'freq_max': 3000}, 'common_reference': {}}
>>> preprocessed_recording = apply_pipeline(recording, preprocessor_dict)
"""

if isinstance(pipeline_or_dict, PreprocessingPipeline):
pipeline = pipeline_or_dict
else:
pipeline = PreprocessingPipeline(pipeline_or_dict)

preprocessed_recording = pipeline.apply(recording, ignore_precomputed_kwargs)
return preprocessed_recording


def get_preprocessing_dict_from_json(recording_json_path):
"""
Generates a preprocessing dict, passable to `create_preprocessed` function and
`PreprocessPipline` class, from a `recording.json` provenance file.

Only extracts preprocessing steps which can be applied "globally" to any recording.
Hence this does not extract `ChannelSlice` and `FrameSlice` steps. To see the
supported list of preprocessors run
>>> from spikeinterface.preprocessing import pp_function_to_class
>>> print(pp_function_to_class.keys()


Parameters
----------
recording_json_path : str or Path
Path to the `recording.json` file

Returns
-------
preprocessor_dict : dict
Dictionary containing preprocessing steps and their kwargs

"""
recording_json = json.load(open(recording_json_path))

pp_from_json = {}
_load_pp_from_dict(recording_json, pp_from_json)

pipeline_dict = {}
for preprocessor in reversed(pp_from_json):

preprocessor_class_name = preprocessor.split(".")[-1]

preprocessor_function = preprocessor_dict.get(preprocessor_class_name)
if preprocessor_function is None:
continue

pp_kwargs = {
key: value
for key, value in pp_from_json[preprocessor].items()
if key not in ["recording", "parent_recording"]
}

pipeline_dict[preprocessor_function.__name__] = pp_kwargs

return pipeline_dict


def _load_pp_from_dict(prov_dict, kwargs_dict):
"""
Recursive function used to iterate through recording provenance dictionary, and
extract preprocessing steps and their kwargs. Based on `_load_extractor_from_dict`
from spikeinterface.core.base.
"""
new_kwargs = dict()
transform_dict_to_extractor = lambda x: _load_pp_from_dict(x) if is_dict_extractor(x) else x
for name, value in prov_dict["kwargs"].items():
if is_dict_extractor(value):
new_kwargs[name] = _load_pp_from_dict(value, kwargs_dict)
elif isinstance(value, dict):
new_kwargs[name] = {k: transform_dict_to_extractor(v) for k, v in value.items()}
elif isinstance(value, list):
new_kwargs[name] = [transform_dict_to_extractor(e) for e in value]
else:
new_kwargs[name] = value

kwargs_dict[prov_dict["class"]] = new_kwargs
return new_kwargs


def _get_all_kwargs_and_values(my_pipeline):
"""
Get all keyword arguments and their values from a pipeline,
including the default values.
"""

all_kwargs = {}
for preprocessor in my_pipeline.preprocessor_dict:

preprocessor_name = preprocessor.split(".")[-1]
pp_function = pp_names_to_functions[preprocessor.split(".")[-1]]
signature = inspect.signature(pp_function)

all_kwargs[preprocessor_name] = {}

for _, value in signature.parameters.items():
par_name = str(value).split("=")[0].split(":")[0]
if par_name != "recording":
try:
default_value = str(value).split("=")
if len(default_value) == 1:
default_value = None
else:
default_value = default_value[-1]
except:
default_value = None

pipeline_value = my_pipeline.preprocessor_dict[preprocessor].get(par_name)

if pipeline_value is None:
if default_value != pipeline_value:
pipeline_value = default_value

all_kwargs[preprocessor_name][par_name] = pipeline_value

return all_kwargs
Loading
Loading