Skip to content

Commit

Permalink
BUG: Fix bug with multichannel classification (#853)
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner authored Feb 15, 2024
1 parent d76abaa commit 9f8b170
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 57 deletions.
2 changes: 2 additions & 0 deletions docs/source/v1.6.md.inc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
- Fix bug where `--no-cache` had no effect (#839 by @larsoner)
- Fix bug where the Maxwell filtering step would fail if [`find_noisy_channels_meg = False`][mne_bids_pipeline._config.find_noisy_channels_meg]` was used (#847 by @larsoner)
- Fix bug where raw, empty-room, and custom noise covariances were errantly calculated on data without ICA or SSP applied (#840 by @larsoner)
- Fix bug where multiple channel types (e.g., MEG and EEG) were not handled correctly in decoding (#853 by @larsoner)
- Fix bug where the previous default for [`ica_n_components`][mne_bids_pipeline._config.ica_n_components] of `0.8` was too conservative, changed the default to `None` to match MNE-Python (#853 by @larsoner)
### :medical_symbol: Code health
Expand Down
5 changes: 3 additions & 2 deletions mne_bids_pipeline/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,7 +1343,7 @@
limit may be too low to achieve convergence.
"""

ica_n_components: Optional[Union[float, int]] = 0.8
ica_n_components: Optional[Union[float, int]] = None
"""
MNE conducts ICA as a sort of a two-step procedure: First, a PCA is run
on the data (trying to exclude zero-valued components in rank-deficient
Expand All @@ -1362,7 +1362,8 @@
explained variance less than the value specified here will be passed to
ICA.
If `None`, **all** principal components will be used.
If `None` (default), `0.999999` will be used to avoid issues when working with
rank-deficient data.
This setting may drastically alter the time required to compute ICA.
"""
Expand Down
29 changes: 29 additions & 0 deletions mne_bids_pipeline/_decoding.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from typing import Optional

import mne
import numpy as np
from joblib import parallel_backend
from mne.utils import _validate_type
from sklearn.base import BaseEstimator
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression

from ._logging import gen_log_kwargs, logger


class LogReg(LogisticRegression):
"""Hack to avoid a warning with n_jobs != 1 when using dask."""
Expand Down Expand Up @@ -70,3 +77,25 @@ def _handle_csp_args(
freq_bins = list(zip(edges[:-1], edges[1:]))
freq_name_to_bins_map[freq_range_name] = freq_bins
return freq_name_to_bins_map


def _decoding_preproc_steps(
subject: str,
session: Optional[str],
epochs: mne.Epochs,
pca: bool = True,
) -> list[BaseEstimator]:
scaler = mne.decoding.Scaler(epochs.info)
steps = [scaler]
if pca:
ranks = mne.compute_rank(inst=epochs, rank="info")
rank = sum(ranks.values())
msg = f"Reducing data dimension via PCA; new rank: {rank} (from {ranks})."
logger.info(**gen_log_kwargs(message=msg))
steps.append(
mne.decoding.UnsupervisedSpatialFilter(
PCA(rank, whiten=True),
average=False,
)
)
return steps
27 changes: 20 additions & 7 deletions mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import mne
import numpy as np
import pandas as pd
from mne.decoding import Scaler, Vectorizer
from mne.decoding import Vectorizer
from mne_bids import BIDSPath
from scipy.io import loadmat, savemat
from sklearn.model_selection import StratifiedKFold, cross_val_score
Expand All @@ -30,7 +30,7 @@
get_sessions,
get_subjects,
)
from ..._decoding import LogReg
from ..._decoding import LogReg, _decoding_preproc_steps
from ..._logging import gen_log_kwargs, logger
from ..._parallel import get_parallel_backend, parallel_func
from ..._report import (
Expand Down Expand Up @@ -113,16 +113,23 @@ def run_epochs_decoding(
# Crop to the desired analysis interval. Do it only after the concatenation to work
# around https://github.com/mne-tools/mne-python/issues/12153
epochs.crop(cfg.decoding_epochs_tmin, cfg.decoding_epochs_tmax)
# omit bad channels and reference MEG sensors
epochs.pick_types(meg=True, eeg=True, ref_meg=False, exclude="bads")
pre_steps = _decoding_preproc_steps(
subject=subject,
session=session,
epochs=epochs,
)

n_cond1 = len(epochs[epochs_conds[0]])
n_cond2 = len(epochs[epochs_conds[1]])

X = epochs.get_data(picks="data") # omit bad channels
X = epochs.get_data()
y = np.r_[np.ones(n_cond1), np.zeros(n_cond2)]

classification_pipeline = make_pipeline(
Scaler(scalings="mean"),
Vectorizer(), # So we can pass the data to scikit-learn
clf = make_pipeline(
*pre_steps,
Vectorizer(),
LogReg(
solver="liblinear", # much faster than the default
random_state=cfg.random_state,
Expand All @@ -138,7 +145,13 @@ def run_epochs_decoding(
n_splits=cfg.decoding_n_splits,
)
scores = cross_val_score(
estimator=classification_pipeline, X=X, y=y, cv=cv, scoring="roc_auc", n_jobs=1
estimator=clf,
X=X,
y=y,
cv=cv,
scoring="roc_auc",
n_jobs=1,
error_score="raise",
)

# Save the scores
Expand Down
31 changes: 26 additions & 5 deletions mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
import mne
import numpy as np
import pandas as pd
from mne.decoding import GeneralizingEstimator, SlidingEstimator, cross_val_multiscore
from mne.decoding import (
GeneralizingEstimator,
SlidingEstimator,
Vectorizer,
cross_val_multiscore,
)
from mne_bids import BIDSPath
from scipy.io import loadmat, savemat
from sklearn.model_selection import StratifiedKFold
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

from ..._config_utils import (
_bids_kwargs,
Expand All @@ -34,7 +38,7 @@
get_sessions,
get_subjects,
)
from ..._decoding import LogReg
from ..._decoding import LogReg, _decoding_preproc_steps
from ..._logging import gen_log_kwargs, logger
from ..._parallel import get_parallel_backend, get_parallel_backend_name
from ..._report import (
Expand Down Expand Up @@ -122,18 +126,35 @@ def run_time_decoding(
epochs = mne.concatenate_epochs([epochs[epochs_conds[0]], epochs[epochs_conds[1]]])
n_cond1 = len(epochs[epochs_conds[0]])
n_cond2 = len(epochs[epochs_conds[1]])
epochs.pick_types(meg=True, eeg=True, ref_meg=False, exclude="bads")
# We can't use the full rank here because the number of samples can just be the
# number of epochs (which can be fewer than the number of channels)
pre_steps = _decoding_preproc_steps(
subject=subject,
session=session,
epochs=epochs,
pca=False,
)
# At some point we might want to enable this, but it's really slow and arguably
# unnecessary so let's omit it for now:
# pre_steps.append(
# mne.decoding.UnsupervisedSpatialFilter(
# PCA(n_components=0.999, whiten=True),
# )
# )

decim = cfg.decoding_time_generalization_decim
if cfg.decoding_time_generalization and decim > 1:
epochs.decimate(decim, verbose="error")

X = epochs.get_data(picks="data") # omit bad channels
X = epochs.get_data()
y = np.r_[np.ones(n_cond1), np.zeros(n_cond2)]
# ProgressBar does not work on dask, so only enable it if not using dask
verbose = get_parallel_backend_name(exec_params=exec_params) != "dask"
with get_parallel_backend(exec_params):
clf = make_pipeline(
StandardScaler(),
*pre_steps,
Vectorizer(),
LogReg(
solver="liblinear", # much faster than the default
random_state=cfg.random_state,
Expand Down
56 changes: 21 additions & 35 deletions mne_bids_pipeline/steps/sensor/_05_decoding_csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
import mne
import numpy as np
import pandas as pd
from mne.decoding import CSP, UnsupervisedSpatialFilter
from mne.decoding import CSP
from mne_bids import BIDSPath
from sklearn.decomposition import PCA
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.pipeline import make_pipeline

Expand All @@ -23,7 +22,7 @@
get_sessions,
get_subjects,
)
from ..._decoding import LogReg, _handle_csp_args
from ..._decoding import LogReg, _decoding_preproc_steps, _handle_csp_args
from ..._logging import gen_log_kwargs, logger
from ..._parallel import get_parallel_backend, parallel_func
from ..._report import (
Expand Down Expand Up @@ -159,30 +158,24 @@ def one_subject_decoding(
bids_path = in_files["epochs"].copy().update(processing=None, split=None)
epochs = mne.read_epochs(in_files.pop("epochs"))
_restrict_analyze_channels(epochs, cfg)
epochs.pick_types(meg=True, eeg=True, ref_meg=False, exclude="bads")

if cfg.time_frequency_subtract_evoked:
epochs.subtract_evoked()

# Perform rank reduction via PCA.
#
# Select the channel type with the smallest rank.
# Limit it to a maximum of 100.
ranks = mne.compute_rank(inst=epochs, rank="info")
ch_type_smallest_rank = min(ranks, key=ranks.get)
rank = min(ranks[ch_type_smallest_rank], 100)
del ch_type_smallest_rank, ranks

msg = f"Reducing data dimension via PCA; new rank: {rank}."
logger.info(**gen_log_kwargs(msg))
pca = UnsupervisedSpatialFilter(PCA(rank), average=False)
preproc_steps = _decoding_preproc_steps(
subject=subject,
session=session,
epochs=epochs,
)

# Classifier
csp = CSP(
n_components=4, # XXX revisit
reg=0.1, # XXX revisit
rank="info",
)
clf = make_pipeline(
*preproc_steps,
csp,
LogReg(
solver="liblinear", # much faster than the default
Expand Down Expand Up @@ -254,17 +247,11 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non
epochs=epochs, contrast=contrast, fmin=fmin, fmax=fmax, cfg=cfg
)
# Get the data for all time points
X = epochs_filt.get_data(picks="data") # omit bad channels

# We apply PCA before running CSP:
# - much faster CSP processing
# - reduced risk of numerical instabilities.
X_pca = pca.fit_transform(X)
del X
X = epochs_filt.get_data()

cv_scores = cross_val_score(
estimator=clf,
X=X_pca,
X=X,
y=y,
scoring=cfg.decoding_metric,
cv=cv,
Expand Down Expand Up @@ -326,14 +313,11 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non
# Crop data to the time window of interest
if tmax is not None: # avoid warnings about outside the interval
tmax = min(tmax, epochs_filt.times[-1])
epochs_filt.crop(tmin, tmax)
X = epochs_filt.get_data(picks="data") # omit bad channels
X_pca = pca.transform(X)
del X

X = epochs_filt.crop(tmin, tmax).get_data()
del epochs_filt
cv_scores = cross_val_score(
estimator=clf,
X=X_pca,
X=X,
y=y,
scoring=cfg.decoding_metric,
cv=cv,
Expand Down Expand Up @@ -454,11 +438,13 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non
results = all_csp_tf_results[contrast]
mean_crossval_scores = list()
tmin, tmax, fmin, fmax = list(), list(), list(), list()
mean_crossval_scores.extend(results["mean_crossval_score"].ravel())
tmin.extend(results["t_min"].ravel())
tmax.extend(results["t_max"].ravel())
fmin.extend(results["f_min"].ravel())
fmax.extend(results["f_max"].ravel())
mean_crossval_scores.extend(
results["mean_crossval_score"].to_numpy().ravel()
)
tmin.extend(results["t_min"].to_numpy().ravel())
tmax.extend(results["t_max"].to_numpy().ravel())
fmin.extend(results["f_min"].to_numpy().ravel())
fmax.extend(results["f_max"].to_numpy().ravel())
mean_crossval_scores = np.array(mean_crossval_scores, float)
fig, ax = plt.subplots(constrained_layout=True)
# XXX Add support for more metrics
Expand Down
5 changes: 0 additions & 5 deletions mne_bids_pipeline/tests/configs/config_ERP_CORE.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@
}

eeg_reference = ["P9", "P10"]
ica_n_components = 30 - len(eeg_reference)
epochs_tmin = -0.6
epochs_tmax = 0.4
baseline = (-0.4, -0.2)
Expand Down Expand Up @@ -180,7 +179,6 @@
}

eeg_reference = ["P9", "P10"]
ica_n_components = 30 - len(eeg_reference)
epochs_tmin = -0.8
epochs_tmax = 0.2
baseline = (None, -0.6)
Expand All @@ -193,7 +191,6 @@
}

eeg_reference = ["P9", "P10"]
ica_n_components = 30 - len(eeg_reference)
epochs_tmin = -0.2
epochs_tmax = 0.8
baseline = (None, 0)
Expand All @@ -214,7 +211,6 @@
}

eeg_reference = ["P9", "P10"]
ica_n_components = 30 - len(eeg_reference)
epochs_tmin = -0.2
epochs_tmax = 0.8
baseline = (None, 0)
Expand Down Expand Up @@ -281,7 +277,6 @@
}

eeg_reference = ["P9", "P10"]
ica_n_components = 30 - len(eeg_reference)
epochs_tmin = -0.2
epochs_tmax = 0.8
baseline = (None, 0)
Expand Down
1 change: 0 additions & 1 deletion mne_bids_pipeline/tests/configs/config_ds003392.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
ica_max_iterations = 1000
ica_l_freq = 1.0
ica_n_components = 0.99
ica_reject_components = "auto"

# Epochs
epochs_tmin = -0.2
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ addopts = "-ra -vv --tb=short --cov=mne_bids_pipeline --cov-report= --junit-xml=
testpaths = ["mne_bids_pipeline"]
junit_family = "xunit2"

[tool.ruff]
[tool.ruff.lint]
select = ["A", "B006", "D", "E", "F", "I", "W", "UP"]
exclude = ["**/freesurfer/contrib", "dist/", "build/"]
ignore = [
Expand All @@ -128,5 +128,5 @@ ignore = [
"UP035", # Import Iterable from collections.abc
]

[tool.ruff.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention = "numpy"

0 comments on commit 9f8b170

Please sign in to comment.