Skip to content

Commit

Permalink
FIX: Partiall strict
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Sep 12, 2024
1 parent dc037e9 commit 33a18cb
Show file tree
Hide file tree
Showing 14 changed files with 131 additions and 147 deletions.
4 changes: 2 additions & 2 deletions mne_bids_pipeline/_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import functools
import pathlib
from collections.abc import Iterable
from collections.abc import Iterable, Sized
from types import ModuleType, SimpleNamespace
from typing import Any, Literal, TypeVar

Expand Down Expand Up @@ -652,7 +652,7 @@ def _do_mf_autobad(*, cfg: SimpleNamespace) -> bool:


# Adapted from MNE-Python
def _pl(x, *, non_pl="", pl="s"):
def _pl(x: int | np.generic | Sized, *, non_pl: str = "", pl: str = "s") -> str:
"""Determine if plural should be used."""
len_x = x if isinstance(x, int | np.generic) else len(x)
return non_pl if len_x == 1 else pl
Expand Down
19 changes: 11 additions & 8 deletions mne_bids_pipeline/_decoding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import mne
import numpy as np
from joblib import parallel_backend
Expand All @@ -18,22 +20,23 @@ def fit(self, *args, **kwargs):


def _handle_csp_args(
decoding_csp_times,
decoding_csp_freqs,
decoding_metric,
decoding_csp_times: list[float] | tuple[float, ...] | np.ndarray | None,
decoding_csp_freqs: dict[str, Any] | None,
decoding_metric: str,
*,
epochs_tmin,
epochs_tmax,
time_frequency_freq_min,
time_frequency_freq_max,
):
epochs_tmin: float,
epochs_tmax: float,
time_frequency_freq_min: float,
time_frequency_freq_max: float,
) -> tuple[dict[str, list[tuple[float, float]]], np.ndarray]:
_validate_type(
decoding_csp_times, (None, list, tuple, np.ndarray), "decoding_csp_times"
)
if decoding_csp_times is None:
decoding_csp_times = np.linspace(max(0, epochs_tmin), epochs_tmax, num=6)
else:
decoding_csp_times = np.array(decoding_csp_times, float)
assert isinstance(decoding_csp_times, np.ndarray)
if decoding_csp_times.ndim != 1 or len(decoding_csp_times) == 1:
raise ValueError(
"decoding_csp_times should be 1 dimensional and contain at least 2 values "
Expand Down
33 changes: 17 additions & 16 deletions mne_bids_pipeline/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,12 @@
class _ParseConfigSteps:
def __init__(self, force_empty: tuple[str, ...] | None = None) -> None:
self._force_empty = _FORCE_EMPTY if force_empty is None else force_empty
self.steps: dict[str, Any] = defaultdict(list)
steps: dict[str, Any] = defaultdict(list)

def _add_step_option(step: str, option: str) -> None:
if step not in steps[option]:
steps[option].append(step)

# Add a few helper functions
for func_extra in (
_config_utils.get_eeg_reference,
Expand Down Expand Up @@ -155,7 +160,7 @@ def __init__(self, force_empty: tuple[str, ...] | None = None) -> None:
continue
if keyword.value.attr in ("exec_params",):
continue
self._add_step_option(step, keyword.value.attr)
_add_step_option(step, keyword.value.attr)
# Also look for root-level conditionals like use_maxwell_filter
# or spatial_filter
for cond in ast.iter_child_nodes(func):
Expand All @@ -172,7 +177,7 @@ def __init__(self, force_empty: tuple[str, ...] | None = None) -> None:
assert isinstance(attr.value, ast.Name)
if attr.value.id != "config":
continue
self._add_step_option(step, attr.attr)
_add_step_option(step, attr.attr)
# Now look at get_config* functions
if not func.name.startswith("get_config"):
continue
Expand All @@ -193,14 +198,14 @@ def __init__(self, force_empty: tuple[str, ...] | None = None) -> None:
key = keyword.value.func.id
if key in _MANUAL_KWS:
for option in _MANUAL_KWS[key]:
self._add_step_option(step, option)
_add_step_option(step, option)
continue
if keyword.value.func.id == "_sanitize_callable":
assert len(keyword.value.args) == 1
assert isinstance(keyword.value.args[0], ast.Attribute)
assert isinstance(keyword.value.args[0].value, ast.Name)
assert keyword.value.args[0].value.id == "config"
self._add_step_option(step, keyword.value.args[0].attr)
_add_step_option(step, keyword.value.args[0].attr)
continue
if key not in (
"_bids_kwargs",
Expand Down Expand Up @@ -230,13 +235,13 @@ def __init__(self, force_empty: tuple[str, ...] | None = None) -> None:
attrs = _CONFIG_RE.findall(source)
assert len(attrs), f"No config.* found in source of {key}"
for attr in attrs:
self._add_step_option(step, attr)
_add_step_option(step, attr)
continue
if isinstance(keyword.value, ast.Name):
key = f"{where}:{keyword.value.id}"
if key in _MANUAL_KWS:
for option in _MANUAL_KWS[f"{where}:{keyword.value.id}"]:
self._add_step_option(step, option)
_add_step_option(step, option)
continue
raise RuntimeError(f"{where} cannot handle Name {key=}")
if isinstance(keyword.value, ast.IfExp): # conditional
Expand All @@ -251,20 +256,16 @@ def __init__(self, force_empty: tuple[str, ...] | None = None) -> None:
continue
assert isinstance(keyword.value.value, ast.Name)
assert keyword.value.value.id == "config", f"{where} {keyword.value.value.id}" # noqa: E501 # fmt: skip
self._add_step_option(step, option)
_add_step_option(step, option)
if step in _NO_CONFIG:
assert not found, f"Found unexpected get_config* in {step}"
else:
assert found, f"Could not find get_config* in {step}"
for key in self._force_empty:
self.steps[key] = list()
for key, val in self.steps.items():
steps[key] = list()
for key, val in steps.items():
assert len(val) == len(set(val)), f"{key} {val}"
self.steps = {k: tuple(v) for k, v in self.steps.items()} # no defaultdict

def _add_step_option(self, step, option):
if step not in self.steps[option]:
self.steps[option].append(step)
self.steps: dict[str, tuple[str, ...]] = {k: tuple(v) for k, v in steps.items()}

def __call__(self, option: str) -> list[str]:
def __call__(self, option: str) -> tuple[str, ...]:
return self.steps[option]
10 changes: 5 additions & 5 deletions mne_bids_pipeline/_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
DEFAULT_DATA_DIR = Path("~/mne_data").expanduser()


def _download_via_openneuro(*, ds_name: str, ds_path: Path):
def _download_via_openneuro(*, ds_name: str, ds_path: Path) -> None:
import openneuro

options = DATASET_OPTIONS[ds_name]
Expand All @@ -25,7 +25,7 @@ def _download_via_openneuro(*, ds_name: str, ds_path: Path):
)


def _download_from_web(*, ds_name: str, ds_path: Path):
def _download_from_web(*, ds_name: str, ds_path: Path) -> None:
"""Retrieve Zip archives from a web URL."""
import pooch

Expand Down Expand Up @@ -55,15 +55,15 @@ def _download_from_web(*, ds_name: str, ds_path: Path):
(path / f"{ds_name}.zip").unlink()


def _download_via_mne(*, ds_name: str, ds_path: Path):
def _download_via_mne(*, ds_name: str, ds_path: Path) -> None:
assert ds_path.stem == ds_name, ds_path
getattr(mne.datasets, DATASET_OPTIONS[ds_name]["mne"]).data_path(
ds_path.parent,
verbose=True,
)


def _download(*, ds_name: str, ds_path: Path):
def _download(*, ds_name: str, ds_path: Path) -> None:
options = DATASET_OPTIONS[ds_name]
openneuro_name = options.get("openneuro", "")
web_url = options.get("web", "")
Expand All @@ -81,7 +81,7 @@ def _download(*, ds_name: str, ds_path: Path):
download_func(ds_name=ds_name, ds_path=ds_path)


def main(dataset):
def main(dataset: str | None) -> None:
"""Download the testing data."""
# Save everything 'MNE_DATA' dir ... defaults to ~/mne_data
mne_data_dir = mne.get_config(key="MNE_DATA", default=False)
Expand Down
13 changes: 8 additions & 5 deletions mne_bids_pipeline/_import_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Iterable
from types import SimpleNamespace
from typing import Literal
from typing import Any, Literal

import mne
import numpy as np
Expand Down Expand Up @@ -487,7 +487,7 @@ def import_er_data(

def _find_breaks_func(
*,
cfg,
cfg: SimpleNamespace,
raw: mne.io.BaseRaw,
subject: str,
session: str | None,
Expand Down Expand Up @@ -703,7 +703,7 @@ def _get_mf_reference_run_path(
subject: str,
session: str | None,
add_bads: bool | None = None,
) -> dict:
) -> dict[str, BIDSPath]:
return _get_run_path(
cfg=cfg,
subject=subject,
Expand Down Expand Up @@ -777,10 +777,13 @@ def _read_bads_tsv(
bids_path_bads: BIDSPath,
) -> list[str]:
bads_tsv = pd.read_csv(bids_path_bads.fpath, sep="\t", header=0)
return bads_tsv[bads_tsv.columns[0]].tolist()
out = bads_tsv[bads_tsv.columns[0]].tolist()
assert isinstance(out, list)
assert all(isinstance(o, str) for o in out)
return out


def _import_data_kwargs(*, config: SimpleNamespace, subject: str) -> dict:
def _import_data_kwargs(*, config: SimpleNamespace, subject: str) -> dict[str, Any]:
"""Get config params needed for any raw data loading."""
return dict(
# import_experimental_data / general
Expand Down
6 changes: 4 additions & 2 deletions mne_bids_pipeline/_io.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""I/O helpers."""

from typing import Any

import json_tricks

from .typing import PathLike


def _write_json(fname: PathLike, data: dict) -> None:
def _write_json(fname: PathLike, data: dict[str, Any]) -> None:
with open(fname, "w", encoding="utf-8") as f:
json_tricks.dump(data, fp=f, allow_nan=True, sort_keys=False)


def _read_json(fname: PathLike) -> dict:
def _read_json(fname: PathLike) -> Any:
with open(fname, encoding="utf-8") as f:
return json_tricks.load(f)
6 changes: 3 additions & 3 deletions mne_bids_pipeline/_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ def _console(self):
self.__console = rich.console.Console(**kwargs)
return self.__console

def title(self, title):
def title(self, title: str) -> None:
# Align left with ASCTIME offset
title = f"[title]┌────────┬ {title}[/]"
self._console.rule(title=title, characters="─", style="title", align="left")

def end(self, msg=""):
def end(self, msg: str = "") -> None:
self._console.print(f"[title]└────────┴ {msg}[/]")

@property
Expand Down Expand Up @@ -167,7 +167,7 @@ def gen_log_kwargs(
return kwargs


def _linkfile(uri):
def _linkfile(uri: str) -> str:
return f"[link=file://{uri}]{uri}[/link]"


Expand Down
Loading

0 comments on commit 33a18cb

Please sign in to comment.