Skip to content

Commit 89968ca

Browse files
authored
Merge pull request #2864 from alejoe91/highpass-spatial-dtype
Fix highpass-spatial-filter return dtype
2 parents 8317eb5 + 8b622cc commit 89968ca

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

src/spikeinterface/preprocessing/highpass_spatial_filter.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44

55
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
6+
from .filter import fix_dtype
67
from ..core import order_channels_by_depth, get_chunk_with_margin
78
from ..core.core_tools import define_function_from_class
89

@@ -47,6 +48,8 @@ class HighpassSpatialFilterRecording(BasePreprocessor):
4748
Order of spatial butterworth filter
4849
highpass_butter_wn : float, default: 0.01
4950
Critical frequency (with respect to Nyquist) of spatial butterworth filter
51+
dtype : dtype, default: None
52+
The dtype of the output traces. If None, the dtype is the same as the input traces
5053
5154
Returns
5255
-------
@@ -73,6 +76,7 @@ def __init__(
7376
agc_window_length_s=0.1,
7477
highpass_butter_order=3,
7578
highpass_butter_wn=0.01,
79+
dtype=None,
7680
):
7781
BasePreprocessor.__init__(self, recording)
7882

@@ -117,6 +121,8 @@ def __init__(
117121
butter_kwargs = dict(btype="highpass", N=highpass_butter_order, Wn=highpass_butter_wn)
118122
sos_filter = scipy.signal.butter(**butter_kwargs, output="sos")
119123

124+
dtype = fix_dtype(recording, dtype)
125+
120126
for parent_segment in recording._recording_segments:
121127
rec_segment = HighPassSpatialFilterSegment(
122128
parent_segment,
@@ -128,6 +134,7 @@ def __init__(
128134
sos_filter,
129135
order_f,
130136
order_r,
137+
dtype=dtype,
131138
)
132139
self.add_recording_segment(rec_segment)
133140

@@ -155,6 +162,7 @@ def __init__(
155162
sos_filter,
156163
order_f,
157164
order_r,
165+
dtype,
158166
):
159167
BasePreprocessorSegment.__init__(self, parent_recording_segment)
160168
self.parent_recording_segment = parent_recording_segment
@@ -178,6 +186,7 @@ def __init__(
178186
self.order_r = order_r
179187
# get filter params
180188
self.sos_filter = sos_filter
189+
self.dtype = dtype
181190

182191
def get_traces(self, start_frame, end_frame, channel_indices):
183192
if channel_indices is None:
@@ -234,7 +243,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
234243
traces = traces[left_margin:-right_margin, channel_indices]
235244
else:
236245
traces = traces[left_margin:, channel_indices]
237-
return traces
246+
return traces.astype(self.dtype, copy=False)
238247

239248

240249
# function for API

src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,34 @@ def test_highpass_spatial_filter_synthetic_data(num_channels, ntr_pad, ntr_tap,
109109
assert raw_traces.shape == si_filtered.shape
110110

111111

112+
@pytest.mark.parametrize("dtype", [np.int16, np.float32, np.float64])
113+
def test_dtype_stability(dtype):
114+
"""
115+
Check that the dtype of the recording and
116+
output data is as expected, as data is cast to float32
117+
during filtering.
118+
"""
119+
num_chan = 32
120+
si_recording = generate_recording(num_channels=num_chan, durations=[2])
121+
si_recording.set_property("gain_to_uV", np.ones(num_chan))
122+
si_recording.set_property("offset_to_uV", np.ones(num_chan))
123+
si_recording = spre.astype(si_recording, dtype)
124+
125+
assert si_recording.dtype == dtype
126+
127+
highpass_spatial_filter = spre.highpass_spatial_filter(si_recording, n_channel_pad=2)
128+
129+
assert highpass_spatial_filter.dtype == dtype
130+
131+
filtered_data_unscaled = highpass_spatial_filter.get_traces(return_scaled=False)
132+
133+
assert filtered_data_unscaled.dtype == dtype
134+
135+
filtered_data_scaled = highpass_spatial_filter.get_traces(return_scaled=True)
136+
137+
assert filtered_data_scaled.dtype == np.float32
138+
139+
112140
# ----------------------------------------------------------------------------------------------------------------------
113141
# Test Utils
114142
# ----------------------------------------------------------------------------------------------------------------------
@@ -125,7 +153,7 @@ def get_ibl_si_data():
125153
ibl_data = ibl_recording.read(slice(None), slice(None), sync=False)[:, :-1].T # cut sync channel
126154

127155
si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap")
128-
si_recording = spre.scale(si_recording, dtype="float32")
156+
si_recording = spre.astype(si_recording, dtype="float32")
129157

130158
return ibl_data, si_recording
131159

0 commit comments

Comments
 (0)