Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add interpolate_to method #13044

Merged
merged 47 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
824457d
working implem
antoinecollas Dec 30, 2024
5e5f16b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 30, 2024
4236efb
pre-commit
antoinecollas Dec 30, 2024
36981ca
Merge branch 'interpolate_to' of https://github.com/antoinecollas/mne…
antoinecollas Dec 30, 2024
e7cefd8
Merge branch 'main' into interpolate_to
antoinecollas Dec 30, 2024
3a0383a
minor changes
antoinecollas Dec 30, 2024
3ea4cad
fix nested imports
antoinecollas Dec 30, 2024
29f7ce4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 30, 2024
8a05713
fix comment
antoinecollas Dec 30, 2024
de0ddc4
add test
antoinecollas Dec 30, 2024
254fdac
fix docstring
antoinecollas Dec 30, 2024
4bf3f1d
fix docstring
antoinecollas Dec 30, 2024
2f27599
rm plt.tight_layout
antoinecollas Dec 30, 2024
872ba1d
taller figure
antoinecollas Dec 30, 2024
045496c
[autofix.ci] apply automated fixes
autofix-ci[bot] Jan 1, 2025
568656d
Merge branch 'main' into interpolate_to
antoinecollas Jan 1, 2025
6fb3ae5
fix figure layout
antoinecollas Jan 3, 2025
dd093a8
improve test
antoinecollas Jan 3, 2025
9551016
simplify getting original data
antoinecollas Jan 3, 2025
fa326ad
simplify setting interpolated data
antoinecollas Jan 3, 2025
1a5ca89
merge two lines
antoinecollas Jan 3, 2025
4333f08
add spline method
antoinecollas Jan 6, 2025
4e25616
add splive vs mne to doc
antoinecollas Jan 6, 2025
a249321
keep all modalities in test
antoinecollas Jan 6, 2025
fb8804f
use self.info instead of old_info
antoinecollas Jan 6, 2025
7ff7f51
fix info when MNE interpolation
antoinecollas Jan 6, 2025
1334fb3
fix origin in spline method
antoinecollas Jan 6, 2025
046f7f4
Merge branch 'main' of https://github.com/mne-tools/mne-python into i…
antoinecollas Feb 5, 2025
45b8e28
keep all original channels (expect eeg)
antoinecollas Feb 6, 2025
5cf7569
doc
antoinecollas Feb 6, 2025
9bc0e59
test epochs and evoked
antoinecollas Feb 6, 2025
2db82cc
fix epochs
antoinecollas Feb 6, 2025
93c19be
Merge branch 'main' of https://github.com/mne-tools/mne-python into i…
antoinecollas Feb 6, 2025
3b52d4d
ylim
antoinecollas Feb 6, 2025
a670566
Update interpolate_to.py
antoinecollas Feb 6, 2025
e711218
review
antoinecollas Feb 9, 2025
056de49
vectorize
antoinecollas Feb 9, 2025
24979fa
comments about loss of infos
antoinecollas Feb 9, 2025
e93fa5e
add test on odering
antoinecollas Feb 9, 2025
67d9e85
Merge branch 'main' into interpolate_to
antoinecollas Feb 9, 2025
b399feb
import
antoinecollas Feb 10, 2025
47ec126
Merge branch 'main' into interpolate_to
antoinecollas Feb 10, 2025
3aaa870
FIX: Changelog
larsoner Feb 10, 2025
68a0592
Apply suggestions from code review
larsoner Feb 10, 2025
16c648f
FIX: Flake
larsoner Feb 10, 2025
ee32071
FIX: Name
larsoner Feb 10, 2025
5460dbb
FIX: Dim
larsoner Feb 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -2514,3 +2514,11 @@ @article{OyamaEtAl2015
year = {2015},
pages = {24--36},
}

