@@ -94,6 +94,8 @@ def cca_crossvalidate(xx, yy, shifts=None, sfreq=1, surrogate=False,
9494 n_trials).
9595 shifts : array, shape=(n_shifts,)
9696 Array of shifts to apply to `y` relative to `x` (can be negative).
97+ sfreq : float
98+ Sampling frequency. If not 1, lags are assumed to be given in seconds.
9799 surrogate : bool
98100 If True, estimate SD of correlation over non-matching pairs.
99101 plot : bool
@@ -133,16 +135,16 @@ def cca_crossvalidate(xx, yy, shifts=None, sfreq=1, surrogate=False,
133135
134136 # Calculate leave-one-out CCAs
135137 print ('Calculate CCAs...' )
136- AA = list ()
137- BB = list ()
138+ AA = []
139+ BB = []
138140 for t in tqdm (np .arange (n_trials )):
139141 # covariance of all trials except t
140142 CC = np .sum (C [..., np .arange (n_trials ) != t ], axis = - 1 , keepdims = True )
141143 if CC .ndim == 4 :
142144 CC = np .squeeze (CC , 3 )
143145
144146 # corresponding CCA
145- [ A , B , R ] = nt_cca (None , None , None , CC , xx [0 ].shape [1 ])
147+ A , B , _ = nt_cca (None , None , None , CC , xx [0 ].shape [1 ])
146148 AA .append (A )
147149 BB .append (B )
148150 del A , B
@@ -227,17 +229,17 @@ def nt_cca(X=None, Y=None, lags=None, C=None, m=None, thresh=1e-12, sfreq=1):
227229 independently from each page.
228230 m : int
229231 Number of channels of X.
230- thresh: float
232+ thresh : float
231233 Discard principal components below this value.
232234 sfreq : float
233235 Sampling frequency. If not 1, lags are assumed to be given in seconds.
234236
235237 Returns
236238 -------
237- A : array, shape=(n_chans_X, min(n_chans_X, n_chans_Y))
239+ A : array, shape=(n_chans_X, min(n_chans_X, n_chans_Y)[, n_lags] )
238240 Transform matrix mapping `X` to canonical space, where `n_comps` is
239241 equal to `min(n_chans_X, n_chans_Y)`.
240- B : array, shape=(n_chans_Y, n_comps)
242+ B : array, shape=(n_chans_Y, n_comps[, n_lags] )
241243 Transform matrix mapping `Y` to canonical space, where `n_comps` is
242244 equal to `min(n_chans_X, n_chans_Y)`.
243245 R : array, shape=(n_comps, n_lags)
@@ -246,16 +248,16 @@ def nt_cca(X=None, Y=None, lags=None, C=None, m=None, thresh=1e-12, sfreq=1):
246248 Notes
247249 -----
248250 Usage 1: CCA of X, Y
249- >> [ A, B, R] = nt_cca(X, Y) # noqa
251+ >> A, B, R = nt_cca(X, Y) # noqa
250252
251253 Usage 2: CCA of X, Y for each value of lags.
252- >> [ A, B, R] = nt_cca(X, Y, lags) # noqa
254+ >> A, B, R = nt_cca(X, Y, lags) # noqa
253255
254256 A positive lag indicates that Y is delayed relative to X.
255257
256258 Usage 3: CCA from covariance matrix
257259 >> C = [X, Y].T * [X, Y] # noqa
258- >> [ A, B, R] = nt_cca([], [], [] , C, X.shape[1]) # noqa
260+ >> A, B, R = nt_cca(None, None, None , C=C, m= X.shape[1]) # noqa
259261
260262 Use the third form to handle multiple files or large data (covariance C can
261263 be calculated chunk-by-chunk).
@@ -381,9 +383,10 @@ def whiten_nt(C, thresh=1e-12, keep=False):
381383 # break symmetry when x and y perfectly correlated (otherwise cols of x*A
382384 # and y*B are not orthogonal)
383385 d = d ** (1 - thresh )
386+ d_norm = d / np .max (d )
384387
385388 dd = np .zeros_like (d )
386- dd [d > thresh ] = (1. / d [d > thresh ])
389+ dd [d_norm > thresh ] = (1. / d [d_norm > thresh ])
387390
388391 D = np .diag (np .sqrt (dd ))
389392 W = np .dot (V , D )
0 commit comments