Skip to content

Commit

Permalink
MAINT: Add and check type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Sep 11, 2024
1 parent 65c5b8f commit 7b2039a
Show file tree
Hide file tree
Showing 31 changed files with 365 additions and 239 deletions.
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,8 @@ repos:
hooks:
- id: yamllint
args: [--strict]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
hooks:
- id: mypy
21 changes: 12 additions & 9 deletions docs/source/examples/gen_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import mne_bids_pipeline
import mne_bids_pipeline.tests.datasets
from mne_bids_pipeline._config_import import _import_config
from mne_bids_pipeline.tests.datasets import DATASET_OPTIONS
from mne_bids_pipeline.tests.datasets import DATASET_OPTIONS, DATASET_OPTIONS_T
from mne_bids_pipeline.tests.test_run import TEST_SUITE

this_dir = Path(__file__).parent
Expand Down Expand Up @@ -160,7 +160,9 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict:
continue

assert dataset_options_key in DATASET_OPTIONS, dataset_options_key
options = DATASET_OPTIONS[dataset_options_key].copy() # we modify locally
options: DATASET_OPTIONS_T = DATASET_OPTIONS[
dataset_options_key
].copy() # we modify locally

report_str = "\n## Generated output\n\n"
example_target_dir = this_dir / dataset_name
Expand Down Expand Up @@ -228,8 +230,8 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict:
source_str = f"## Dataset source\n\nThis dataset was acquired from [{url}]({url})\n"