@inproceedings{MellotEtAl2024,
title = {Physics-informed and Unsupervised Riemannian Domain Adaptation for Machine Learning on Heterogeneous EEG Datasets},
author = {Mellot, Apolline and Collas, Antoine and Chevallier, Sylvain and Engemann, Denis and Gramfort, Alexandre},
booktitle = {Proceedings of the 32nd European Signal Processing Conference (EUSIPCO)},
year = {2024},
address = {Lyon, France}
}
81 changes: 81 additions & 0 deletions examples/preprocessing/interpolate_to.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
.. _ex-interpolate-to-any-montage:

======================================================
Interpolate EEG data to any montage
======================================================

This example demonstrates how to interpolate EEG channels to match a given montage.
This can be useful for standardizing
EEG channel layouts across different datasets (see :footcite:`MellotEtAl2024`).

- Using the field interpolation for EEG data.
- Using the target montage "biosemi16".

In this example, the data from the original EEG channels will be
interpolated onto the positions defined by the "biosemi16" montage.
"""

# Authors: Antoine Collas <[email protected]>
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import matplotlib.pyplot as plt

import mne
from mne.channels import make_standard_montage
from mne.datasets import sample

print(__doc__)
ylim = (-10, 10)

# %%
# Load EEG data
data_path = sample.data_path()
eeg_file_path = data_path / "MEG" / "sample" / "sample_audvis-ave.fif"
evoked = mne.read_evokeds(eeg_file_path, condition="Left Auditory", baseline=(None, 0))

# Select only EEG channels
evoked.pick("eeg")

# Plot the original EEG layout
evoked.plot(exclude=[], picks="eeg", ylim=dict(eeg=ylim))

# %%
# Define the target montage
standard_montage = make_standard_montage("biosemi16")

# %%
# Use interpolate_to to project EEG data to the standard montage
evoked_interpolated_spline = evoked.copy().interpolate_to(
standard_montage, method="spline"
)

# Plot the interpolated EEG layout
evoked_interpolated_spline.plot(exclude=[], picks="eeg", ylim=dict(eeg=ylim))

# %%
# Use interpolate_to to project EEG data to the standard montage
evoked_interpolated_mne = evoked.copy().interpolate_to(standard_montage, method="MNE")

# Plot the interpolated EEG layout
evoked_interpolated_mne.plot(exclude=[], picks="eeg", ylim=dict(eeg=ylim))

# %%
# Comparing before and after interpolation
fig, axs = plt.subplots(3, 1, figsize=(8, 6), constrained_layout=True)
evoked.plot(exclude=[], picks="eeg", axes=axs[0], show=False, ylim=dict(eeg=ylim))
axs[0].set_title("Original EEG Layout")
evoked_interpolated_spline.plot(
exclude=[], picks="eeg", axes=axs[1], show=False, ylim=dict(eeg=ylim)
)
axs[1].set_title("Interpolated to Standard 1020 Montage using spline interpolation")
evoked_interpolated_mne.plot(
exclude=[], picks="eeg", axes=axs[2], show=False, ylim=dict(eeg=ylim)
)
axs[2].set_title("Interpolated to Standard 1020 Montage using MNE interpolation")

# %%
# References
# ----------
# .. footbibliography::
153 changes: 153 additions & 0 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,159 @@ def interpolate_bads(

return self

def interpolate_to(self, sensors, origin="auto", method="spline", reg=0.0):
"""Interpolate EEG data onto a new montage.

.. warning::
Be careful, only EEG channels are interpolated. Other channel types are
not interpolated.

Parameters
----------
sensors : DigMontage
The target montage containing channel positions to interpolate onto.
origin : array-like, shape (3,) | str
Origin of the sphere in the head coordinate frame and in meters.
Can be ``'auto'`` (default), which means a head-digitization-based
origin fit.
method : str
Method to use for EEG channels.
Supported methods are 'spline' (default) and 'MNE'.
reg : float
The regularization parameter for the interpolation method
(only used when the method is 'spline').

Returns
-------
inst : instance of Raw, Epochs, or Evoked
The instance with updated channel locations and data.

Notes
-----
This method is useful for standardizing EEG layouts across datasets.

