Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ Changelog

- Add generic reader function :func:`mne.io.read_raw` that loads files based on their extensions (it wraps the underlying specific ``read_raw_xxx`` functions) by `Clemens Brunner`_

- Add ``'auto'`` option to :meth:`mne.preprocessing.ICA.find_bads_ecg` to automatically determine the threshold for CTPS method by `Yu-Han Luo`_

Bug
~~~

Expand Down Expand Up @@ -174,3 +176,5 @@ API
- Add ``use_dev_head_trans`` parameter to :func:`mne.preprocessing.annotate_movement` to allow choosing the device to head transform is used to define the fixed cHPI coordinates By `Luke Bloy`_

- The function ``mne.channels.read_dig_captrack`` will be deprecated in version 0.22 in favor of :func:`mne.channels.read_dig_captrak` to correct the spelling error: "captraCK" -> "captraK", by `Stefan Appelhoff`_

- The ``threshold`` argument in :meth:`mne.preprocessing.ICA.find_bads_ecg` defaults to ``None`` in version 0.21 but will change to ``'auto'`` in 0.22 by `Yu-Han Luo`_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need you add your name to names.inc too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you! I had added my name to names.inc in previous PR, so I did not in this one. Locally built documentation works as expected.

12 changes: 10 additions & 2 deletions examples/preprocessing/plot_run_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# - MEG channel selection
# - 1-30 Hz band-pass filter
# - epoching -0.2 to 0.5 seconds with respect to events
# - rejection based on peak-to-peak amplitude

data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
Expand All @@ -35,9 +36,12 @@
raw.pick_types(meg=True, eeg=False, exclude='bads', stim=True).load_data()
raw.filter(1, 30, fir_design='firwin')

# peak-to-peak amplitude rejection parameters
reject = dict(grad=4000e-13, mag=4e-12)
# longer + more epochs for more artifact exposure
events = mne.find_events(raw, stim_channel='STI 014')
epochs = mne.Epochs(raw, events, event_id=None, tmin=-0.2, tmax=0.5)
epochs = mne.Epochs(raw, events, event_id=None, tmin=-0.2, tmax=0.5,
reject=reject)

###############################################################################
# Fit ICA model using the FastICA algorithm, detect and plot components
Expand All @@ -46,10 +50,14 @@
ica = ICA(n_components=0.95, method='fastica').fit(epochs)

ecg_epochs = create_ecg_epochs(raw, tmin=-.5, tmax=.5)
ecg_inds, scores = ica.find_bads_ecg(ecg_epochs)
ecg_inds, scores = ica.find_bads_ecg(ecg_epochs, threshold='auto')

ica.plot_components(ecg_inds)

###############################################################################
# Plot properties of ECG components:
ica.plot_properties(epochs, picks=ecg_inds)

###############################################################################
# Plot the estimated source of detected ECG related components
ica.plot_sources(raw, picks=ecg_inds)
67 changes: 60 additions & 7 deletions mne/preprocessing/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from numbers import Integral
from time import time

import math
import os
import json

