Skip to content

Commit cf119df

Browse files
committed
add robust detrending
1 parent 0aaa146 commit cf119df

4 files changed

Lines changed: 251 additions & 10 deletions

File tree

examples/example_detrend.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
import numpy as np
1616
from matplotlib.gridspec import GridSpec
1717

18-
from meegkit.detrend import regress
18+
from meegkit.detrend import regress, detrend
1919

2020
np.random.seed(9)
2121

22+
# Regression
23+
# =============================================================================
24+
2225
# Simple regression example, no weights
2326
# -----------------------------------------------------------------------------
2427
# fit random walk
@@ -72,8 +75,55 @@
7275
[b, z] = regress(y, x, w)
7376

7477
plt.figure(4)
75-
plt.plot(y, label='data')
76-
plt.plot(z, ls=':', label='fit')
78+
plt.plot(y, label='data', color='C0')
79+
plt.plot(z, ls=':', label='fit', color='C1')
7780
plt.title('Channel-wise regression')
7881
plt.legend()
82+
83+
84+
# Detrending
85+
# =============================================================================
86+
87+
# Basic example with a linear trend
88+
# -----------------------------------------------------------------------------
89+
x = np.arange(100)[:, None]
90+
x = x + np.random.randn(*x.shape)
91+
y, _, _ = detrend(x, 1)
92+
93+
plt.figure(5)
94+
plt.plot(x, label='original')
95+
plt.plot(y, label='detrended')
96+
plt.legend()
97+
98+
# Detrend biased random walk
99+
# -----------------------------------------------------------------------------
100+
x = np.cumsum(np.random.randn(1000, 1) + 0.1)
101+
y, _, _ = detrend(x, 3)
102+
103+
plt.figure(6)
104+
plt.plot(x, label='original')
105+
plt.plot(y, label='detrended')
106+
plt.legend()
107+
108+
# Detrend with weights
109+
# -----------------------------------------------------------------------------
110+
x = np.linspace(0, 100, 1000)[:, None]
111+
x = x + 3 * np.random.randn(*x.shape)
112+
113+
# introduce some strong artifact on the first 100 samples
114+
x[:100, :] = 100
115+
116+
# Detrend
117+
y, _, _ = detrend(x, 3, None)
118+
119+
# Same process but this time downweight artifactual window
120+
w = np.ones(x.shape)
121+
w[:100, :] = 0
122+
yy, _, _ = detrend(x, 3, w)
123+
124+
plt.figure(7)
125+
plt.plot(x, label='original')
126+
plt.plot(y, label='detrended - no weights')
127+
plt.plot(yy, label='detrended - weights')
128+
plt.legend()
79129
plt.show()

meegkit/detrend.py

Lines changed: 160 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,126 @@
11
"""Robust detrending."""
22
import numpy as np
33

4+
from scipy.linalg import pinv, lstsq, solve
5+
46
from .utils import demean, pca, unfold
57
from .utils.matrix import _check_weights
68

79

