Skip to content

Commit

Permalink
MAINT: Add regression and raw saving
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Jan 24, 2024
1 parent 742e27e commit 2911e7f
Show file tree
Hide file tree
Showing 17 changed files with 285 additions and 44 deletions.
2 changes: 1 addition & 1 deletion docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ nav:
- Epoching: settings/preprocessing/epochs.md
- Artifact removal:
- Stimulation artifact: settings/preprocessing/stim_artifact.md
- SSP & ICA: settings/preprocessing/ssp_ica.md
- SSP, ICA, and artifact regression: settings/preprocessing/ssp_ica.md
- Amplitude-based artifact rejection: settings/preprocessing/artifacts.md
- Sensor-level analysis:
- Condition contrasts: settings/sensor/contrasts.md
Expand Down
2 changes: 2 additions & 0 deletions docs/source/examples/gen_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict:
key = "Maxwell filter"
funcs[key] = funcs[key] or config.use_maxwell_filter
funcs["Frequency filter"] = config.l_freq or config.h_freq
key = "Artifact regression"
funcs[key] = funcs[key] or (config.regress_artifact is not None)
key = "SSP"
funcs[key] = funcs[key] or (config.spatial_filter == "ssp")
key = "ICA"
Expand Down
1 change: 1 addition & 0 deletions docs/source/settings/preprocessing/ssp_ica.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ tags:
::: mne_bids_pipeline._config
options:
members:
- regress_artifact
- spatial_filter
- min_ecg_epochs
- min_eog_epochs
Expand Down
4 changes: 2 additions & 2 deletions docs/source/v1.6.md.inc
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

## vX.Y.0 (unreleased)

