Skip to content

Commit

Permalink
MAINT: Better type checking (#1013)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
larsoner and pre-commit-ci[bot] authored Oct 29, 2024
1 parent c6bb948 commit dbbbe24
Show file tree
Hide file tree
Showing 34 changed files with 151 additions and 102 deletions.
6 changes: 4 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
files: ^(.*\.(py|yaml))$
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.0
rev: v0.7.1
hooks:
- id: ruff
args: ["--fix"]
Expand All @@ -20,6 +20,8 @@ repos:
args: [--strict]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
rev: v1.13.0
hooks:
- id: mypy
additional_dependencies:
- types-PyYAML
11 changes: 7 additions & 4 deletions docs/source/features/gen_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,16 @@
overview_lines.append(a_b)
continue
assert isinstance(a_b, tuple), type(a_b)
a_b = list(a_b) # allow modification
for ii, idx in enumerate(a_b):
a_b_list: list[str] = list(a_b) # allow modification
del a_b
for ii, idx in enumerate(a_b_list):
assert idx in title_map, (dir_header, idx, sorted(title_map))
if idx not in mapped:
mapped.add(idx)
a_b[ii] = f'{idx}["{title_map[idx]}"]'
overview_lines.append(f" {chr_pre}{a_b[0]} --> {chr_pre}{a_b[1]}")
a_b_list[ii] = f'{idx}["{title_map[idx]}"]'
overview_lines.append(
f" {chr_pre}{a_b_list[0]} --> {chr_pre}{a_b_list[1]}"
)
all_steps_list: list[str] = list()
for a_b in manual_order[dir_header]:
if not isinstance(a_b, str):
Expand Down
7 changes: 4 additions & 3 deletions docs/source/settings/gen_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,13 @@
)


def main():
def main() -> None:
"""Parse the configuration and generate the markdown documentation."""
print(f"Parsing {config_path} to generate settings .md files.")
# max file-level depth is 2 even though we have 3 subsection levels
levels = [None, None]
current_path, current_lines = None, list()
levels = ["", ""]
current_path: Path | None = None
current_lines: list[str] = list()
text = config_path.read_text("utf-8")
lines = text.splitlines()
lines += ["# #"] # add a dummy line to trigger the last write
Expand Down
3 changes: 1 addition & 2 deletions docs/source/v1.10.md.inc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
### :bug: Bug fixes