10+
def detrend(x, order, w=None, basis='polynomials', threshold=3, n_iter=4):
11+
"""Robustly remove trend.
12+
13+
The data are fit to the basis using weighted least squares. The weight is
14+
updated by setting samples for which the residual is greater than 'thresh'
15+
times its std to zero, and the fit is repeated at most 'niter'-1 times.
16+
17+
The choice of order (and basis) determines what complexity of the trend
18+
that can be removed. It may be useful to first detrend with a low order
19+
to avoid fitting outliers, and then increase the order.
20+
21+
Parameters
22+
----------
23+
x : array, shape=(n_times, n_channels[, n_trials])
24+
Raw data
25+
order : int
26+
Order of polynomial or number of sin/cosine pairs
27+
w: weights
28+
basis: {'polynomials', 'sinusoids'} | ndarray
29+
Basis for regression.
30+
threshold : int
31+
Threshold for outliers, in number of standard deviations (default=3).
32+
niter : int
33+
Number of iterations (default=5).
34+
35+
Returns
36+
-------
37+
y: detrended data
38+
w: updated weights
39+
r: basis matrix used
40+
41+
Examples
42+
--------
43+
Fit linear trend, ignoring samples > 3*sd from it, and remove:
44+
>> y = detrend(x, 1)
45+
46+
Fit/remove polynomial order=5 with initial weighting w, threshold = 4*sd:
47+
>> y = detrend(x, 5, w, [],4 )
48+
49+
Fit/remove linear then 3rd order polynomial:
50+
>> [y, w]= detrend(x, 1)
51+
>> [yy, ww] = detrend(y, 3)
52+
53+
"""
54+
if threshold == 0:
55+
raise ValueError('thresh=0 is not what you want...')
56+
57+
# check/fix sizes
58+
dims = x.shape
59+
w = _check_weights(w, x)
60+
x = unfold(x)
61+
w = unfold(w)
62+
n_times, n_chans = x.shape
63+
64+
# regressors
65+
if isinstance(basis, np.ndarray):
66+
r = basis
67+
else:
68+
lin = np.linspace(-1, 1, n_times)
69+
if basis == 'polynomials' or basis is None:
70+
r = np.zeros((n_times, order))
71+
for i, o in enumerate(range(1, order + 1)):
72+
r[:, i] = lin ** o
73+
elif basis == 'sinusoids':
74+
r = np.zeros((n_times, order * 2))
75+
for i, o in enumerate(range(1, order + 1)):
76+
r[:, 2 * i] = np.sin[2 * np.pi * o * lin / 2]
77+
r[:, 2 * i + 1] = np.cos[2 * np.pi * o * lin / 2]
78+
else:
79+
raise ValueError('!')
80+
81+
# iteratively remove trends
82+
# the tricky bit is to ensure that weighted means are removed before
83+
# calculating the regression (see regress()).
84+
for iIter in range(n_iter):
85+
# weighted regression on basis
86+
_, y = regress(x, r, w)
87+
88+
# find outliers
89+
d = x - y
90+
if w.any():
91+
d = d * w
92+
ww = np.ones_like(x)
93+
ww[(abs(d) > threshold * np.std(d))] = 0
94+
95+
# update weights
96+
if not w.any():
97+
w = ww
98+
else:
99+
w = np.amin((w, ww), axis=0)
100+
del ww
101+
102+
y = x - y
103+
y = np.reshape(y, dims)
104+
w = np.reshape(w, dims)
105+
106+
# if show: # don't return, just plot
107+
# figure(1)
108+
# subplot 411
109+
# plot(x)
110+
# title('raw')
111+
# subplot 412
112+
# plot(y)
113+
# title('detrended')
114+
# subplot 413
115+
# plot(x-y)
116+
# title('trend')
117+
# subplot 414
118+
# nt_imagescc(w')
119+
# title('weight')
120+
121+
return y, w, r
122+
123+
8124
def regress(y, x, w=None, threshold=1e-7, return_mean=False):
9125
"""Weighted regression.
10126
@@ -14,7 +130,7 @@ def regress(y, x, w=None, threshold=1e-7, return_mean=False):
14130
Data.
15131
x : array, shape=(n_times, n_chans)
16132
Regressor.
17-
w :
133+
w : array, shape=(n_times, n_chans)
18134
Weight to apply to `y`. `w` is either a matrix of same size as `y`, or
19135
a column vector to be applied to each column of `y`.
20136
threshold : float
@@ -49,7 +165,7 @@ def regress(y, x, w=None, threshold=1e-7, return_mean=False):
49165
# PCA
50166
V, _ = pca(xx.T.dot(xx), thresh=threshold)
51167
xxx = xx.dot(V)
52-
b = yy.T.dot(xxx) / xxx.T.dot(xxx)
168+
b = mrdivide(yy.T.dot(xxx), xxx.T.dot(xxx))
53169
b = b.T
54170
z = np.dot(demean(x, w).dot(V), b)
55171
z = z + mn
@@ -67,7 +183,7 @@ def regress(y, x, w=None, threshold=1e-7, return_mean=False):
67183
xx = demean(x, w) * w
68184
V, _ = pca(xx.T.dot(xx), thresh=threshold)
69185
xxx = xx.dot(V)
70-
b = yy.T.dot(xxx) / xxx.T.dot(xxx)
186+
b = mrdivide(yy.T.dot(xxx), xxx.T.dot(xxx))
71187

72188
z = demean(x, w).dot(V).dot(b.T)
73189
z = z + mn
@@ -89,11 +205,51 @@ def regress(y, x, w=None, threshold=1e-7, return_mean=False):
89205
xx = x * wc
90206
V, _ = pca(xx.T.dot(xx), thresh=threshold)
91207
xx = xx.dot(V)
92-
c = yy.T.dot(xx) / xx.T.dot(xx)
208+
c = mrdivide(yy.T.dot(xx), xx.T.dot(xx))
93209

94210
z[:, i] = x.dot(V.dot(c.T)).flatten()
95211
z[:, i] += mn[:, i]
96212
b[i] = c
97213
b = b[:, :V.shape[1]]
98214

99215
return b, z
216+
217+
218+
def mrdivide(A, B):
219+
r"""Matrix right-division (A/B).
220+
221+
Solves the linear system XB = A for X. We can write equivalently:
222+
223+
1) XB = A
224+
2) (XB).T = A.T
225+
3) B.T X.T = A.T
226+
227+
Therefore A/B amounts to solving B.T X.T = A.T for X.T:
228+
229+
>> mldivide(B.T, A.T).T
230+
231+
References
232+
----------
233+
.. [1] https://docs.scipy.org/doc/numpy/user/numpy-for-matlab-users.html
234+
235+
"""
236+
return mldivide(B.T, A.T).T
237+
238+
239+
def mldivide(A, B):
240+
r"""Matrix left-division (A\B).
241+
242+
Solves the AX = B for X. In other words, X minimizes norm(A*X - B), the
243+
length of the vector AX - B:
244+
- linalg.solve(A, B) if A is square
245+
- linalg.lstsq(A, B) otherwise
246+
247+
References
248+
----------
249+
.. [1] https://docs.scipy.org/doc/numpy/user/numpy-for-matlab-users.html
250+
251+
"""
252+
if A.shape[0] == A.shape[1]:
253+
return solve(A, B)
254+
else:
255+
return lstsq(A, B)

meegkit/utils/matrix.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,8 @@ def fold(X, epoch_size):
435435
def unfold(X):
436436
"""Unfold 3D X into 2D (concatenate trials)."""
437437
n_samples, n_chans, n_trials = theshapeof(X)
438+
if X.size == 0:
439+
return X
438440

439441
if X.shape == (n_samples,):
440442
X = X[:, None]

tests/test_detrend.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Test robust detrending."""
22
import numpy as np
33

