Skip to content

Commit 3ac330a

Browse files
authored
Fix changes in SSD for backward compatibility [circle deploy] (mne-tools#13327)
1 parent b4c2a3b commit 3ac330a

File tree

3 files changed

+73
-3
lines changed

3 files changed

+73
-3
lines changed

examples/decoding/ssd_spatial_filters.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
frequency band of interest and the noise covariance based on surrounding
1414
frequencies.
1515
"""
16+
1617
# Author: Denis A. Engemann <denis.engemann@gmail.com>
1718
# Victoria Peterson <victoriapeterson09@gmail.com>
1819
# License: BSD-3-Clause
@@ -82,8 +83,8 @@
8283
ssd_sources, sfreq=raw.info["sfreq"], n_fft=4096
8384
)
8485

85-
# Get spec_ratio information (already sorted).
86-
# Note that this is not necessary if sort_by_spectral_ratio=True (default).
86+
# Get spec_ratio information (already sorted)
87+
# Note that this is not necessary if sort_by_spectral_ratio=True (default)
8788
spec_ratio, sorter = ssd.get_spectral_ratio(ssd_sources)
8889

8990
# Plot spectral ratio (see Eq. 24 in Nikulin et al., 2011).

mne/decoding/ssd.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
logger,
1717
)
1818
from ._covs_ged import _ssd_estimate
19-
from ._mod_ged import _ssd_mod
19+
from ._mod_ged import _get_spectral_ratio, _ssd_mod
2020
from .base import _GEDTransformer
2121

2222

@@ -289,6 +289,37 @@ def fit_transform(self, X, y=None, **fit_params):
289289
# use parent TransformerMixin method but with custom docstring
290290
return super().fit_transform(X, y=y, **fit_params)
291291

292+
def get_spectral_ratio(self, ssd_sources):
293+
"""Get the spectal signal-to-noise ratio for each spatial filter.
294+
295+
Spectral ratio measure for best n_components selection
296+
See :footcite:`NikulinEtAl2011`, Eq. (24).
297+
298+
Parameters
299+
----------
300+
ssd_sources : array
301+
Data projected to SSD space.
302+
303+
Returns
304+
-------
305+
spec_ratio : array, shape (n_channels)
306+
Array with the sprectal ratio value for each component.
307+
sorter_spec : array, shape (n_channels)
308+
Array of indices for sorting spec_ratio.
309+
310+
References
311+
----------
312+
.. footbibliography::
313+
"""
314+
spec_ratio, sorter_spec = _get_spectral_ratio(
315+
ssd_sources=ssd_sources,
316+
sfreq=self.sfreq_,
317+
n_fft=self.n_fft_,
318+
freqs_signal=self.freqs_signal_,
319+
freqs_noise=self.freqs_noise_,
320+
)
321+
return spec_ratio, sorter_spec
322+
292323
def inverse_transform(self):
293324
"""Not implemented yet."""
294325
raise NotImplementedError("inverse_transform is not yet available.")

mne/decoding/tests/test_ssd.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,44 @@ def test_picks_arg():
570570
ssd.fit(X).transform(X)
571571

572572

573+
def test_get_spectral_ratio():
574+
"""Test that method is the same as function in _mod_ged.py."""
575+
X, _, _ = simulate_data()
576+
sf = 250
577+
n_channels = X.shape[0]
578+
info = create_info(ch_names=n_channels, sfreq=sf, ch_types="eeg")
579+
580+
# Init
581+
filt_params_signal = dict(
582+
l_freq=freqs_sig[0],
583+
h_freq=freqs_sig[1],
584+
l_trans_bandwidth=1,
585+
h_trans_bandwidth=1,
586+
)
587+
filt_params_noise = dict(
588+
l_freq=freqs_noise[0],
589+
h_freq=freqs_noise[1],
590+
l_trans_bandwidth=1,
591+
h_trans_bandwidth=1,
592+
)
593+
594+
ssd = SSD(
595+
info,
596+
filt_params_signal,
597+
filt_params_noise,
598+
n_components=None,
599+
sort_by_spectral_ratio=False,
600+
)
601+
ssd.fit(X)
602+
ssd_sources = ssd.transform(X)
603+
spec_ratio_ssd, sorter_spec_ssd = ssd.get_spectral_ratio(ssd_sources)
604+
spec_ratio_ged, sorter_spec_ged = _get_spectral_ratio(
605+
ssd_sources, ssd.sfreq_, ssd.n_fft_, ssd.freqs_signal_, ssd.freqs_noise_
606+
)
607+
assert_array_equal(spec_ratio_ssd, spec_ratio_ged)
608+
assert_array_equal(sorter_spec_ssd, sorter_spec_ged)
609+
610+
573611
@pytest.mark.filterwarnings("ignore:.*invalid value encountered in divide.*")
574612
@pytest.mark.filterwarnings("ignore:.*is longer than.*")
575613
@parametrize_with_checks(

0 commit comments

Comments
 (0)