diff --git a/.circleci/config.yml b/.circleci/config.yml index 62e687cba..ceb51dfbf 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -297,6 +297,26 @@ jobs: paths: - ~/mne_data/eeg_matchingpennies + cache_MNE-phantom-KIT-data: + <<: *imageconfig + steps: + - attach_workspace: + at: ~/ + - restore_cache: + keys: + - data-cache-MNE-phantom-KIT-data-1 + - bash_env + - gitconfig # email address is needed for datalad + - run: + name: Get MNE-phantom-KIT-data + command: | + $DOWNLOAD_DATA MNE-phantom-KIT-data + - codecov/upload + - save_cache: + key: data-cache-MNE-phantom-KIT-data-1 + paths: + - ~/mne_data/MNE-phantom-KIT-data + cache_ERP_CORE: <<: *imageconfig steps: @@ -765,6 +785,32 @@ jobs: paths: - mne_data/derivatives/mne-bids-pipeline/eeg_matchingpennies/*/*/*.html + test_MNE-phantom-KIT-data: + <<: *imageconfig + steps: + - attach_workspace: + at: ~/ + - bash_env + - restore_cache: + keys: + - data-cache-MNE-phantom-KIT-data-1 + - run: + name: test MNE-phantom-KIT-data + command: $RUN_TESTS MNE-phantom-KIT-data + - codecov/upload + - store_test_results: + path: ./test-results + - store_artifacts: + path: ./test-results + destination: test-results + - store_artifacts: + path: /home/circleci/reports/MNE-phantom-KIT-data + destination: reports/MNE-phantom-KIT-data + - persist_to_workspace: + root: ~/ + paths: + - mne_data/derivatives/mne-bids-pipeline/MNE-phantom-KIT-data/*/*/*.html + test_ERP_CORE_N400: <<: *imageconfig resource_class: large @@ -1191,6 +1237,15 @@ workflows: - cache_eeg_matchingpennies <<: *filter_tags + - cache_MNE-phantom-KIT-data: + requires: + - setup_env + <<: *filter_tags + - test_MNE-phantom-KIT-data: + requires: + - cache_MNE-phantom-KIT-data + <<: *filter_tags + - cache_ERP_CORE: requires: - setup_env @@ -1242,6 +1297,7 @@ workflows: - test_ds003392 - test_ds004229 - test_eeg_matchingpennies + - test_MNE-phantom-KIT-data - test_ERP_CORE_N400 - test_ERP_CORE_ERN - test_ERP_CORE_LRP diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 8763aa9c0..29107ff32 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -90,7 +90,7 @@ nav: - Epoching: settings/preprocessing/epochs.md - Artifact removal: - Stimulation artifact: settings/preprocessing/stim_artifact.md - - SSP & ICA: settings/preprocessing/ssp_ica.md + - SSP, ICA, and artifact regression: settings/preprocessing/ssp_ica.md - Amplitude-based artifact rejection: settings/preprocessing/artifacts.md - Sensor-level analysis: - Condition contrasts: settings/sensor/contrasts.md @@ -116,6 +116,7 @@ nav: - examples/ds000248_no_mri.md - examples/ds003104.md - examples/eeg_matchingpennies.md + - examples/MNE-phantom-KIT-data.md - examples/ds001810.md - examples/ds000117.md - examples/ds003775.md diff --git a/docs/source/examples/gen_examples.py b/docs/source/examples/gen_examples.py index 1f2514274..b55e526d8 100755 --- a/docs/source/examples/gen_examples.py +++ b/docs/source/examples/gen_examples.py @@ -63,6 +63,8 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: key = "Maxwell filter" funcs[key] = funcs[key] or config.use_maxwell_filter funcs["Frequency filter"] = config.l_freq or config.h_freq + key = "Artifact regression" + funcs[key] = funcs[key] or (config.regress_artifact is not None) key = "SSP" funcs[key] = funcs[key] or (config.spatial_filter == "ssp") key = "ICA" @@ -144,6 +146,7 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: logger.warning(f"Dataset {dataset_name} has no HTML report.") continue + assert dataset_options_key in DATASET_OPTIONS, dataset_options_key options = DATASET_OPTIONS[dataset_options_key].copy() # we modify locally report_str = "\n## Generated output\n\n" @@ -200,13 +203,18 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: f"{fname.name} :fontawesome-solid-square-poll-vertical:\n\n" ) - assert sum(key in options for key in ("openneuro", "git", "web", "datalad")) == 1 + assert ( + sum(key in options for key in ("openneuro", "git", "web", "datalad", "mne")) + == 1 + ) if "openneuro" in options: url = f'https://openneuro.org/datasets/{options["openneuro"]}' elif "git" in options: url = options["git"] elif "web" in options: url = options["web"] + elif "mne" in options: + url = f"https://mne.tools/dev/generated/mne.datasets.{options['mne']}.data_path.html" # noqa: E501 else: assert "datalad" in options # guaranteed above url = "" @@ -246,7 +254,9 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: # TODO: For things like ERP_CORE_ERN, decoding_csp are not populated # properly by the root config - config_path = root / "tests" / "configs" / f"config_{dataset_name}.py" + config_path = ( + root / "tests" / "configs" / f"config_{dataset_name.replace('-', '_')}.py" + ) config = config_path.read_text(encoding="utf-8-sig").strip() descr_end_idx = config[2:].find('"""') config_descr = "# " + config[: descr_end_idx + 1].replace('"""', "").strip() diff --git a/docs/source/settings/preprocessing/ssp_ica.md b/docs/source/settings/preprocessing/ssp_ica.md index b132ef4bf..f25110729 100644 --- a/docs/source/settings/preprocessing/ssp_ica.md +++ b/docs/source/settings/preprocessing/ssp_ica.md @@ -11,6 +11,7 @@ tags: ::: mne_bids_pipeline._config options: members: + - regress_artifact - spatial_filter - min_ecg_epochs - min_eog_epochs diff --git a/docs/source/v1.6.md.inc b/docs/source/v1.6.md.inc index cf5596cb1..afb7835c3 100644 --- a/docs/source/v1.6.md.inc +++ b/docs/source/v1.6.md.inc @@ -2,9 +2,9 @@ ## vX.Y.0 (unreleased) -[//]: # (### :new: New features & enhancements) +:new: New features & enhancements -[//]: # (- Whatever (#000 by @whoever)) +- Added [`regress_artifact`][mne_bids_pipeline._config.regress_artifact] to allow artifact regression (e.g., of MEG reference sensors in KIT systems) (#837 by @larsoner) [//]: # (### :warning: Behavior changes) diff --git a/mne_bids_pipeline/_config.py b/mne_bids_pipeline/_config.py index 652e5ebfb..e3c7626bb 100644 --- a/mne_bids_pipeline/_config.py +++ b/mne_bids_pipeline/_config.py @@ -1,7 +1,8 @@ # Default settings for data processing and analysis. -from typing import Callable, Iterable, Literal, Optional, Union +from typing import Annotated, Any, Callable, Literal, Optional, Sequence, Union +from annotated_types import Ge, Interval, Len from mne import Covariance from mne_bids import BIDSPath @@ -94,7 +95,7 @@ The task to process. """ -runs: Union[Iterable, Literal["all"]] = "all" +runs: Union[Sequence, Literal["all"]] = "all" """ The runs to process. If `'all'`, will process all runs found in the BIDS dataset. @@ -143,7 +144,7 @@ The BIDS `space` entity. """ -plot_psd_for_runs: Union[Literal["all"], Iterable[str]] = "all" +plot_psd_for_runs: Union[Literal["all"], Sequence[str]] = "all" """ For which runs to add a power spectral density (PSD) plot to the generated report. This can take a considerable amount of time if you have many long @@ -151,7 +152,7 @@ plotting. """ -subjects: Union[Iterable[str], Literal["all"]] = "all" +subjects: Union[Sequence[str], Literal["all"]] = "all" """ Subjects to analyze. If `'all'`, include all subjects. To only include a subset of subjects, pass a list of their identifiers. Even @@ -171,7 +172,7 @@ ``` """ -exclude_subjects: Iterable[str] = [] +exclude_subjects: Sequence[str] = [] """ Specify subjects to exclude from analysis. The MEG empty-room mock-subject is automatically excluded from regular analysis. @@ -201,7 +202,7 @@ covariance (via `noise_cov='rest'`). """ -ch_types: Iterable[Literal["meg", "mag", "grad", "eeg"]] = [] +ch_types: Annotated[Sequence[Literal["meg", "mag", "grad", "eeg"]], Len(1, 4)] = [] """ The channel types to consider. @@ -252,7 +253,7 @@ ``` """ -eog_channels: Optional[Iterable[str]] = None +eog_channels: Optional[Sequence[str]] = None """ Specify EOG channels to use, or create virtual EOG channels. @@ -320,7 +321,7 @@ ``` """ -eeg_reference: Union[Literal["average"], str, Iterable["str"]] = "average" +eeg_reference: Union[Literal["average"], str, Sequence["str"]] = "average" """ The EEG reference to use. If `average`, will use the average reference, i.e. the average across all channels. If a string, must be the name of a single @@ -371,7 +372,7 @@ ``` """ -drop_channels: Iterable[str] = [] +drop_channels: Sequence[str] = [] """ Names of channels to remove from the data. This can be useful, for example, if you have added a new bipolar channel via `eeg_bipolar_channels` and now wish @@ -385,7 +386,7 @@ """ analyze_channels: Union[ - Literal["all"], Literal["ch_types"], Iterable["str"] + Literal["all"], Literal["ch_types"], Sequence["str"] ] = "ch_types" """ The names of the channels to analyze during ERP/ERF and time-frequency analysis @@ -789,7 +790,7 @@ Keep it `None` if no lowpass filtering should be applied. """ -notch_freq: Optional[Union[float, Iterable[float]]] = None +notch_freq: Optional[Union[float, Sequence[float]]] = None """ Notch filter frequency. More than one frequency can be supplied, e.g. to remove harmonics. Keep it `None` if no notch filter should be applied. @@ -827,7 +828,7 @@ Specifies the transition bandwidth of the notch filter. The default is `1.`. """ -notch_widths: Optional[Union[float, Iterable[float]]] = None +notch_widths: Optional[Union[float, Sequence[float]]] = None """ Specifies the width of each stop band. `None` uses the MNE default. """ @@ -931,7 +932,7 @@ window for metadata generation. """ -epochs_metadata_keep_first: Optional[Iterable[str]] = None +epochs_metadata_keep_first: Optional[Sequence[str]] = None """ Event groupings using hierarchical event descriptors (HEDs) for which to store the time of the **first** occurrence of any event of this group in a new column @@ -959,7 +960,7 @@ and `first_stimulus`. """ -epochs_metadata_keep_last: Optional[Iterable[str]] = None +epochs_metadata_keep_last: Optional[Sequence[str]] = None """ Same as `epochs_metadata_keep_first`, but for keeping the **last** occurrence of matching event types. The columns indicating the event types @@ -979,7 +980,7 @@ ``` """ # noqa: E501 -conditions: Optional[Union[Iterable[str], dict[str, str]]] = None +conditions: Optional[Union[Sequence[str], dict[str, str]]] = None """ The time-locked events based on which to create evoked responses. This can either be name of the experimental condition as specified in the @@ -1058,7 +1059,7 @@ ``` """ -contrasts: Iterable[Union[tuple[str, str], ArbitraryContrast]] = [] +contrasts: Sequence[Union[tuple[str, str], ArbitraryContrast]] = [] """ The conditions to contrast via a subtraction of ERPs / ERFs. The list elements can either be tuples or dictionaries (or a mix of both). Each element in the @@ -1125,6 +1126,24 @@ # # Currently you cannot use both. +regress_artifact: Optional[dict[str, Any]] = None +""" +Keyword arguments to pass to the `mne.preprocessing.EOGRegression` model used +in `mne.preprocessing.regress_artifact`. If `None`, no time-domain regression will +be applied. Note that any channels picked in `regress_artifact["picks_artifact"]` will +have the same time-domain filters applied to them as the experimental data. + +Artifact regression is applied before SSP or ICA. + +???+ example "Example" + For example, if you have MEG reference channel data recorded in three + miscellaneous channels, you could do: + + ```python + regress_artifact = {"picks": "meg", "picks_artifact": ["MISC 001", "MISC 002", "MISC 003"]} + ``` +""" # noqa: E501 + spatial_filter: Optional[Literal["ssp", "ica"]] = None """ Whether to use a spatial filter to detect and remove artifacts. The BIDS @@ -1516,7 +1535,7 @@ you don't need to be worried about **exactly** balancing class sizes. """ -decoding_n_splits: int = 5 +decoding_n_splits: Annotated[int, Ge(2)] = 5 """ The number of folds (also called "splits") to use in the K-fold cross-validation scheme. @@ -1577,7 +1596,7 @@ test to determine the significance of the decoding scores across participants. """ -cluster_permutation_p_threshold: float = 0.05 +cluster_permutation_p_threshold: Annotated[float, Interval(gt=0, lt=1)] = 0.05 """ The alpha level (p-value, p threshold) to use for rejecting the null hypothesis that the clusters show no significant difference between conditions. This is @@ -1609,7 +1628,7 @@ # TIME-FREQUENCY # -------------- -time_frequency_conditions: Iterable[str] = [] +time_frequency_conditions: Sequence[str] = [] """ The conditions to compute time-frequency decomposition on. diff --git a/mne_bids_pipeline/_config_import.py b/mne_bids_pipeline/_config_import.py index 66fe9583a..db5487cb7 100644 --- a/mne_bids_pipeline/_config_import.py +++ b/mne_bids_pipeline/_config_import.py @@ -12,8 +12,7 @@ import matplotlib import mne import numpy as np -from pydantic import ValidationError -from pydantic.dataclasses import dataclass +from pydantic import BaseModel, ConfigDict, ValidationError from ._logging import gen_log_kwargs, logger from .typing import PathLike @@ -269,17 +268,6 @@ def _check_config(config: SimpleNamespace, config_path: Optional[PathLike]) -> N f'ica_reject["{ch_type}"] ({ica_reject[ch_type]})' ) - if not config.ch_types: - raise ValueError("Please specify ch_types in your configuration.") - - _VALID_TYPES = ("meg", "mag", "grad", "eeg") - if any(ch_type not in _VALID_TYPES for ch_type in config.ch_types): - raise ValueError( - "Invalid channel type passed. Please adjust `ch_types` in your " - f"configuration, got {config.ch_types} but supported types are " - f"{_VALID_TYPES}" - ) - if config.noise_cov == "emptyroom" and "eeg" in config.ch_types: raise ValueError( "You requested to process data that contains EEG channels. In " @@ -312,16 +300,7 @@ def _check_config(config: SimpleNamespace, config_path: Optional[PathLike]) -> N f"but you set baseline={bl}" ) - # check decoding parameters - if config.decoding_n_splits < 2: - raise ValueError("decoding_n_splits should be at least 2.") - # check cluster permutation parameters - if not 0 < config.cluster_permutation_p_threshold < 1: - raise ValueError( - "cluster_permutation_p_threshold should be in the (0, 1) interval." - ) - if config.cluster_n_permutations < 10 / config.cluster_permutation_p_threshold: raise ValueError( "cluster_n_permutations is not big enough to calculate " @@ -380,33 +359,30 @@ def _pydantic_validate( # https://docs.pydantic.dev/latest/usage/dataclasses/ from . import _config as root_config - annotations = copy.deepcopy(root_config.__annotations__) # just be safe - attrs = { - key: _default_factory(key, val) - for key, val in root_config.__dict__.items() - if key in annotations - } - # everything should be type annotated, make sure they are - asym = set(attrs).symmetric_difference(set(annotations)) - assert asym == set(), asym + # Modify annotations to add nested strict parsing + annotations = dict() + attrs = dict() + for key, annot in root_config.__annotations__.items(): + annotations[key] = annot + attrs[key] = _default_factory(key, root_config.__dict__[key]) name = "user configuration" if config_path is not None: name += f" from {config_path}" - UserConfig = type( - name, - (object,), - {"__annotations__": annotations, **attrs}, - ) - dataclass_config = dict( + model_config = ConfigDict( arbitrary_types_allowed=False, validate_assignment=True, strict=True, # do not allow float for int for example + extra="forbid", + ) + UserConfig = type( + name, + (BaseModel,), + {"__annotations__": annotations, "model_config": model_config, **attrs}, ) - UserConfig = dataclass(config=dataclass_config)(UserConfig) # Now use pydantic to automagically validate user_vals = {key: val for key, val in config.__dict__.items() if key in annotations} try: - UserConfig(**user_vals) + UserConfig.model_validate(user_vals) except ValidationError as err: raise ValueError(str(err)) from None diff --git a/mne_bids_pipeline/_download.py b/mne_bids_pipeline/_download.py index 45de893ed..46cf17e7a 100644 --- a/mne_bids_pipeline/_download.py +++ b/mne_bids_pipeline/_download.py @@ -77,13 +77,24 @@ def _download_from_web(*, ds_name: str, ds_path: Path): (path / f"{ds_name}.zip").unlink() +def _download_via_mne(*, ds_name: str, ds_path: Path): + assert ds_path.stem == ds_name, ds_path + getattr(mne.datasets, DATASET_OPTIONS[ds_name]["mne"]).data_path( + ds_path.parent, + verbose=True, + ) + + def _download(*, ds_name: str, ds_path: Path): options = DATASET_OPTIONS[ds_name] openneuro_name = options.get("openneuro", "") git_url = options.get("git", "") osf_node = options.get("osf", "") web_url = options.get("web", "") - assert sum(bool(x) for x in (openneuro_name, git_url, osf_node, web_url)) == 1 + mne_mod = options.get("mne", "") + assert ( + sum(bool(x) for x in (openneuro_name, git_url, osf_node, web_url, mne_mod)) == 1 + ) if openneuro_name: download_func = _download_via_openneuro @@ -91,6 +102,8 @@ def _download(*, ds_name: str, ds_path: Path): download_func = _download_via_datalad elif osf_node: raise RuntimeError("OSF downloads are currently not supported.") + elif mne_mod: + download_func = _download_via_mne else: assert web_url download_func = _download_from_web diff --git a/mne_bids_pipeline/_import_data.py b/mne_bids_pipeline/_import_data.py index d7f22240d..be892576b 100644 --- a/mne_bids_pipeline/_import_data.py +++ b/mne_bids_pipeline/_import_data.py @@ -452,7 +452,6 @@ def import_er_data( cfg=cfg, bids_path_bads=bids_path_er_bads_in, ) - raw_er.pick("meg", exclude=[]) # Don't deal with ref for now (initial data quality / auto bad step) if bids_path_ref_in is None: @@ -530,7 +529,7 @@ def _get_bids_path_in( session: Optional[str], run: Optional[str], task: Optional[str], - kind: Literal["orig", "sss"] = "orig", + kind: Literal["orig", "sss", "filt"] = "orig", ) -> BIDSPath: # b/c can be used before this is updated path_kwargs = dict( @@ -544,13 +543,13 @@ def _get_bids_path_in( datatype=get_datatype(config=cfg), check=False, ) - if kind == "sss": + if kind != "orig": + assert kind in ("sss", "filt"), kind path_kwargs["root"] = cfg.deriv_root path_kwargs["suffix"] = "raw" path_kwargs["extension"] = ".fif" - path_kwargs["processing"] = "sss" + path_kwargs["processing"] = kind else: - assert kind == "orig", kind path_kwargs["root"] = cfg.bids_root path_kwargs["suffix"] = None path_kwargs["extension"] = None @@ -566,7 +565,7 @@ def _get_run_path( session: Optional[str], run: Optional[str], task: Optional[str], - kind: Literal["orig", "sss"], + kind: Literal["orig", "sss", "filt"], add_bads: Optional[bool] = None, allow_missing: bool = False, key: Optional[str] = None, @@ -594,7 +593,7 @@ def _get_rest_path( cfg: SimpleNamespace, subject: str, session: Optional[str], - kind: Literal["orig", "sss"], + kind: Literal["orig", "sss", "filt"], add_bads: Optional[bool] = None, ) -> dict: if not (cfg.process_rest and not cfg.task_is_rest): @@ -616,13 +615,14 @@ def _get_noise_path( cfg: SimpleNamespace, subject: str, session: Optional[str], - kind: Literal["orig", "sss"], + kind: Literal["orig", "sss", "filt"], mf_reference_run: Optional[str], add_bads: Optional[bool] = None, ) -> dict: if not (cfg.process_empty_room and get_datatype(config=cfg) == "meg"): return dict() - if kind == "sss": + if kind != "orig": + assert kind in ("sss", "filt") raw_fname = _get_bids_path_in( cfg=cfg, subject=subject, @@ -661,7 +661,7 @@ def _get_run_rest_noise_path( session: Optional[str], run: Optional[str], task: Optional[str], - kind: Literal["orig", "sss"], + kind: Literal["orig", "sss", "filt"], mf_reference_run: Optional[str], add_bads: Optional[bool] = None, ) -> dict: @@ -705,7 +705,7 @@ def _path_dict( cfg: SimpleNamespace, bids_path_in: BIDSPath, add_bads: Optional[bool] = None, - kind: Literal["orig", "sss"], + kind: Literal["orig", "sss", "filt"], allow_missing: bool, key: Optional[str] = None, ) -> dict: @@ -805,3 +805,14 @@ def _import_data_kwargs(*, config: SimpleNamespace, subject: str) -> dict: runs=get_runs(config=config, subject=subject), # XXX needs to accept session! **_bids_kwargs(config=config), ) + + +def _get_run_type( + run: Optional[str], + task: Optional[str], +) -> str: + if run is None and task in ("noise", "rest"): + run_type = dict(rest="resting-state", noise="empty-room")[task] + else: + run_type = "experimental" + return run_type diff --git a/mne_bids_pipeline/_report.py b/mne_bids_pipeline/_report.py index ed514925d..80f2f1962 100644 --- a/mne_bids_pipeline/_report.py +++ b/mne_bids_pipeline/_report.py @@ -68,14 +68,13 @@ def _open_report( yield report finally: try: - msg = "Adding config and sys info to report" - logger.info(**gen_log_kwargs(message=msg)) _finalize( report=report, exec_params=exec_params, subject=subject, session=session, run=run, + task=task, ) except Exception as exc: logger.warning(f"Failed: {exc}") @@ -506,12 +505,17 @@ def _finalize( subject: str, session: Optional[str], run: Optional[str], + task: Optional[str], ) -> None: """Add system information and the pipeline configuration to the report.""" # ensure they are always appended titles = ["Configuration file", "System information"] for title in titles: report.remove(title=title, remove_all=True) + # Print this exactly once + if _cached_sys_info.cache_info()[-1] == 0: # never run + msg = "Adding config and sys info to report" + logger.info(**gen_log_kwargs(message=msg)) # No longer need replace=True in these report.add_code( code=exec_params.config_path, diff --git a/mne_bids_pipeline/_run.py b/mne_bids_pipeline/_run.py index 128b876ed..c7e46267b 100644 --- a/mne_bids_pipeline/_run.py +++ b/mne_bids_pipeline/_run.py @@ -225,13 +225,18 @@ def wrapper(*args, **kwargs): for key, (fname, this_hash) in out_files_hashes.items(): fname = pathlib.Path(fname) if not fname.exists(): - msg = "Output file missing, will recompute …" + msg = ( + f"Output file missing {str(fname)}, " "will recompute …" + ) emoji = "🧩" bad_out_files = True break got_hash = hash_(key, fname, kind="out")[1] if this_hash != got_hash: - msg = "Output file hash mismatch, will recompute …" + msg = ( + f"Output file hash mismatch for {str(fname)}, " + "will recompute …" + ) emoji = "🚫" bad_out_files = True break diff --git a/mne_bids_pipeline/steps/preprocessing/_04_frequency_filter.py b/mne_bids_pipeline/steps/preprocessing/_04_frequency_filter.py index a44a1c70e..fd9c6c874 100644 --- a/mne_bids_pipeline/steps/preprocessing/_04_frequency_filter.py +++ b/mne_bids_pipeline/steps/preprocessing/_04_frequency_filter.py @@ -20,6 +20,8 @@ import mne import numpy as np +from mne.io.pick import _picks_to_idx +from mne.preprocessing import EOGRegression from ..._config_utils import ( get_runs_tasks, @@ -28,6 +30,7 @@ ) from ..._import_data import ( _get_run_rest_noise_path, + _get_run_type, _import_data_kwargs, import_er_data, import_experimental_data, @@ -69,6 +72,7 @@ def notch_filter( trans_bandwidth: Union[float, Literal["auto"]], notch_widths: Optional[Union[float, Iterable[float]]], run_type: Literal["experimental", "empty-room", "resting-state"], + picks: Optional[np.ndarray], ) -> None: """Filter data channels (MEG and EEG).""" if freqs is None: @@ -86,6 +90,7 @@ def notch_filter( trans_bandwidth=trans_bandwidth, notch_widths=notch_widths, n_jobs=1, + picks=picks, ) @@ -100,6 +105,7 @@ def bandpass_filter( l_trans_bandwidth: Union[float, Literal["auto"]], h_trans_bandwidth: Union[float, Literal["auto"]], run_type: Literal["experimental", "empty-room", "resting-state"], + picks: Optional[np.ndarray], ) -> None: """Filter data channels (MEG and EEG).""" if l_freq is not None and h_freq is None: @@ -122,6 +128,7 @@ def bandpass_filter( l_trans_bandwidth=l_trans_bandwidth, h_trans_bandwidth=h_trans_bandwidth, n_jobs=1, + picks=picks, ) @@ -161,14 +168,10 @@ def filter_data( bids_path_in = in_files.pop(in_key) bids_path_bads_in = in_files.pop(f"{in_key}-bads", None) - if run is None and task in ("noise", "rest"): - run_type = dict(rest="resting-state", noise="empty-room")[task] - else: - run_type = "experimental" - + run_type = _get_run_type(run=run, task=task) + msg = f"Reading {run_type} recording: " f"{bids_path_in.basename}" + logger.info(**gen_log_kwargs(message=msg)) if cfg.use_maxwell_filter: - msg = f"Reading {run_type} recording: " f"{bids_path_in.basename}" - logger.info(**gen_log_kwargs(message=msg)) raw = mne.io.read_raw_fif(bids_path_in) elif run is None and task == "noise": raw = import_er_data( @@ -191,6 +194,8 @@ def filter_data( out_files[in_key] = bids_path_in.copy().update( root=cfg.deriv_root, + subject=subject, # save under subject's directory so all files are there + session=session, processing="filt", extension=".fif", suffix="raw", @@ -200,6 +205,18 @@ def filter_data( check=False, ) + if cfg.regress_artifact is None: + picks = None + else: + # Need to figure out the correct picks to use + model = EOGRegression(**cfg.regress_artifact) + picks_regress = _picks_to_idx( + raw.info, model.picks, none="data", exclude=model.exclude + ) + picks_artifact = _picks_to_idx(raw.info, model.picks_artifact) + picks_data = _picks_to_idx(raw.info, "data", exclude=()) # raw.filter default + picks = np.unique(np.r_[picks_regress, picks_artifact, picks_data]) + raw.load_data() notch_filter( raw=raw, @@ -211,6 +228,7 @@ def filter_data( trans_bandwidth=cfg.notch_trans_bandwidth, notch_widths=cfg.notch_widths, run_type=run_type, + picks=picks, ) bandpass_filter( raw=raw, @@ -223,6 +241,7 @@ def filter_data( h_trans_bandwidth=cfg.h_trans_bandwidth, l_trans_bandwidth=cfg.l_trans_bandwidth, run_type=run_type, + picks=picks, ) resample( raw=raw, @@ -287,6 +306,7 @@ def get_config( notch_trans_bandwidth=config.notch_trans_bandwidth, notch_widths=config.notch_widths, raw_resample_sfreq=config.raw_resample_sfreq, + regress_artifact=config.regress_artifact, **_import_data_kwargs(config=config, subject=subject), ) return cfg diff --git a/mne_bids_pipeline/steps/preprocessing/_05_regress_artifact.py b/mne_bids_pipeline/steps/preprocessing/_05_regress_artifact.py new file mode 100644 index 000000000..8a2b2a0f6 --- /dev/null +++ b/mne_bids_pipeline/steps/preprocessing/_05_regress_artifact.py @@ -0,0 +1,172 @@ +"""Run Signal Subspace Projections (SSP) for artifact correction. + +These are often also referred to as PCA vectors. +""" + +from types import SimpleNamespace +from typing import Optional + +import mne +from mne.io.pick import _picks_to_idx +from mne.preprocessing import EOGRegression + +from ..._config_utils import ( + get_runs_tasks, + get_sessions, + get_subjects, +) +from ..._import_data import _get_run_rest_noise_path, _get_run_type, _import_data_kwargs +from ..._logging import gen_log_kwargs, logger +from ..._parallel import get_parallel_backend, parallel_func +from ..._report import _add_raw, _open_report +from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs + + +def get_input_fnames_regress_artifact( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], + run: str, + task: Optional[str], +) -> dict: + """Get paths of files required by regress_artifact function.""" + out = _get_run_rest_noise_path( + cfg=cfg, + subject=subject, + session=session, + run=run, + task=task, + kind="filt", + mf_reference_run=cfg.mf_reference_run, + ) + assert len(out) + return out + + +@failsafe_run( + get_input_fnames=get_input_fnames_regress_artifact, +) +def run_regress_artifact( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + run: str, + task: Optional[str], + in_files: dict, +) -> dict: + model = EOGRegression(proj=False, **cfg.regress_artifact) + out_files = dict() + in_key = f"raw_task-{task}_run-{run}" + bids_path_in = in_files.pop(in_key) + out_files[in_key] = bids_path_in.copy().update(processing="regress") + run_type = _get_run_type(run=run, task=task) + msg = f"Reading {run_type} recording: " f"{bids_path_in.basename}" + logger.info(**gen_log_kwargs(message=msg)) + raw = mne.io.read_raw_fif(bids_path_in).load_data() + projs = raw.info["projs"] + raw.del_proj() + model.fit(raw) + all_types = raw.get_channel_types() + picks = _picks_to_idx(raw.info, model.picks, none="data", exclude=model.exclude) + ch_types = set(all_types[pick] for pick in picks) + del picks + out_files["regress"] = bids_path_in.copy().update( + processing=None, + split=None, + run=None, + suffix="regress", + extension=".h5", + ) + model.apply(raw, copy=False) + if projs: + raw.add_proj(projs) + raw.save(out_files[in_key], overwrite=True) + _update_for_splits(out_files, in_key) + model.save(out_files["regress"], overwrite=True) + assert len(in_files) == 0, in_files.keys() + + # Report + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + run=run, + task=task, + ) as report: + msg = "Adding regressed raw data to report" + logger.info(**gen_log_kwargs(message=msg)) + figs, captions = list(), list() + for kind in ("mag", "grad", "eeg"): + if kind not in ch_types: + continue + figs.append(model.plot(ch_type=kind)) + captions.append(f"Run {run}: {kind}") + if figs: + report.add_figure( + fig=figs, + caption=captions, + title="Regression weights", + tags=("raw", f"run-{run}", "regression"), + replace=True, + ) + _add_raw( + cfg=cfg, + report=report, + bids_path_in=out_files[in_key], + title="Raw (regression)", + tags=("regression",), + raw=raw, + ) + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +def get_config( + *, + config: SimpleNamespace, + subject: str, +) -> SimpleNamespace: + cfg = SimpleNamespace( + regress_artifact=config.regress_artifact, + **_import_data_kwargs(config=config, subject=subject), + ) + return cfg + + +def main(*, config: SimpleNamespace) -> None: + """Run artifact regression.""" + if config.regress_artifact is None: + msg = "Skipping …" + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) + return + + with get_parallel_backend(config.exec_params): + parallel, run_func = parallel_func( + run_regress_artifact, exec_params=config.exec_params + ) + + logs = parallel( + run_func( + cfg=get_config( + config=config, + subject=subject, + ), + exec_params=config.exec_params, + subject=subject, + session=session, + run=run, + task=task, + ) + for subject in get_subjects(config) + for session in get_sessions(config) + for run, task in get_runs_tasks( + config=config, + subject=subject, + session=session, + ) + ) + + save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py b/mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py index 00346df25..7bfef3c56 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py +++ b/mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py @@ -253,7 +253,7 @@ def get_input_fnames_run_ica( for run in cfg.runs: key = f"raw_run-{run}" in_files[key] = bids_basename.copy().update( - run=run, processing="filt", suffix="raw" + run=run, processing=cfg.processing, suffix="raw" ) _update_for_splits(in_files, key, single=True) return in_files @@ -614,6 +614,7 @@ def get_config( eog_channels=config.eog_channels, rest_epochs_duration=config.rest_epochs_duration, rest_epochs_overlap=config.rest_epochs_overlap, + processing="filt" if config.regress_artifact is None else "regress", **_bids_kwargs(config=config), ) return cfg diff --git a/mne_bids_pipeline/steps/preprocessing/_06b_run_ssp.py b/mne_bids_pipeline/steps/preprocessing/_06b_run_ssp.py index 46b88ee90..7aa0e97de 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06b_run_ssp.py +++ b/mne_bids_pipeline/steps/preprocessing/_06b_run_ssp.py @@ -47,7 +47,7 @@ def get_input_fnames_run_ssp( for run in cfg.runs: key = f"raw_run-{run}" in_files[key] = bids_basename.copy().update( - run=run, processing="filt", suffix="raw" + run=run, processing=cfg.processing, suffix="raw" ) _update_for_splits(in_files, key, single=True) return in_files @@ -66,7 +66,7 @@ def run_ssp( ) -> dict: import matplotlib.pyplot as plt - # compute SSP on first run of raw + # compute SSP on all runs of raw raw_fnames = [in_files.pop(f"raw_run-{run}") for run in cfg.runs] # when saving proj, use run=None @@ -229,6 +229,7 @@ def get_config( epochs_decim=config.epochs_decim, use_maxwell_filter=config.use_maxwell_filter, runs=get_runs(config=config, subject=subject), + processing="filt" if config.regress_artifact is None else "regress", **_bids_kwargs(config=config), ) return cfg diff --git a/mne_bids_pipeline/steps/preprocessing/_05_make_epochs.py b/mne_bids_pipeline/steps/preprocessing/_07_make_epochs.py similarity index 100% rename from mne_bids_pipeline/steps/preprocessing/_05_make_epochs.py rename to mne_bids_pipeline/steps/preprocessing/_07_make_epochs.py diff --git a/mne_bids_pipeline/steps/preprocessing/_07a_apply_ica.py b/mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py similarity index 99% rename from mne_bids_pipeline/steps/preprocessing/_07a_apply_ica.py rename to mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py index c24d8e015..f4b999cc8 100644 --- a/mne_bids_pipeline/steps/preprocessing/_07a_apply_ica.py +++ b/mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py @@ -1,4 +1,4 @@ -"""Apply ICA and obtain the cleaned epochs. +"""Apply ICA and obtain the cleaned epochs and raw data. Blinks and ECG artifacts are automatically detected and the corresponding ICA components are removed from the data. diff --git a/mne_bids_pipeline/steps/preprocessing/_07b_apply_ssp.py b/mne_bids_pipeline/steps/preprocessing/_08b_apply_ssp.py similarity index 96% rename from mne_bids_pipeline/steps/preprocessing/_07b_apply_ssp.py rename to mne_bids_pipeline/steps/preprocessing/_08b_apply_ssp.py index 9b1a83fc9..b1eda9cd1 100644 --- a/mne_bids_pipeline/steps/preprocessing/_07b_apply_ssp.py +++ b/mne_bids_pipeline/steps/preprocessing/_08b_apply_ssp.py @@ -1,4 +1,4 @@ -"""Apply SSP projections and obtain the cleaned epochs. +"""Apply SSP projections and obtain the cleaned epochs and raw data. Blinks and ECG artifacts are automatically detected and the corresponding SSP projections components are removed from the data. @@ -57,8 +57,6 @@ def apply_ssp( session: Optional[str], in_files: dict, ) -> dict: - # load epochs to reject ICA components - # compute SSP on first run of raw out_files = dict() out_files["epochs"] = ( in_files["epochs"].copy().update(processing="ssp", split=None, check=False) diff --git a/mne_bids_pipeline/steps/preprocessing/_08_ptp_reject.py b/mne_bids_pipeline/steps/preprocessing/_09_ptp_reject.py similarity index 100% rename from mne_bids_pipeline/steps/preprocessing/_08_ptp_reject.py rename to mne_bids_pipeline/steps/preprocessing/_09_ptp_reject.py diff --git a/mne_bids_pipeline/steps/preprocessing/__init__.py b/mne_bids_pipeline/steps/preprocessing/__init__.py index 686b7cf27..07d65224a 100644 --- a/mne_bids_pipeline/steps/preprocessing/__init__.py +++ b/mne_bids_pipeline/steps/preprocessing/__init__.py @@ -5,12 +5,13 @@ _02_head_pos, _03_maxfilter, _04_frequency_filter, - _05_make_epochs, + _05_regress_artifact, _06a_run_ica, _06b_run_ssp, - _07a_apply_ica, - _07b_apply_ssp, - _08_ptp_reject, + _07_make_epochs, + _08a_apply_ica, + _08b_apply_ssp, + _09_ptp_reject, ) _STEPS = ( @@ -18,10 +19,11 @@ _02_head_pos, _03_maxfilter, _04_frequency_filter, - _05_make_epochs, + _05_regress_artifact, _06a_run_ica, _06b_run_ssp, - _07a_apply_ica, - _07b_apply_ssp, - _08_ptp_reject, + _07_make_epochs, + _08a_apply_ica, + _08b_apply_ssp, + _09_ptp_reject, ) diff --git a/mne_bids_pipeline/tests/configs/config_MNE_phantom_KIT_data.py b/mne_bids_pipeline/tests/configs/config_MNE_phantom_KIT_data.py new file mode 100644 index 000000000..ef3347a53 --- /dev/null +++ b/mne_bids_pipeline/tests/configs/config_MNE_phantom_KIT_data.py @@ -0,0 +1,28 @@ +""" +KIT phantom data. + +https://mne.tools/dev/documentation/datasets.html#kit-phantom-dataset +""" + +study_name = "MNE-phantom-KIT-data" +bids_root = "~/mne_data/MNE-phantom-KIT-data" +deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/MNE-phantom-KIT-data" +task = "phantom" +ch_types = ["meg"] + +# Preprocessing +l_freq = None +h_freq = 40.0 +regress_artifact = dict( + picks="meg", picks_artifact=["MISC 001", "MISC 002", "MISC 003"] +) + +# Epochs +epochs_tmin = -0.08 +epochs_tmax = 0.18 +epochs_decim = 10 # 2000->200 Hz +baseline = (None, 0) +conditions = ["dip01", "dip13", "dip25", "dip37", "dip49"] + +# Decoding +decode = True # should be very good performance diff --git a/mne_bids_pipeline/tests/datasets.py b/mne_bids_pipeline/tests/datasets.py index f96a01042..c559f06ca 100644 --- a/mne_bids_pipeline/tests/datasets.py +++ b/mne_bids_pipeline/tests/datasets.py @@ -9,6 +9,7 @@ class DATASET_OPTIONS_T(TypedDict, total=False): openneuro: str # "" osf: str # "" web: str # "" + mne: str # "" include: list[str] # [] exclude: list[str] # [] hash: str # "" @@ -122,4 +123,7 @@ class DATASET_OPTIONS_T(TypedDict, total=False): "sub-emptyroom/ses-20000101", ], }, + "MNE-phantom-KIT-data": { + "mne": "phantom_kit", + }, } diff --git a/mne_bids_pipeline/tests/test_run.py b/mne_bids_pipeline/tests/test_run.py index 4eee1aa02..2e068ef70 100644 --- a/mne_bids_pipeline/tests/test_run.py +++ b/mne_bids_pipeline/tests/test_run.py @@ -124,6 +124,9 @@ class _TestOptionsT(TypedDict, total=False): "config": "config_ERP_CORE.py", "task": "P3", }, + "MNE-phantom-KIT-data": { + "config": "config_MNE_phantom_KIT_data.py", + }, } diff --git a/mne_bids_pipeline/tests/test_validation.py b/mne_bids_pipeline/tests/test_validation.py index c47432155..e99bfecf9 100644 --- a/mne_bids_pipeline/tests/test_validation.py +++ b/mne_bids_pipeline/tests/test_validation.py @@ -14,7 +14,7 @@ def test_validation(tmp_path, capsys): bad_text += f"bids_root = '{tmp_path}'\n" # no ch_types config_path.write_text(bad_text) - with pytest.raises(ValueError, match="Please specify ch_types"): + with pytest.raises(ValueError, match="Value should have at least 1 item"): _import_config(config_path=config_path) bad_text += "ch_types = ['eeg']\n" # conditions