Skip to content

Commit a0c0561

Browse files
authored
[FIX] Use linalg.eigh in RESS (nbara#46)
* [FIX] Use linalg.eigh in RESS * Update dss.py * Update test_ress.py
1 parent f2fe945 commit a0c0561

File tree

4 files changed

+22
-15
lines changed

4 files changed

+22
-15
lines changed

meegkit/asr.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -363,25 +363,21 @@ def clean_windows(X, sfreq, max_bad_chans=0.2, zthresholds=[-3.5, 5],
363363

364364
# combine the three masks
365365
remove_mask = np.logical_or.reduce((mask1, mask2, mask3))
366-
removed_wins = np.where(remove_mask)
366+
removed_wins = np.where(remove_mask)[0]
367367

368368
# reconstruct the samples to remove
369369
sample_maskidx = []
370-
for i in range(len(removed_wins[0])):
370+
for i, win in enumerate(removed_wins):
371371
if i == 0:
372-
sample_maskidx = np.arange(
373-
offsets[removed_wins[0][i]], offsets[removed_wins[0][i]] + N)
372+
sample_maskidx = np.arange(offsets[win], offsets[win] + N)
374373
else:
375-
sample_maskidx = np.vstack((
376-
sample_maskidx,
377-
np.arange(offsets[removed_wins[0][i]],
378-
offsets[removed_wins[0][i]] + N)
379-
))
374+
sample_maskidx = np.r_[(sample_maskidx,
375+
np.arange(offsets[win], offsets[win] + N))]
380376

381377
# delete the bad chunks from the data
382378
sample_mask2remove = np.unique(sample_maskidx)
383379
if sample_mask2remove.size:
384-
clean = np.delete(X, sample_mask2remove, 1)
380+
clean = np.delete(X, sample_mask2remove, axis=1)
385381
sample_mask = np.ones((1, ns), dtype=bool)
386382
sample_mask[0, sample_mask2remove] = False
387383
else:

meegkit/dss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def dss1(X, weights=None, keep1=None, keep2=1e-12):
2121
keep1: int
2222
Number of PCs to retain in function:`dss0` (default=all).
2323
keep2: float
24-
Ignore PCs smaller than keep2 in function:`dss0` (default=10^-12).
24+
Ignore PCs smaller than keep2 in function:`dss0` (default=1e-12).
2525
2626
Returns
2727
-------
@@ -35,7 +35,7 @@ def dss1(X, weights=None, keep1=None, keep2=1e-12):
3535
Power per component (averaged).
3636
3737
"""
38-
n_samples, n_chans, n_trials = theshapeof(X)
38+
n_trials = theshapeof(X)[-1]
3939

4040
# if demean: # remove weighted mean
4141
# X = demean(X, weights)

meegkit/ress.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1,
99
peak_width: float = .5, neig_width: float = 1, n_keep: int = 1,
10-
return_maps: bool = False):
10+
gamma: float = 0.01, return_maps: bool = False):
1111
"""Rhythmic Entrainment Source Separation.
1212
1313
As described in [1]_.
@@ -29,6 +29,10 @@ def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1,
2929
FWHM of the neighboring frequencies (default=1).
3030
n_keep : int
3131
Number of components to keep (default=1). -1 keeps all components.
32+
gamma : float
33+
Regularization coefficient, between 0 and 1 (default=0.01, which
34+
corresponds to 1 % regularization and helps reduce numerical problems
35+
for noisy or reduced-rank matrices [2]_).
3236
return_maps : bool
3337
If True, also output mixing (to_ress) and unmixing matrices
3438
(from_ress), used to transform the data into RESS component space and
@@ -67,6 +71,9 @@ def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1,
6771
.. [1] Cohen, M. X., & Gulbinaite, R. (2017). Rhythmic entrainment source
6872
separation: Optimizing analyses of neural responses to rhythmic sensory
6973
stimulation. Neuroimage, 147, 43-56.
74+
.. [2] Cohen, M. X. (2021). A tutorial on generalized eigendecomposition
75+
for source separation in multichannel electrophysiology.
76+
ArXiv:2104.12356 [Eess, q-Bio].
7077
7178
"""
7279
n_samples, n_chans, n_trials = theshapeof(X)
@@ -82,8 +89,12 @@ def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1,
8289
fwhm=neig_width, n_harm=1))
8390
c1, _ = tscov(gaussfilt(X, sfreq, peak_freq, fwhm=peak_width, n_harm=1))
8491

92+
# add 1% regularization to avoid numerical precision problems in the GED
93+
c0 = (c01 + c02) / 2
94+
c0 = c0 * (1 - gamma) + gamma * np.trace(c0) / len(c0) * np.eye(len(c0))
95+
8596
# perform generalized eigendecomposition
86-
d, to_ress = linalg.eig(c1, (c01 + c02) / 2)
97+
d, to_ress = linalg.eigh(c1, c0)
8798
d = d.real
8899
to_ress = to_ress.real
89100

tests/test_ress.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,4 @@ def test_ress(target, n_trials, peak_width, neig_width, neig_freq, show=False):
156156
if __name__ == '__main__':
157157
import pytest
158158
pytest.main([__file__])
159-
# test_ress(12, 20, 1, 1, 1, show=True)
159+
# test_ress(20, 16, 1, 1, 1, show=False)

0 commit comments

Comments
 (0)