if "openneuro" in options:
for key in ("include", "exclude"):
options[key] = options.get(key, [])
options["include"] = options.get("include", [])
options["exclude"] = options.get("exclude", [])
download_str = (
f'\n??? example "How to download this dataset"\n'
f" Run in your terminal:\n"
Expand Down Expand Up @@ -295,6 +297,7 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict:
f.write(download_str)
f.write(config_str)
f.write(report_str)
del dataset_name, funcs

# Finally, write our examples.html file with a table of examples

Expand All @@ -315,13 +318,13 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict:
with out_path.open("w", encoding="utf-8") as f:
f.write(_example_header)
header_written = False
for dataset_name, funcs in all_demonstrated.items():
for this_dataset_name, these_funcs in all_demonstrated.items():
if not header_written:
f.write("Dataset | " + " | ".join(funcs.keys()) + "\n")
f.write("--------|" + "|".join([":---:"] * len(funcs)) + "\n")
f.write("Dataset | " + " | ".join(these_funcs.keys()) + "\n")
f.write("--------|" + "|".join([":---:"] * len(these_funcs)) + "\n")
header_written = True
f.write(
f"[{dataset_name}]({dataset_name}.md) | "
+ " | ".join(_bool_to_icon(v) for v in funcs.values())
f"[{this_dataset_name}]({this_dataset_name}.md) | "
+ " | ".join(_bool_to_icon(v) for v in these_funcs.values())
+ "\n"
)
24 changes: 13 additions & 11 deletions docs/source/features/gen_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,10 @@
if dir_ == "all":
continue # this is an alias
dir_module = importlib.import_module(f"mne_bids_pipeline.steps.{dir_}")
assert dir_module.__doc__ is not None
dir_header = dir_module.__doc__.split("\n")[0].rstrip(".")
dir_body = dir_module.__doc__.split("\n", maxsplit=1)
if len(dir_body) > 1:
dir_body = dir_body[1].strip()
else:
dir_body = ""
dir_body_list = dir_module.__doc__.split("\n", maxsplit=1)
dir_body = dir_body_list[1].strip() if len(dir_body_list) > 1 else ""
icon = icon_map[dir_header]
module_header = f"{di}. {icon} {dir_header}"
lines.append(f"## {module_header}\n")
Expand All @@ -132,6 +130,8 @@
dir_name, step_title = dir_, f"Run all {dir_header.lower()} steps."
lines.append(f"`{dir_name}` | {step_title} |")
for module in modules:
assert module.__file__ is not None
assert module.__doc__ is not None
step_name = f"{dir_name}/{Path(module.__file__).name}"[:-3]
step_title = module.__doc__.split("\n")[0]
lines.append(f"`{step_name}` | {step_title} |")
Expand All @@ -153,6 +153,8 @@
prev_idx = None
title_map = {}
for mi, module in enumerate(modules, 1):
assert module.__doc__ is not None
assert module.__name__ is not None
step_title = module.__doc__.split("\n")[0].rstrip(".")
idx = module.__name__.split(".")[-1].split("_")[1] # 01, 05a, etc.
# Need to quote the title to deal with parens, and sanitize quotes
Expand Down Expand Up @@ -189,12 +191,12 @@
mapped.add(idx)
a_b[ii] = f'{idx}["{title_map[idx]}"]'
overview_lines.append(f" {chr_pre}{a_b[0]} --> {chr_pre}{a_b[1]}")
all_steps = set(
sum(
[a_b for a_b in manual_order[dir_header] if not isinstance(a_b, str)],
(),
)
)
all_steps_list: list[str] = list()
for a_b in manual_order[dir_header]:
if not isinstance(a_b, str):
all_steps_list.extend(a_b)
all_steps = set(all_steps_list)
assert len(all_steps) == len(all_steps_list)
assert mapped == all_steps, all_steps.symmetric_difference(mapped)
overview_lines.append("```\n\n</details>\n")

Expand Down
12 changes: 7 additions & 5 deletions mne_bids_pipeline/_config_import.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast
import copy
import difflib
import importlib
import importlib.util
import os
import pathlib
from dataclasses import field
Expand Down Expand Up @@ -48,7 +48,7 @@ def _import_config(
log=log,
)

extra_exec_params_keys = ()
extra_exec_params_keys: tuple[str, ...] = ()
extra_config = os.getenv("_MNE_BIDS_STUDY_TESTING_EXTRA_CONFIG", "")
if extra_config:
msg = f"With testing config: {extra_config}"
Expand Down Expand Up @@ -142,7 +142,10 @@ def _update_config_from_path(
spec = importlib.util.spec_from_file_location(
name="custom_config", location=config_path
)
assert spec is not None
assert spec.loader is not None
custom_cfg = importlib.util.module_from_spec(spec)
assert custom_cfg is not None
spec.loader.exec_module(custom_cfg)
for key in dir(custom_cfg):
if not key.startswith("__"):
Expand Down Expand Up @@ -395,12 +398,12 @@ def _pydantic_validate(
# Now use pydantic to automagically validate
user_vals = {key: val for key, val in config.__dict__.items() if key in annotations}
try:
UserConfig.model_validate(user_vals)
UserConfig.model_validate(user_vals) # type: ignore[attr-defined]
except ValidationError as err:
raise ValueError(str(err)) from None


_REMOVED_NAMES = {
_REMOVED_NAMES: dict[str, dict[str, str | None]] = {
"debug": dict(
new_name="on_error",
instead='use on_error="debug" instead',
Expand Down Expand Up @@ -430,7 +433,6 @@ def _check_misspellings_removals(
) -> None:
# for each name in the user names, check if it's in the valid names but
# the correct one is not defined
valid_names = set(valid_names)
for user_name in user_names:
if user_name not in valid_names:
# find the closest match
Expand Down
83 changes: 46 additions & 37 deletions mne_bids_pipeline/_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _get_datatypes_cached(root):
return mne_bids.get_datatypes(root=root)


def _get_ignore_datatypes(config: SimpleNamespace) -> tuple[str]:
def _get_ignore_datatypes(config: SimpleNamespace) -> tuple[str, ...]:
_all_datatypes: list[str] = _get_datatypes_cached(root=config.bids_root)
_ignore_datatypes = set(_all_datatypes) - set([get_datatype(config)])
return tuple(sorted(_ignore_datatypes))
Expand Down Expand Up @@ -163,7 +163,7 @@ def _get_runs_all_subjects_cached(
config = SimpleNamespace(**config_dict)
# Sometimes we check list equivalence for ch_types, so convert it back
config.ch_types = list(config.ch_types)
subj_runs = dict()
subj_runs: dict[str, list[None] | list[str]] = dict()
for subject in get_subjects(config):
# Only traverse through the current subject's directory
valid_runs_subj = _get_entity_vals_cached(
Expand All @@ -174,27 +174,27 @@ def _get_runs_all_subjects_cached(

# If we don't have any `run` entities, just set it to None, as we
# commonly do when creating a BIDSPath.
if not valid_runs_subj:
valid_runs_subj = [None]

if subject in (config.exclude_runs or {}):
valid_runs_subj = [
r for r in valid_runs_subj if r not in config.exclude_runs[subject]
]
subj_runs[subject] = valid_runs_subj
if valid_runs_subj:
if subject in (config.exclude_runs or {}):
valid_runs_subj = [
r for r in valid_runs_subj if r not in config.exclude_runs[subject]
]
subj_runs[subject] = valid_runs_subj
else:
subj_runs[subject] = [None]

return subj_runs


def get_intersect_run(config: SimpleNamespace) -> list[str]:
def get_intersect_run(config: SimpleNamespace) -> list[str | None]:
"""Return the intersection of all the runs of all subjects."""
subj_runs = get_runs_all_subjects(config)
# Do not use something like:
# list(set.intersection(*map(set, subj_runs.values())))
# as it will not preserve order. Instead just be explicit and preserve order.
# We could use "sorted", but it's probably better to use the order provided by
# the user (if they want to put `runs=["02", "01"]` etc. it's better to use "02")
all_runs = list()
all_runs: list[str | None] = list()
for runs in subj_runs.values():
for run in runs:
if run not in all_runs:
Expand Down Expand Up @@ -264,38 +264,47 @@ def get_runs_tasks(
config: SimpleNamespace,
subject: str,
session: str | None,
which: tuple[str] = ("runs", "noise", "rest"),
) -> list[tuple[str]]:
which: tuple[str, ...] = ("runs", "noise", "rest"),
) -> tuple[tuple[str | None, str | None], ...]:
"""Get (run, task) tuples for all runs plus (maybe) rest."""
from ._import_data import _get_noise_path, _get_rest_path

assert isinstance(which, tuple)
assert all(isinstance(inc, str) for inc in which)
assert all(inc in ("runs", "noise", "rest") for inc in which)
runs = list()
tasks = list()
runs: list[str | None] = list()
tasks: list[str | None] = list()
if "runs" in which:
runs.extend(get_runs(config=config, subject=subject))
tasks.extend([get_task(config=config)] * len(runs))
kwargs = dict(
cfg=config,
subject=subject,
session=session,
kind="orig",
add_bads=False,
)
if "rest" in which and _get_rest_path(**kwargs):
runs.append(None)
tasks.append("rest")
if "rest" in which:
rest_path = _get_rest_path(
cfg=config,
subject=subject,
session=session,
kind="orig",
add_bads=False,
)
if rest_path:
runs.append(None)
tasks.append("rest")
if "noise" in which:
mf_reference_run = get_mf_reference_run(config=config)
if _get_noise_path(mf_reference_run=mf_reference_run, **kwargs):
noise_path = _get_noise_path(
mf_reference_run=mf_reference_run,
cfg=config,
subject=subject,
session=session,
kind="orig",
add_bads=False,
)
if noise_path:
runs.append(None)
tasks.append("noise")
return tuple(zip(runs, tasks))


def get_mf_reference_run(config: SimpleNamespace) -> str:
def get_mf_reference_run(config: SimpleNamespace) -> str | None:
# Retrieve to run identifier (number, name) of the reference run
if config.mf_reference_run is not None:
return config.mf_reference_run
Expand All @@ -310,14 +319,13 @@ def get_mf_reference_run(config: SimpleNamespace) -> str:
f"dataset only contains the following runs: {inter_runs}"
)
raise ValueError(msg)
if inter_runs:
return inter_runs[0]
else:
if not inter_runs:
raise ValueError(
f"The intersection of runs by subjects is empty. "
f"Check the list of runs: "
f"{get_runs_all_subjects(config)}"
)
return inter_runs[0]


def get_task(config: SimpleNamespace) -> str | None:
Expand Down Expand Up @@ -374,7 +382,7 @@ def sanitize_cond_name(cond: str) -> str:


def get_mf_cal_fname(
*, config: SimpleNamespace, subject: str, session: str
*, config: SimpleNamespace, subject: str, session: str | None
) -> pathlib.Path:
if config.mf_cal_fname is None:
bids_path = BIDSPath(
Expand Down Expand Up @@ -402,7 +410,7 @@ def get_mf_cal_fname(


def get_mf_ctc_fname(
*, config: SimpleNamespace, subject: str, session: str
*, config: SimpleNamespace, subject: str, session: str | None
) -> pathlib.Path:
if config.mf_ctc_fname is None:
mf_ctc_fpath = BIDSPath(
Expand Down Expand Up @@ -451,9 +459,10 @@ def _restrict_analyze_channels(
return inst


def _get_bem_conductivity(cfg: SimpleNamespace) -> tuple[tuple[float], str]:
def _get_bem_conductivity(cfg: SimpleNamespace) -> tuple[tuple[float, ...] | None, str]:
conductivity: tuple[float, ...] | None = None # should never be used
if cfg.fs_subject in ("fsaverage", cfg.use_template_mri):
conductivity = None # should never be used
pass
tag = "5120-5120-5120"
elif "eeg" in cfg.ch_types:
conductivity = (0.3, 0.006, 0.3)
Expand Down Expand Up @@ -578,7 +587,7 @@ def get_eeg_reference(
return config.eeg_reference


def _validate_contrasts(contrasts: SimpleNamespace) -> None:
def _validate_contrasts(contrasts: list[tuple | dict]) -> None:
for contrast in contrasts:
if isinstance(contrast, tuple):
if len(contrast) != 2:
Expand All @@ -595,7 +604,7 @@ def _validate_contrasts(contrasts: SimpleNamespace) -> None:
raise ValueError("Contrasts must be tuples or well-formed dicts")


def _get_step_modules() -> dict[str, tuple[ModuleType]]:
def _get_step_modules() -> dict[str, tuple[ModuleType, ...]]:
from .steps import freesurfer, init, preprocessing, sensor, source

INIT_STEPS = init._STEPS
Expand Down
Loading

0 comments on commit 7b2039a

Please sign in to comment.