Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Changelog

- Added temporal derivative distribution repair :func:`mne.preprocessing.nirs.temporal_derivative_distribution_repair` by `Robert Luke`_

- Add functionality to interpolate bad NIRS channels by `Robert Luke`_

Bug
~~~
Expand Down
4 changes: 3 additions & 1 deletion mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,8 @@ def interpolate_bads(self, reset_bads=True, mode='accurate',
.. versionadded:: 0.9.0
"""
from ..bem import _check_origin
from .interpolation import _interpolate_bads_eeg, _interpolate_bads_meg
from .interpolation import _interpolate_bads_eeg,\
_interpolate_bads_meg, _interpolate_bads_nirs

_check_preload(self, "interpolation")

Expand All @@ -1003,6 +1004,7 @@ def interpolate_bads(self, reset_bads=True, mode='accurate',
origin = _check_origin(origin, self.info)
_interpolate_bads_eeg(self, origin=origin)
_interpolate_bads_meg(self, mode=mode, origin=origin)
_interpolate_bads_nirs(self)

if reset_bads is True:
self.info['bads'] = []
Expand Down
56 changes: 56 additions & 0 deletions mne/channels/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..io.pick import pick_types, pick_channels, pick_info
from ..surface import _normalize_vectors
from ..forward import _map_meg_channels
from ..utils import _check_option


def _calc_h(cosang, stiffness=4, n_legendre_terms=50):
Expand Down Expand Up @@ -213,3 +214,58 @@ def _interpolate_bads_meg(inst, mode='accurate', origin=(0., 0., 0.04),
info_to = pick_info(inst.info, picks_bad)
mapping = _map_meg_channels(info_from, info_to, mode=mode, origin=origin)
_do_interp_dots(inst, mapping, picks_good, picks_bad)


@verbose
def _interpolate_bads_nirs(inst, method='nearest', verbose=None):
"""Interpolate bad nirs channels. Simply replaces by closest non bad.

Parameters
----------
inst : mne.io.Raw, mne.Epochs or mne.Evoked
The data to interpolate. Must be preloaded.
method : str
Only the method 'nearest' is currently available. This method replaces
each bad channel with the nearest non bad channel.
%(verbose)s
"""
from scipy.spatial.distance import pdist, squareform
from mne.preprocessing.nirs import _channel_frequencies,\
_check_channels_ordered

# Returns pick of all nirs and ensures channels are correctly ordered
freqs = np.unique(_channel_frequencies(inst))
picks_nirs = _check_channels_ordered(inst, freqs)
if len(picks_nirs) == 0:
return

nirs_ch_names = [inst.info['ch_names'][p] for p in picks_nirs]
bads_nirs = [ch for ch in inst.info['bads'] if ch in nirs_ch_names]
if len(bads_nirs) == 0:
return
picks_bad = pick_channels(inst.info['ch_names'], bads_nirs, exclude=[])
bads_mask = [p in picks_bad for p in picks_nirs]

chs = [inst.info['chs'][i] for i in picks_nirs]
locs3d = np.array([ch['loc'][:3] for ch in chs])

_check_option('fnirs_method', method, ['nearest'])

if method == 'nearest':

dist = pdist(locs3d)
dist = squareform(dist)

for bad in picks_bad:
dists_to_bad = dist[bad]
# Ignore distances to self
dists_to_bad[dists_to_bad == 0] = np.inf
# Ignore distances to other bad channels
dists_to_bad[bads_mask] = np.inf
# Find closest remaining channels for same frequency
closest_idx = np.argmin(dists_to_bad) + (bad % 2)
inst._data[bad] = inst._data[closest_idx]

inst.info['bads'] = []

return inst
21 changes: 21 additions & 0 deletions mne/channels/tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
import numpy as np
from numpy.testing import (assert_allclose, assert_array_equal)
import pytest
from itertools import compress

from mne import io, pick_types, pick_channels, read_events, Epochs
from mne.channels.interpolation import _make_interpolation_matrix
from mne.datasets import testing
from mne.utils import run_tests_if_main
from mne.preprocessing.nirs import optical_density, scalp_coupling_index
from mne.datasets.testing import data_path
from mne.io import read_raw_nirx

base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data')
raw_fname = op.join(base_dir, 'test_raw.fif')
Expand Down Expand Up @@ -219,4 +223,21 @@ def test_interpolation_ctf_comp():
assert raw.info['bads'] == []


@testing.requires_testing_data
def test_interpolation_nirs():
"""Test interpolating bad nirs channels."""
fname = op.join(data_path(download=False),
'NIRx', 'nirx_15_2_recording_w_overlap')
raw_intensity = read_raw_nirx(fname, preload=False)
raw_od = optical_density(raw_intensity)
sci = scalp_coupling_index(raw_od)
raw_od.info['bads'] = list(compress(raw_od.ch_names, sci < 0.5))
bad_0 = np.where([name == raw_od.info['bads'][0] for
name in raw_od.ch_names])[0][0]
bad_0_std_pre_interp = np.std(raw_od._data[bad_0])
raw_od.interpolate_bads()
assert raw_od.info['bads'] == []
assert bad_0_std_pre_interp > np.std(raw_od._data[bad_0])


run_tests_if_main()
4 changes: 2 additions & 2 deletions mne/preprocessing/nirs/nirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def short_channels(info, threshold=0.01):

def _channel_frequencies(raw):
"""Return the light frequency for each channel."""
picks = _picks_to_idx(raw.info, 'fnirs', exclude=[])
picks = _picks_to_idx(raw.info, 'fnirs', exclude=[], allow_empty=True)
freqs = np.empty(picks.size, int)
for ii in picks:
freqs[ii] = raw.info['chs'][ii]['loc'][9]
Expand All @@ -69,7 +69,7 @@ def _check_channels_ordered(raw, freqs):
"""Check channels followed expected fNIRS format."""
# Every second channel should be same SD pair
# and have the specified light frequencies.
picks = _picks_to_idx(raw.info, 'fnirs', exclude=[])
picks = _picks_to_idx(raw.info, 'fnirs', exclude=[], allow_empty=True)
if len(picks) % 2 != 0:
raise ValueError(
'NIRS channels not ordered correctly. An even number of NIRS '
Expand Down