Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Don't require specific order for fNIRS #10642

Merged
merged 8 commits into from
May 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mne/channels/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,11 @@ def _interpolate_bads_nirs(inst, method='nearest', exclude=(), verbose=None):
from scipy.spatial.distance import pdist, squareform
from mne.preprocessing.nirs import _validate_nirs_info

# Returns pick of all nirs and ensures channels are correctly ordered
picks_nirs = _validate_nirs_info(inst.info)
if len(picks_nirs) == 0:
if len(pick_types(inst.info, fnirs=True, exclude=())) == 0:
return

# Returns pick of all nirs and ensures channels are correctly ordered
picks_nirs = _validate_nirs_info(inst.info)
nirs_ch_names = [inst.info['ch_names'][p] for p in picks_nirs]
nirs_ch_names = [ch for ch in nirs_ch_names if ch not in exclude]
bads_nirs = [ch for ch in inst.info['bads'] if ch in nirs_ch_names]
Expand Down
25 changes: 24 additions & 1 deletion mne/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from mne.coreg import create_default_subject
from mne.datasets import testing
from mne.fixes import has_numba, _compare_version
from mne.io import read_raw_fif, read_raw_ctf
from mne.io import read_raw_fif, read_raw_ctf, read_raw_nirx, read_raw_snirf
from mne.stats import cluster_level
from mne.utils import (_pl, _assert_no_instances, numerics, Bunch,
_check_qt_version, _TempDir)
Expand All @@ -48,6 +48,17 @@
ctf_dir = op.join(test_path, 'CTF')
fname_ctf_continuous = op.join(ctf_dir, 'testdata_ctf.ds')

nirx_path = test_path / 'NIRx'
snirf_path = test_path / 'SNIRF'
nirsport2 = nirx_path / 'nirsport_v2' / 'aurora_recording _w_short_and_acc'
nirsport2_snirf = (
snirf_path / 'NIRx' / 'NIRSport2' / '1.0.3' /
'2021-05-05_001.snirf')
nirsport2_2021_9 = nirx_path / 'nirsport_v2' / 'aurora_2021_9'
nirsport2_20219_snirf = (
snirf_path / 'NIRx' / 'NIRSport2' / '2021.9' /
'2021-10-01_002.snirf')

# data from mne.io.tests.data
base_dir = op.join(op.dirname(__file__), 'io', 'tests', 'data')
fname_raw_io = op.join(base_dir, 'test_raw.fif')
Expand Down Expand Up @@ -925,3 +936,15 @@ def run(nbexec=nbexec, code=code):

item.runtest = run
return


@pytest.mark.filterwarnings('ignore:.*Extraction of measurement.*:')
@pytest.fixture(params=(
[nirsport2, nirsport2_snirf, testing._pytest_param()],
[nirsport2_2021_9, nirsport2_20219_snirf, testing._pytest_param()],
))
def nirx_snirf(request):
"""Return a (raw_nirx, raw_snirf) matched pair."""
pytest.importorskip('h5py')
return (read_raw_nirx(request.param[0], preload=True),
read_raw_snirf(request.param[1], preload=True))
13 changes: 8 additions & 5 deletions mne/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
# Lars Buitinck <L.J.Buitinck@uva.nl>
# License: BSD

import functools
import inspect
from math import log
import os
from pathlib import Path
import warnings

import numpy as np
Expand Down Expand Up @@ -72,10 +70,15 @@ def _median_complex(data, axis):