[//]: # (### :new: New features & enhancements)
:new: New features & enhancements

[//]: # (- Whatever (#000 by @whoever))
- Added [`regress_artifact`][mne_bids_pipeline._config.regress_artifact] to allow artifact regression (e.g., of MEG reference sensors in KIT systems) (#837 by @larsoner)

[//]: # (### :warning: Behavior changes)

Expand Down
26 changes: 22 additions & 4 deletions mne_bids_pipeline/_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Default settings for data processing and analysis.

from typing import Optional, Union, Iterable, List, Tuple, Dict, Callable, Literal
from typing import Optional, Union, Iterable, List, Tuple, Dict, Callable, Literal, Any

from mne import Covariance
from mne_bids import BIDSPath
Expand Down Expand Up @@ -1126,6 +1126,24 @@
#
# Currently you cannot use both.

regress_artifact: Optional[Dict[str, Any]] = None
"""
Keyword arguments to pass to the `mne.preprocessing.EOGRegression` model used
in `mne.preprocessing.regress_artifact`. If `None`, no time-domain regression will
be applied. Note that any channels picked in `regress_artifact["picks_artifact"]` will
have the same time-domain filters applied to them as the experimental data.
Artifact regression is applied before SSP or ICA.
???+ example "Example"
For example, if you have MEG reference channel data recorded in three
miscellaneous channels, you could do:
```python
regress_artifact = {"picks": "meg", "picks_artifact": ["MISC 001", "MISC 002", "MISC 003"]}
```
""" # noqa: E501

spatial_filter: Optional[Literal["ssp", "ica"]] = None
"""
Whether to use a spatial filter to detect and remove artifacts. The BIDS
Expand Down Expand Up @@ -1237,7 +1255,7 @@
"""
Peak-to-peak amplitude limits to exclude epochs from ICA fitting. This allows you to
remove strong transient artifacts from the epochs used for fitting ICA, which could
negatively affect ICA performance.
negatively affect ICA performance.
The parameter values are the same as for [`reject`][mne_bids_pipeline._config.reject],
but `"autoreject_global"` is not supported. `"autoreject_local"` here behaves
Expand All @@ -1264,7 +1282,7 @@
to **not** specify rejection thresholds for EOG and ECG channels here –
otherwise, ICA won't be able to "see" these artifacts.
???+ info
???+ info
This setting is applied only to the epochs that are used for **fitting** ICA. The
goal is to make it easier for ICA to produce a good decomposition. After fitting,
ICA is applied to the epochs to be analyzed, usually with one or more components
Expand Down Expand Up @@ -1386,7 +1404,7 @@
If `None` (default), do not apply artifact rejection.
If a dictionary, manually specify rejection thresholds (see examples).
If a dictionary, manually specify rejection thresholds (see examples).
The thresholds provided here must be at least as stringent as those in
[`ica_reject`][mne_bids_pipeline._config.ica_reject] if using ICA. In case of
`'autoreject_global'`, thresholds for any channel that do not meet this
Expand Down
33 changes: 22 additions & 11 deletions mne_bids_pipeline/_import_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,6 @@ def import_er_data(
cfg=cfg,
bids_path_bads=bids_path_er_bads_in,
)
raw_er.pick("meg", exclude=[])

# Don't deal with ref for now (initial data quality / auto bad step)
if bids_path_ref_in is None:
Expand Down Expand Up @@ -527,7 +526,7 @@ def _get_bids_path_in(
session: Optional[str],
run: Optional[str],
task: Optional[str],
kind: Literal["orig", "sss"] = "orig",
kind: Literal["orig", "sss", "filt"] = "orig",
) -> BIDSPath:
# b/c can be used before this is updated
path_kwargs = dict(
Expand All @@ -541,13 +540,13 @@ def _get_bids_path_in(
datatype=get_datatype(config=cfg),
check=False,
)
if kind == "sss":
if kind != "orig":
assert kind in ("sss", "filt"), kind
path_kwargs["root"] = cfg.deriv_root
path_kwargs["suffix"] = "raw"
path_kwargs["extension"] = ".fif"
path_kwargs["processing"] = "sss"
path_kwargs["processing"] = kind
else:
assert kind == "orig", kind
path_kwargs["root"] = cfg.bids_root
path_kwargs["suffix"] = None
path_kwargs["extension"] = None
Expand All @@ -563,7 +562,7 @@ def _get_run_path(
session: Optional[str],
run: Optional[str],
task: Optional[str],
kind: Literal["orig", "sss"],
kind: Literal["orig", "sss", "filt"],
add_bads: Optional[bool] = None,
allow_missing: bool = False,
key: Optional[str] = None,
Expand Down Expand Up @@ -591,7 +590,7 @@ def _get_rest_path(
cfg: SimpleNamespace,
subject: str,
session: Optional[str],
kind: Literal["orig", "sss"],
kind: Literal["orig", "sss", "filt"],
add_bads: Optional[bool] = None,
) -> dict:
if not (cfg.process_rest and not cfg.task_is_rest):
Expand All @@ -613,13 +612,14 @@ def _get_noise_path(
cfg: SimpleNamespace,
subject: str,
session: Optional[str],
kind: Literal["orig", "sss"],
kind: Literal["orig", "sss", "filt"],
mf_reference_run: Optional[str],
add_bads: Optional[bool] = None,
) -> dict:
if not (cfg.process_empty_room and get_datatype(config=cfg) == "meg"):
return dict()
if kind == "sss":
if kind != "orig":
assert kind in ("sss", "filt")
raw_fname = _get_bids_path_in(
cfg=cfg,
subject=subject,
Expand Down Expand Up @@ -658,7 +658,7 @@ def _get_run_rest_noise_path(
session: Optional[str],
run: Optional[str],
task: Optional[str],
kind: Literal["orig", "sss"],
kind: Literal["orig", "sss", "filt"],
mf_reference_run: Optional[str],
add_bads: Optional[bool] = None,
) -> dict:
Expand Down Expand Up @@ -702,7 +702,7 @@ def _path_dict(
cfg: SimpleNamespace,
bids_path_in: BIDSPath,
add_bads: Optional[bool] = None,
kind: Literal["orig", "sss"],
kind: Literal["orig", "sss", "filt"],
allow_missing: bool,
key: Optional[str] = None,
) -> dict:
Expand Down Expand Up @@ -802,3 +802,14 @@ def _import_data_kwargs(*, config: SimpleNamespace, subject: str) -> dict:
runs=get_runs(config=config, subject=subject), # XXX needs to accept session!
**_bids_kwargs(config=config),
)


def _get_run_type(
run: Optional[str],
task: Optional[str],
) -> str:
if run is None and task in ("noise", "rest"):
run_type = dict(rest="resting-state", noise="empty-room")[task]
else:
run_type = "experimental"
return run_type
8 changes: 6 additions & 2 deletions mne_bids_pipeline/_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,13 @@ def _open_report(
yield report
finally:
try:
msg = "Adding config and sys info to report"
logger.info(**gen_log_kwargs(message=msg))
_finalize(
report=report,
exec_params=exec_params,
subject=subject,
session=session,
run=run,
task=task,
)
except Exception as exc:
logger.warning(f"Failed: {exc}")
Expand Down Expand Up @@ -507,12 +506,17 @@ def _finalize(
subject: str,
session: Optional[str],
run: Optional[str],
task: Optional[str],
) -> None:
"""Add system information and the pipeline configuration to the report."""
# ensure they are always appended
titles = ["Configuration file", "System information"]
for title in titles:
report.remove(title=title, remove_all=True)
# Print this exactly once
if _cached_sys_info.cache_info()[-1] == 0: # never run
msg = "Adding config and sys info to report"
logger.info(**gen_log_kwargs(message=msg))
# No longer need replace=True in these
report.add_code(
code=exec_params.config_path,
Expand Down
9 changes: 7 additions & 2 deletions mne_bids_pipeline/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,18 @@ def wrapper(*args, **kwargs):
for key, (fname, this_hash) in out_files_hashes.items():
fname = pathlib.Path(fname)
if not fname.exists():
msg = "Output file missing, will recompute …"
msg = (
f"Output file missing {str(fname)}, " "will recompute …"
)
emoji = "🧩"
bad_out_files = True
break
got_hash = hash_(key, fname, kind="out")[1]
if this_hash != got_hash:
msg = "Output file hash mismatch, will recompute …"
msg = (
f"Output file hash mismatch for {str(fname)}, "
"will recompute …"
)
emoji = "🚫"
bad_out_files = True
break
Expand Down
39 changes: 32 additions & 7 deletions mne_bids_pipeline/steps/preprocessing/_04_frequency_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from typing import Optional, Union, Literal, Iterable

import mne
from mne.io.pick import _picks_to_idx
from mne.preprocessing import EOGRegression

from ..._config_utils import (
get_sessions,
Expand All @@ -29,6 +31,7 @@
import_experimental_data,
import_er_data,
_get_run_rest_noise_path,
_get_run_type,
_import_data_kwargs,
)
from ..._logging import gen_log_kwargs, logger
Expand Down Expand Up @@ -68,6 +71,7 @@ def notch_filter(
trans_bandwidth: Union[float, Literal["auto"]],
notch_widths: Optional[Union[float, Iterable[float]]],
run_type: Literal["experimental", "empty-room", "resting-state"],
picks: Optional[np.ndarray],
) -> None:
"""Filter data channels (MEG and EEG)."""
if freqs is None:
Expand All @@ -85,6 +89,7 @@ def notch_filter(
trans_bandwidth=trans_bandwidth,
notch_widths=notch_widths,
n_jobs=1,
picks=picks,
)


Expand All @@ -99,6 +104,7 @@ def bandpass_filter(
l_trans_bandwidth: Union[float, Literal["auto"]],
h_trans_bandwidth: Union[float, Literal["auto"]],
run_type: Literal["experimental", "empty-room", "resting-state"],
picks: Optional[np.ndarray],
) -> None:
"""Filter data channels (MEG and EEG)."""
if l_freq is not None and h_freq is None:
Expand All @@ -121,6 +127,7 @@ def bandpass_filter(
l_trans_bandwidth=l_trans_bandwidth,
h_trans_bandwidth=h_trans_bandwidth,
n_jobs=1,
picks=picks,
)


Expand Down Expand Up @@ -160,14 +167,10 @@ def filter_data(
bids_path_in = in_files.pop(in_key)
bids_path_bads_in = in_files.pop(f"{in_key}-bads", None)

if run is None and task in ("noise", "rest"):
run_type = dict(rest="resting-state", noise="empty-room")[task]
else:
run_type = "experimental"

run_type = _get_run_type(run=run, task=task)
msg = f"Reading {run_type} recording: " f"{bids_path_in.basename}"
logger.info(**gen_log_kwargs(message=msg))
if cfg.use_maxwell_filter:
msg = f"Reading {run_type} recording: " f"{bids_path_in.basename}"
logger.info(**gen_log_kwargs(message=msg))
raw = mne.io.read_raw_fif(bids_path_in)
elif run is None and task == "noise":
raw = import_er_data(
Expand All @@ -190,6 +193,8 @@ def filter_data(

out_files[in_key] = bids_path_in.copy().update(
root=cfg.deriv_root,
subject=subject, # save under subject's directory so all files are there
session=session,
processing="filt",
extension=".fif",
suffix="raw",
Expand All @@ -199,6 +204,18 @@ def filter_data(
check=False,
)

if cfg.regress_artifact is None:
picks = None
else:
# Need to figure out the correct picks to use
model = EOGRegression(**cfg.regress_artifact)
picks_regress = _picks_to_idx(
raw.info, model.picks, none="data", exclude=model.exclude
)
picks_artifact = _picks_to_idx(raw.info, model.picks_artifact)
picks_data = _picks_to_idx(raw.info, "data", exclude=()) # raw.filter default
picks = np.unique(np.r_[picks_regress, picks_artifact, picks_data])

raw.load_data()
notch_filter(
raw=raw,
Expand All @@ -210,6 +227,7 @@ def filter_data(
trans_bandwidth=cfg.notch_trans_bandwidth,
notch_widths=cfg.notch_widths,
run_type=run_type,
picks=picks,
)
bandpass_filter(
raw=raw,
Expand All @@ -222,6 +240,7 @@ def filter_data(
h_trans_bandwidth=cfg.h_trans_bandwidth,
l_trans_bandwidth=cfg.l_trans_bandwidth,
run_type=run_type,
picks=picks,
)
resample(
raw=raw,
Expand All @@ -236,6 +255,11 @@ def filter_data(
# For example, might need to create
# derivatives/mne-bids-pipeline/sub-emptyroom/ses-20230412/meg
out_files[in_key].fpath.parent.mkdir(exist_ok=True, parents=True)
logger.info(
**gen_log_kwargs(
message=f"Writing filtered data to: {out_files[in_key].basename}"
)
)
raw.save(
out_files[in_key],
overwrite=True,
Expand Down Expand Up @@ -286,6 +310,7 @@ def get_config(
notch_trans_bandwidth=config.notch_trans_bandwidth,
notch_widths=config.notch_widths,
raw_resample_sfreq=config.raw_resample_sfreq,
regress_artifact=config.regress_artifact,
**_import_data_kwargs(config=config, subject=subject),
)
return cfg
Expand Down
Loading

0 comments on commit 2911e7f

Please sign in to comment.