Skip to content

Commit

Permalink
BUG: Fix bug with CSP computation and Maxwell filter (#890)
Browse files Browse the repository at this point in the history
Co-authored-by: Richard Höchenberger <richard.hoechenberger@gmail.com>
  • Loading branch information
larsoner and hoechenberger authored Mar 15, 2024
1 parent 1c2a081 commit 45e4c13
Show file tree
Hide file tree
Showing 19 changed files with 209 additions and 186 deletions.
2 changes: 1 addition & 1 deletion docs/source/examples/gen_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict:
)
if dataset_name in all_demonstrated:
logger.warning(
f"Duplicate dataset name {test_dataset_name} -> {dataset_name}, " "skipping"
f"Duplicate dataset name {test_dataset_name} -> {dataset_name}, skipping"
)
continue
del test_dataset_options, test_dataset_name
Expand Down
9 changes: 5 additions & 4 deletions docs/source/v1.8.md.inc
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
## v1.8.0 (unreleased)

[//]: # (### :new: New features & enhancements)
### :new: New features & enhancements

[//]: # (- Whatever (#000 by @whoever))
- Disabling CSP time-frequency mode is now supported by passing an empty list to [`decoding_csp_times`][mne_bids_pipeline._config.decoding_csp_times] (#890 by @whoever)

[//]: # (### :warning: Behavior changes)

Expand All @@ -14,9 +14,10 @@

[//]: # (- Whatever (#000 by @whoever))

[//]: # (### :bug: Bug fixes)
### :bug: Bug fixes

[//]: # (- Whatever (#000 by @whoever))
- Fix handling of Maxwell-filtered data in CSP (#890 by @larsoner)
- Avoid recomputation / cache miss when the same empty-room file is matched to multiple subjects (#890 by @larsoner)

### :medical_symbol: Code health

Expand Down
2 changes: 1 addition & 1 deletion mne_bids_pipeline/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1611,7 +1611,7 @@
Must contain at least two elements. By default, 5 equally-spaced bins are
created across the non-negative time range of the epochs.
All specified time points must be contained in the epochs interval.
If `None`, do not perform **time-frequency** analysis, and only run CSP on
If an empty list, do not perform **time-frequency** analysis, and only run CSP on
**frequency** data.
???+ example "Example"
Expand Down
2 changes: 1 addition & 1 deletion mne_bids_pipeline/_config_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def _check_misspellings_removals(
if user_name not in valid_names:
# find the closest match
closest_match = difflib.get_close_matches(user_name, valid_names, n=1)
msg = f"Found a variable named {repr(user_name)} in your custom " "config,"
msg = f"Found a variable named {repr(user_name)} in your custom config,"
if closest_match and closest_match[0] not in user_names:
this_msg = (
f"{msg} did you mean {repr(closest_match[0])}? "
Expand Down
2 changes: 1 addition & 1 deletion mne_bids_pipeline/_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def get_mf_ctc_fname(
root=config.bids_root,
).meg_crosstalk_fpath
if mf_ctc_fpath is None:
raise ValueError("Could not find Maxwell Filter cross-talk " "file.")
raise ValueError("Could not find Maxwell Filter cross-talk file.")
else:
mf_ctc_fpath = pathlib.Path(config.mf_ctc_fname).expanduser().absolute()
if not mf_ctc_fpath.exists():
Expand Down
15 changes: 12 additions & 3 deletions mne_bids_pipeline/_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,19 @@ def _handle_csp_args(
)
if decoding_csp_times is None:
decoding_csp_times = np.linspace(max(0, epochs_tmin), epochs_tmax, num=6)
if len(decoding_csp_times) < 2:
raise ValueError("decoding_csp_times should contain at least 2 values.")
else:
decoding_csp_times = np.array(decoding_csp_times, float)
if decoding_csp_times.ndim != 1 or len(decoding_csp_times) == 1:
raise ValueError(
"decoding_csp_times should be 1 dimensional and contain at least 2 values "
"to define time intervals, or be empty to disable time-frequency mode, got "
f"shape {decoding_csp_times.shape}"
)
if not np.array_equal(decoding_csp_times, np.sort(decoding_csp_times)):
ValueError("decoding_csp_times should be sorted.")
time_bins = np.c_[decoding_csp_times[:-1], decoding_csp_times[1:]]
assert time_bins.ndim == 2 and time_bins.shape[1] == 2, time_bins.shape

if decoding_metric != "roc_auc":
raise ValueError(
f'CSP decoding currently only supports the "roc_auc" '
Expand Down Expand Up @@ -76,7 +85,7 @@ def _handle_csp_args(

freq_bins = list(zip(edges[:-1], edges[1:]))
freq_name_to_bins_map[freq_range_name] = freq_bins
return freq_name_to_bins_map
return freq_name_to_bins_map, time_bins


def _decoding_preproc_steps(
Expand Down
40 changes: 24 additions & 16 deletions mne_bids_pipeline/_import_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
get_runs,
get_task,
)
from ._io import _empty_room_match_path, _read_json
from ._io import _read_json
from ._logging import gen_log_kwargs, logger
from ._run import _update_for_splits
from .typing import PathLike
Expand Down Expand Up @@ -403,6 +403,7 @@ def import_experimental_data(
_fix_stim_artifact_func(cfg=cfg, raw=raw)

if bids_path_bads_in is not None:
run = "rest" if data_is_rest else run # improve logging
bads = _read_bads_tsv(cfg=cfg, bids_path_bads=bids_path_bads_in)
msg = f"Marking {len(bads)} channel{_pl(bads)} as bad."
logger.info(**gen_log_kwargs(message=msg))
Expand Down Expand Up @@ -585,6 +586,8 @@ def _get_run_path(
add_bads=add_bads,
kind=kind,
allow_missing=allow_missing,
subject=subject,
session=session,
)


Expand Down Expand Up @@ -651,6 +654,8 @@ def _get_noise_path(
add_bads=add_bads,
kind=kind,
allow_missing=True,
subject=subject,
session=session,
)


Expand Down Expand Up @@ -701,6 +706,12 @@ def _get_mf_reference_run_path(
)


def _empty_room_match_path(run_path: BIDSPath, cfg: SimpleNamespace) -> BIDSPath:
return run_path.copy().update(
extension=".json", suffix="emptyroommatch", root=cfg.deriv_root
)


def _path_dict(
*,
cfg: SimpleNamespace,
Expand All @@ -709,6 +720,8 @@ def _path_dict(
kind: Literal["orig", "sss", "filt"],
allow_missing: bool,
key: Optional[str] = None,
subject: str,
session: Optional[str],
) -> dict:
if add_bads is None:
add_bads = kind == "orig" and _do_mf_autobad(cfg=cfg)
Expand All @@ -719,35 +732,30 @@ def _path_dict(
if allow_missing and not in_files[key].fpath.exists():
return dict()
if add_bads:
bads_tsv_fname = _bads_path(cfg=cfg, bids_path_in=bids_path_in)
bads_tsv_fname = _bads_path(
cfg=cfg,
bids_path_in=bids_path_in,
subject=subject,
session=session,
)
if bads_tsv_fname.fpath.is_file() or not allow_missing:
in_files[f"{key}-bads"] = bads_tsv_fname
return in_files


def _auto_scores_path(
*,
cfg: SimpleNamespace,
bids_path_in: BIDSPath,
) -> BIDSPath:
return bids_path_in.copy().update(
suffix="scores",
extension=".json",
root=cfg.deriv_root,
split=None,
check=False,
)


def _bads_path(
*,
cfg: SimpleNamespace,
bids_path_in: BIDSPath,
subject: str,
session: Optional[str],
) -> BIDSPath:
return bids_path_in.copy().update(
suffix="bads",
extension=".tsv",
root=cfg.deriv_root,
subject=subject,
session=session,
split=None,
check=False,
)
Expand Down
9 changes: 0 additions & 9 deletions mne_bids_pipeline/_io.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
"""I/O helpers."""

from types import SimpleNamespace

import json_tricks
from mne_bids import BIDSPath

from .typing import PathLike

Expand All @@ -16,9 +13,3 @@ def _write_json(fname: PathLike, data: dict) -> None:
def _read_json(fname: PathLike) -> dict:
with open(fname, encoding="utf-8") as f:
return json_tricks.load(f)


def _empty_room_match_path(run_path: BIDSPath, cfg: SimpleNamespace) -> BIDSPath:
return run_path.copy().update(
extension=".json", suffix="emptyroommatch", root=cfg.deriv_root
)
6 changes: 4 additions & 2 deletions mne_bids_pipeline/_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,14 +594,14 @@ def add_csp_grand_average(
cond_1: str,
cond_2: str,
fname_csp_freq_results: BIDSPath,
fname_csp_cluster_results: pd.DataFrame,
fname_csp_cluster_results: Optional[pd.DataFrame],
):
"""Add CSP decoding results to the grand average report."""
import matplotlib.pyplot as plt # nested import to help joblib

# First, plot decoding scores across frequency bins (entire epochs).
section = "Decoding: CSP"
freq_name_to_bins_map = _handle_csp_args(
freq_name_to_bins_map, _ = _handle_csp_args(
cfg.decoding_csp_times,
cfg.decoding_csp_freqs,
cfg.decoding_metric,
Expand Down Expand Up @@ -684,6 +684,8 @@ def add_csp_grand_average(
)

# Now, plot decoding scores across time-frequency bins.
if fname_csp_cluster_results is None:
return
csp_cluster_results = loadmat(fname_csp_cluster_results)
fig, ax = plt.subplots(
nrows=1, ncols=2, sharex=True, sharey=True, constrained_layout=True
Expand Down
11 changes: 5 additions & 6 deletions mne_bids_pipeline/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ def __mne_bids_pipeline_failsafe_wrapper__(*args, **kwargs):
# Find the limit / step where the error occurred
step_dir = pathlib.Path(__file__).parent / "steps"
tb = traceback.extract_tb(e.__traceback__)
for fi, frame in enumerate(inspect.stack()):
for fi, frame in enumerate(tb):
is_step = pathlib.Path(frame.filename).parent.parent == step_dir
del frame
if is_step:
# omit everything before the "step" dir, which will
# generally be stuff from this file and joblib
tb = tb[-fi:]
tb = tb[fi:]
break
tb = "".join(traceback.format_list(tb))

Expand Down Expand Up @@ -221,17 +221,16 @@ def wrapper(*args, **kwargs):
for key, (fname, this_hash) in out_files_hashes.items():
fname = pathlib.Path(fname)
if not fname.exists():
msg = (
f"Output file missing {str(fname)}, " "will recompute …"
)
msg = f"Output file missing: {fname}, will recompute …"
emoji = "🧩"
bad_out_files = True
break
got_hash = hash_(key, fname, kind="out")[1]
if this_hash != got_hash:
msg = (
f"Output file {self.memory_file_method} mismatch for "
f"{str(fname)}, will recompute …"
f"{fname} ({this_hash} != {got_hash}), will "
"recompute …"
)
emoji = "🚫"
bad_out_files = True
Expand Down
3 changes: 2 additions & 1 deletion mne_bids_pipeline/steps/init/_02_find_empty_room.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
get_sessions,
get_subjects,
)
from ..._io import _empty_room_match_path, _write_json
from ..._import_data import _empty_room_match_path
from ..._io import _write_json
from ..._logging import gen_log_kwargs, logger
from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs

Expand Down
18 changes: 11 additions & 7 deletions mne_bids_pipeline/steps/preprocessing/_01_data_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
get_subjects,
)
from ..._import_data import (
_auto_scores_path,
_bads_path,
_get_mf_reference_run_path,
_get_run_rest_noise_path,
Expand Down Expand Up @@ -159,7 +158,7 @@ def _find_bads_maxwell(
elif cfg.find_noisy_channels_meg and not cfg.find_flat_channels_meg:
msg = "Finding noisy channels using Maxwell filtering."
else:
msg = "Finding flat channels and noisy channels using " "Maxwell filtering."
msg = "Finding flat channels and noisy channels using Maxwell filtering."
logger.info(**gen_log_kwargs(message=msg))

if run is None and task == "noise":
Expand Down Expand Up @@ -232,18 +231,23 @@ def _find_bads_maxwell(
logger.info(**gen_log_kwargs(message=msg))

if cfg.find_noisy_channels_meg:
out_files["auto_scores"] = _auto_scores_path(
cfg=cfg,
bids_path_in=bids_path_in,
out_files["auto_scores"] = bids_path_in.copy().update(
suffix="scores",
extension=".json",
root=cfg.deriv_root,
split=None,
check=False,
session=session,
subject=subject,
)
if not out_files["auto_scores"].fpath.parent.exists():
out_files["auto_scores"].fpath.parent.mkdir(parents=True)
_write_json(out_files["auto_scores"], auto_scores)

# Write the bad channels to disk.
out_files["bads_tsv"] = _bads_path(
cfg=cfg,
bids_path_in=bids_path_in,
subject=subject,
session=session,
)
bads_for_tsv = []
reasons = []
Expand Down
Loading

0 comments on commit 45e4c13

Please sign in to comment.