# helpers to get function arguments
def _get_args(function, varargs=False):
def _get_args(function, varargs=False, *,
exclude=('var_positional', 'var_keyword')):
params = inspect.signature(function).parameters
args = [key for key, param in params.items()
if param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)]
# As of Python 3.10:
# https://docs.python.org/3/library/inspect.html#inspect.Parameter.kind
# POSITIONAL_ONLY, POSITIONAL_OR_KEYWORD, VAR_POSITIONAL, KEYWORD_ONLY,
# VAR_KEYWORD
exclude = set(getattr(inspect.Parameter, ex.upper()) for ex in exclude)
args = [key for key, param in params.items() if param.kind not in exclude]
if varargs:
varargs = [param.name for param in params.values()
if param.kind == param.VAR_POSITIONAL]
Expand Down
20 changes: 0 additions & 20 deletions mne/io/nirx/nirx.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from ..utils import _mult_cal_one
from ..constants import FIFF
from ..meas_info import create_info, _format_dig_points
from ..pick import pick_types
from ...annotations import Annotations
from ..._freesurfer import get_mni_fiducials
from ...transforms import apply_trans, _get_trans
Expand Down Expand Up @@ -459,7 +458,6 @@ def __init__(self, fname, saturated, preload=False, verbose=None):
ch_names.append(list())
annot = Annotations(onset, duration, description, ch_names=ch_names)
self.set_annotations(annot)
self.pick(picks=_nirs_sort_idx(self.info))

