Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework some type annotations to help the type checker #873

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions mne_bids_pipeline/_config_import.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import ast
import copy
import difflib
import importlib
import importlib.util
import os
import pathlib
from collections.abc import Collection
from dataclasses import field
from functools import partial
from types import SimpleNamespace
Expand Down Expand Up @@ -142,6 +143,11 @@ def _update_config_from_path(
spec = importlib.util.spec_from_file_location(
name="custom_config", location=config_path
)

# help type checker
assert spec is not None
assert spec.loader is not None

custom_cfg = importlib.util.module_from_spec(spec)
spec.loader.exec_module(custom_cfg)
for key in dir(custom_cfg):
Expand Down Expand Up @@ -421,7 +427,7 @@ def _pydantic_validate(

def _check_misspellings_removals(
*,
valid_names: list[str],
valid_names: Collection[str],
user_names: list[str],
log: bool,
config_validation: str,
Expand Down
31 changes: 21 additions & 10 deletions mne_bids_pipeline/_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,21 @@
import inspect
import logging
import os
from typing import Any, TypedDict

import rich.console
import rich.theme

from .typing import LogKwargsT


class ConsoleKwargs(TypedDict):
soft_wrap: bool
force_terminal: bool | None
legacy_windows: bool | None
theme: rich.theme.Theme


class _MBPLogger:
def __init__(self):
self._level = logging.INFO
Expand All @@ -30,12 +38,8 @@ def _console(self):
legacy_windows = os.getenv("MNE_BIDS_PIPELINE_LEGACY_WINDOWS", None)
if legacy_windows is not None:
legacy_windows = legacy_windows.lower() in ("true", "1")
kwargs = dict(
soft_wrap=True,
force_terminal=force_terminal,
legacy_windows=legacy_windows,
)
kwargs["theme"] = rich.theme.Theme(

theme = rich.theme.Theme(
dict(
default="white",
# Rule
Expand All @@ -50,6 +54,13 @@ def _console(self):
error="red",
)
)

kwargs: ConsoleKwargs = {
"soft_wrap": True,
"force_terminal": force_terminal,
"legacy_windows": legacy_windows,
"theme": theme,
}
self.__console = rich.console.Console(**kwargs)
return self.__console

Expand All @@ -70,16 +81,16 @@ def level(self, level):
level = int(level)
self._level = level

def debug(self, msg: str, *, extra: LogKwargsT | None = None) -> None:
def debug(self, msg: str, *, extra: dict[Any, Any] | None = None) -> None:
self._log_message(kind="debug", msg=msg, **(extra or {}))

def info(self, msg: str, *, extra: LogKwargsT | None = None) -> None:
def info(self, msg: str, *, extra: dict[Any, Any] | None | None = None) -> None:
self._log_message(kind="info", msg=msg, **(extra or {}))

def warning(self, msg: str, *, extra: LogKwargsT | None = None) -> None:
def warning(self, msg: str, *, extra: dict[Any, Any] | None | None = None) -> None:
self._log_message(kind="warning", msg=msg, **(extra or {}))

def error(self, msg: str, *, extra: LogKwargsT | None = None) -> None:
def error(self, msg: str, *, extra: dict[Any, Any] | None | None = None) -> None:
self._log_message(kind="error", msg=msg, **(extra or {}))

def _log_message(
Expand Down
27 changes: 15 additions & 12 deletions mne_bids_pipeline/_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def _open_report(
session: str | None,
run: str | None = None,
task: str | None = None,
fname_report: BIDSPath | None = None,
bp_report: BIDSPath | None = None,
name: str = "report",
):
if fname_report is None:
fname_report = BIDSPath(
if bp_report is None:
bp_report = BIDSPath(
subject=subject,
session=session,
# Report is across all runs, but for logging purposes it's helpful
Expand All @@ -52,24 +52,27 @@ def _open_report(
suffix="report",
check=False,
)
fname_report = fname_report.fpath
assert fname_report.suffix == ".h5", fname_report.suffix

report_path = bp_report.fpath
del bp_report
assert report_path.suffix == ".h5", report_path.suffix

# prevent parallel file access
with FileLock(f"{fname_report}.lock"), _agg_backend():
if not fname_report.is_file():
with FileLock(f"{report_path}.lock"), _agg_backend():
if not report_path.is_file():
msg = f"Initializing {name} HDF5 file"
logger.info(**gen_log_kwargs(message=msg))
report = _gen_empty_report(
cfg=cfg,
subject=subject,
session=session,
)
report.save(fname_report)
report.save(report_path)
try:
report = mne.open_report(fname_report)
report = mne.open_report(report_path)
except Exception as exc:
raise exc.__class__(
f"Could not open {name} HDF5 file:\n{fname_report}, "
f"Could not open {name} HDF5 file:\n{report_path}, "
"Perhaps you need to delete it? Got error:\n\n"
f'{indent(traceback.format_exc(), " ")}'
) from None
Expand All @@ -87,10 +90,10 @@ def _open_report(
)
except Exception as exc:
logger.warning(f"Failed: {exc}")
fname_report_html = fname_report.with_suffix(".html")
fname_report_html = report_path.with_suffix(".html")
msg = f"Saving {name}: {_linkfile(fname_report_html)}"
logger.info(**gen_log_kwargs(message=msg))
report.save(fname_report, overwrite=True)
report.save(report_path, overwrite=True)
report.save(fname_report_html, overwrite=True, open_browser=False)


Expand Down
2 changes: 2 additions & 0 deletions mne_bids_pipeline/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def wrapper(*args, **kwargs):
del out_files

if msg is not None:
assert emoji is not None # help type checker
logger.info(**gen_log_kwargs(message=msg, emoji=emoji))
if short_circuit:
return
Expand Down Expand Up @@ -327,6 +328,7 @@ def save_logs(*, config: SimpleNamespace, logs: list[pd.Series]) -> None:
new_val += val.__qualname__
val = "custom callable" if not new_val else new_val
val = json_tricks.dumps(val, indent=4, sort_keys=False)
assert isinstance(val, str) # help type checker
# 32767 char limit per cell (could split over lines but if something is
# this long, you'll probably get the gist from the first 32k chars)
if len(val) > 32767:
Expand Down
5 changes: 4 additions & 1 deletion mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def run_ica(
del run

# Set an EEG reference
assert isinstance(epochs, mne.Epochs) # help type checker
if "eeg" in cfg.ch_types:
projection = True if cfg.eeg_reference == "average" else False
epochs.set_eeg_reference(cfg.eeg_reference, projection=projection)
Expand Down Expand Up @@ -200,8 +201,10 @@ def run_ica(
logger.info(**gen_log_kwargs(message=msg))
epochs.drop_bad(reject=ica_reject)
ar = None

msg = "Saving ICA epochs to disk."
logger.info(**gen_log_kwargs(message=msg))
assert isinstance(epochs, mne.Epochs) # help type checker
epochs.save(
out_files["epochs"],
overwrite=True,
Expand Down Expand Up @@ -254,7 +257,7 @@ def run_ica(
subject=subject,
session=session,
task=cfg.task,
fname_report=out_files["report"],
bp_report=out_files["report"],
name="ICA.fit report",
) as report:
report.title = f"ICA – {report.title}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def find_ica_artifacts(
subject=subject,
session=session,
task=cfg.task,
fname_report=out_files["report"],
bp_report=out_files["report"],
name="ICA report",
) as report:
report.add_ica(
Expand Down
2 changes: 1 addition & 1 deletion mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _ica_paths(

def _read_ica_and_exclude(
in_files: dict,
) -> None:
) -> mne.preprocessing.ICA:
ica = read_ica(fname=in_files.pop("ica"))
tsv_data = pd.read_csv(in_files.pop("components"), sep="\t")
ica.exclude = tsv_data.loc[tsv_data["status"] == "bad", "component"].to_list()
Expand Down