- Empty room matching is now done for all sessions (previously only for the first session) for each subject. (#976 by @drammock)

- [`noise_cov_method`][mne_bids_pipeline._config.noise_cov_method] is now properly used for noise covariance estimation from raw data (#1010 by @larsoner)

- When running the pipeline with [`mf_filter_chpi`][mne_bids_pipeline._config.mf_filter_chpi] enabled (#977 by @drammock and @larsoner):

1. Emptyroom files that lack cHPI channels will now be processed (for line noise only) instead of raising an error.
Expand All @@ -33,6 +31,7 @@
### :medical_symbol: Code health

- Switch from using relative to using absolute imports (#969 by @hoechenberger)
- Enable strict type checking via mypy (#995, #1013 by @larsoner)

### :medical_symbol: Code health and infrastructure

Expand Down
2 changes: 1 addition & 1 deletion mne_bids_pipeline/_config_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def default_factory() -> Any:
else:
assert typ is list
default_factory = partial(typ, allowlist[idx]) # type: ignore
return field(default_factory=default_factory) # type: ignore
return field(default_factory=default_factory)
return val


Expand Down
2 changes: 1 addition & 1 deletion mne_bids_pipeline/_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def _get_step_modules() -> dict[str, tuple[ModuleType, ...]]:
return STEP_MODULES


def _bids_kwargs(*, config: SimpleNamespace) -> dict[str, Any]:
def _bids_kwargs(*, config: SimpleNamespace) -> dict[str, str | None]:
"""Get the standard BIDS config entries."""
return dict(
proc=config.proc,
Expand Down
2 changes: 1 addition & 1 deletion mne_bids_pipeline/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .typing import PathLike


def _write_json(fname: PathLike, data: dict[str, Any]) -> None:
def _write_json(fname: PathLike, data: dict[str, Any] | None) -> None:
with open(fname, "w", encoding="utf-8") as f:
json_tricks.dump(data, fp=f, allow_nan=True, sort_keys=False)

Expand Down
2 changes: 1 addition & 1 deletion mne_bids_pipeline/_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self) -> None:
@property
def _console(self) -> rich.console.Console:
try:
return self.__console # type: ignore[no-any-return,has-type]
return self.__console # type: ignore[has-type]
except AttributeError:
pass # need to instantiate it, continue

Expand Down
2 changes: 1 addition & 1 deletion mne_bids_pipeline/_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def setup_dask_client(*, exec_params: SimpleNamespace) -> None:
"distributed.worker.memory.spill": False,
}
)
client = Client( # type: ignore[no-untyped-call]
client = Client(
memory_limit=exec_params.dask_worker_memory_limit,
n_workers=n_workers,
threads_per_worker=1,
Expand Down
16 changes: 13 additions & 3 deletions mne_bids_pipeline/steps/freesurfer/_01_recon_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import shutil
import sys
from pathlib import Path
from types import SimpleNamespace
from typing import Any

from mne.utils import run_subprocess

Expand All @@ -22,7 +24,13 @@
fs_bids_app = Path(__file__).parent / "contrib" / "run.py"


def run_recon(root_dir, subject, fs_bids_app, subjects_dir, session=None) -> None:
def run_recon(
root_dir: Path,
subject: str,
fs_bids_app: Any,
subjects_dir: Path,
session: str | None = None,
) -> None:
subj_dir = subjects_dir / f"sub-{subject}"
sub_ses = f"Subject {subject}"
if session is not None:
Expand Down Expand Up @@ -70,11 +78,13 @@ def run_recon(root_dir, subject, fs_bids_app, subjects_dir, session=None) -> Non
run_subprocess(cmd, env=env, verbose=logger.level)


def _has_session_specific_anat(subject, session, subjects_dir):
def _has_session_specific_anat(
subject: str, session: str | None, subjects_dir: Path
) -> bool:
return (subjects_dir / f"sub-{subject}_ses-{session}").exists()


def main(*, config) -> None:
def main(*, config: SimpleNamespace) -> None:
"""Run freesurfer recon-all command on BIDS dataset.
The script allows to run the freesurfer recon-all
Expand Down
2 changes: 1 addition & 1 deletion mne_bids_pipeline/steps/freesurfer/_02_coreg_surfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_config(
return cfg


def main(*, config) -> None:
def main(*, config: SimpleNamespace) -> None:
# Ensure we're also processing fsaverage if present
subjects = get_subjects(config)
sessions = get_sessions(config)
Expand Down
4 changes: 2 additions & 2 deletions mne_bids_pipeline/steps/init/_01_init_derivatives_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def init_subject_dirs(

def get_config(
*,
config,
config: SimpleNamespace,
) -> SimpleNamespace:
cfg = SimpleNamespace(
PIPELINE_NAME=config.PIPELINE_NAME,
Expand All @@ -71,7 +71,7 @@ def get_config(
return cfg


def main(*, config):
def main(*, config: SimpleNamespace) -> None:
"""Initialize the output directories."""
init_dataset(cfg=get_config(config=config), exec_params=config.exec_params)
# Don't bother with parallelization here as I/O operations are generally
Expand Down
4 changes: 2 additions & 2 deletions mne_bids_pipeline/steps/init/_02_find_empty_room.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ def find_empty_room(

def get_config(
*,
config,
config: SimpleNamespace,
) -> SimpleNamespace:
cfg = SimpleNamespace(
**_bids_kwargs(config=config),
)
return cfg


def main(*, config) -> None:
def main(*, config: SimpleNamespace) -> None:
"""Run find_empty_room."""
if not config.process_empty_room:
msg = "Skipping, process_empty_room is set to False …"
Expand Down
7 changes: 4 additions & 3 deletions mne_bids_pipeline/steps/preprocessing/_01_data_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def assess_data_quality(
)
preexisting_bads = sorted(raw.info["bads"])

auto_scores: dict | None = None
auto_noisy_chs: list[str] | None = None
auto_flat_chs: list[str] | None = None
if _do_mf_autobad(cfg=cfg):
(
auto_noisy_chs,
Expand All @@ -126,8 +129,6 @@ def assess_data_quality(
raw.info["bads"] = bads
del bads
logger.info(**gen_log_kwargs(message=msg))
else:
auto_scores = auto_noisy_chs = auto_flat_chs = None
del key

# Always output the scores and bads TSV
Expand Down Expand Up @@ -241,7 +242,7 @@ def _find_bads_maxwell(
session: str | None,
run: str | None,
task: str | None,
):
) -> tuple[list[str], list[str], dict]:
if cfg.find_flat_channels_meg:
if cfg.find_noisy_channels_meg:
msg = "Finding flat channels and noisy channels using Maxwell filtering."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

def detect_bad_components(
*,
cfg,
cfg: SimpleNamespace,
which: Literal["eog", "ecg"],
epochs: mne.BaseEpochs | None,
ica: mne.preprocessing.ICA,
Expand Down
9 changes: 6 additions & 3 deletions mne_bids_pipeline/steps/preprocessing/_07_make_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from types import SimpleNamespace

import mne
import numpy as np
from mne_bids import BIDSPath

from mne_bids_pipeline._config_utils import (
Expand Down Expand Up @@ -275,7 +276,9 @@ def _add_epochs_image_kwargs(cfg: SimpleNamespace) -> dict:


# TODO: ideally we wouldn't need this anymore and could refactor the code above
def _get_events(cfg, subject, session):
def _get_events(
cfg: SimpleNamespace, subject: str, session: str | None
) -> tuple[np.ndarray, dict, float, int]:
raws_filt = []
raw_fname = BIDSPath(
subject=subject,
Expand Down Expand Up @@ -307,7 +310,7 @@ def _get_events(cfg, subject, session):

def get_config(
*,
config,
config: SimpleNamespace,
subject: str,
) -> SimpleNamespace:
cfg = SimpleNamespace(
Expand Down Expand Up @@ -337,7 +340,7 @@ def get_config(
return cfg


def main(*, config) -> None:
def main(*, config: SimpleNamespace) -> None:
"""Run epochs."""
with get_parallel_backend(config.exec_params):
parallel, run_func = parallel_func(run_epochs, exec_params=config.exec_params)
Expand Down
2 changes: 1 addition & 1 deletion mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _ica_paths(
cfg: SimpleNamespace,
subject: str,
session: str | None,
):
) -> dict[str, BIDSPath]:
bids_basename = BIDSPath(
subject=subject,
session=session,
Expand Down
17 changes: 15 additions & 2 deletions mne_bids_pipeline/steps/sensor/_05_decoding_csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ def _prepare_labels(*, epochs: mne.BaseEpochs, contrast: tuple[str, str]) -> np.


def prepare_epochs_and_y(
*, epochs: mne.BaseEpochs, contrast: tuple[str, str], cfg, fmin: float, fmax: float
*,
epochs: mne.BaseEpochs,
contrast: tuple[str, str],
cfg: SimpleNamespace,
fmin: float,
fmax: float,
) -> tuple[mne.BaseEpochs, np.ndarray]:
"""Band-pass between, sub-select the desired epochs, and prepare y."""
# filtering out the conditions we are not interested in, to ensure here we
Expand Down Expand Up @@ -223,7 +228,15 @@ def one_subject_decoding(
)
del freq_decoding_table_rows

def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=None):
def _fmt_contrast(
cond1: str,
cond2: str,
fmin: float,
fmax: float,
freq_range_name: str,
tmin: float | None = None,
tmax: float | None = None,
) -> str:
msg = (
f"Contrast: {cond1}{cond2}, "
f"{fmin:4.1f}{fmax:4.1f} Hz ({freq_range_name})"
Expand Down
4 changes: 2 additions & 2 deletions mne_bids_pipeline/steps/sensor/_06_make_cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def retrieve_custom_cov(
return cov


def _get_cov_type(cfg):
def _get_cov_type(cfg: SimpleNamespace) -> str:
if cfg.noise_cov == "custom":
return "custom"
elif cfg.noise_cov == "rest":
Expand All @@ -216,7 +216,7 @@ def run_covariance(
subject: str,
session: str | None = None,
in_files: dict,
) -> dict:
) -> dict[str, BIDSPath]:
import matplotlib.pyplot as plt

out_files = dict()
Expand Down
6 changes: 3 additions & 3 deletions mne_bids_pipeline/steps/sensor/_99_group_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def _decoding_out_fname(
cond_2: str | None,
kind: str,
extension: str = ".mat",
):
) -> BIDSPath:
if cond_1 is None:
assert cond_2 is None
processing = ""
Expand Down Expand Up @@ -741,7 +741,7 @@ def average_csp_decoding(
cond_1: str,
cond_2: str,
in_files: dict,
):
) -> dict[str, BIDSPath]:
msg = f"Summarizing CSP results: {cond_1} - {cond_2}."
logger.info(**gen_log_kwargs(message=msg))
in_files.pop("epochs")
Expand Down Expand Up @@ -973,7 +973,7 @@ def _average_csp_time_freq(

def get_config(
*,
config,
config: SimpleNamespace,
) -> SimpleNamespace:
cfg = SimpleNamespace(
subjects=get_subjects(config),
Expand Down
2 changes: 1 addition & 1 deletion mne_bids_pipeline/steps/source/_01_make_bem_surfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from mne_bids_pipeline._run import _prep_out_files, failsafe_run, save_logs


def _get_bem_params(cfg: SimpleNamespace):
def _get_bem_params(cfg: SimpleNamespace) -> tuple[str, Path, Path]:
mri_dir = Path(cfg.fs_subjects_dir) / cfg.fs_subject / "mri"
flash_dir = mri_dir / "flash" / "parameter_maps"
if cfg.bem_mri_images == "FLASH" and not flash_dir.exists():
Expand Down
2 changes: 1 addition & 1 deletion mne_bids_pipeline/steps/source/_02_make_bem_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def get_config(
return cfg


def main(*, config) -> None:
def main(*, config: SimpleNamespace) -> None:
"""Run BEM solution calculation."""
if not config.run_source_estimation:
msg = "Skipping, run_source_estimation is set to False …"
Expand Down
Loading

0 comments on commit dbbbe24

Please sign in to comment.