Expand Down Expand Up @@ -1115,12 +1116,43 @@ def _find_bads_ch(self, inst, chs, threshold=3.0, start=None,

return labels, scores

def _get_ctps_threshold(self, pk_threshold=20):
"""Automatically decide the threshold of Kuiper index for CTPS method.

This function finds the threshold of Kuiper index based on the
threshold of pk. Kuiper statistic that minimizes the difference between
pk and the pk threshold (defaults to 20 [1]) is returned. It is assumed
that the data are appropriately filtered and bad data are rejected at
least based on peak-to-peak amplitude when/before running the ICA
decomposition on data.

References
----------
[1] Dammers, J., Schiek, M., Boers, F., Silex, C., Zvyagintsev,
M., Pietrzyk, U., Mathiak, K., 2008. Integration of amplitude
and phase statistics for complete artifact removal in independent
components of neuromagnetic recordings. Biomedical
Engineering, IEEE Transactions on 55 (10), pp.2356.
"""
N = self.info['sfreq']
Vs = np.arange(1, 100) / 100
C = math.sqrt(N) + 0.155 + 0.24 / math.sqrt(N)
# in formula (13), when k gets large, only k=1 matters for the
# summation. k*V*C thus becomes V*C
Pks = 2 * (4 * (Vs * C)**2 - 1) * (np.exp(-2 * (Vs * C)**2))
# NOTE: the threshold of pk is transformed to Pk for comparison
# pk = -log10(Pk)
return Vs[np.argmin(np.abs(Pks - 10**(-pk_threshold)))]

@verbose
def find_bads_ecg(self, inst, ch_name=None, threshold=None, start=None,
stop=None, l_freq=8, h_freq=16, method='ctps',
reject_by_annotation=True, measure="zscore",
verbose=None):
"""Detect ECG related components using correlation.
"""Detect ECG related components.

Cross-trial phase statistics (default) or Pearson correlation can be
used for detection.

.. note:: If no ECG channel is available, routine attempts to create
an artificial ECG based on cross-channel averaging.
Expand All @@ -1133,9 +1165,14 @@ def find_bads_ecg(self, inst, ch_name=None, threshold=None, start=None,
The name of the channel to use for ECG peak detection.
The argument is mandatory if the dataset contains no ECG
channels.
threshold : float
The value above which a feature is classified as outlier. If
method is 'ctps', defaults to 0.25, else defaults to 3.0.
threshold : float | str
The value above which a feature is classified as outlier. If 'auto'
and method is 'ctps', automatically compute the threshold. If
'auto' and method is 'correlation', defaults to 3.0. The default
translates to 0.25 for 'ctps' and 3.0 for 'correlation' in version
0.21 but will change to 'auto' in version 0.22.

.. versionchanged:: 0.21
start : int | float | None
First sample to include. If float, data will be interpreted as
time in seconds. If None, data will be used from the first sample.
Expand All @@ -1162,8 +1199,8 @@ def find_bads_ecg(self, inst, ch_name=None, threshold=None, start=None,

.. versionadded:: 0.14.0
measure : {'zscore', "cor"}
Which method to use for finding outliers. "zscore" (default) is
the iterated Z-scoring method, and "cor" is an absolute raw
Which method to use for finding outliers. 'zscore' (default) is
the iterated Z-scoring method, and 'cor' is an absolute raw
correlation threshold with a range of 0 to 1.

.. versionadded:: 0.21
Expand All @@ -1174,7 +1211,8 @@ def find_bads_ecg(self, inst, ch_name=None, threshold=None, start=None,
ecg_idx : list of int
The indices of ECG related components.
scores : np.ndarray of float, shape (``n_components_``)
The correlation scores.
If method is 'ctps', the normalized Kuiper index scores. If method
is 'correlation', the correlation scores.

See Also
--------
Expand All @@ -1198,7 +1236,16 @@ def find_bads_ecg(self, inst, ch_name=None, threshold=None, start=None,

if method == 'ctps':
if threshold is None:
warn('The default for "threshold" will change from None to'
'"auto" in version 0.22. To avoid this warning, '
'explicitly set threshold to "auto".',
DeprecationWarning)
threshold = 0.25
elif threshold == 'auto':
# TODO: defaults to 'auto' in v0.22
threshold = self._get_ctps_threshold()
logger.info('Using threshold: %.2f for CTPS ECG detection'
% threshold)
if isinstance(inst, BaseRaw):
sources = self.get_sources(create_ecg_epochs(
inst, ch_name, l_freq=l_freq, h_freq=h_freq,
Expand All @@ -1225,6 +1272,12 @@ def find_bads_ecg(self, inst, ch_name=None, threshold=None, start=None,
self.labels_['ecg/%s' % ch_name] = list(ecg_idx)
elif method == 'correlation':
if threshold is None:
warn('The default for "threshold" will change from None to'
'"auto" in version 0.22. To avoid this warning, '
'explicitly set threshold to "auto".',
DeprecationWarning)
threshold = 3.0
elif threshold == 'auto':
threshold = 3.0
self.labels_['ecg'], scores = self._find_bads_ch(
inst, [ecg], threshold=threshold, start=start, stop=stop,
Expand Down
30 changes: 19 additions & 11 deletions mne/preprocessing/tests/test_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,10 +413,15 @@ def test_ica_additional(method):
ica.fit(raw, picks=[1, 2, 3, 4, 5], start=start, stop=stop2)
_assert_ica_attributes(ica)

# check Kuiper index threshold
assert_equal(ica._get_ctps_threshold(), 0.21)
# check deprecation warning
with pytest.warns(DeprecationWarning, match='The default for "threshold"'):
ica.find_bads_ecg(raw, threshold=None)
# check passing a ch_name to find_bads_ecg
with pytest.warns(RuntimeWarning, match='longer'):
_, scores_1 = ica.find_bads_ecg(raw)
_, scores_2 = ica.find_bads_ecg(raw, raw.ch_names[1])
_, scores_1 = ica.find_bads_ecg(raw, threshold='auto')
_, scores_2 = ica.find_bads_ecg(raw, raw.ch_names[1], threshold='auto')
assert scores_1[0] != scores_2[0]

# test corrmap
Expand Down Expand Up @@ -620,21 +625,22 @@ def f(x, y):
epochs_data = epochs.get_data().copy()

with pytest.warns(RuntimeWarning, match='longer'):
idx, scores = ica.find_bads_ecg(raw, method='ctps')
idx, scores = ica.find_bads_ecg(raw, method='ctps', threshold='auto')
assert_equal(len(scores), ica.n_components_)
with pytest.warns(RuntimeWarning, match='longer'):
idx, scores = ica.find_bads_ecg(raw, method='correlation')
idx, scores = ica.find_bads_ecg(raw, method='correlation',
threshold='auto')
assert_equal(len(scores), ica.n_components_)

with pytest.warns(RuntimeWarning, match='longer'):
idx, scores = ica.find_bads_eog(raw)
assert_equal(len(scores), ica.n_components_)

idx, scores = ica.find_bads_ecg(epochs, method='ctps')
idx, scores = ica.find_bads_ecg(epochs, method='ctps', threshold='auto')

assert_equal(len(scores), ica.n_components_)
pytest.raises(ValueError, ica.find_bads_ecg, epochs.average(),
method='ctps')
method='ctps', threshold='auto')
pytest.raises(ValueError, ica.find_bads_ecg, raw,
method='crazy-coupling')

Expand All @@ -651,7 +657,8 @@ def f(x, y):
idx, scores = ica.find_bads_eog(evoked, ch_name='MEG 1441')
assert_equal(len(scores), ica.n_components_)

idx, scores = ica.find_bads_ecg(evoked, method='correlation')
idx, scores = ica.find_bads_ecg(evoked, method='correlation',
threshold='auto')
assert_equal(len(scores), ica.n_components_)

assert_array_equal(raw_data, raw[:][0])
Expand Down Expand Up @@ -725,14 +732,14 @@ def f(x, y):
ica.fit(raw, picks=picks[:5])
_assert_ica_attributes(ica)
with pytest.warns(RuntimeWarning, match='longer'):
ica.find_bads_ecg(raw)
ica.find_bads_ecg(raw, threshold='auto')
ica.find_bads_eog(epochs, ch_name='MEG 0121')
assert_array_equal(raw_data, raw[:][0])

raw.drop_channels(['MEG 0122'])
pytest.raises(RuntimeError, ica.find_bads_eog, raw)
with pytest.warns(RuntimeWarning, match='longer'):
pytest.raises(RuntimeError, ica.find_bads_ecg, raw)
pytest.raises(RuntimeError, ica.find_bads_ecg, raw, threshold='auto')


@requires_sklearn
Expand Down Expand Up @@ -1070,7 +1077,8 @@ def test_ica_labels():
for key in ('ecg', 'ref_meg', 'ecg/ECG-MAG'):
assert key not in ica.labels_

ica.find_bads_ecg(raw, l_freq=None, h_freq=None, method='correlation')
ica.find_bads_ecg(raw, l_freq=None, h_freq=None, method='correlation',
threshold='auto')
picks = list(pick_types(raw.info, meg=False, ecg=True))
for idx, ch in enumerate(picks):
assert '{}/{}/{}'.format('ecg', idx, raw.ch_names[ch]) in ica.labels_
Expand Down Expand Up @@ -1101,7 +1109,7 @@ def test_ica_labels():
assert key in ica.labels_
assert 'ecg/ECG-MAG' not in ica.labels_

ica.find_bads_ecg(raw, l_freq=None, h_freq=None)
ica.find_bads_ecg(raw, l_freq=None, h_freq=None, threshold='auto')
for key in ('ecg', 'eog', 'ref_meg', 'ecg/ECG-MAG'):
assert key in ica.labels_

Expand Down
6 changes: 4 additions & 2 deletions tutorials/preprocessing/plot_40_artifact_correction_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@

ica.exclude = []
# find which ICs match the ECG pattern
ecg_indices, ecg_scores = ica.find_bads_ecg(raw, method='correlation')
ecg_indices, ecg_scores = ica.find_bads_ecg(raw, method='correlation',
threshold='auto')
ica.exclude = ecg_indices

# barplot of ICA component "ECG match" scores
Expand Down Expand Up @@ -405,7 +406,8 @@
new_ica.fit(filt_raw)

# find which ICs match the ECG pattern
ecg_indices, ecg_scores = new_ica.find_bads_ecg(raw, method='correlation')
ecg_indices, ecg_scores = new_ica.find_bads_ecg(raw, method='correlation',
threshold='auto')
new_ica.exclude = ecg_indices

# barplot of ICA component "ECG match" scores
Expand Down