Skip to content

blank_staturation with a window around saturation signal #1541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/spikeinterface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .core import *

import warnings

warnings.filterwarnings("ignore", message="distutils Version classes are deprecated")
warnings.filterwarnings("ignore", message="the imp module is deprecated")

Expand Down
111 changes: 96 additions & 15 deletions src/spikeinterface/preprocessing/clip.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import numpy as np

try:
from numba import njit
HAVE_NUMBA = True
except ModuleNotFoundError as err:
HAVE_NUMBA = False

from spikeinterface.core.core_tools import define_function_from_class
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment

Expand Down Expand Up @@ -70,6 +76,10 @@ class BlankSaturationRecording(BasePreprocessor):
fill_value: float or None
The value to write instead of the saturating signal. If None, then the value is
automatically computed as the median signal value
ms_before: float (default 0)
Time (ms) to replace before the saturation signal
ms_after: float (default 0)
Time (ms) to replace after the saturation signal
num_chunks_per_segment: int (default 50)
The number of chunks per segments to consider to estimate the threshold/fill_values
chunk_size: int (default 500)
Expand All @@ -83,8 +93,14 @@ class BlankSaturationRecording(BasePreprocessor):
The filtered traces recording extractor object

"""
name = 'blank_staturation'

def __init__(self, recording, abs_threshold=None, quantile_threshold=None,
direction='upper', fill_value=None,
ms_before=0, ms_after=0,
num_chunks_per_segment=50, chunk_size=500, seed=0):


name = "blank_staturation"

def __init__(
self,
Expand Down Expand Up @@ -135,41 +151,106 @@ def __init__(

BasePreprocessor.__init__(self, recording)
for parent_segment in recording._recording_segments:
rec_segment = ClipRecordingSegment(parent_segment, a_min, value_min, a_max, value_max)
rec_segment = ClipRecordingSegment(
parent_segment, a_min, value_min, a_max, value_max,
ms_before=ms_before, ms_after=ms_after
)
self.add_recording_segment(rec_segment)

self._kwargs = dict(
recording=recording,
abs_threshold=abs_threshold,
quantile_threshold=quantile_threshold,
direction=direction,
fill_value=fill_value,
num_chunks_per_segment=num_chunks_per_segment,
chunk_size=chunk_size,
seed=seed,
)
self._kwargs = dict(recording=recording, abs_threshold=abs_threshold, ms_before=ms_before, ms_after=ms_after,
quantile_threshold=quantile_threshold, direction=direction, fill_value=fill_value,
num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size,
seed=seed)



class ClipRecordingSegment(BasePreprocessorSegment):
def __init__(self, parent_recording_segment, a_min, value_min, a_max, value_max):
def __init__(self, parent_recording_segment, a_min, value_min, a_max, value_max,
ms_before=0, ms_after=0):
BasePreprocessorSegment.__init__(self, parent_recording_segment)

self.a_min = a_min
self.value_min = value_min
self.a_max = a_max
self.value_max = value_max
self.ms_before = ms_before
self.ms_after = ms_after


def get_traces(self, start_frame, end_frame, channel_indices):
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices)
traces = traces.copy()
fs = self.parent_recording_segment.sampling_frequency

frames_before = int(self.ms_before * fs // 1000)
frames_after = int(self.ms_after * fs // 1000)

if self.a_min is not None:
traces[traces <= self.a_min] = self.value_min
traces = replace_slice_min(traces, self.a_min, frames_before, frames_after, self.value_min)

if self.a_max is not None:
traces[traces >= self.a_max] = self.value_max
traces = replace_slice_max(traces, self.a_max, frames_before, frames_after, self.value_max)

return traces

def replace_slice_min(traces, a_min, frames_before, frames_after, value_min):
if HAVE_NUMBA:
return _replace_slice_min_numba(traces, a_min, frames_before, frames_after, value_min)
else:
return _replace_slice_min_for_loop(traces, a_min, frames_before, frames_after, value_min)

def replace_slice_max(traces, a_max, frames_before, frames_after, value_max):
if HAVE_NUMBA:
return _replace_slice_max_numba(traces, a_max, frames_before, frames_after, value_max)
else:
return _replace_slice_max_for_loop(traces, a_max, frames_before, frames_after, value_max)

# For loops
def _replace_slice_min_for_loop(traces, a_min, frames_before, frames_after, value_min):
min_indices, channels = np.where(traces <= a_min)
for index, chan in zip(min_indices, channels):
traces[max(0, index - frames_before):min(len(traces), index + frames_after + 1), chan] = value_min
return traces

def _replace_slice_max_for_loop(traces, a_max, frames_before, frames_after, value_max):
max_indices, channels = np.where(traces >= a_max)
for index, chan in zip(max_indices, channels):
traces[max(0, index - frames_before):min(len(traces), index + frames_after + 1), chan] = value_max
return traces

if HAVE_NUMBA:
# Numba
@njit(cache=True)
def _replace_slice_max_numba(traces, a_max, frames_before, frames_after, value_max):
m, n = traces.shape
to_clear = np.zeros(m, dtype=np.bool_)
for j in range(n):
to_clear[:] = False
for i in range(m):
if traces[i, j] >= a_max:
to_clear[
max(0, i - frames_before) : min(m, i + frames_after + 1)
] = True
for i in range(m):
if to_clear[i]:
traces[i, j] = value_max
return traces

@njit(cache=True)
def _replace_slice_min_numba(traces, a_min, frames_before, frames_after, value_min):
m, n = traces.shape
to_clear = np.zeros(m, dtype=np.bool_)
for j in range(n):
to_clear[:] = False
for i in range(m):
if traces[i, j] <= a_min:
to_clear[
max(0, i - frames_before) : min(m, i + frames_after + 1)
] = True
for i in range(m):
if to_clear[i]:
traces[i, j] = value_min
return traces

clip = define_function_from_class(source_class=ClipRecording, name="clip")
blank_staturation = define_function_from_class(source_class=BlankSaturationRecording, name="blank_staturation")