3
3
import numpy as np
4
4
5
5
from .basepreprocessor import BasePreprocessor , BasePreprocessorSegment
6
+ from .filter import fix_dtype
6
7
from ..core import order_channels_by_depth , get_chunk_with_margin
7
8
from ..core .core_tools import define_function_from_class
8
9
@@ -47,6 +48,8 @@ class HighpassSpatialFilterRecording(BasePreprocessor):
47
48
Order of spatial butterworth filter
48
49
highpass_butter_wn : float, default: 0.01
49
50
Critical frequency (with respect to Nyquist) of spatial butterworth filter
51
+ dtype : dtype, default: None
52
+ The dtype of the output traces. If None, the dtype is the same as the input traces
50
53
51
54
Returns
52
55
-------
@@ -73,6 +76,7 @@ def __init__(
73
76
agc_window_length_s = 0.1 ,
74
77
highpass_butter_order = 3 ,
75
78
highpass_butter_wn = 0.01 ,
79
+ dtype = None ,
76
80
):
77
81
BasePreprocessor .__init__ (self , recording )
78
82
@@ -117,6 +121,8 @@ def __init__(
117
121
butter_kwargs = dict (btype = "highpass" , N = highpass_butter_order , Wn = highpass_butter_wn )
118
122
sos_filter = scipy .signal .butter (** butter_kwargs , output = "sos" )
119
123
124
+ dtype = fix_dtype (recording , dtype )
125
+
120
126
for parent_segment in recording ._recording_segments :
121
127
rec_segment = HighPassSpatialFilterSegment (
122
128
parent_segment ,
@@ -128,6 +134,7 @@ def __init__(
128
134
sos_filter ,
129
135
order_f ,
130
136
order_r ,
137
+ dtype = dtype ,
131
138
)
132
139
self .add_recording_segment (rec_segment )
133
140
@@ -155,6 +162,7 @@ def __init__(
155
162
sos_filter ,
156
163
order_f ,
157
164
order_r ,
165
+ dtype ,
158
166
):
159
167
BasePreprocessorSegment .__init__ (self , parent_recording_segment )
160
168
self .parent_recording_segment = parent_recording_segment
@@ -178,6 +186,7 @@ def __init__(
178
186
self .order_r = order_r
179
187
# get filter params
180
188
self .sos_filter = sos_filter
189
+ self .dtype = dtype
181
190
182
191
def get_traces (self , start_frame , end_frame , channel_indices ):
183
192
if channel_indices is None :
@@ -234,7 +243,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
234
243
traces = traces [left_margin :- right_margin , channel_indices ]
235
244
else :
236
245
traces = traces [left_margin :, channel_indices ]
237
- return traces
246
+ return traces . astype ( self . dtype , copy = False )
238
247
239
248
240
249
# function for API
0 commit comments