11"""Robust detrending."""
22import numpy as np
33
4+ from scipy .signal import lfilter
5+
46from .utils import demean , mrdivide , pca , unfold
57from .utils .matrix import _check_weights
8+ from .utils .sig import stmcb
69
710
811def detrend (x , order , w = None , basis = 'polynomials' , threshold = 3 , n_iter = 4 ,
@@ -209,6 +212,66 @@ def regress(x, r, w=None, threshold=1e-7, return_mean=False):
209212 return b , z
210213
211214
215+ def reduce_ringing (X , samples , order = 10 , n_samples = 100 , extra = 50 , threshold = 3 ,
216+ show = False ):
217+ """Subtract filter impulse response from signal at given samples.
218+
219+ Parameters
220+ ----------
221+ X: ndarray, shape=(n_times, n_chans[, n_trials])
222+ Data containing ringing artifacts.
223+ samples : list of ints
224+ Sample indices where to find ringing artifacts.
225+ order : int
226+ Order of polynomial trend (default=10).
227+ n_samples = 100
228+ Number of samples over which to estimate impulse response
229+ (default=100).
230+ extra : int
231+ Samples before stimulus to anchor trend (default=50).
232+ threshold: float
233+ Threshold for robust detrending (default=3).
234+
235+ Returns
236+ -------
237+ y : ndarray, shape=(n_times, n_chans[, n_trials])
238+ Clean data.
239+
240+ """
241+ NNUM = 8
242+ NDEN = 8 # number of filter coeffs
243+
244+ # remove samples too close to beginning or end
245+ samples = samples [samples > extra ]
246+ samples = samples [samples < X .shape [0 ] - n_samples ]
247+
248+ y = X .copy ()
249+ for i , s in enumerate (samples ):
250+ for c in range (X .shape [1 ]):
251+ # select portion to fit filter response, remove polynomial trend
252+ response = X [s - extra :s + n_samples , c ]
253+ # response = detrend(response, order, threshold)
254+ response = response [extra :]
255+
256+ # estimate filter parameters - helps ensure stable filter
257+ response = np .r_ [(response , np .zeros (response .shape ))]
258+ [B , A ] = stmcb (response , q = NNUM , p = NDEN , niter = 20 )
259+
260+ # estimate filter response to event
261+ pulse = np .arange (n_samples ) < 1
262+ model = lfilter (B , A , pulse )
263+ idx = s + np .arange (model .shape [0 ])
264+ y [idx , c ] = X [idx , c ] - model
265+
266+ if show :
267+ w = np .zeros ((X .shape [0 ], X .shape [1 ]))
268+ for s in samples :
269+ w [s :s + n_samples , :] = 1
270+ _plot_detrend (X , y , w )
271+
272+ return y
273+
274+
212275def _plot_detrend (x , y , w ):
213276 """Plot detrending results."""
214277 import matplotlib .pyplot as plt
@@ -228,8 +291,9 @@ def _plot_detrend(x, y, w):
228291 ax1 .legend ()
229292
230293 ax2 = f .add_subplot (gs [3 , 0 ])
231- ax2 .imshow (w .T , aspect = 'auto' , cmap = 'Greys' )
232- ax2 .set_yticks (np .arange (1 , n_chans + 1 , 1 ))
294+ ax2 .pcolormesh (w .T , cmap = 'Greys' )
295+ ax2 .set_yticks (np .arange (0 , n_chans ) + 0.5 )
296+ ax2 .set_yticklabels (['ch{}' .format (i ) for i in np .arange (n_chans )])
233297 ax2 .set_xlim (0 , n_times )
234298 ax2 .set_ylabel ('ch. weights' )
235299 ax2 .set_xlabel ('samples' )
0 commit comments