Skip to content

Commit f2fe945

Browse files
romquentinnbara
andauthored
Add trial-masked detrending (nbara#41)
* add create_maksed_weight * Formatting docstring * format public function * add test + correct dimension of weights * docstring Co-authored-by: nicolas barascud <10333715+nbara@users.noreply.github.com>
1 parent 852391b commit f2fe945

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

meegkit/detrend.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,45 @@ def _plot_detrend(x, y, w):
297297
ax2.set_ylabel('ch. weights')
298298
ax2.set_xlabel('samples')
299299
plt.show()
300+
301+
302+
def create_masked_weight(x, events, tmin, tmax, sfreq):
303+
"""Output a weight matrix for trial-masked detrending.
304+
305+
Creates a (n_times, n_channels) weight matrix with masked
306+
periods (value of zero) in order to mask the trials of interest during
307+
detrending [1]_.
308+
309+
Parameters
310+
----------
311+
x : ndarray, shape=(n_times, n_channels)
312+
Raw data matrix.
313+
events : ndarray, shape=(n_events)
314+
Time samples of the events.
315+
tmin : float
316+
Start time before event (in seconds).
317+
tmax : float
318+
End time after event (in seconds).
319+
sfreq : float
320+
The sampling frequency of the data.
321+
322+
Returns
323+
-------
324+
weights : ndarray, shape=(n_times, n_channels)
325+
Weight for each channel and each time sample (zero is masked).
326+
327+
References
328+
----------
329+
.. [1] van Driel, J., Olivers, C. N., & Fahrenfort, J. J. (2021). High-pass
330+
filtering artifacts in multivariate classification of neural time series
331+
data. Journal of Neuroscience Methods, 352, 109080.
332+
333+
"""
334+
if x.ndim != 2:
335+
raise ValueError('The shape of x must be (n_times, n_channels)')
336+
337+
weights = np.ones(x.shape)
338+
for e in events:
339+
weights[int(e + tmin * sfreq): int(e + tmax * sfreq + 1), :] = 0
340+
341+
return weights

tests/test_detrend.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Test robust detrending."""
22
import numpy as np
33

4-
from meegkit.detrend import regress, detrend, reduce_ringing
4+
from meegkit.detrend import regress, detrend, reduce_ringing, create_masked_weight
55

66
from scipy.signal import butter, lfilter
77

@@ -89,6 +89,16 @@ def test_detrend(show=False):
8989
x += 2 * np.sin(2 * np.pi * np.arange(1000) / 200)[:, None]
9090
y, _, _ = detrend(x, 5, basis='sinusoids', show=True)
9191

92+
# trial-masked detrending
93+
trend = np.linspace(0, 100, 1000)[:, None]
94+
data = 3 * np.random.randn(*trend.shape)
95+
data[:100, :] = 100
96+
x = trend + data
97+
events = np.arange(30, 970, 40)
98+
tmin, tmax, sfreq = -0.2, 0.3, 20
99+
w = create_masked_weight(x, events, tmin, tmax, sfreq)
100+
y, _, _ = detrend(x, 1, w, basis='polynomials', show=show)
101+
92102

93103
def test_ringing():
94104
"""Test reduce_ringing function."""

0 commit comments

Comments
 (0)