11"""Robust detrending."""
22import numpy as np
33
4+ from scipy .linalg import pinv , lstsq , solve
5+
46from .utils import demean , pca , unfold
57from .utils .matrix import _check_weights
68
79
10+ def detrend (x , order , w = None , basis = 'polynomials' , threshold = 3 , n_iter = 4 ):
11+ """Robustly remove trend.
12+
13+ The data are fit to the basis using weighted least squares. The weight is
14+ updated by setting samples for which the residual is greater than 'thresh'
15+ times its std to zero, and the fit is repeated at most 'niter'-1 times.
16+
17+ The choice of order (and basis) determines what complexity of the trend
18+ that can be removed. It may be useful to first detrend with a low order
19+ to avoid fitting outliers, and then increase the order.
20+
21+ Parameters
22+ ----------
23+ x : array, shape=(n_times, n_channels[, n_trials])
24+ Raw data
25+ order : int
26+ Order of polynomial or number of sin/cosine pairs
27+ w: weights
28+ basis: {'polynomials', 'sinusoids'} | ndarray
29+ Basis for regression.
30+ threshold : int
31+ Threshold for outliers, in number of standard deviations (default=3).
32+ niter : int
33+ Number of iterations (default=5).
34+
35+ Returns
36+ -------
37+ y: detrended data
38+ w: updated weights
39+ r: basis matrix used
40+
41+ Examples
42+ --------
43+ Fit linear trend, ignoring samples > 3*sd from it, and remove:
44+ >> y = detrend(x, 1)
45+
46+ Fit/remove polynomial order=5 with initial weighting w, threshold = 4*sd:
47+ >> y = detrend(x, 5, w, [],4 )
48+
49+ Fit/remove linear then 3rd order polynomial:
50+ >> [y, w]= detrend(x, 1)
51+ >> [yy, ww] = detrend(y, 3)
52+
53+ """
54+ if threshold == 0 :
55+ raise ValueError ('thresh=0 is not what you want...' )
56+
57+ # check/fix sizes
58+ dims = x .shape
59+ w = _check_weights (w , x )
60+ x = unfold (x )
61+ w = unfold (w )
62+ n_times , n_chans = x .shape
63+
64+ # regressors
65+ if isinstance (basis , np .ndarray ):
66+ r = basis
67+ else :
68+ lin = np .linspace (- 1 , 1 , n_times )
69+ if basis == 'polynomials' or basis is None :
70+ r = np .zeros ((n_times , order ))
71+ for i , o in enumerate (range (1 , order + 1 )):
72+ r [:, i ] = lin ** o
73+ elif basis == 'sinusoids' :
74+ r = np .zeros ((n_times , order * 2 ))
75+ for i , o in enumerate (range (1 , order + 1 )):
76+ r [:, 2 * i ] = np .sin [2 * np .pi * o * lin / 2 ]
77+ r [:, 2 * i + 1 ] = np .cos [2 * np .pi * o * lin / 2 ]
78+ else :
79+ raise ValueError ('!' )
80+
81+ # iteratively remove trends
82+ # the tricky bit is to ensure that weighted means are removed before
83+ # calculating the regression (see regress()).
84+ for iIter in range (n_iter ):
85+ # weighted regression on basis
86+ _ , y = regress (x , r , w )
87+
88+ # find outliers
89+ d = x - y
90+ if w .any ():
91+ d = d * w
92+ ww = np .ones_like (x )
93+ ww [(abs (d ) > threshold * np .std (d ))] = 0
94+
95+ # update weights
96+ if not w .any ():
97+ w = ww
98+ else :
99+ w = np .amin ((w , ww ), axis = 0 )
100+ del ww
101+
102+ y = x - y
103+ y = np .reshape (y , dims )
104+ w = np .reshape (w , dims )
105+
106+ # if show: # don't return, just plot
107+ # figure(1)
108+ # subplot 411
109+ # plot(x)
110+ # title('raw')
111+ # subplot 412
112+ # plot(y)
113+ # title('detrended')
114+ # subplot 413
115+ # plot(x-y)
116+ # title('trend')
117+ # subplot 414
118+ # nt_imagescc(w')
119+ # title('weight')
120+
121+ return y , w , r
122+
123+
8124def regress (y , x , w = None , threshold = 1e-7 , return_mean = False ):
9125 """Weighted regression.
10126
@@ -14,7 +130,7 @@ def regress(y, x, w=None, threshold=1e-7, return_mean=False):
14130 Data.
15131 x : array, shape=(n_times, n_chans)
16132 Regressor.
17- w :
133+ w : array, shape=(n_times, n_chans)
18134 Weight to apply to `y`. `w` is either a matrix of same size as `y`, or
19135 a column vector to be applied to each column of `y`.
20136 threshold : float
@@ -49,7 +165,7 @@ def regress(y, x, w=None, threshold=1e-7, return_mean=False):
49165 # PCA
50166 V , _ = pca (xx .T .dot (xx ), thresh = threshold )
51167 xxx = xx .dot (V )
52- b = yy .T .dot (xxx ) / xxx .T .dot (xxx )
168+ b = mrdivide ( yy .T .dot (xxx ), xxx .T .dot (xxx ) )
53169 b = b .T
54170 z = np .dot (demean (x , w ).dot (V ), b )
55171 z = z + mn
@@ -67,7 +183,7 @@ def regress(y, x, w=None, threshold=1e-7, return_mean=False):
67183 xx = demean (x , w ) * w
68184 V , _ = pca (xx .T .dot (xx ), thresh = threshold )
69185 xxx = xx .dot (V )
70- b = yy .T .dot (xxx ) / xxx .T .dot (xxx )
186+ b = mrdivide ( yy .T .dot (xxx ), xxx .T .dot (xxx ) )
71187
72188 z = demean (x , w ).dot (V ).dot (b .T )
73189 z = z + mn
@@ -89,11 +205,51 @@ def regress(y, x, w=None, threshold=1e-7, return_mean=False):
89205 xx = x * wc
90206 V , _ = pca (xx .T .dot (xx ), thresh = threshold )
91207 xx = xx .dot (V )
92- c = yy .T .dot (xx ) / xx .T .dot (xx )
208+ c = mrdivide ( yy .T .dot (xx ), xx .T .dot (xx ) )
93209
94210 z [:, i ] = x .dot (V .dot (c .T )).flatten ()
95211 z [:, i ] += mn [:, i ]
96212 b [i ] = c
97213 b = b [:, :V .shape [1 ]]
98214
99215 return b , z
216+
217+
218+ def mrdivide (A , B ):
219+ r"""Matrix right-division (A/B).
220+
221+ Solves the linear system XB = A for X. We can write equivalently:
222+
223+ 1) XB = A
224+ 2) (XB).T = A.T
225+ 3) B.T X.T = A.T
226+
227+ Therefore A/B amounts to solving B.T X.T = A.T for X.T:
228+
229+ >> mldivide(B.T, A.T).T
230+
231+ References
232+ ----------
233+ .. [1] https://docs.scipy.org/doc/numpy/user/numpy-for-matlab-users.html
234+
235+ """
236+ return mldivide (B .T , A .T ).T
237+
238+
239+ def mldivide (A , B ):
240+ r"""Matrix left-division (A\B).
241+
242+ Solves the AX = B for X. In other words, X minimizes norm(A*X - B), the
243+ length of the vector AX - B:
244+ - linalg.solve(A, B) if A is square
245+ - linalg.lstsq(A, B) otherwise
246+
247+ References
248+ ----------
249+ .. [1] https://docs.scipy.org/doc/numpy/user/numpy-for-matlab-users.html
250+
251+ """
252+ if A .shape [0 ] == A .shape [1 ]:
253+ return solve (A , B )
254+ else :
255+ return lstsq (A , B )
0 commit comments