Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MNE-ICALabel #812

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/settings/preprocessing/ssp_ica.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ tags:
- ssp_ecg_channel
- ica_reject
- ica_algorithm
- ica_use_icalabel
- ica_l_freq
- ica_max_iterations
- ica_n_components
Expand Down
5 changes: 5 additions & 0 deletions docs/source/v1.5.md.inc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
This release contains a number of very important bug fixes that address problems related to decoding, time-frequency analysis, and inverse modeling.
All users are encouraged to update.

We also improved logging during parallel processing, added support for finding and repairing bad epochs via
[`autoreject`](https://autoreject.github.io), and included support for automatic labeling of ICA artifacts
via [MNE-ICALabel][https://mne.tools/mne-icalabel].

### :new: New features & enhancements

- Added `deriv_root` argument to CLI (#773 by @vferat)
Expand All @@ -22,6 +26,7 @@ All users are encouraged to update.
- Added support for "local" [`autoreject`](https://autoreject.github.io) to find (and repair) bad channels on a per-epoch
basis before submitting them to ICA fitting. This can be enabled by setting [`ica_reject`][mne_bids_pipeline._config.ica_reject]
to `"autoreject_local"`. (#810 by @hoechenberger)
- Added support for automated labeling of ICA components via [MNE-ICALabel][https://mne.tools/mne-icalabel] (#812 by @hoechenberger)
- Website documentation tables can now be sorted (e.g., to find examples that use a specific feature) (#808 by @larsoner)

[//]: # (### :warning: Behavior changes)
Expand Down
23 changes: 19 additions & 4 deletions mne_bids_pipeline/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,7 @@
"""
Peak-to-peak amplitude limits to exclude epochs from ICA fitting. This allows you to
remove strong transient artifacts from the epochs used for fitting ICA, which could
negatively affect ICA performance.
negatively affect ICA performance.

The parameter values are the same as for [`reject`][mne_bids_pipeline._config.reject],
but `"autoreject_global"` is not supported.
Expand All @@ -1262,7 +1262,7 @@
to **not** specify rejection thresholds for EOG and ECG channels here –
otherwise, ICA won't be able to "see" these artifacts.

???+ info
???+ info
This setting is applied only to the epochs that are used for **fitting** ICA. The
goal is to make it easier for ICA to produce a good decomposition. After fitting,
ICA is applied to the epochs to be analyzed, usually with one or more components
Expand Down Expand Up @@ -1367,6 +1367,20 @@
false-alarm rate increases dramatically.
"""

ica_use_icalabel: bool = False
"""
Whether to use MNE-ICALabel to automatically label ICA components. Only available for
EEG data.

!!! info
Using MNE-ICALabel mandates that you also set:
```python
eeg_reference = "average"
ica_l_freq = 1
h_freq = 100
```
"""

# Rejection based on peak-to-peak amplitude
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -1384,7 +1398,7 @@

If `None` (default), do not apply artifact rejection.

If a dictionary, manually specify rejection thresholds (see examples).
If a dictionary, manually specify rejection thresholds (see examples).
The thresholds provided here must be at least as stringent as those in
[`ica_reject`][mne_bids_pipeline._config.ica_reject] if using ICA. In case of
`'autoreject_global'`, thresholds for any channel that do not meet this
Expand Down Expand Up @@ -1443,7 +1457,8 @@

!!! info
This setting only takes effect if [`reject`][mne_bids_pipeline._config.reject] has
been set to `"autoreject_local"`.
been set to `"autoreject_local"`. It is not applied when using
`"autoreject_global"`.

!!! info
Channels marked as globally bad in the BIDS dataset (in `*_channels.tsv)`) will not
Expand Down
14 changes: 14 additions & 0 deletions mne_bids_pipeline/_config_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,20 @@ def _check_config(config: SimpleNamespace, config_path: Optional[PathLike]) -> N
f"but got shape {destination.shape}"
)

# MNE-ICALabel
if config.ica_use_icalabel:
if config.ica_l_freq != 1.0 or config.h_freq != 100.0:
raise ValueError(
f"When using MNE-ICALabel, you must set ica_l_freq=1 and h_freq=100, "
f"but got: ica_l_freq={config.ica_l_freq} and h_freq={config.h_freq}"
)

if config.eeg_reference != "average":
raise ValueError(
f'When using MNE-ICALabel, you must set eeg_reference="average", but '
f"got: eeg_reference={config.eeg_reference}"
)


def _default_factory(key, val):
# convert a default to a default factory if needed, having an explicit
Expand Down
106 changes: 85 additions & 21 deletions mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pandas as pd
import numpy as np
import autoreject
from mne_icalabel import label_components

import mne
from mne.report import Report
Expand Down Expand Up @@ -135,7 +136,7 @@ def make_ecg_epochs(
del raw # Free memory

if len(ecg_epochs) == 0:
msg = "No ECG events could be found. Not running ECG artifact " "detection."
msg = "No ECG events could be found. Not running ECG artifact detection."
logger.info(**gen_log_kwargs(message=msg))
ecg_epochs = None
else:
Expand Down Expand Up @@ -173,7 +174,7 @@ def make_eog_epochs(
eog_epochs = create_eog_epochs(raw, ch_name=ch_names, baseline=(None, -0.2))

if len(eog_epochs) == 0:
msg = "No EOG events could be found. Not running EOG artifact " "detection."
msg = "No EOG events could be found. Not running EOG artifact detection."
logger.warning(**gen_log_kwargs(message=msg))
eog_epochs = None
else:
Expand All @@ -184,7 +185,7 @@ def make_eog_epochs(
return eog_epochs


def detect_bad_components(
def detect_bad_components_mne(
*,
cfg,
which: Literal["eog", "ecg"],
Expand All @@ -195,7 +196,7 @@ def detect_bad_components(
session: Optional[str],
) -> Tuple[List[int], np.ndarray]:
artifact = which.upper()
msg = f"Performing automated {artifact} artifact detection …"
msg = f"Performing automated {artifact} artifact detection (MNE) …"
logger.info(**gen_log_kwargs(message=msg))

if which == "eog":
Expand Down Expand Up @@ -224,7 +225,7 @@ def detect_bad_components(
logger.warning(**gen_log_kwargs(message=warn))
else:
msg = (
f"Detected {len(inds)} {artifact}-related ICs in "
f"Detected {len(inds)} {artifact}-related independent component(s) in "
f"{len(epochs)} {artifact} epochs."
)
logger.info(**gen_log_kwargs(message=msg))
Expand Down Expand Up @@ -271,6 +272,14 @@ def run_ica(
in_files: dict,
) -> dict:
"""Run ICA."""
if cfg.ica_use_icalabel:
# The ICALabel network was trained on extended-Infomax ICA decompositions fit
# on data flltered between 1 and 100 Hz.
assert cfg.ica_algorithm in ["picard-extended_infomax", "extended_infomax"]
assert cfg.ica_l_freq == 1.0
assert cfg.h_freq == 100.0
assert cfg.eeg_reference == "average"

raw_fnames = [in_files.pop(f"raw_run-{run}") for run in cfg.runs]
bids_basename = raw_fnames[0].copy().update(processing=None, split=None, run=None)
out_files = dict()
Expand Down Expand Up @@ -395,7 +404,18 @@ def run_ica(

# Set an EEG reference
if "eeg" in cfg.ch_types:
projection = True if cfg.eeg_reference == "average" else False
if cfg.ica_use_icalabel:
assert cfg.eeg_reference == "average"
projection = False # Avg. ref. needs to be applied for MNE-ICALabel
elif cfg.eeg_reference == "average":
projection = True
else:
projection = False

if not projection:
msg = "Applying average reference to EEG epochs used for ICA fitting."
logger.info(**gen_log_kwargs(message=msg))

epochs.set_eeg_reference(cfg.eeg_reference, projection=projection)

if cfg.ica_reject == "autoreject_local":
Expand Down Expand Up @@ -446,9 +466,9 @@ def run_ica(
if cfg.task is not None:
title += f", task-{cfg.task}"

# ECG and EOG component detection
# Run MNE's built-in ECG and EOG component detection
if epochs_ecg:
ecg_ics, ecg_scores = detect_bad_components(
ecg_ics, ecg_scores = detect_bad_components_mne(
cfg=cfg,
which="ecg",
epochs=epochs_ecg,
Expand All @@ -461,7 +481,7 @@ def run_ica(
ecg_ics = ecg_scores = []

if epochs_eog:
eog_ics, eog_scores = detect_bad_components(
eog_ics, eog_scores = detect_bad_components_mne(
cfg=cfg,
which="eog",
epochs=epochs_eog,
Expand All @@ -473,11 +493,34 @@ def run_ica(
else:
eog_ics = eog_scores = []

# Run MNE-ICALabel if requested.
if cfg.ica_use_icalabel:
icalabel_ics = []
icalabel_labels = []

msg = "Performing automated artifact detection (MNE-ICALabel) …"
logger.info(**gen_log_kwargs(message=msg))

label_results = label_components(inst=epochs, ica=ica, method="iclabel")
for idx, label in enumerate(label_results["labels"]):
if label not in ["brain", "other"]:
icalabel_ics.append(idx)
icalabel_labels.append(label)

msg = (
f"Detected {len(icalabel_ics)} artifact-related independent component(s) "
f"in {len(epochs)} epochs."
)
logger.info(**gen_log_kwargs(message=msg))
else:
icalabel_ics = []

ica.exclude = sorted(set(ecg_ics + eog_ics + icalabel_ics))

# Save ICA to disk.
# We also store the automatically identified ECG- and EOG-related ICs.
msg = "Saving ICA solution and detected artifacts to disk."
logger.info(**gen_log_kwargs(message=msg))
ica.exclude = sorted(set(ecg_ics + eog_ics))
ica.save(out_files["ica"], overwrite=True)
_update_for_splits(out_files, "ica")

Expand All @@ -492,15 +535,28 @@ def run_ica(
)
)

for component in ecg_ics:
row_idx = tsv_data["component"] == component
tsv_data.loc[row_idx, "status"] = "bad"
tsv_data.loc[row_idx, "status_description"] = "Auto-detected ECG artifact"

for component in eog_ics:
row_idx = tsv_data["component"] == component
tsv_data.loc[row_idx, "status"] = "bad"
tsv_data.loc[row_idx, "status_description"] = "Auto-detected EOG artifact"
if cfg.ica_use_icalabel:
assert len(icalabel_ics) == len(icalabel_labels)
for component, label in zip(icalabel_ics, icalabel_labels):
row_idx = tsv_data["component"] == component
tsv_data.loc[row_idx, "status"] = "bad"
tsv_data.loc[
row_idx, "status_description"
] = f"Auto-detected {label} (MNE-ICALabel)"
else:
for component in ecg_ics:
row_idx = tsv_data["component"] == component
tsv_data.loc[row_idx, "status"] = "bad"
tsv_data.loc[
row_idx, "status_description"
] = "Auto-detected ECG artifact (MNE)"

for component in eog_ics:
row_idx = tsv_data["component"] == component
tsv_data.loc[row_idx, "status"] = "bad"
tsv_data.loc[
row_idx, "status_description"
] = "Auto-detected EOG artifact (MNE)"

tsv_data.to_csv(out_files["components"], sep="\t", index=False)

Expand All @@ -510,10 +566,16 @@ def run_ica(
logger.info(**gen_log_kwargs(message=msg))

report = Report(info_fname=epochs, title=title, verbose=False)

ecg_evoked = None if epochs_ecg is None else epochs_ecg.average()
eog_evoked = None if epochs_eog is None else epochs_eog.average()
ecg_scores = None if len(ecg_scores) == 0 else ecg_scores
eog_scores = None if len(eog_scores) == 0 else eog_scores

if cfg.ica_use_icalabel:
# We didn't run MNE's scoring
ecg_scores = eog_scores = None
else:
ecg_scores = None if len(ecg_scores) == 0 else ecg_scores
eog_scores = None if len(eog_scores) == 0 else eog_scores

with _agg_backend():
if cfg.ica_reject == "autoreject_local":
Expand Down Expand Up @@ -588,10 +650,12 @@ def get_config(
ica_reject=config.ica_reject,
ica_eog_threshold=config.ica_eog_threshold,
ica_ctps_ecg_threshold=config.ica_ctps_ecg_threshold,
ica_use_icalabel=config.ica_use_icalabel,
autoreject_n_interpolate=config.autoreject_n_interpolate,
random_state=config.random_state,
ch_types=config.ch_types,
l_freq=config.l_freq,
h_freq=config.h_freq,
epochs_decim=config.epochs_decim,
raw_resample_sfreq=config.raw_resample_sfreq,
event_repeated=config.event_repeated,
Expand Down
10 changes: 8 additions & 2 deletions mne_bids_pipeline/tests/configs/config_ERP_CORE.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,20 @@
t_break_annot_start_after_previous_event = 3.0
t_break_annot_stop_before_next_event = 1.5

# Settings for autoreject and ICA
if task == "N400": # test autoreject local without ICA
spatial_filter = None
reject = "autoreject_local"
autoreject_n_interpolate = [2, 4]
elif task == "N170": # test autoreject local before ICA
elif task == "N170": # test autoreject local before ICA, and MNE-ICALabel
spatial_filter = "ica"
ica_algorithm = "picard-extended_infomax"
ica_use_icalabel = True
ica_l_freq = 1
h_freq = 100
ica_reject = "autoreject_local"
reject = "autoreject_global"
autoreject_n_interpolate = [2, 4]
autoreject_n_interpolate = [12] # Only for testing!
else:
spatial_filter = "ica"
ica_reject = dict(eeg=350e-6, eog=500e-6)
Expand Down Expand Up @@ -249,6 +254,7 @@
baseline = (None, 0)
conditions = ["stimulus/face/normal", "stimulus/car/normal"]
contrasts = [("stimulus/face/normal", "stimulus/car/normal")]
cluster_forming_t_threshold = 1.25 # Only for testing!
elif task == "P3":
rename_events = {
"response/201": "response/correct",
Expand Down
18 changes: 11 additions & 7 deletions mne_bids_pipeline/tests/test_documented.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@ def test_options_documented():
with open(root_path / "_config.py", "r") as fid:
contents = fid.read()
contents = ast.parse(contents)
in_config = [
unannotated = [
item.targets[0].id for item in contents.body if isinstance(item, ast.Assign)
]
assert unannotated == []
_config_py = [
item.target.id for item in contents.body if isinstance(item, ast.AnnAssign)
]
assert len(set(in_config)) == len(in_config)
in_config = set(in_config)
assert len(set(_config_py)) == len(_config_py)
_config_py = set(_config_py)
# ensure we clean our namespace correctly
config = _get_default_config()
config_names = set(d for d in dir(config) if not d.startswith("_"))
assert in_config == config_names
_get_default_config_names = set(d for d in dir(config) if not d.startswith("_"))
assert _config_py == _get_default_config_names
settings_path = root_path.parent / "docs" / "source" / "settings"
assert settings_path.is_dir()
in_doc = set()
Expand All @@ -51,8 +55,8 @@ def test_options_documented():
assert val not in in_doc, "Duplicate documentation"
in_doc.add(val)
what = "docs/source/settings doc"
assert in_doc.difference(in_config) == set(), f"Extra values in {what}"
assert in_config.difference(in_doc) == set(), f"Values missing from {what}"
assert in_doc.difference(_config_py) == set(), f"Extra values in {what}"
assert _config_py.difference(in_doc) == set(), f"Values missing from {what}"


def test_datasets_in_doc():
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ dependencies = [
"autoreject",
"mne[hdf5] >=1.2",
"mne-bids[full]",
"mne-icalabel",
"onnxruntime", # for mne-icalabel
"filelock",
"setuptools >=65",
]
Expand Down
Loading