Skip to content

Commit 5e95ee5

Browse files
nbaranbarascud-sc
andauthored
[FIX] Whitening with low amplitude data (nbara#64)
* minor docfixes * fix whitening when data is very small * Update test_cca.py * pep Co-authored-by: Nicolas Barascud <nbarascud@snapchat.com>
1 parent 154cfc5 commit 5e95ee5

File tree

3 files changed

+31
-10
lines changed

3 files changed

+31
-10
lines changed

meegkit/cca.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def cca_crossvalidate(xx, yy, shifts=None, sfreq=1, surrogate=False,
9494
n_trials).
9595
shifts : array, shape=(n_shifts,)
9696
Array of shifts to apply to `y` relative to `x` (can be negative).
97+
sfreq : float
98+
Sampling frequency. If not 1, lags are assumed to be given in seconds.
9799
surrogate : bool
98100
If True, estimate SD of correlation over non-matching pairs.
99101
plot : bool
@@ -133,16 +135,16 @@ def cca_crossvalidate(xx, yy, shifts=None, sfreq=1, surrogate=False,
133135

134136
# Calculate leave-one-out CCAs
135137
print('Calculate CCAs...')
136-
AA = list()
137-
BB = list()
138+
AA = []
139+
BB = []
138140
for t in tqdm(np.arange(n_trials)):
139141
# covariance of all trials except t
140142
CC = np.sum(C[..., np.arange(n_trials) != t], axis=-1, keepdims=True)
141143
if CC.ndim == 4:
142144
CC = np.squeeze(CC, 3)
143145

144146
# corresponding CCA
145-
[A, B, R] = nt_cca(None, None, None, CC, xx[0].shape[1])
147+
A, B, _ = nt_cca(None, None, None, CC, xx[0].shape[1])
146148
AA.append(A)
147149
BB.append(B)
148150
del A, B
@@ -227,17 +229,17 @@ def nt_cca(X=None, Y=None, lags=None, C=None, m=None, thresh=1e-12, sfreq=1):
227229
independently from each page.
228230
m : int
229231
Number of channels of X.
230-
thresh: float
232+
thresh : float
231233
Discard principal components below this value.
232234
sfreq : float
233235
Sampling frequency. If not 1, lags are assumed to be given in seconds.
234236
235237
Returns
236238
-------
237-
A : array, shape=(n_chans_X, min(n_chans_X, n_chans_Y))
239+
A : array, shape=(n_chans_X, min(n_chans_X, n_chans_Y)[, n_lags])
238240
Transform matrix mapping `X` to canonical space, where `n_comps` is
239241
equal to `min(n_chans_X, n_chans_Y)`.
240-
B : array, shape=(n_chans_Y, n_comps)
242+
B : array, shape=(n_chans_Y, n_comps[, n_lags])
241243
Transform matrix mapping `Y` to canonical space, where `n_comps` is
242244
equal to `min(n_chans_X, n_chans_Y)`.
243245
R : array, shape=(n_comps, n_lags)
@@ -246,16 +248,16 @@ def nt_cca(X=None, Y=None, lags=None, C=None, m=None, thresh=1e-12, sfreq=1):
246248
Notes
247249
-----
248250
Usage 1: CCA of X, Y
249-
>> [A, B, R] = nt_cca(X, Y) # noqa
251+
>> A, B, R = nt_cca(X, Y) # noqa
250252
251253
Usage 2: CCA of X, Y for each value of lags.
252-
>> [A, B, R] = nt_cca(X, Y, lags) # noqa
254+
>> A, B, R = nt_cca(X, Y, lags) # noqa
253255
254256
A positive lag indicates that Y is delayed relative to X.
255257
256258
Usage 3: CCA from covariance matrix
257259
>> C = [X, Y].T * [X, Y] # noqa
258-
>> [A, B, R] = nt_cca([], [], [], C, X.shape[1]) # noqa
260+
>> A, B, R = nt_cca(None, None, None, C=C, m=X.shape[1]) # noqa
259261
260262
Use the third form to handle multiple files or large data (covariance C can
261263
be calculated chunk-by-chunk).
@@ -381,9 +383,10 @@ def whiten_nt(C, thresh=1e-12, keep=False):
381383
# break symmetry when x and y perfectly correlated (otherwise cols of x*A
382384
# and y*B are not orthogonal)
383385
d = d ** (1 - thresh)
386+
d_norm = d / np.max(d)
384387

385388
dd = np.zeros_like(d)
386-
dd[d > thresh] = (1. / d[d > thresh])
389+
dd[d_norm > thresh] = (1. / d[d_norm > thresh])
387390

388391
D = np.diag(np.sqrt(dd))
389392
W = np.dot(V, D)

tests/data/ccadata_meg_2trials.npz

30.2 MB
Binary file not shown.

tests/test_cca.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,21 @@ def test_cca2():
7878
# plt.show()
7979

8080

81+
def test_cca_scaling():
82+
"""Test CCA with MEG data."""
83+
data = np.load('./tests/data/ccadata_meg_2trials.npz')
84+
raw = data['arr_0']
85+
env = data['arr_1']
86+
87+
# Test with scaling (unit: fT)
88+
A0, B0, R0 = nt_cca(raw * 1e15, env)
89+
90+
# Test without scaling (unit: T)
91+
A1, B1, R1 = nt_cca(raw, env)
92+
93+
np.testing.assert_almost_equal(R0, R1)
94+
95+
8196
def test_canoncorr():
8297
"""Compare with Matlab's canoncorr."""
8398
x = np.array([[16, 2, 3, 13],
@@ -130,6 +145,9 @@ def test_cca_lags():
130145
lags = np.arange(-10, 11, 1)
131146
A1, B1, R1 = nt_cca(x, y, lags)
132147

148+
assert A1.ndim == B1.ndim == 3
149+
assert A1.shape[-1] == B1.shape[-1] == lags.size
150+
133151
# import matplotlib.pyplot as plt
134152
# f, ax1 = plt.subplots(1, 1)
135153
# ax1.plot(lags, R1.T)

0 commit comments

Comments
 (0)