def _read_segment_file(self, data, idx, fi, start, stop, cals, mult):
"""Read a segment of data from a file.
Expand Down Expand Up @@ -512,21 +510,3 @@ def _convert_fnirs_to_head(trans, fro, to, src_locs, det_locs, ch_locs):
det_locs = apply_trans(mri_head_t, det_locs)
ch_locs = apply_trans(mri_head_t, ch_locs)
return src_locs, det_locs, ch_locs, mri_head_t


def _nirs_sort_idx(info):
# TODO: Remove any actual reordering that is done and just use this
# function to get picks to operate on in an ordered way. This should be
# done by refactoring mne.preprocessing.nirs.nirs._check_channels_ordered
# and this function to make sure the picks we obtain here are in the
# correct order.
nirs_picks = pick_types(info, fnirs=True, exclude=())
other_picks = np.setdiff1d(np.arange(info['nchan']), nirs_picks)
prefixes = [info['ch_names'][pick].split()[0] for pick in nirs_picks]
nirs_names = [info['ch_names'][pick] for pick in nirs_picks]
nirs_sorted = sorted(nirs_names,
key=lambda name: (prefixes.index(name.split()[0]),
name.split(maxsplit=1)[1]))
nirs_picks = nirs_picks[
[nirs_names.index(name) for name in nirs_sorted]]
return np.concatenate((nirs_picks, other_picks))
27 changes: 6 additions & 21 deletions mne/io/nirx/tests/test_nirx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@

from mne import pick_types
from mne.datasets.testing import data_path, requires_testing_data
from mne.io import read_raw_nirx, read_raw_snirf
from mne.utils import requires_h5py
from mne.io import read_raw_nirx
from mne.io.tests.test_raw import _test_raw_reader
from mne.preprocessing import annotate_nan
from mne.transforms import apply_trans, _get_trans
from mne.preprocessing.nirs import source_detector_distances,\
short_channels
short_channels, _reorder_nirx
from mne.io.constants import FIFF

testing_path = data_path(download=False)
Expand All @@ -46,31 +45,17 @@
testing_path, 'NIRx', 'nirsport_v1', 'nirx_15_3_recording_w_'
'saturation_on_montage_channels')

# NIRSport2 device using Aurora software and matching snirf file
# NIRSport2 device using Aurora software
nirsport2 = op.join(
testing_path, 'NIRx', 'nirsport_v2', 'aurora_recording _w_short_and_acc')
nirsport2_snirf = op.join(
testing_path, 'SNIRF', 'NIRx', 'NIRSport2', '1.0.3',
'2021-05-05_001.snirf')

nirsport2_2021_9 = op.join(
testing_path, 'NIRx', 'nirsport_v2', 'aurora_2021_9')
snirf_nirsport2_20219 = op.join(
testing_path, 'SNIRF', 'NIRx', 'NIRSport2', '2021.9',
'2021-10-01_002.snirf')


@requires_h5py
@requires_testing_data
@pytest.mark.filterwarnings('ignore:.*Extraction of measurement.*:')
@pytest.mark.parametrize('fname_nirx, fname_snirf', (
[nirsport2, nirsport2_snirf],
[nirsport2_2021_9, snirf_nirsport2_20219],
))
def test_nirsport_v2_matches_snirf(fname_nirx, fname_snirf):
def test_nirsport_v2_matches_snirf(nirx_snirf):
"""Test NIRSport2 raw files return same data as snirf."""
raw = read_raw_nirx(fname_nirx, preload=True)
raw_snirf = read_raw_snirf(fname_snirf, preload=True)
raw, raw_snirf = nirx_snirf
_reorder_nirx(raw_snirf)
assert raw.ch_names == raw_snirf.ch_names

assert_allclose(raw._data, raw_snirf._data)
Expand Down
6 changes: 1 addition & 5 deletions mne/io/snirf/_snirf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..constants import FIFF
from .._digitization import _make_dig_points
from ...transforms import _frame_to_str, apply_trans
from ..nirx.nirx import _convert_fnirs_to_head, _nirs_sort_idx
from ..nirx.nirx import _convert_fnirs_to_head
from ..._freesurfer import get_mni_fiducials


Expand Down Expand Up @@ -409,10 +409,6 @@ def natural_keys(text):
annot.append(data[:, 0], 1.0, desc.decode('UTF-8'))
self.set_annotations(annot, emit_warning=False)

# MNE requires channels are paired as alternating wavelengths
if len(_validate_nirs_info(self.info, throw_errors=False)) == 0:
self.pick(picks=_nirs_sort_idx(self.info))

# Validate that the fNIRS info is correctly formatted
_validate_nirs_info(self.info)

Expand Down
46 changes: 22 additions & 24 deletions mne/io/snirf/tests/test_snirf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from mne.io import read_raw_snirf, read_raw_nirx
from mne.io.tests.test_raw import _test_raw_reader
from mne.preprocessing.nirs import (optical_density, beer_lambert_law,
short_channels, source_detector_distances)
short_channels, source_detector_distances,
_reorder_nirx)
from mne.transforms import apply_trans, _get_trans
from mne.io.constants import FIFF

Expand Down Expand Up @@ -93,23 +94,18 @@ def test_snirf_gowerlabs():
assert len(raw.ch_names) == 216
assert_allclose(raw.info['sfreq'], 10.0)
# we don't force them to be sorted according to a naive split
# (but we do force them to be interleaved, which is tested by beer_lambert
# above)
assert raw.ch_names != sorted(raw.ch_names)
# ... and this file does have a nice logical ordering already
# ... but this file does have a nice logical ordering already
print(raw.ch_names)
assert raw.ch_names == sorted(
raw.ch_names, # use a key which is (source int, detector int)
key=lambda name: (int(name.split()[0].split('_')[0][1:]),
int(name.split()[0].split('_')[1][1:])))
prefixes = [name.split()[0] for name in raw.ch_names]
# TODO: This is actually not the order on disk -- we reorder to ravel as
# S-D then freq, but gowerlabs order is freq then S-D. So hopefully soon
# we can change these lines to check that the first half of prefixes
# matches the second half of prefixes, rather than every-other matching the
# other every-other
assert prefixes[::2] == prefixes[1::2]
prefixes = prefixes[::2]
assert prefixes == ['S1_D1', 'S1_D2', 'S1_D3', 'S1_D4', 'S1_D5', 'S1_D6', 'S1_D7', 'S1_D8', 'S1_D9', 'S1_D10', 'S1_D11', 'S1_D12', 'S2_D1', 'S2_D2', 'S2_D3', 'S2_D4', 'S2_D5', 'S2_D6', 'S2_D7', 'S2_D8', 'S2_D9', 'S2_D10', 'S2_D11', 'S2_D12', 'S3_D1', 'S3_D2', 'S3_D3', 'S3_D4', 'S3_D5', 'S3_D6', 'S3_D7', 'S3_D8', 'S3_D9', 'S3_D10', 'S3_D11', 'S3_D12', 'S4_D1', 'S4_D2', 'S4_D3', 'S4_D4', 'S4_D5', 'S4_D6', 'S4_D7', 'S4_D8', 'S4_D9', 'S4_D10', 'S4_D11', 'S4_D12', 'S5_D1', 'S5_D2', 'S5_D3', 'S5_D4', 'S5_D5', 'S5_D6', 'S5_D7', 'S5_D8', 'S5_D9', 'S5_D10', 'S5_D11', 'S5_D12', 'S6_D1', 'S6_D2', 'S6_D3', 'S6_D4', 'S6_D5', 'S6_D6', 'S6_D7', 'S6_D8', 'S6_D9', 'S6_D10', 'S6_D11', 'S6_D12', 'S7_D1', 'S7_D2', 'S7_D3', 'S7_D4', 'S7_D5', 'S7_D6', 'S7_D7', 'S7_D8', 'S7_D9', 'S7_D10', 'S7_D11', 'S7_D12', 'S8_D1', 'S8_D2', 'S8_D3', 'S8_D4', 'S8_D5', 'S8_D6', 'S8_D7', 'S8_D8', 'S8_D9', 'S8_D10', 'S8_D11', 'S8_D12', 'S9_D1', 'S9_D2', 'S9_D3', 'S9_D4', 'S9_D5', 'S9_D6', 'S9_D7', 'S9_D8', 'S9_D9', 'S9_D10', 'S9_D11', 'S9_D12'] # noqa: E501
raw.ch_names,
# use a key which is (src triplet, freq, src, freq, det)
key=lambda name: (
(int(name.split()[0].split('_')[0][1:]) - 1) // 3,
int(name.split()[1]),
int(name.split()[0].split('_')[0][1:]),
int(name.split()[0].split('_')[1][1:])
))


@requires_testing_data
Expand All @@ -122,13 +118,13 @@ def test_snirf_basic():
assert raw.info['sfreq'] == 12.5

# Test channel naming
assert raw.info['ch_names'][:4] == ["S1_D1 760", "S1_D1 850",
"S1_D9 760", "S1_D9 850"]
assert raw.info['ch_names'][24:26] == ["S5_D13 760", "S5_D13 850"]
assert raw.info['ch_names'][:4] == ["S1_D1 760", "S1_D9 760",
"S2_D3 760", "S2_D10 760"]
assert raw.info['ch_names'][24:26] == ['S5_D8 850', 'S5_D13 850']

# Test frequency encoding
assert raw.info['chs'][0]['loc'][9] == 760
assert raw.info['chs'][1]['loc'][9] == 850
assert raw.info['chs'][24]['loc'][9] == 850

# Test source locations
assert_allclose([-8.6765 * 1e-2, 0.0049 * 1e-2, -2.6167 * 1e-2],
Expand Down Expand Up @@ -159,6 +155,7 @@ def test_snirf_basic():
def test_snirf_against_nirx():
"""Test against file snirf was created from."""
raw = read_raw_snirf(sfnirs_homer_103_wShort, preload=True)
_reorder_nirx(raw)
raw_orig = read_raw_nirx(sfnirs_homer_103_wShort_original, preload=True)

# Check annotations are the same
Expand Down Expand Up @@ -225,13 +222,13 @@ def test_snirf_nirsport2():
assert_almost_equal(raw.info['sfreq'], 7.6, decimal=1)

# Test channel naming
assert raw.info['ch_names'][:4] == ['S1_D1 760', 'S1_D1 850',
'S1_D3 760', 'S1_D3 850']
assert raw.info['ch_names'][24:26] == ['S6_D4 760', 'S6_D4 850']
assert raw.info['ch_names'][:4] == ['S1_D1 760', 'S1_D3 760',
'S1_D9 760', 'S1_D16 760']
assert raw.info['ch_names'][24:26] == ['S8_D15 760', 'S8_D20 760']

# Test frequency encoding
assert raw.info['chs'][0]['loc'][9] == 760
assert raw.info['chs'][1]['loc'][9] == 850
assert raw.info['chs'][-1]['loc'][9] == 850

assert sum(short_channels(raw.info)) == 16

Expand All @@ -257,6 +254,7 @@ def test_snirf_nirsport2_w_positions():
"""Test reading SNIRF files with known positions."""
raw = read_raw_snirf(nirx_nirsport2_103_2, preload=True,
optode_frame="mri")
_reorder_nirx(raw)

# Test data import
assert raw._data.shape == (40, 128)
Expand Down
6 changes: 5 additions & 1 deletion mne/preprocessing/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,13 @@ def get_score_funcs():
score_funcs.update({n: _make_xy_sfunc(f)
for n, f in xy_arg_dist_funcs
if _get_args(f) == ['u', 'v']})
# In SciPy 1.9+, pearsonr has (u, v, *, alternative='two-sided'), so we
# should just look at the positional_only and positional_or_keyword entries
exclude = ('var_positional', 'var_keyword', 'keyword_only')
score_funcs.update({n: _make_xy_sfunc(f, ndim_output=True)
for n, f in xy_arg_stats_funcs
if _get_args(f) == ['x', 'y']})
if _get_args(f, exclude=exclude) == ['x', 'y']})
assert 'pearsonr' in score_funcs
return score_funcs


Expand Down
5 changes: 3 additions & 2 deletions mne/preprocessing/nirs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

from .nirs import (short_channels, source_detector_distances,
_check_channels_ordered, _channel_frequencies,
_fnirs_check_bads, _fnirs_spread_bads, _channel_chromophore,
_validate_nirs_info, _fnirs_optode_names, _optode_position)
_fnirs_spread_bads, _channel_chromophore,
_validate_nirs_info, _fnirs_optode_names, _optode_position,
_reorder_nirx)
from ._optical_density import optical_density
from ._beer_lambert_law import beer_lambert_law
from ._scalp_coupling_index import scalp_coupling_index
Expand Down
20 changes: 10 additions & 10 deletions mne/preprocessing/nirs/_beer_lambert_law.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from ...io import BaseRaw
from ...io.constants import FIFF
from ...utils import _validate_type, warn
from ..nirs import source_detector_distances, _channel_frequencies,\
_check_channels_ordered, _channel_chromophore
from ..nirs import source_detector_distances, _validate_nirs_info


def beer_lambert_law(raw, ppf=6.):
Expand All @@ -35,8 +34,10 @@ def beer_lambert_law(raw, ppf=6.):
_validate_type(raw, BaseRaw, 'raw')
_validate_type(ppf, 'numeric', 'ppf')
ppf = float(ppf)
freqs = np.unique(_channel_frequencies(raw.info, nominal=True))
picks = _check_channels_ordered(raw.info, freqs)
picks = _validate_nirs_info(raw.info, fnirs='od', which='Beer-lambert')
# This is the one place we *really* need the actual/accurate frequencies
freqs = np.array(
[raw.info['chs'][pick]['loc'][9] for pick in picks], float)
abs_coef = _load_absorption(freqs)
distances = source_detector_distances(raw.info)
if (distances == 0).any():
Expand All @@ -49,25 +50,24 @@ def beer_lambert_law(raw, ppf=6.):
'likely due to optode locations being stored in a '
' unit other than meters.')
rename = dict()
for ii in picks[::2]:
for ii, jj in zip(picks[::2], picks[1::2]):
EL = abs_coef * distances[ii] * ppf
iEL = linalg.pinv(EL)

raw._data[[ii, ii + 1]] = iEL @ raw._data[[ii, ii + 1]] * 1e-3
raw._data[[ii, jj]] = iEL @ raw._data[[ii, jj]] * 1e-3

# Update channel information
coil_dict = dict(hbo=FIFF.FIFFV_COIL_FNIRS_HBO,
hbr=FIFF.FIFFV_COIL_FNIRS_HBR)
for ki, kind in enumerate(('hbo', 'hbr')):
ch = raw.info['chs'][ii + ki]
for ki, kind in zip((ii, jj), ('hbo', 'hbr')):
ch = raw.info['chs'][ki]
ch.update(coil_type=coil_dict[kind], unit=FIFF.FIFF_UNIT_MOL)
new_name = f'{ch["ch_name"].split(" ")[0]} {kind}'
rename[ch['ch_name']] = new_name
raw.rename_channels(rename)

# Validate the format of data after transformation is valid
chroma = np.unique(_channel_chromophore(raw.info))
_check_channels_ordered(raw.info, chroma)
_validate_nirs_info(raw.info, fnirs='hb')
return raw


Expand Down
Loading