Skip to content

Commit

Permalink
MAINT: Add and check type hints (#995)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel McCloy <dan@mccloy.info>
Co-authored-by: Richard Höchenberger <richard.hoechenberger@gmail.com>
  • Loading branch information
3 people authored Oct 26, 2024
1 parent 3447017 commit c6bb948
Show file tree
Hide file tree
Showing 43 changed files with 708 additions and 519 deletions.
2 changes: 1 addition & 1 deletion .github/pull_request_template.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
### Before merging …

- [ ] Changelog has been updated (`docs/source/changes.md`)
- [ ] Changelog has been updated (`docs/source/vX.Y.md.inc`)
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ build/
.hypothesis/
.coverage*
junit-results.xml
.cache/
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,8 @@ repos:
hooks:
- id: yamllint
args: [--strict]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
hooks:
- id: mypy
30 changes: 17 additions & 13 deletions docs/source/examples/gen_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,32 @@
import shutil
import sys
from collections import defaultdict
from collections.abc import Iterable
from collections.abc import Generator, Iterable
from pathlib import Path
from typing import Any

from tqdm import tqdm

import mne_bids_pipeline
import mne_bids_pipeline.tests.datasets
from mne_bids_pipeline._config_import import _import_config
from mne_bids_pipeline.tests.datasets import DATASET_OPTIONS
from mne_bids_pipeline.tests.datasets import DATASET_OPTIONS, DATASET_OPTIONS_T
from mne_bids_pipeline.tests.test_run import TEST_SUITE

this_dir = Path(__file__).parent
root = Path(mne_bids_pipeline.__file__).parent.resolve(strict=True)
logger = logging.getLogger()


def _bool_to_icon(x: bool | Iterable) -> str:
def _bool_to_icon(x: bool | Iterable[Any]) -> str:
if x:
return "✅"
else:
return "❌"


@contextlib.contextmanager
def _task_context(task):
def _task_context(task: str | None) -> Generator[None, None, None]:
old_argv = sys.argv
if task:
sys.argv = [sys.argv[0], f"--task={task}"]
Expand All @@ -41,7 +42,7 @@ def _task_context(task):
sys.argv = old_argv


def _gen_demonstrated_funcs(example_config_path: Path) -> dict:
def _gen_demonstrated_funcs(example_config_path: Path) -> dict[str, bool]:
"""Generate dict of demonstrated functionality based on config."""
# Here we use a defaultdict, and for keys that might vary across configs
# we should use an `funcs[key] = funcs[key] or ...` so that we effectively
Expand Down Expand Up @@ -160,7 +161,9 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict:
continue

assert dataset_options_key in DATASET_OPTIONS, dataset_options_key
options = DATASET_OPTIONS[dataset_options_key].copy() # we modify locally
options: DATASET_OPTIONS_T = DATASET_OPTIONS[
dataset_options_key
].copy() # we modify locally

report_str = "\n## Generated output\n\n"
example_target_dir = this_dir / dataset_name
Expand Down Expand Up @@ -228,8 +231,8 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict:
source_str = f"## Dataset source\n\nThis dataset was acquired from [{url}]({url})\n"

if "openneuro" in options:
for key in ("include", "exclude"):
options[key] = options.get(key, [])
options.setdefault("include", [])
options.setdefault("exclude", [])
download_str = (
f'\n??? example "How to download this dataset"\n'
f" Run in your terminal:\n"
Expand Down Expand Up @@ -295,6 +298,7 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict:
f.write(download_str)
f.write(config_str)
f.write(report_str)
del dataset_name, funcs

# Finally, write our examples.html file with a table of examples

Expand All @@ -315,13 +319,13 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict:
with out_path.open("w", encoding="utf-8") as f:
f.write(_example_header)
header_written = False
for dataset_name, funcs in all_demonstrated.items():
for this_dataset_name, these_funcs in all_demonstrated.items():
if not header_written:
f.write("Dataset | " + " | ".join(funcs.keys()) + "\n")
f.write("--------|" + "|".join([":---:"] * len(funcs)) + "\n")
f.write("Dataset | " + " | ".join(these_funcs.keys()) + "\n")
f.write("--------|" + "|".join([":---:"] * len(these_funcs)) + "\n")
header_written = True
f.write(
f"[{dataset_name}]({dataset_name}.md) | "
+ " | ".join(_bool_to_icon(v) for v in funcs.values())
f"[{this_dataset_name}]({this_dataset_name}.md) | "
+ " | ".join(_bool_to_icon(v) for v in these_funcs.values())
+ "\n"
)
23 changes: 12 additions & 11 deletions docs/source/features/gen_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,10 @@
if dir_ == "all":
continue # this is an alias
dir_module = importlib.import_module(f"mne_bids_pipeline.steps.{dir_}")
assert dir_module.__doc__ is not None
dir_header = dir_module.__doc__.split("\n")[0].rstrip(".")
dir_body = dir_module.__doc__.split("\n", maxsplit=1)
if len(dir_body) > 1:
dir_body = dir_body[1].strip()
else:
dir_body = ""
dir_body_list = dir_module.__doc__.split("\n", maxsplit=1)
dir_body = dir_body_list[1].strip() if len(dir_body_list) > 1 else ""
icon = icon_map[dir_header]
module_header = f"{di}. {icon} {dir_header}"
lines.append(f"## {module_header}\n")
Expand All @@ -132,6 +130,8 @@
dir_name, step_title = dir_, f"Run all {dir_header.lower()} steps."
lines.append(f"`{dir_name}` | {step_title} |")
for module in modules:
assert module.__file__ is not None
assert module.__doc__ is not None
step_name = f"{dir_name}/{Path(module.__file__).name}"[:-3]
step_title = module.__doc__.split("\n")[0]
lines.append(f"`{step_name}` | {step_title} |")
Expand All @@ -153,6 +153,8 @@
prev_idx = None
title_map = {}
for mi, module in enumerate(modules, 1):
assert module.__doc__ is not None
assert module.__name__ is not None
step_title = module.__doc__.split("\n")[0].rstrip(".")
idx = module.__name__.split(".")[-1].split("_")[1] # 01, 05a, etc.
# Need to quote the title to deal with parens, and sanitize quotes
Expand Down Expand Up @@ -189,12 +191,11 @@
mapped.add(idx)
a_b[ii] = f'{idx}["{title_map[idx]}"]'
overview_lines.append(f" {chr_pre}{a_b[0]} --> {chr_pre}{a_b[1]}")
all_steps = set(
sum(
[a_b for a_b in manual_order[dir_header] if not isinstance(a_b, str)],
(),
)
)
all_steps_list: list[str] = list()
for a_b in manual_order[dir_header]:
if not isinstance(a_b, str):
all_steps_list.extend(a_b)
all_steps = set(all_steps_list)
assert mapped == all_steps, all_steps.symmetric_difference(mapped)
overview_lines.append("```\n\n</details>\n")

Expand Down
1 change: 1 addition & 0 deletions docs/source/v1.9.md.inc
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@
- Use GitHub's `dependabot` service to automatically keep GitHub Actions up-to-date. (#893 by @hoechenberger)
- Clean up some strings that our autoformatter failed to correctly merge. (#965 by @drammock)
- Type hints are now checked using `mypy`. (#995 by @larsoner)
16 changes: 8 additions & 8 deletions mne_bids_pipeline/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
Enabling interactive mode deactivates parallel processing.
"""

sessions: list | Literal["all"] = "all"
sessions: list[str] | Literal["all"] = "all"
"""
The sessions to process. If `'all'`, will process all sessions found in the
BIDS dataset.
Expand All @@ -95,7 +95,7 @@
Whether the task should be treated as resting-state data.
"""

runs: Sequence | Literal["all"] = "all"
runs: Sequence[str] | Literal["all"] = "all"
"""
The runs to process. If `'all'`, will process all runs found in the
BIDS dataset.
Expand Down Expand Up @@ -407,7 +407,7 @@
```
"""

reader_extra_params: dict = {}
reader_extra_params: dict[str, Any] = {}
"""
Parameters to be passed to `read_raw_bids()` calls when importing raw data.
Expand Down Expand Up @@ -891,7 +891,7 @@

# ## Epoching

rename_events: dict = dict()
rename_events: dict[str, str] = dict()
"""
A dictionary specifying which events in the BIDS dataset to rename upon
loading, and before processing begins.
Expand Down Expand Up @@ -1812,14 +1812,14 @@
```
"""

time_frequency_crop: dict | None = None
time_frequency_crop: dict[str, float] | None = None
"""
Period and frequency range to crop the time-frequency analysis to.
If `None`, no cropping.
???+ example "Example"
```python
time_frequency_crop = dict(tmin=-0.3, tmax=0.5, fmin=5, fmax=20)
time_frequency_crop = dict(tmin=-0.3, tmax=0.5, fmin=5., fmax=20.)
```
"""

Expand Down Expand Up @@ -2038,7 +2038,7 @@ def mri_landmarks_kind(bids_path):
version of MNE-BIDS-Pipeline.
"""

depth: Annotated[float, Interval(ge=0, le=1)] | dict = 0.8
depth: Annotated[float, Interval(ge=0, le=1)] | dict[str, Any] = 0.8
"""
If a number, it acts as the depth weighting exponent to use
(must be between `0` and`1`), with`0` meaning no depth weighting is performed.
Expand Down Expand Up @@ -2216,7 +2216,7 @@ def noise_cov(bids_path):
```
"""

report_add_epochs_image_kwargs: dict | None = None
report_add_epochs_image_kwargs: dict[str, Any] | None = None
"""
Specifies the limits for the color scales of the epochs_image in the report.
If `None`, it defaults to the current default in MNE-Python.
Expand Down
31 changes: 19 additions & 12 deletions mne_bids_pipeline/_config_import.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import ast
import copy
import difflib
import importlib
import importlib.util
import os
import pathlib
from dataclasses import field
from functools import partial
from types import SimpleNamespace
from typing import Any

import matplotlib
import mne
Expand Down Expand Up @@ -48,7 +49,7 @@ def _import_config(
log=log,
)

extra_exec_params_keys = ()
extra_exec_params_keys: tuple[str, ...] = ()
extra_config = os.getenv("_MNE_BIDS_STUDY_TESTING_EXTRA_CONFIG", "")
if extra_config:
msg = f"With testing config: {extra_config}"
Expand Down Expand Up @@ -107,7 +108,7 @@ def _import_config(
return config


def _get_default_config():
def _get_default_config() -> SimpleNamespace:
from . import _config

# Don't use _config itself as it's mutable -- make a new object
Expand All @@ -134,15 +135,18 @@ def _update_config_from_path(
*,
config: SimpleNamespace,
config_path: PathLike,
):
) -> list[str]:
user_names = list()
config_path = pathlib.Path(config_path).expanduser().resolve(strict=True)
# Import configuration from an arbitrary path without having to fiddle
# with `sys.path`.
spec = importlib.util.spec_from_file_location(
name="custom_config", location=config_path
)
assert spec is not None
assert spec.loader is not None
custom_cfg = importlib.util.module_from_spec(spec)
assert custom_cfg is not None
spec.loader.exec_module(custom_cfg)
for key in dir(custom_cfg):
if not key.startswith("__"):
Expand Down Expand Up @@ -338,7 +342,7 @@ def _check_config(config: SimpleNamespace, config_path: PathLike | None) -> None
)


def _default_factory(key, val):
def _default_factory(key: str, val: Any) -> Any:
# convert a default to a default factory if needed, having an explicit
# allowlist of non-empty ones
allowlist = [
Expand All @@ -347,6 +351,10 @@ def _default_factory(key, val):
["evoked"], # inverse_targets
[4, 8, 16], # autoreject_n_interpolate
]

def default_factory() -> Any:
return val

for typ in (dict, list):
if isinstance(val, typ):
try:
Expand All @@ -356,18 +364,18 @@ def _default_factory(key, val):
default_factory = typ
else:
if typ is dict:
default_factory = partial(typ, **allowlist[idx])
default_factory = partial(typ, **allowlist[idx]) # type: ignore
else:
assert typ is list
default_factory = partial(typ, allowlist[idx])
return field(default_factory=default_factory)
default_factory = partial(typ, allowlist[idx]) # type: ignore
return field(default_factory=default_factory) # type: ignore
return val


def _pydantic_validate(
config: SimpleNamespace,
config_path: PathLike | None,
):
) -> None:
"""Create dataclass from config type hints and validate with pydantic."""
# https://docs.pydantic.dev/latest/usage/dataclasses/
from . import _config as root_config
Expand Down Expand Up @@ -395,12 +403,12 @@ def _pydantic_validate(
# Now use pydantic to automagically validate
user_vals = {key: val for key, val in config.__dict__.items() if key in annotations}
try:
UserConfig.model_validate(user_vals)
UserConfig.model_validate(user_vals) # type: ignore[attr-defined]
except ValidationError as err:
raise ValueError(str(err)) from None


_REMOVED_NAMES = {
_REMOVED_NAMES: dict[str, dict[str, str | None]] = {
"debug": dict(
new_name="on_error",
instead='use on_error="debug" instead',
Expand Down Expand Up @@ -430,7 +438,6 @@ def _check_misspellings_removals(
) -> None:
# for each name in the user names, check if it's in the valid names but
# the correct one is not defined
valid_names = set(valid_names)
for user_name in user_names:
if user_name not in valid_names:
# find the closest match
Expand Down
Loading

0 comments on commit c6bb948

Please sign in to comment.