.. versionadded:: 1.10.0
"""
from .._fiff.proj import _has_eeg_average_ref_proj
from ..epochs import BaseEpochs, EpochsArray
from ..evoked import Evoked, EvokedArray
from ..forward._field_interpolation import _map_meg_or_eeg_channels
from ..io import RawArray
from ..io.base import BaseRaw
from .interpolation import _make_interpolation_matrix

# Check that the method option is valid.
_check_option("method", method, ["spline", "MNE"])
if sensors is None:
raise ValueError("A sensors configuration must be provided.")
antoinecollas marked this conversation as resolved.
Show resolved Hide resolved

# Get target positions from the montage
ch_pos = sensors.get_positions().get("ch_pos", {})
target_ch_names = list(ch_pos.keys())
if not target_ch_names:
raise ValueError(
"The provided sensors configuration has no channel positions."
)

# Get original channel order
orig_names = self.info["ch_names"]

# Identify EEG channel
picks_good_eeg = pick_types(self.info, meg=False, eeg=True, exclude="bads")
if len(picks_good_eeg) == 0:
raise ValueError("No good EEG channels available for interpolation.")
# Also get the full list of EEG channel indices (including bad channels)
picks_remove_eeg = pick_types(self.info, meg=False, eeg=True, exclude=[])
eeg_names_orig = [orig_names[i] for i in picks_remove_eeg]

# Identify non-EEG channels in original order
non_eeg_names_ordered = [ch for ch in orig_names if ch not in eeg_names_orig]

# Create destination info for new EEG channels
sfreq = self.info["sfreq"]
info_interp = create_info(
ch_names=target_ch_names,
sfreq=sfreq,
ch_types=["eeg"] * len(target_ch_names),
)
info_interp.set_montage(sensors)
info_interp["bads"] = [ch for ch in self.info["bads"] if ch in target_ch_names]
# Do not assign "projs" directly.

# Compute the interpolation mapping
if method == "spline":
origin_val = _check_origin(origin, self.info)
pos_from = self.info._get_channel_positions(picks_good_eeg) - origin_val
pos_to = np.stack(list(ch_pos.values()), axis=0)

def _check_pos_sphere(pos):
d = np.linalg.norm(pos, axis=-1)
d_norm = np.mean(d / np.mean(d))
if np.abs(1.0 - d_norm) > 0.1:
warn("Your spherical fit is poor; interpolation may be inaccurate.")

_check_pos_sphere(pos_from)
_check_pos_sphere(pos_to)
mapping = _make_interpolation_matrix(pos_from, pos_to, alpha=reg)

elif method == "MNE":
antoinecollas marked this conversation as resolved.
Show resolved Hide resolved
info_eeg = pick_info(self.info, picks_good_eeg)
# If the original info has an average EEG reference projector but
# the destination info does not,
# update info_interp via a temporary RawArray.
if _has_eeg_average_ref_proj(self.info) and not _has_eeg_average_ref_proj(
info_interp
):
# Create dummy data: shape (n_channels, 1)
temp_data = np.zeros((len(info_interp["ch_names"]), 1))
temp_raw = RawArray(temp_data, info_interp, first_samp=0)
# Using the public API, add an average reference projector.
temp_raw.set_eeg_reference(
ref_channels="average", projection=True, verbose=False
)
# Extract the updated info.
info_interp = temp_raw.info
mapping = _map_meg_or_eeg_channels(
info_eeg, info_interp, mode="accurate", origin=origin
)
else:
raise ValueError("Unsupported interpolation method.")
antoinecollas marked this conversation as resolved.
Show resolved Hide resolved

# Interpolate EEG data
data_good = self.get_data(picks=picks_good_eeg)
data_interp = mapping.dot(data_good)
antoinecollas marked this conversation as resolved.
Show resolved Hide resolved

# Create a new instance for the interpolated EEG channels
if isinstance(self, BaseRaw):
inst_interp = RawArray(data_interp, info_interp, first_samp=self.first_samp)
antoinecollas marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(self, BaseEpochs):
data_interp = np.transpose(data_interp, (1, 0, 2))
inst_interp = EpochsArray(data_interp, info_interp)
elif isinstance(self, Evoked):
inst_interp = EvokedArray(data_interp, info_interp)
antoinecollas marked this conversation as resolved.
Show resolved Hide resolved

# Merge only if non-EEG channels exist
if not non_eeg_names_ordered:
return inst_interp

inst_non_eeg = self.copy().pick(non_eeg_names_ordered).load_data()
inst_out = inst_non_eeg.add_channels([inst_interp], force_update_info=True)

# Reorder channels
# Insert the entire new EEG block at the position of the first EEG channel.
new_order = []
inserted = False
for ch in orig_names:
if ch in eeg_names_orig:
if not inserted:
new_order.extend(info_interp["ch_names"])
inserted = True
# Skip original EEG channels.
else:
new_order.append(ch)
antoinecollas marked this conversation as resolved.
Show resolved Hide resolved
inst_out.reorder_channels(new_order)
return inst_out


@verbose
def rename_channels(info, mapping, allow_duplicates=False, *, verbose=None):
Expand Down
63 changes: 62 additions & 1 deletion mne/channels/tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mne import Epochs, pick_channels, pick_types, read_events
from mne._fiff.constants import FIFF
from mne._fiff.proj import _has_eeg_average_ref_proj
from mne.channels import make_dig_montage
from mne.channels import make_dig_montage, make_standard_montage
from mne.channels.interpolation import _make_interpolation_matrix
from mne.datasets import testing
from mne.io import RawArray, read_raw_ctf, read_raw_fif, read_raw_nirx
Expand Down Expand Up @@ -439,3 +439,64 @@ def test_method_str():
raw.interpolate_bads(method="spline")
raw.pick("eeg", exclude=())
raw.interpolate_bads(method="spline")


@pytest.mark.parametrize("montage_name", ["biosemi16", "standard_1020"])
antoinecollas marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("method", ["spline", "MNE"])
@pytest.mark.parametrize("data_type", ["raw", "epochs", "evoked"])
def test_interpolate_to_eeg(montage_name, method, data_type):
"""Test the interpolate_to method for EEG for raw, epochs, and evoked."""
# Load EEG data
raw, epochs_eeg = _load_data("eeg")
epochs_eeg = epochs_eeg.copy()

# Load data for raw
raw.load_data()

# Create a target montage
montage = make_standard_montage(montage_name)

# Prepare data to interpolate to
if data_type == "raw":
inst = raw.copy()
elif data_type == "epochs":
inst = epochs_eeg.copy()
elif data_type == "evoked":
inst = epochs_eeg.average()
shape = list(inst._data.shape)
orig_total = len(inst.info["ch_names"])
n_eeg_orig = len(pick_types(inst.info, eeg=True))

# Interpolate
inst_interp = inst.copy().interpolate_to(montage, method=method)

# Check that the new channel names include the montage channels.
assert set(montage.ch_names).issubset(set(inst_interp.info["ch_names"]))

# Check that the original container's channel ordering is changed.
assert inst.info["ch_names"] != inst_interp.info["ch_names"]

# Check that the data shape is as expected.
new_nchan_expected = orig_total - n_eeg_orig + len(montage.ch_names)
if len(shape) == 2:
expected_shape = tuple([new_nchan_expected, shape[1]])
elif len(shape) == 3:
expected_shape = tuple([shape[0], new_nchan_expected, shape[2]])
assert inst_interp._data.shape == expected_shape

# Validate that bad channels are carried over.
# Mark the first non eeg channel as bad
bads = None
i = 0
all_ch = inst_interp.info["ch_names"]
eeg_ch = pick_types(inst_interp.info, eeg=True)
eeg_ch = [all_ch[i] for i in eeg_ch]
while (i < len(all_ch)) and (bads is None):
ch = all_ch[i]
if ch not in eeg_ch:
bads = [ch]
i += 1
if bads is not None:
inst.info["bads"] = bads
inst_interp = inst.copy().interpolate_to(montage, method=method)
assert inst_interp.info["bads"] == bads
larsoner marked this conversation as resolved.
Show resolved Hide resolved
Loading