Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Fix bug with CSP computation and Maxwell filter #890

Merged
merged 11 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/examples/gen_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict:
)
if dataset_name in all_demonstrated:
logger.warning(
f"Duplicate dataset name {test_dataset_name} -> {dataset_name}, " "skipping"
f"Duplicate dataset name {test_dataset_name} -> {dataset_name}, skipping"
)
continue
del test_dataset_options, test_dataset_name
Expand Down
5 changes: 3 additions & 2 deletions docs/source/v1.8.md.inc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

[//]: # (- Whatever (#000 by @whoever))

[//]: # (### :bug: Bug fixes)
### :bug: Bug fixes

[//]: # (- Whatever (#000 by @whoever))
- Fix handling of Maxwell filtered data was in CSP (#890 by @larsoner)
- Avoid recomputation / cache miss when the same empty room file is matched to multiple subjects (#890 by @larsoner)
larsoner marked this conversation as resolved.
Show resolved Hide resolved

### :medical_symbol: Code health

Expand Down
2 changes: 1 addition & 1 deletion mne_bids_pipeline/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1611,7 +1611,7 @@
Must contain at least two elements. By default, 5 equally-spaced bins are
created across the non-negative time range of the epochs.
All specified time points must be contained in the epochs interval.
If `None`, do not perform **time-frequency** analysis, and only run CSP on
If an empty list, do not perform **time-frequency** analysis, and only run CSP on
larsoner marked this conversation as resolved.
Show resolved Hide resolved
**frequency** data.

???+ example "Example"
Expand Down
2 changes: 1 addition & 1 deletion mne_bids_pipeline/_config_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def _check_misspellings_removals(
if user_name not in valid_names:
# find the closest match
closest_match = difflib.get_close_matches(user_name, valid_names, n=1)
msg = f"Found a variable named {repr(user_name)} in your custom " "config,"
msg = f"Found a variable named {repr(user_name)} in your custom config,"
if closest_match and closest_match[0] not in user_names:
this_msg = (
f"{msg} did you mean {repr(closest_match[0])}? "
Expand Down
2 changes: 1 addition & 1 deletion mne_bids_pipeline/_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def get_mf_ctc_fname(
root=config.bids_root,
).meg_crosstalk_fpath
if mf_ctc_fpath is None:
raise ValueError("Could not find Maxwell Filter cross-talk " "file.")
raise ValueError("Could not find Maxwell Filter cross-talk file.")
else:
mf_ctc_fpath = pathlib.Path(config.mf_ctc_fname).expanduser().absolute()
if not mf_ctc_fpath.exists():
Expand Down
2 changes: 1 addition & 1 deletion mne_bids_pipeline/_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _handle_csp_args(
)
if decoding_csp_times is None:
decoding_csp_times = np.linspace(max(0, epochs_tmin), epochs_tmax, num=6)
if len(decoding_csp_times) < 2:
if len(decoding_csp_times) and len(decoding_csp_times) < 2:
larsoner marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("decoding_csp_times should contain at least 2 values.")
if not np.array_equal(decoding_csp_times, np.sort(decoding_csp_times)):
ValueError("decoding_csp_times should be sorted.")
Expand Down
40 changes: 24 additions & 16 deletions mne_bids_pipeline/_import_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
get_runs,
get_task,
)
from ._io import _empty_room_match_path, _read_json
from ._io import _read_json
from ._logging import gen_log_kwargs, logger
from ._run import _update_for_splits
from .typing import PathLike
Expand Down Expand Up @@ -403,6 +403,7 @@ def import_experimental_data(
_fix_stim_artifact_func(cfg=cfg, raw=raw)

if bids_path_bads_in is not None:
run = "rest" if data_is_rest else run # improve logging
bads = _read_bads_tsv(cfg=cfg, bids_path_bads=bids_path_bads_in)
msg = f"Marking {len(bads)} channel{_pl(bads)} as bad."
logger.info(**gen_log_kwargs(message=msg))
Expand Down Expand Up @@ -585,6 +586,8 @@ def _get_run_path(
add_bads=add_bads,
kind=kind,
allow_missing=allow_missing,
subject=subject,
session=session,
)


Expand Down Expand Up @@ -651,6 +654,8 @@ def _get_noise_path(
add_bads=add_bads,
kind=kind,
allow_missing=True,
subject=subject,
session=session,
)


Expand Down Expand Up @@ -701,6 +706,12 @@ def _get_mf_reference_run_path(
)


def _empty_room_match_path(run_path: BIDSPath, cfg: SimpleNamespace) -> BIDSPath:
return run_path.copy().update(
extension=".json", suffix="emptyroommatch", root=cfg.deriv_root
)


def _path_dict(
*,
cfg: SimpleNamespace,
Expand All @@ -709,6 +720,8 @@ def _path_dict(
kind: Literal["orig", "sss", "filt"],
allow_missing: bool,
key: Optional[str] = None,
subject: str,
session: Optional[str],
) -> dict:
if add_bads is None:
add_bads = kind == "orig" and _do_mf_autobad(cfg=cfg)
Expand All @@ -719,35 +732,30 @@ def _path_dict(
if allow_missing and not in_files[key].fpath.exists():
return dict()
if add_bads:
bads_tsv_fname = _bads_path(cfg=cfg, bids_path_in=bids_path_in)
bads_tsv_fname = _bads_path(
cfg=cfg,
bids_path_in=bids_path_in,
subject=subject,
session=session,
)
if bads_tsv_fname.fpath.is_file() or not allow_missing:
in_files[f"{key}-bads"] = bads_tsv_fname
return in_files


def _auto_scores_path(
*,
cfg: SimpleNamespace,
bids_path_in: BIDSPath,
) -> BIDSPath:
return bids_path_in.copy().update(
suffix="scores",
extension=".json",
root=cfg.deriv_root,
split=None,
check=False,
)


def _bads_path(
*,
cfg: SimpleNamespace,
bids_path_in: BIDSPath,
subject: str,
session: Optional[str],
) -> BIDSPath:
return bids_path_in.copy().update(
suffix="bads",
extension=".tsv",
root=cfg.deriv_root,
subject=subject,
session=session,
split=None,
check=False,
)
Expand Down
9 changes: 0 additions & 9 deletions mne_bids_pipeline/_io.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
"""I/O helpers."""

from types import SimpleNamespace

import json_tricks
from mne_bids import BIDSPath

from .typing import PathLike

Expand All @@ -16,9 +13,3 @@ def _write_json(fname: PathLike, data: dict) -> None:
def _read_json(fname: PathLike) -> dict:
with open(fname, encoding="utf-8") as f:
return json_tricks.load(f)


def _empty_room_match_path(run_path: BIDSPath, cfg: SimpleNamespace) -> BIDSPath:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this to _import_data.py which is a better place for it. _io.py now just has to do with generic file I/O (currently just json) rather than our chose (BIDS)Path structures.

return run_path.copy().update(
extension=".json", suffix="emptyroommatch", root=cfg.deriv_root
)
4 changes: 3 additions & 1 deletion mne_bids_pipeline/_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def add_csp_grand_average(
cond_1: str,
cond_2: str,
fname_csp_freq_results: BIDSPath,
fname_csp_cluster_results: pd.DataFrame,
fname_csp_cluster_results: Optional[pd.DataFrame],
):
"""Add CSP decoding results to the grand average report."""
import matplotlib.pyplot as plt # nested import to help joblib
Expand Down Expand Up @@ -684,6 +684,8 @@ def add_csp_grand_average(
)

# Now, plot decoding scores across time-frequency bins.
if fname_csp_cluster_results is None:
return
csp_cluster_results = loadmat(fname_csp_cluster_results)
fig, ax = plt.subplots(
nrows=1, ncols=2, sharex=True, sharey=True, constrained_layout=True
Expand Down
11 changes: 5 additions & 6 deletions mne_bids_pipeline/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ def __mne_bids_pipeline_failsafe_wrapper__(*args, **kwargs):
# Find the limit / step where the error occurred
step_dir = pathlib.Path(__file__).parent / "steps"
tb = traceback.extract_tb(e.__traceback__)
for fi, frame in enumerate(inspect.stack()):
for fi, frame in enumerate(tb):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was not working properly and is now

is_step = pathlib.Path(frame.filename).parent.parent == step_dir
del frame
if is_step:
# omit everything before the "step" dir, which will
# generally be stuff from this file and joblib
tb = tb[-fi:]
tb = tb[fi:]
break
tb = "".join(traceback.format_list(tb))

Expand Down Expand Up @@ -221,17 +221,16 @@ def wrapper(*args, **kwargs):
for key, (fname, this_hash) in out_files_hashes.items():
fname = pathlib.Path(fname)
if not fname.exists():
msg = (
f"Output file missing {str(fname)}, " "will recompute …"
)
msg = f"Output file missing {str(fname)}, will recompute …"
larsoner marked this conversation as resolved.
Show resolved Hide resolved
emoji = "🧩"
bad_out_files = True
break
got_hash = hash_(key, fname, kind="out")[1]
if this_hash != got_hash:
msg = (
f"Output file {self.memory_file_method} mismatch for "
f"{str(fname)}, will recompute …"
f"{str(fname)} ({this_hash} != {got_hash}), will "
larsoner marked this conversation as resolved.
Show resolved Hide resolved
"recompute …"
)
emoji = "🚫"
bad_out_files = True
Expand Down
3 changes: 2 additions & 1 deletion mne_bids_pipeline/steps/init/_02_find_empty_room.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
get_sessions,
get_subjects,
)
from ..._io import _empty_room_match_path, _write_json
from ..._import_data import _empty_room_match_path
from ..._io import _write_json
from ..._logging import gen_log_kwargs, logger
from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs

Expand Down
18 changes: 11 additions & 7 deletions mne_bids_pipeline/steps/preprocessing/_01_data_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
get_subjects,
)
from ..._import_data import (
_auto_scores_path,
_bads_path,
_get_mf_reference_run_path,
_get_run_rest_noise_path,
Expand Down Expand Up @@ -159,7 +158,7 @@ def _find_bads_maxwell(
elif cfg.find_noisy_channels_meg and not cfg.find_flat_channels_meg:
msg = "Finding noisy channels using Maxwell filtering."
else:
msg = "Finding flat channels and noisy channels using " "Maxwell filtering."
msg = "Finding flat channels and noisy channels using Maxwell filtering."
logger.info(**gen_log_kwargs(message=msg))

if run is None and task == "noise":
Expand Down Expand Up @@ -232,18 +231,23 @@ def _find_bads_maxwell(
logger.info(**gen_log_kwargs(message=msg))

if cfg.find_noisy_channels_meg:
out_files["auto_scores"] = _auto_scores_path(
cfg=cfg,
bids_path_in=bids_path_in,
out_files["auto_scores"] = bids_path_in.copy().update(
suffix="scores",
extension=".json",
root=cfg.deriv_root,
split=None,
check=False,
session=session,
subject=subject,
larsoner marked this conversation as resolved.
Show resolved Hide resolved
)
if not out_files["auto_scores"].fpath.parent.exists():
out_files["auto_scores"].fpath.parent.mkdir(parents=True)
larsoner marked this conversation as resolved.
Show resolved Hide resolved
_write_json(out_files["auto_scores"], auto_scores)

# Write the bad channels to disk.
out_files["bads_tsv"] = _bads_path(
cfg=cfg,
bids_path_in=bids_path_in,
subject=subject,
session=session,
)
bads_for_tsv = []
reasons = []
Expand Down
54 changes: 29 additions & 25 deletions mne_bids_pipeline/steps/sensor/_05_decoding_csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,21 +79,15 @@ def prepare_epochs_and_y(
*, epochs: mne.BaseEpochs, contrast: tuple[str, str], cfg, fmin: float, fmax: float
) -> tuple[mne.BaseEpochs, np.ndarray]:
"""Band-pass between, sub-select the desired epochs, and prepare y."""
epochs_filt = epochs.copy().pick(["meg", "eeg"])

# We only take mag to speed up computation
# because the information is redundant between grad and mag
if cfg.datatype == "meg" and cfg.use_maxwell_filter:
epochs_filt.pick("mag")
Comment on lines -82 to -87
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to pick at all here because it's done elsewhere, and we use PCA to reduce dimensionality (so the redundancy between mag/grad doesn't matter anyway)


# filtering out the conditions we are not interested in, to ensure here we
# have a valid partition between the condition of the contrast.
#

# XXX Hack for handling epochs selection via metadata
# This also makes a copy
if contrast[0].startswith("event_name.isin"):
epochs_filt = epochs_filt[f"{contrast[0]} or {contrast[1]}"]
epochs_filt = epochs[f"{contrast[0]} or {contrast[1]}"]
else:
epochs_filt = epochs_filt[contrast]
epochs_filt = epochs[contrast]

# Filtering is costly, so do it last, after the selection of the channels
# and epochs. We know that often the filter will be longer than the signal,
Expand Down Expand Up @@ -266,8 +260,8 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non
# ranges to avoid leaking of information.
time_bins = np.array(cfg.decoding_csp_times)
if time_bins.ndim == 1:
time_bins = np.array(list(zip(time_bins[:-1], time_bins[1:])))
assert time_bins.ndim == 2
time_bins = np.c_[time_bins[:-1], time_bins[1:]]
assert time_bins.ndim == 2 and time_bins.shape[1] == 2, time_bins.shape

tf_decoding_table_rows = []

Expand All @@ -292,13 +286,18 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non
}
tf_decoding_table_rows.append(row)

tf_decoding_table = pd.concat(
[pd.DataFrame.from_dict(row) for row in tf_decoding_table_rows],
ignore_index=True,
)
if len(tf_decoding_table_rows):
tf_decoding_table = pd.concat(
[pd.DataFrame.from_dict(row) for row in tf_decoding_table_rows],
ignore_index=True,
)
else:
tf_decoding_table = pd.DataFrame()
del tf_decoding_table_rows

for idx, row in tf_decoding_table.iterrows():
if len(row) == 0:
break # no data
tmin = row["t_min"]
tmax = row["t_max"]
fmin = row["f_min"]
Expand Down Expand Up @@ -340,8 +339,10 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non
)
with pd.ExcelWriter(fname_results) as w:
freq_decoding_table.to_excel(w, sheet_name="CSP Frequency", index=False)
tf_decoding_table.to_excel(w, sheet_name="CSP Time-Frequency", index=False)
if not tf_decoding_table.empty:
tf_decoding_table.to_excel(w, sheet_name="CSP Time-Frequency", index=False)
out_files = {"csp-excel": fname_results}
del freq_decoding_table

# Report
with _open_report(
Expand Down Expand Up @@ -381,14 +382,15 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non
csp_freq_results["scores"] = csp_freq_results["scores"].apply(
lambda x: np.array(x[1:-1].split(), float)
)
csp_tf_results = pd.read_excel(
fname_decoding, sheet_name="CSP Time-Frequency"
)
csp_tf_results["scores"] = csp_tf_results["scores"].apply(
lambda x: np.array(x[1:-1].split(), float)
)
all_csp_tf_results[contrast] = csp_tf_results
del csp_tf_results
if not tf_decoding_table.empty:
csp_tf_results = pd.read_excel(
fname_decoding, sheet_name="CSP Time-Frequency"
)
csp_tf_results["scores"] = csp_tf_results["scores"].apply(
lambda x: np.array(x[1:-1].split(), float)
)
all_csp_tf_results[contrast] = csp_tf_results
del csp_tf_results

all_decoding_scores = list()
contrast_names = list()
Expand Down Expand Up @@ -497,6 +499,8 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non
tags=tags,
replace=True,
)
plt.close(fig)
del fig, title
Comment on lines +488 to +489
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid warning about unclosed figures


assert len(in_files) == 0, in_files.keys()
return _prep_out_files(exec_params=exec_params, out_files=out_files)
Expand Down
Loading
Loading