diff --git a/docs/source/examples/gen_examples.py b/docs/source/examples/gen_examples.py index 922c5c569..2ae1a1b5a 100755 --- a/docs/source/examples/gen_examples.py +++ b/docs/source/examples/gen_examples.py @@ -139,7 +139,7 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: ) if dataset_name in all_demonstrated: logger.warning( - f"Duplicate dataset name {test_dataset_name} -> {dataset_name}, " "skipping" + f"Duplicate dataset name {test_dataset_name} -> {dataset_name}, skipping" ) continue del test_dataset_options, test_dataset_name diff --git a/docs/source/v1.8.md.inc b/docs/source/v1.8.md.inc index f9f583fb9..552d65272 100644 --- a/docs/source/v1.8.md.inc +++ b/docs/source/v1.8.md.inc @@ -1,8 +1,8 @@ ## v1.8.0 (unreleased) -[//]: # (### :new: New features & enhancements) +### :new: New features & enhancements -[//]: # (- Whatever (#000 by @whoever)) +- Disabling CSP time-frequency mode is now supported by passing an empty list to [`decoding_csp_times`][mne_bids_pipeline._config.decoding_csp_times] (#890 by @whoever) [//]: # (### :warning: Behavior changes) @@ -14,9 +14,10 @@ [//]: # (- Whatever (#000 by @whoever)) -[//]: # (### :bug: Bug fixes) +### :bug: Bug fixes -[//]: # (- Whatever (#000 by @whoever)) +- Fix handling of Maxwell-filtered data in CSP (#890 by @larsoner) +- Avoid recomputation / cache miss when the same empty-room file is matched to multiple subjects (#890 by @larsoner) ### :medical_symbol: Code health diff --git a/mne_bids_pipeline/_config.py b/mne_bids_pipeline/_config.py index 07c41ece7..0df4096e1 100644 --- a/mne_bids_pipeline/_config.py +++ b/mne_bids_pipeline/_config.py @@ -1611,7 +1611,7 @@ Must contain at least two elements. By default, 5 equally-spaced bins are created across the non-negative time range of the epochs. All specified time points must be contained in the epochs interval. -If `None`, do not perform **time-frequency** analysis, and only run CSP on +If an empty list, do not perform **time-frequency** analysis, and only run CSP on **frequency** data. ???+ example "Example" diff --git a/mne_bids_pipeline/_config_import.py b/mne_bids_pipeline/_config_import.py index 25eea4b36..98286d4ba 100644 --- a/mne_bids_pipeline/_config_import.py +++ b/mne_bids_pipeline/_config_import.py @@ -434,7 +434,7 @@ def _check_misspellings_removals( if user_name not in valid_names: # find the closest match closest_match = difflib.get_close_matches(user_name, valid_names, n=1) - msg = f"Found a variable named {repr(user_name)} in your custom " "config," + msg = f"Found a variable named {repr(user_name)} in your custom config," if closest_match and closest_match[0] not in user_names: this_msg = ( f"{msg} did you mean {repr(closest_match[0])}? " diff --git a/mne_bids_pipeline/_config_utils.py b/mne_bids_pipeline/_config_utils.py index 701cd93a3..d6bcb0ce5 100644 --- a/mne_bids_pipeline/_config_utils.py +++ b/mne_bids_pipeline/_config_utils.py @@ -403,7 +403,7 @@ def get_mf_ctc_fname( root=config.bids_root, ).meg_crosstalk_fpath if mf_ctc_fpath is None: - raise ValueError("Could not find Maxwell Filter cross-talk " "file.") + raise ValueError("Could not find Maxwell Filter cross-talk file.") else: mf_ctc_fpath = pathlib.Path(config.mf_ctc_fname).expanduser().absolute() if not mf_ctc_fpath.exists(): diff --git a/mne_bids_pipeline/_decoding.py b/mne_bids_pipeline/_decoding.py index df10d6f1f..3968fcf3c 100644 --- a/mne_bids_pipeline/_decoding.py +++ b/mne_bids_pipeline/_decoding.py @@ -34,10 +34,19 @@ def _handle_csp_args( ) if decoding_csp_times is None: decoding_csp_times = np.linspace(max(0, epochs_tmin), epochs_tmax, num=6) - if len(decoding_csp_times) < 2: - raise ValueError("decoding_csp_times should contain at least 2 values.") + else: + decoding_csp_times = np.array(decoding_csp_times, float) + if decoding_csp_times.ndim != 1 or len(decoding_csp_times) == 1: + raise ValueError( + "decoding_csp_times should be 1 dimensional and contain at least 2 values " + "to define time intervals, or be empty to disable time-frequency mode, got " + f"shape {decoding_csp_times.shape}" + ) if not np.array_equal(decoding_csp_times, np.sort(decoding_csp_times)): ValueError("decoding_csp_times should be sorted.") + time_bins = np.c_[decoding_csp_times[:-1], decoding_csp_times[1:]] + assert time_bins.ndim == 2 and time_bins.shape[1] == 2, time_bins.shape + if decoding_metric != "roc_auc": raise ValueError( f'CSP decoding currently only supports the "roc_auc" ' @@ -76,7 +85,7 @@ def _handle_csp_args( freq_bins = list(zip(edges[:-1], edges[1:])) freq_name_to_bins_map[freq_range_name] = freq_bins - return freq_name_to_bins_map + return freq_name_to_bins_map, time_bins def _decoding_preproc_steps( diff --git a/mne_bids_pipeline/_import_data.py b/mne_bids_pipeline/_import_data.py index c3c319f44..aaf7b56e3 100644 --- a/mne_bids_pipeline/_import_data.py +++ b/mne_bids_pipeline/_import_data.py @@ -16,7 +16,7 @@ get_runs, get_task, ) -from ._io import _empty_room_match_path, _read_json +from ._io import _read_json from ._logging import gen_log_kwargs, logger from ._run import _update_for_splits from .typing import PathLike @@ -403,6 +403,7 @@ def import_experimental_data( _fix_stim_artifact_func(cfg=cfg, raw=raw) if bids_path_bads_in is not None: + run = "rest" if data_is_rest else run # improve logging bads = _read_bads_tsv(cfg=cfg, bids_path_bads=bids_path_bads_in) msg = f"Marking {len(bads)} channel{_pl(bads)} as bad." logger.info(**gen_log_kwargs(message=msg)) @@ -585,6 +586,8 @@ def _get_run_path( add_bads=add_bads, kind=kind, allow_missing=allow_missing, + subject=subject, + session=session, ) @@ -651,6 +654,8 @@ def _get_noise_path( add_bads=add_bads, kind=kind, allow_missing=True, + subject=subject, + session=session, ) @@ -701,6 +706,12 @@ def _get_mf_reference_run_path( ) +def _empty_room_match_path(run_path: BIDSPath, cfg: SimpleNamespace) -> BIDSPath: + return run_path.copy().update( + extension=".json", suffix="emptyroommatch", root=cfg.deriv_root + ) + + def _path_dict( *, cfg: SimpleNamespace, @@ -709,6 +720,8 @@ def _path_dict( kind: Literal["orig", "sss", "filt"], allow_missing: bool, key: Optional[str] = None, + subject: str, + session: Optional[str], ) -> dict: if add_bads is None: add_bads = kind == "orig" and _do_mf_autobad(cfg=cfg) @@ -719,35 +732,30 @@ def _path_dict( if allow_missing and not in_files[key].fpath.exists(): return dict() if add_bads: - bads_tsv_fname = _bads_path(cfg=cfg, bids_path_in=bids_path_in) + bads_tsv_fname = _bads_path( + cfg=cfg, + bids_path_in=bids_path_in, + subject=subject, + session=session, + ) if bads_tsv_fname.fpath.is_file() or not allow_missing: in_files[f"{key}-bads"] = bads_tsv_fname return in_files -def _auto_scores_path( - *, - cfg: SimpleNamespace, - bids_path_in: BIDSPath, -) -> BIDSPath: - return bids_path_in.copy().update( - suffix="scores", - extension=".json", - root=cfg.deriv_root, - split=None, - check=False, - ) - - def _bads_path( *, cfg: SimpleNamespace, bids_path_in: BIDSPath, + subject: str, + session: Optional[str], ) -> BIDSPath: return bids_path_in.copy().update( suffix="bads", extension=".tsv", root=cfg.deriv_root, + subject=subject, + session=session, split=None, check=False, ) diff --git a/mne_bids_pipeline/_io.py b/mne_bids_pipeline/_io.py index f1a2b0ce3..dc894cb6b 100644 --- a/mne_bids_pipeline/_io.py +++ b/mne_bids_pipeline/_io.py @@ -1,9 +1,6 @@ """I/O helpers.""" -from types import SimpleNamespace - import json_tricks -from mne_bids import BIDSPath from .typing import PathLike @@ -16,9 +13,3 @@ def _write_json(fname: PathLike, data: dict) -> None: def _read_json(fname: PathLike) -> dict: with open(fname, encoding="utf-8") as f: return json_tricks.load(f) - - -def _empty_room_match_path(run_path: BIDSPath, cfg: SimpleNamespace) -> BIDSPath: - return run_path.copy().update( - extension=".json", suffix="emptyroommatch", root=cfg.deriv_root - ) diff --git a/mne_bids_pipeline/_report.py b/mne_bids_pipeline/_report.py index e607de98f..5a4136cef 100644 --- a/mne_bids_pipeline/_report.py +++ b/mne_bids_pipeline/_report.py @@ -594,14 +594,14 @@ def add_csp_grand_average( cond_1: str, cond_2: str, fname_csp_freq_results: BIDSPath, - fname_csp_cluster_results: pd.DataFrame, + fname_csp_cluster_results: Optional[pd.DataFrame], ): """Add CSP decoding results to the grand average report.""" import matplotlib.pyplot as plt # nested import to help joblib # First, plot decoding scores across frequency bins (entire epochs). section = "Decoding: CSP" - freq_name_to_bins_map = _handle_csp_args( + freq_name_to_bins_map, _ = _handle_csp_args( cfg.decoding_csp_times, cfg.decoding_csp_freqs, cfg.decoding_metric, @@ -684,6 +684,8 @@ def add_csp_grand_average( ) # Now, plot decoding scores across time-frequency bins. + if fname_csp_cluster_results is None: + return csp_cluster_results = loadmat(fname_csp_cluster_results) fig, ax = plt.subplots( nrows=1, ncols=2, sharex=True, sharey=True, constrained_layout=True diff --git a/mne_bids_pipeline/_run.py b/mne_bids_pipeline/_run.py index 73e4c6082..ca3bd1faf 100644 --- a/mne_bids_pipeline/_run.py +++ b/mne_bids_pipeline/_run.py @@ -67,13 +67,13 @@ def __mne_bids_pipeline_failsafe_wrapper__(*args, **kwargs): # Find the limit / step where the error occurred step_dir = pathlib.Path(__file__).parent / "steps" tb = traceback.extract_tb(e.__traceback__) - for fi, frame in enumerate(inspect.stack()): + for fi, frame in enumerate(tb): is_step = pathlib.Path(frame.filename).parent.parent == step_dir del frame if is_step: # omit everything before the "step" dir, which will # generally be stuff from this file and joblib - tb = tb[-fi:] + tb = tb[fi:] break tb = "".join(traceback.format_list(tb)) @@ -221,9 +221,7 @@ def wrapper(*args, **kwargs): for key, (fname, this_hash) in out_files_hashes.items(): fname = pathlib.Path(fname) if not fname.exists(): - msg = ( - f"Output file missing {str(fname)}, " "will recompute …" - ) + msg = f"Output file missing: {fname}, will recompute …" emoji = "🧩" bad_out_files = True break @@ -231,7 +229,8 @@ def wrapper(*args, **kwargs): if this_hash != got_hash: msg = ( f"Output file {self.memory_file_method} mismatch for " - f"{str(fname)}, will recompute …" + f"{fname} ({this_hash} != {got_hash}), will " + "recompute …" ) emoji = "🚫" bad_out_files = True diff --git a/mne_bids_pipeline/steps/init/_02_find_empty_room.py b/mne_bids_pipeline/steps/init/_02_find_empty_room.py index fcb0536c5..d56318365 100644 --- a/mne_bids_pipeline/steps/init/_02_find_empty_room.py +++ b/mne_bids_pipeline/steps/init/_02_find_empty_room.py @@ -13,7 +13,8 @@ get_sessions, get_subjects, ) -from ..._io import _empty_room_match_path, _write_json +from ..._import_data import _empty_room_match_path +from ..._io import _write_json from ..._logging import gen_log_kwargs, logger from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs diff --git a/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py b/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py index c12dd6a26..1cbeca387 100644 --- a/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py +++ b/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py @@ -17,7 +17,6 @@ get_subjects, ) from ..._import_data import ( - _auto_scores_path, _bads_path, _get_mf_reference_run_path, _get_run_rest_noise_path, @@ -159,7 +158,7 @@ def _find_bads_maxwell( elif cfg.find_noisy_channels_meg and not cfg.find_flat_channels_meg: msg = "Finding noisy channels using Maxwell filtering." else: - msg = "Finding flat channels and noisy channels using " "Maxwell filtering." + msg = "Finding flat channels and noisy channels using Maxwell filtering." logger.info(**gen_log_kwargs(message=msg)) if run is None and task == "noise": @@ -232,18 +231,23 @@ def _find_bads_maxwell( logger.info(**gen_log_kwargs(message=msg)) if cfg.find_noisy_channels_meg: - out_files["auto_scores"] = _auto_scores_path( - cfg=cfg, - bids_path_in=bids_path_in, + out_files["auto_scores"] = bids_path_in.copy().update( + suffix="scores", + extension=".json", + root=cfg.deriv_root, + split=None, + check=False, + session=session, + subject=subject, ) - if not out_files["auto_scores"].fpath.parent.exists(): - out_files["auto_scores"].fpath.parent.mkdir(parents=True) _write_json(out_files["auto_scores"], auto_scores) # Write the bad channels to disk. out_files["bads_tsv"] = _bads_path( cfg=cfg, bids_path_in=bids_path_in, + subject=subject, + session=session, ) bads_for_tsv = [] reasons = [] diff --git a/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py b/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py index 9b93c0c32..ca7791fd4 100644 --- a/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py +++ b/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py @@ -79,21 +79,15 @@ def prepare_epochs_and_y( *, epochs: mne.BaseEpochs, contrast: tuple[str, str], cfg, fmin: float, fmax: float ) -> tuple[mne.BaseEpochs, np.ndarray]: """Band-pass between, sub-select the desired epochs, and prepare y.""" - epochs_filt = epochs.copy().pick(["meg", "eeg"]) - - # We only take mag to speed up computation - # because the information is redundant between grad and mag - if cfg.datatype == "meg" and cfg.use_maxwell_filter: - epochs_filt.pick("mag") - # filtering out the conditions we are not interested in, to ensure here we # have a valid partition between the condition of the contrast. - # + # XXX Hack for handling epochs selection via metadata + # This also makes a copy if contrast[0].startswith("event_name.isin"): - epochs_filt = epochs_filt[f"{contrast[0]} or {contrast[1]}"] + epochs_filt = epochs[f"{contrast[0]} or {contrast[1]}"] else: - epochs_filt = epochs_filt[contrast] + epochs_filt = epochs[contrast] # Filtering is costly, so do it last, after the selection of the channels # and epochs. We know that often the filter will be longer than the signal, @@ -190,7 +184,7 @@ def one_subject_decoding( ) # Loop over frequencies (all time points lumped together) - freq_name_to_bins_map = _handle_csp_args( + freq_name_to_bins_map, time_bins = _handle_csp_args( cfg.decoding_csp_times, cfg.decoding_csp_freqs, cfg.decoding_metric, @@ -264,11 +258,6 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non # # Note: We don't support varying time ranges for different frequency # ranges to avoid leaking of information. - time_bins = np.array(cfg.decoding_csp_times) - if time_bins.ndim == 1: - time_bins = np.array(list(zip(time_bins[:-1], time_bins[1:]))) - assert time_bins.ndim == 2 - tf_decoding_table_rows = [] for freq_range_name, freq_bins in freq_name_to_bins_map.items(): @@ -292,13 +281,18 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non } tf_decoding_table_rows.append(row) - tf_decoding_table = pd.concat( - [pd.DataFrame.from_dict(row) for row in tf_decoding_table_rows], - ignore_index=True, - ) + if len(tf_decoding_table_rows): + tf_decoding_table = pd.concat( + [pd.DataFrame.from_dict(row) for row in tf_decoding_table_rows], + ignore_index=True, + ) + else: + tf_decoding_table = pd.DataFrame() del tf_decoding_table_rows for idx, row in tf_decoding_table.iterrows(): + if len(row) == 0: + break # no data tmin = row["t_min"] tmax = row["t_max"] fmin = row["f_min"] @@ -340,8 +334,10 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non ) with pd.ExcelWriter(fname_results) as w: freq_decoding_table.to_excel(w, sheet_name="CSP Frequency", index=False) - tf_decoding_table.to_excel(w, sheet_name="CSP Time-Frequency", index=False) + if not tf_decoding_table.empty: + tf_decoding_table.to_excel(w, sheet_name="CSP Time-Frequency", index=False) out_files = {"csp-excel": fname_results} + del freq_decoding_table # Report with _open_report( @@ -350,15 +346,6 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non msg = "Adding CSP decoding results to the report." logger.info(**gen_log_kwargs(message=msg)) section = "Decoding: CSP" - freq_name_to_bins_map = _handle_csp_args( - cfg.decoding_csp_times, - cfg.decoding_csp_freqs, - cfg.decoding_metric, - epochs_tmin=cfg.epochs_tmin, - epochs_tmax=cfg.epochs_tmax, - time_frequency_freq_min=cfg.time_frequency_freq_min, - time_frequency_freq_max=cfg.time_frequency_freq_max, - ) all_csp_tf_results = dict() for contrast in cfg.decoding_contrasts: cond_1, cond_2 = contrast @@ -381,14 +368,15 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non csp_freq_results["scores"] = csp_freq_results["scores"].apply( lambda x: np.array(x[1:-1].split(), float) ) - csp_tf_results = pd.read_excel( - fname_decoding, sheet_name="CSP Time-Frequency" - ) - csp_tf_results["scores"] = csp_tf_results["scores"].apply( - lambda x: np.array(x[1:-1].split(), float) - ) - all_csp_tf_results[contrast] = csp_tf_results - del csp_tf_results + if not tf_decoding_table.empty: + csp_tf_results = pd.read_excel( + fname_decoding, sheet_name="CSP Time-Frequency" + ) + csp_tf_results["scores"] = csp_tf_results["scores"].apply( + lambda x: np.array(x[1:-1].split(), float) + ) + all_csp_tf_results[contrast] = csp_tf_results + del csp_tf_results all_decoding_scores = list() contrast_names = list() @@ -497,6 +485,8 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non tags=tags, replace=True, ) + plt.close(fig) + del fig, title assert len(in_files) == 0, in_files.keys() return _prep_out_files(exec_params=exec_params, out_files=out_files) diff --git a/mne_bids_pipeline/steps/sensor/_06_make_cov.py b/mne_bids_pipeline/steps/sensor/_06_make_cov.py index 075abe472..e3c8cdc9e 100644 --- a/mne_bids_pipeline/steps/sensor/_06_make_cov.py +++ b/mne_bids_pipeline/steps/sensor/_06_make_cov.py @@ -184,7 +184,7 @@ def retrieve_custom_cov( check=False, ) - msg = "Retrieving noise covariance matrix from custom user-supplied " "function" + msg = "Retrieving noise covariance matrix from custom user-supplied function" logger.info(**gen_log_kwargs(message=msg)) msg = f'Output: {out_files["cov"].basename}' logger.info(**gen_log_kwargs(message=msg)) diff --git a/mne_bids_pipeline/steps/sensor/_99_group_average.py b/mne_bids_pipeline/steps/sensor/_99_group_average.py index e84877683..b3747c147 100644 --- a/mne_bids_pipeline/steps/sensor/_99_group_average.py +++ b/mne_bids_pipeline/steps/sensor/_99_group_average.py @@ -722,18 +722,20 @@ def average_csp_decoding( all_decoding_data_time_freq = [] for key in list(in_files): fname_xlsx = in_files.pop(key) - decoding_data_freq = pd.read_excel( - fname_xlsx, - sheet_name="CSP Frequency", - dtype={"subject": str}, # don't drop trailing zeros - ) - decoding_data_time_freq = pd.read_excel( - fname_xlsx, - sheet_name="CSP Time-Frequency", - dtype={"subject": str}, # don't drop trailing zeros - ) - all_decoding_data_freq.append(decoding_data_freq) - all_decoding_data_time_freq.append(decoding_data_time_freq) + with pd.ExcelFile(fname_xlsx) as xf: + decoding_data_freq = pd.read_excel( + xf, + sheet_name="CSP Frequency", + dtype={"subject": str}, # don't drop trailing zeros + ) + all_decoding_data_freq.append(decoding_data_freq) + if "CSP Time-Frequency" in xf.sheet_names: + decoding_data_time_freq = pd.read_excel( + xf, + sheet_name="CSP Time-Frequency", + dtype={"subject": str}, # don't drop trailing zeros + ) + all_decoding_data_time_freq.append(decoding_data_time_freq) del fname_xlsx # Now calculate descriptes and bootstrap CIs. @@ -743,12 +745,15 @@ def average_csp_decoding( session=session, data=all_decoding_data_freq, ) - grand_average_time_freq = _average_csp_time_freq( - cfg=cfg, - subject=subject, - session=session, - data=all_decoding_data_time_freq, - ) + if len(all_decoding_data_time_freq): + grand_average_time_freq = _average_csp_time_freq( + cfg=cfg, + subject=subject, + session=session, + data=all_decoding_data_time_freq, + ) + else: + grand_average_time_freq = None out_files = dict() out_files["freq"] = _decoding_out_fname( @@ -762,17 +767,15 @@ def average_csp_decoding( ) with pd.ExcelWriter(out_files["freq"]) as w: grand_average_freq.to_excel(w, sheet_name="CSP Frequency", index=False) - grand_average_time_freq.to_excel( - w, sheet_name="CSP Time-Frequency", index=False - ) + if grand_average_time_freq is not None: + grand_average_time_freq.to_excel( + w, sheet_name="CSP Time-Frequency", index=False + ) + del grand_average_time_freq # Perform a cluster-based permutation test. subjects = cfg.subjects - time_bins = np.array(cfg.decoding_csp_times) - if time_bins.ndim == 1: - time_bins = np.array(list(zip(time_bins[:-1], time_bins[1:]))) - time_bins = pd.DataFrame(time_bins, columns=["t_min", "t_max"]) - freq_name_to_bins_map = _handle_csp_args( + freq_name_to_bins_map, time_bins = _handle_csp_args( cfg.decoding_csp_times, cfg.decoding_csp_freqs, cfg.decoding_metric, @@ -781,79 +784,84 @@ def average_csp_decoding( time_frequency_freq_min=cfg.time_frequency_freq_min, time_frequency_freq_max=cfg.time_frequency_freq_max, ) - data_for_clustering = {} - for freq_range_name in freq_name_to_bins_map: - a = np.empty( - shape=( - len(subjects), - len(time_bins), - len(freq_name_to_bins_map[freq_range_name]), + if not len(time_bins): + fname_csp_cluster_results = None + else: + time_bins = pd.DataFrame(time_bins, columns=["t_min", "t_max"]) + data_for_clustering = {} + for freq_range_name in freq_name_to_bins_map: + a = np.empty( + shape=( + len(subjects), + len(time_bins), + len(freq_name_to_bins_map[freq_range_name]), + ) ) + a.fill(np.nan) + data_for_clustering[freq_range_name] = a + + g = pd.concat(all_decoding_data_time_freq).groupby( + ["subject", "freq_range_name", "t_min", "t_max"] ) - a.fill(np.nan) - data_for_clustering[freq_range_name] = a - g = pd.concat(all_decoding_data_time_freq).groupby( - ["subject", "freq_range_name", "t_min", "t_max"] - ) + for (subject_, freq_range_name, t_min, t_max), df in g: + scores = df["mean_crossval_score"] + sub_idx = subjects.index(subject_) + time_bin_idx = time_bins.loc[ + (np.isclose(time_bins["t_min"], t_min)) + & (np.isclose(time_bins["t_max"], t_max)), + :, + ].index + assert len(time_bin_idx) == 1 + time_bin_idx = time_bin_idx[0] + data_for_clustering[freq_range_name][sub_idx][time_bin_idx] = scores - for (subject_, freq_range_name, t_min, t_max), df in g: - scores = df["mean_crossval_score"] - sub_idx = subjects.index(subject_) - time_bin_idx = time_bins.loc[ - (np.isclose(time_bins["t_min"], t_min)) - & (np.isclose(time_bins["t_max"], t_max)), - :, - ].index - assert len(time_bin_idx) == 1 - time_bin_idx = time_bin_idx[0] - data_for_clustering[freq_range_name][sub_idx][time_bin_idx] = scores - - if cfg.cluster_forming_t_threshold is None: - import scipy.stats - - cluster_forming_t_threshold = scipy.stats.t.ppf( - 1 - 0.05, - len(cfg.subjects) - 1, # one-sided test - ) - else: - cluster_forming_t_threshold = cfg.cluster_forming_t_threshold + if cfg.cluster_forming_t_threshold is None: + import scipy.stats - cluster_permutation_results = {} - for freq_range_name, X in data_for_clustering.items(): - if len(X) < 2: - t_vals = np.full(X.shape[1:], np.nan) - H0 = all_clusters = cluster_p_vals = np.array([]) - else: - ( - t_vals, - all_clusters, - cluster_p_vals, - H0, - ) = mne.stats.permutation_cluster_1samp_test( # noqa: E501 - X=X - 0.5, # One-sample test against zero. - threshold=cluster_forming_t_threshold, - n_permutations=cfg.cluster_n_permutations, - adjacency=None, # each time & freq bin connected to its neighbors - out_type="mask", - tail=1, # one-sided: significantly above chance level - seed=cfg.random_state, + cluster_forming_t_threshold = scipy.stats.t.ppf( + 1 - 0.05, + len(cfg.subjects) - 1, # one-sided test ) - n_permutations = H0.size - 1 - all_clusters = np.array(all_clusters) # preserve "empty" 0th dimension - cluster_permutation_results[freq_range_name] = { - "mean_crossval_scores": X.mean(axis=0), - "t_vals": t_vals, - "clusters": all_clusters, - "cluster_p_vals": cluster_p_vals, - "cluster_t_threshold": cluster_forming_t_threshold, - "n_permutations": n_permutations, - "time_bin_edges": cfg.decoding_csp_times, - "freq_bin_edges": cfg.decoding_csp_freqs[freq_range_name], - } - - out_files["cluster"] = out_files["freq"].copy().update(extension=".mat") - savemat(file_name=out_files["cluster"], mdict=cluster_permutation_results) + else: + cluster_forming_t_threshold = cfg.cluster_forming_t_threshold + + cluster_permutation_results = {} + for freq_range_name, X in data_for_clustering.items(): + if len(X) < 2: + t_vals = np.full(X.shape[1:], np.nan) + H0 = all_clusters = cluster_p_vals = np.array([]) + else: + ( + t_vals, + all_clusters, + cluster_p_vals, + H0, + ) = mne.stats.permutation_cluster_1samp_test( # noqa: E501 + X=X - 0.5, # One-sample test against zero. + threshold=cluster_forming_t_threshold, + n_permutations=cfg.cluster_n_permutations, + adjacency=None, # each time & freq bin connected to its neighbors + out_type="mask", + tail=1, # one-sided: significantly above chance level + seed=cfg.random_state, + ) + n_permutations = H0.size - 1 + all_clusters = np.array(all_clusters) # preserve "empty" 0th dimension + cluster_permutation_results[freq_range_name] = { + "mean_crossval_scores": X.mean(axis=0), + "t_vals": t_vals, + "clusters": all_clusters, + "cluster_p_vals": cluster_p_vals, + "cluster_t_threshold": cluster_forming_t_threshold, + "n_permutations": n_permutations, + "time_bin_edges": cfg.decoding_csp_times, + "freq_bin_edges": cfg.decoding_csp_freqs[freq_range_name], + } + + out_files["cluster"] = out_files["freq"].copy().update(extension=".mat") + savemat(file_name=out_files["cluster"], mdict=cluster_permutation_results) + fname_csp_cluster_results = out_files["cluster"] assert subject == "average" with _open_report( @@ -867,7 +875,7 @@ def average_csp_decoding( cond_1=cond_1, cond_2=cond_2, fname_csp_freq_results=out_files["freq"], - fname_csp_cluster_results=out_files["cluster"], + fname_csp_cluster_results=fname_csp_cluster_results, ) return _prep_out_files(out_files=out_files, exec_params=exec_params) diff --git a/mne_bids_pipeline/steps/source/_02_make_bem_solution.py b/mne_bids_pipeline/steps/source/_02_make_bem_solution.py index 1320d6dc7..1f2947d01 100644 --- a/mne_bids_pipeline/steps/source/_02_make_bem_solution.py +++ b/mne_bids_pipeline/steps/source/_02_make_bem_solution.py @@ -99,7 +99,7 @@ def main(*, config) -> None: return if config.use_template_mri is not None: - msg = "Skipping, BEM solution computation not needed for " "MRI template …" + msg = "Skipping, BEM solution computation not needed for MRI template …" logger.info(**gen_log_kwargs(message=msg, emoji="skip")) if config.use_template_mri == "fsaverage": # Ensure we have the BEM diff --git a/mne_bids_pipeline/tests/configs/config_ds000117.py b/mne_bids_pipeline/tests/configs/config_ds000117.py index 14fd77499..2e49f1a4e 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000117.py +++ b/mne_bids_pipeline/tests/configs/config_ds000117.py @@ -15,6 +15,7 @@ find_flat_channels_meg = True find_noisy_channels_meg = True use_maxwell_filter = True +process_empty_room = True mf_reference_run = "02" mf_cal_fname = bids_root + "/derivatives/meg_derivatives/sss_cal.dat" diff --git a/mne_bids_pipeline/tests/configs/config_ds003392.py b/mne_bids_pipeline/tests/configs/config_ds003392.py index b8ee82d2e..37c8a46c3 100644 --- a/mne_bids_pipeline/tests/configs/config_ds003392.py +++ b/mne_bids_pipeline/tests/configs/config_ds003392.py @@ -37,6 +37,11 @@ decoding_time_generalization = True decoding_time_generalization_decim = 4 contrasts = [("incoherent", "coherent")] +decoding_csp = True +decoding_csp_times = [] +decoding_csp_freqs = { + "alpha": (8, 12), +} # Noise estimation noise_cov = "emptyroom" diff --git a/mne_bids_pipeline/tests/conftest.py b/mne_bids_pipeline/tests/conftest.py index 2ac1e9403..dbefd583c 100644 --- a/mne_bids_pipeline/tests/conftest.py +++ b/mne_bids_pipeline/tests/conftest.py @@ -62,6 +62,10 @@ def pytest_configure(config): ignore:datetime\.datetime\.utcnow.*:DeprecationWarning # pandas with no good workaround ignore:The behavior of DataFrame concatenation with empty.*:FutureWarning + # joblib on Windows sometimes + ignore:Persisting input arguments took.*:UserWarning + # matplotlib needs to update + ignore:Conversion of an array with ndim.*:DeprecationWarning """ for warning_line in warning_lines.split("\n"): warning_line = warning_line.strip()