4-
from meegkit.detrend import regress
4+
from numpy.testing import assert_almost_equal
5+
6+
from meegkit.detrend import regress, detrend
57

68

79
def test_regress():
@@ -40,6 +42,37 @@ def test_regress():
4042
assert b.shape == (2, 1)
4143

4244

45+
def test_detrend():
46+
"""Test detrending."""
47+
# basic
48+
# x = np.arange(100)[:, None]
49+
# x = x + np.random.randn(*x.shape)
50+
# y, _, _ = detrend(x, 1)
51+
52+
# assert y.shape == x.shape
53+
54+
# detrend biased random walk
55+
x = np.cumsum(np.random.randn(1000, 1) + 0.1)
56+
y, _, _ = detrend(x, 3)
57+
58+
assert y.shape == x.shape
59+
60+
# weights
61+
trend = np.linspace(0, 100, 1000)[:, None]
62+
data = 3 * np.random.randn(*trend.shape)
63+
x = trend + data
64+
x[:100, :] = 100
65+
w = np.ones(x.shape)
66+
w[:100, :] = 0
67+
y, _, _ = detrend(x, 3, None)
68+
yy, _, _ = detrend(x, 3, w)
69+
70+
assert y.shape == x.shape
71+
assert yy.shape == x.shape
72+
73+
assert_almost_equal(yy[100:], data[100:], decimal=.3)
74+
4375
if __name__ == '__main__':
44-
import pytest
45-
pytest.main([__file__])
76+
# import pytest
77+
# pytest.main([__file__])
78+
test_detrend()

0 commit comments

Comments
 (0)