Skip to content

Commit 0aaa146

Browse files
committed
add weighted regression
1 parent 2352df9 commit 0aaa146

5 files changed

Lines changed: 232 additions & 3 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# MEEGkit
22

33
[![Build Status](https://travis-ci.org/nbara/python-meegkit.svg?branch=master)](https://travis-ci.org/nbara/python-meegkit)
4+
[![codecov](https://codecov.io/gh/nbara/python-meegkit/branch/master/graph/badge.svg)](https://codecov.io/gh/nbara/python-meegkit)
45
[![Binder](https://mybinder.org/badge.svg)](https://mybinder.org/v2/gh/nbara/python-meegkit/master)
56

67
Denoising tools for M/EEG processing in Python.

examples/example_detrend.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Robust detrending examples.
2+
3+
Some toy examples to showcase usage for ``meegkit.detrend`` module.
4+
5+
Robust referencing is adapted from [1]_.
6+
7+
References
8+
----------
9+
.. [1] de Cheveigné, A., & Arzounian, D. (2018). Robust detrending,
10+
rereferencing, outlier detection, and inpainting for multichannel data.
11+
NeuroImage, 172, 903-912.
12+
13+
"""
14+
import matplotlib.pyplot as plt
15+
import numpy as np
16+
from matplotlib.gridspec import GridSpec
17+
18+
from meegkit.detrend import regress
19+
20+
np.random.seed(9)
21+
22+
# Simple regression example, no weights
23+
# -----------------------------------------------------------------------------
24+
# fit random walk
25+
y = np.cumsum(np.random.randn(1000, 1), axis=0)
26+
x = np.arange(1000)[:, None]
27+
x = np.hstack([x, x ** 2, x ** 3])
28+
[b, z] = regress(y, x)
29+
30+
plt.figure(1)
31+
plt.plot(y, label='data')
32+
plt.plot(z, label='fit')
33+
plt.title('No weights')
34+
plt.legend()
35+
36+
# Simple regression example, with weights
37+
# -----------------------------------------------------------------------------
38+
y = np.cumsum(np.random.randn(1000, 1), axis=0)
39+
w = np.random.rand(*y.shape)
40+
[b, z] = regress(y, x, w)
41+
42+
plt.figure(2)
43+
plt.plot(y, label='data')
44+
plt.plot(z, label='fit')
45+
plt.title('Weighted regression')
46+
plt.legend()
47+
48+
# Downweight 1st half of the data
49+
# -----------------------------------------------------------------------------
50+
y = np.cumsum(np.random.randn(1000, 1), axis=0) + 1000
51+
w = np.ones(y.shape[0])
52+
w[:500] = 0
53+
[b, z] = regress(y, x, w)
54+
55+
f = plt.figure(3, constrained_layout=True)
56+
gs = GridSpec(3, 1, figure=f)
57+
ax1 = f.add_subplot(gs[:2, 0])
58+
ax1.plot(y, label='data')
59+
ax1.plot(z, label='fit')
60+
ax1.set_xticklabels('')
61+
ax1.set_title('Split-wise regression')
62+
ax1.legend()
63+
ax2 = f.add_subplot(gs[2, 0])
64+
l, = ax2.plot(np.arange(1000), np.zeros(1000))
65+
ax2.stackplot(np.arange(1000), w, labels=['weights'], color=l.get_color())
66+
ax2.legend(loc=2)
67+
68+
# Multichannel regression
69+
# -----------------------------------------------------------------------------
70+
y = np.cumsum(np.random.randn(1000, 2), axis=0)
71+
w = np.ones(y.shape[0])
72+
[b, z] = regress(y, x, w)
73+
74+
plt.figure(4)
75+
plt.plot(y, label='data')
76+
plt.plot(z, ls=':', label='fit')
77+
plt.title('Channel-wise regression')
78+
plt.legend()
79+
plt.show()

meegkit/detrend.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Robust detrending."""
2+
import numpy as np
3+
4+
from .utils import demean, pca, unfold
5+
from .utils.matrix import _check_weights
6+
7+
8+
def regress(y, x, w=None, threshold=1e-7, return_mean=False):
9+
"""Weighted regression.
10+
11+
Parameters
12+
----------
13+
y : array, shape=(n_times, n_chans)
14+
Data.
15+
x : array, shape=(n_times, n_chans)
16+
Regressor.
17+
w :
18+
Weight to apply to `y`. `w` is either a matrix of same size as `y`, or
19+
a column vector to be applied to each column of `y`.
20+
threshold : float
21+
PCA threshold (default=1e-7).
22+
return_mean : bool
23+
If True, also return the signal mean prior to regression.
24+
25+
Returns
26+
-------
27+
b : array, shape=(n_chans, n_chans)
28+
Regression matrix (apply to x to approximate y).
29+
z : array, shape=(n_times, n_chans)
30+
Regression (x @ b).
31+
32+
"""
33+
# check/fix sizes
34+
w = _check_weights(w, y)
35+
n_times = y.shape[0]
36+
n_chans = y.shape[1]
37+
x = unfold(x)
38+
y = unfold(y)
39+
if x.shape[0] != y.shape[0]:
40+
raise ValueError('x and y have incompatible shapes!')
41+
42+
# save weighted mean
43+
mn = y - demean(y, w)
44+
45+
if not w.any(): # simple regression
46+
xx = demean(x)
47+
yy = demean(y)
48+
49+
# PCA
50+
V, _ = pca(xx.T.dot(xx), thresh=threshold)
51+
xxx = xx.dot(V)
52+
b = yy.T.dot(xxx) / xxx.T.dot(xxx)
53+
b = b.T
54+
z = np.dot(demean(x, w).dot(V), b)
55+
z = z + mn
56+
57+
else: # weighted regression
58+
if w.shape[0] != n_times:
59+
raise ValueError('!')
60+
61+
if w.shape[1] == 1: # same weight for all channels
62+
if sum(w.flatten()) == 0:
63+
print('weights all zero')
64+
b = 0
65+
else:
66+
yy = demean(y, w) * w
67+
xx = demean(x, w) * w
68+
V, _ = pca(xx.T.dot(xx), thresh=threshold)
69+
xxx = xx.dot(V)
70+
b = yy.T.dot(xxx) / xxx.T.dot(xxx)
71+
72+
z = demean(x, w).dot(V).dot(b.T)
73+
z = z + mn
74+
75+
else: # each channel has own weight
76+
if w.shape[1] != y.shape[1]:
77+
raise ValueError('!')
78+
z = np.zeros(y.shape)
79+
b = np.zeros((n_chans, n_chans))
80+
for i in range(n_chans):
81+
if sum(w[:, i]) == 0:
82+
print('weights all zero for channel {}'.format(i))
83+
c = np.zeros(y.shape[1], 1)
84+
else:
85+
wc = w[:, i][:, None] # channel-specific weight
86+
yy = demean(y[:, i], wc) * wc
87+
# remove channel-specific-weighted mean from regressor
88+
x = demean(x, wc)
89+
xx = x * wc
90+
V, _ = pca(xx.T.dot(xx), thresh=threshold)
91+
xx = xx.dot(V)
92+
c = yy.T.dot(xx) / xx.T.dot(xx)
93+
94+
z[:, i] = x.dot(V.dot(c.T)).flatten()
95+
z[:, i] += mn[:, i]
96+
b[i] = c
97+
b = b[:, :V.shape[1]]
98+
99+
return b, z

meegkit/utils/matrix.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,9 @@ def unfold(X):
436436
"""Unfold 3D X into 2D (concatenate trials)."""
437437
n_samples, n_chans, n_trials = theshapeof(X)
438438

439+
if X.shape == (n_samples,):
440+
X = X[:, None]
441+
439442
if n_trials > 1:
440443
return np.reshape(
441444
np.transpose(X, (0, 2, 1)),
@@ -603,7 +606,8 @@ def _check_weights(weights, X):
603606
warnings.warn('weights should be a list or a numpy array.')
604607
weights = np.array([])
605608

606-
if len(weights) > 0:
609+
weights = np.asanyarray(weights)
610+
if weights.size > 0:
607611
dtype = np.complex128 if np.any(np.iscomplex(weights)) else np.float64
608612
weights = np.asanyarray(weights, dtype=dtype)
609613
if weights.ndim > 3:
@@ -618,8 +622,9 @@ def _check_weights(weights, X):
618622
if X.ndim == 3 and weights.ndim == 1:
619623
weights = weights[:, np.newaxis, np.newaxis]
620624

621-
if weights.shape[1] > 1:
622-
raise ValueError("Weights array should have a single column.")
625+
if weights.ndim > 1:
626+
if weights.shape[1] > 1 and weights.shape[1] != X.shape[1]:
627+
raise ValueError("Weights array should have a single column.")
623628

624629
return weights
625630

tests/test_detrend.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Test robust detrending."""
2+
import numpy as np
3+
4+
from meegkit.detrend import regress
5+
6+
7+
def test_regress():
8+
"""Test regression."""
9+
# Simple regression example, no weights
10+
# fit random walk
11+
y = np.cumsum(np.random.randn(1000, 1), axis=0)
12+
x = np.arange(1000)[:, None]
13+
x = np.hstack([x, x ** 2, x ** 3])
14+
[b, z] = regress(y, x)
15+
16+
# Simple regression example, with weights
17+
y = np.cumsum(np.random.randn(1000, 1), axis=0)
18+
w = np.random.rand(*y.shape)
19+
[b, z] = regress(y, x, w)
20+
21+
# Downweight 1st half of the data
22+
y = np.cumsum(np.random.randn(1000, 1), axis=0) + 1000
23+
w = np.ones(y.shape[0])
24+
w[:500] = 0
25+
[b, z] = regress(y, x, w)
26+
27+
# # Multichannel regression
28+
y = np.cumsum(np.random.randn(1000, 2), axis=0)
29+
w = np.ones(y.shape[0])
30+
[b, z] = regress(y, x, w)
31+
assert z.shape == (1000, 2)
32+
assert b.shape == (2, 1)
33+
34+
# Multichannel regression
35+
y = np.cumsum(np.random.randn(1000, 2), axis=0)
36+
w = np.ones(y.shape)
37+
w[:, 1] == .8
38+
[b, z] = regress(y, x, w)
39+
assert z.shape == (1000, 2)
40+
assert b.shape == (2, 1)
41+
42+
43+
if __name__ == '__main__':
44+
import pytest
45+
pytest.main([__file__])

0 commit comments

Comments
 (0)