From 824457d30c34c798d4850c090b085ecfee03e635 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Mon, 30 Dec 2024 20:17:29 +0100 Subject: [PATCH 01/40] working implem --- examples/preprocessing/interpolate_to.py | 63 +++++++++++++++++ mne/channels/channels.py | 88 ++++++++++++++++++++++++ 2 files changed, 151 insertions(+) create mode 100644 examples/preprocessing/interpolate_to.py diff --git a/examples/preprocessing/interpolate_to.py b/examples/preprocessing/interpolate_to.py new file mode 100644 index 00000000000..ee834380ae4 --- /dev/null +++ b/examples/preprocessing/interpolate_to.py @@ -0,0 +1,63 @@ +""" +.. _ex-interpolate-to-any-montage: + +====================================================== +Interpolate EEG data to a any montage +====================================================== + +This example demonstrates how to interpolate EEG channels to match a given +montage using the :func:`interpolate_to` method. This can be useful for standardizing +EEG channel layouts across different datasets. + +- Using the MNE method for interpolation. +- The target montage will be the standard "standard_1020" montage. + +In this example, the data from the original EEG channels will be interpolated onto the positions defined by the "standard_1020" montage. +""" + +# Authors: Antoine Collas +# License: BSD-3-Clause + +import mne +from mne.datasets import sample +from mne.channels import make_standard_montage +import matplotlib.pyplot as plt + +print(__doc__) + +# %% +# Load EEG data +data_path = sample.data_path() +eeg_file_path = data_path / "MEG" / "sample" / "sample_audvis-ave.fif" +evoked = mne.read_evokeds(eeg_file_path, condition="Left Auditory", baseline=(None, 0)) + +# Select only EEG channels +evoked.pick("eeg") + +# Plot the original EEG layout +evoked.plot(exclude=[], picks="eeg") + +# %% +# Define the target montage +standard_montage = make_standard_montage('standard_1020') + +# %% +# Use interpolate_to to project EEG data to the standard montage +evoked_interpolated = evoked.copy().interpolate_to(standard_montage) + +# Plot the interpolated EEG layout +evoked_interpolated.plot(exclude=[], picks="eeg") + +# %% +# Comparing before and after interpolation +fig, axs = plt.subplots(2, 1, figsize=(8, 6)) +evoked.plot(exclude=[], picks="eeg", axes=axs[0], show=False) +axs[0].set_title("Original EEG Layout") +evoked_interpolated.plot(exclude=[], picks="eeg", axes=axs[1], show=False) +axs[1].set_title("Interpolated to Standard 1020 Montage") +plt.tight_layout() + +# %% +# References +# ---------- +# .. footbibliography:: diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 8fbff33c13e..368a4f73044 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -950,6 +950,94 @@ def interpolate_bads( return self + def interpolate_to(self, montage, method='MNE', reg=0.0): + """Interpolate data onto a new montage. + + Parameters + ---------- + montage : DigMontage + The target montage containing channel positions to interpolate onto. + method : str + The interpolation method to use. 'MNE' by default. + reg : float + The regularization parameter for the interpolation method (if applicable). + + Returns + ------- + inst : instance of Raw, Epochs, or Evoked + The instance with updated channel locations and data. + """ + import numpy as np + import mne + from mne import pick_types + from mne.forward._field_interpolation import _map_meg_or_eeg_channels + + # Ensure data is loaded + _check_preload(self, "interpolation") + + # Extract positions and data for EEG channels + picks_from = pick_types(self.info, meg=False, eeg=True, exclude=[]) + if len(picks_from) == 0: + raise ValueError("No EEG channels available for interpolation.") + + if hasattr(self, '_data'): + data_orig = self._data[picks_from] + else: + # If epochs-like data, for simplicity take the mean across epochs + data_orig = self.get_data()[:, picks_from, :].mean(axis=0) + + # Get target positions from the montage + ch_pos = montage.get_positions()['ch_pos'] + target_ch_names = list(ch_pos.keys()) + if len(target_ch_names) == 0: + raise ValueError("The provided montage does not contain any channel positions.") + + # Create a new info structure using MNE public API + sfreq = self.info['sfreq'] + ch_types = ['eeg'] * len(target_ch_names) + new_info = mne.create_info(ch_names=target_ch_names, sfreq=sfreq, ch_types=ch_types) + new_info.set_montage(montage) + + # Create a simple old_info + sfreq = self.info['sfreq'] + ch_names = self.info['ch_names'] + ch_types = ['eeg'] * len(ch_names) + old_info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) + old_info.set_montage(self.info.get_montage()) + + # Compute mapping from current montage to target montage + mapping = _map_meg_or_eeg_channels(old_info, new_info, mode='accurate', origin='auto') + + # Apply the interpolation mapping + D_new = mapping.dot(data_orig) + + # Update bad channels + new_bads = [ch for ch in self.info['bads'] if ch in target_ch_names] + new_info['bads'] = new_bads + + # Update the instance's info and data + self.info = new_info + if hasattr(self, '_data'): + if self._data.ndim == 2: + # Raw-like: directly assign the new data + self._data = D_new + else: + # Epochs-like + n_epochs, _, n_times = self._data.shape + new_data = np.zeros((n_epochs, len(target_ch_names), n_times), dtype=self._data.dtype) + for e in range(n_epochs): + epoch_data_orig = self._data[e, picks_from, :] + new_data[e, :, :] = mapping.dot(epoch_data_orig) + self._data = new_data + else: + # Evoked-like data + if hasattr(self, 'data'): + self.data = D_new + else: + raise NotImplementedError("This method requires preloaded data.") + + return self + @verbose def rename_channels(info, mapping, allow_duplicates=False, *, verbose=None): From 5e5f16bb6bad334a47a38ba8cfc6b3028abd342d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 30 Dec 2024 19:28:25 +0000 Subject: [PATCH 02/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/preprocessing/interpolate_to.py | 7 ++-- mne/channels/channels.py | 41 +++++++++++++++--------- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/examples/preprocessing/interpolate_to.py b/examples/preprocessing/interpolate_to.py index ee834380ae4..1710479e0da 100644 --- a/examples/preprocessing/interpolate_to.py +++ b/examples/preprocessing/interpolate_to.py @@ -18,10 +18,11 @@ # Authors: Antoine Collas # License: BSD-3-Clause +import matplotlib.pyplot as plt + import mne -from mne.datasets import sample from mne.channels import make_standard_montage -import matplotlib.pyplot as plt +from mne.datasets import sample print(__doc__) @@ -39,7 +40,7 @@ # %% # Define the target montage -standard_montage = make_standard_montage('standard_1020') +standard_montage = make_standard_montage("standard_1020") # %% # Use interpolate_to to project EEG data to the standard montage diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 368a4f73044..6d8db1d7bb9 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -950,7 +950,7 @@ def interpolate_bads( return self - def interpolate_to(self, montage, method='MNE', reg=0.0): + def interpolate_to(self, montage, method="MNE", reg=0.0): """Interpolate data onto a new montage. Parameters @@ -968,6 +968,7 @@ def interpolate_to(self, montage, method='MNE', reg=0.0): The instance with updated channel locations and data. """ import numpy as np + import mne from mne import pick_types from mne.forward._field_interpolation import _map_meg_or_eeg_channels @@ -980,58 +981,66 @@ def interpolate_to(self, montage, method='MNE', reg=0.0): if len(picks_from) == 0: raise ValueError("No EEG channels available for interpolation.") - if hasattr(self, '_data'): + if hasattr(self, "_data"): data_orig = self._data[picks_from] else: # If epochs-like data, for simplicity take the mean across epochs data_orig = self.get_data()[:, picks_from, :].mean(axis=0) # Get target positions from the montage - ch_pos = montage.get_positions()['ch_pos'] + ch_pos = montage.get_positions()["ch_pos"] target_ch_names = list(ch_pos.keys()) if len(target_ch_names) == 0: - raise ValueError("The provided montage does not contain any channel positions.") + raise ValueError( + "The provided montage does not contain any channel positions." + ) # Create a new info structure using MNE public API - sfreq = self.info['sfreq'] - ch_types = ['eeg'] * len(target_ch_names) - new_info = mne.create_info(ch_names=target_ch_names, sfreq=sfreq, ch_types=ch_types) + sfreq = self.info["sfreq"] + ch_types = ["eeg"] * len(target_ch_names) + new_info = mne.create_info( + ch_names=target_ch_names, sfreq=sfreq, ch_types=ch_types + ) new_info.set_montage(montage) # Create a simple old_info - sfreq = self.info['sfreq'] - ch_names = self.info['ch_names'] - ch_types = ['eeg'] * len(ch_names) + sfreq = self.info["sfreq"] + ch_names = self.info["ch_names"] + ch_types = ["eeg"] * len(ch_names) old_info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) old_info.set_montage(self.info.get_montage()) # Compute mapping from current montage to target montage - mapping = _map_meg_or_eeg_channels(old_info, new_info, mode='accurate', origin='auto') + mapping = _map_meg_or_eeg_channels( + old_info, new_info, mode="accurate", origin="auto" + ) # Apply the interpolation mapping D_new = mapping.dot(data_orig) # Update bad channels - new_bads = [ch for ch in self.info['bads'] if ch in target_ch_names] - new_info['bads'] = new_bads + new_bads = [ch for ch in self.info["bads"] if ch in target_ch_names] + new_info["bads"] = new_bads # Update the instance's info and data self.info = new_info - if hasattr(self, '_data'): + if hasattr(self, "_data"): if self._data.ndim == 2: # Raw-like: directly assign the new data self._data = D_new else: # Epochs-like n_epochs, _, n_times = self._data.shape - new_data = np.zeros((n_epochs, len(target_ch_names), n_times), dtype=self._data.dtype) + new_data = np.zeros( + (n_epochs, len(target_ch_names), n_times), dtype=self._data.dtype + ) for e in range(n_epochs): epoch_data_orig = self._data[e, picks_from, :] new_data[e, :, :] = mapping.dot(epoch_data_orig) self._data = new_data else: # Evoked-like data - if hasattr(self, 'data'): + if hasattr(self, "data"): self.data = D_new else: raise NotImplementedError("This method requires preloaded data.") From 4236efbbf8801913d8cc0ca71fb78193a9a03df0 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Mon, 30 Dec 2024 20:35:47 +0100 Subject: [PATCH 03/40] pre-commit --- examples/preprocessing/interpolate_to.py | 10 +++--- mne/channels/channels.py | 41 +++++++++++++++--------- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/examples/preprocessing/interpolate_to.py b/examples/preprocessing/interpolate_to.py index ee834380ae4..23d545e0596 100644 --- a/examples/preprocessing/interpolate_to.py +++ b/examples/preprocessing/interpolate_to.py @@ -12,16 +12,18 @@ - Using the MNE method for interpolation. - The target montage will be the standard "standard_1020" montage. -In this example, the data from the original EEG channels will be interpolated onto the positions defined by the "standard_1020" montage. +In this example, the data from the original EEG channels will be +interpolated onto the positions defined by the "standard_1020" montage. """ # Authors: Antoine Collas # License: BSD-3-Clause +import matplotlib.pyplot as plt + import mne -from mne.datasets import sample from mne.channels import make_standard_montage -import matplotlib.pyplot as plt +from mne.datasets import sample print(__doc__) @@ -39,7 +41,7 @@ # %% # Define the target montage -standard_montage = make_standard_montage('standard_1020') +standard_montage = make_standard_montage("standard_1020") # %% # Use interpolate_to to project EEG data to the standard montage diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 368a4f73044..6d8db1d7bb9 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -950,7 +950,7 @@ def interpolate_bads( return self - def interpolate_to(self, montage, method='MNE', reg=0.0): + def interpolate_to(self, montage, method="MNE", reg=0.0): """Interpolate data onto a new montage. Parameters @@ -968,6 +968,7 @@ def interpolate_to(self, montage, method='MNE', reg=0.0): The instance with updated channel locations and data. """ import numpy as np + import mne from mne import pick_types from mne.forward._field_interpolation import _map_meg_or_eeg_channels @@ -980,58 +981,66 @@ def interpolate_to(self, montage, method='MNE', reg=0.0): if len(picks_from) == 0: raise ValueError("No EEG channels available for interpolation.") - if hasattr(self, '_data'): + if hasattr(self, "_data"): data_orig = self._data[picks_from] else: # If epochs-like data, for simplicity take the mean across epochs data_orig = self.get_data()[:, picks_from, :].mean(axis=0) # Get target positions from the montage - ch_pos = montage.get_positions()['ch_pos'] + ch_pos = montage.get_positions()["ch_pos"] target_ch_names = list(ch_pos.keys()) if len(target_ch_names) == 0: - raise ValueError("The provided montage does not contain any channel positions.") + raise ValueError( + "The provided montage does not contain any channel positions." + ) # Create a new info structure using MNE public API - sfreq = self.info['sfreq'] - ch_types = ['eeg'] * len(target_ch_names) - new_info = mne.create_info(ch_names=target_ch_names, sfreq=sfreq, ch_types=ch_types) + sfreq = self.info["sfreq"] + ch_types = ["eeg"] * len(target_ch_names) + new_info = mne.create_info( + ch_names=target_ch_names, sfreq=sfreq, ch_types=ch_types + ) new_info.set_montage(montage) # Create a simple old_info - sfreq = self.info['sfreq'] - ch_names = self.info['ch_names'] - ch_types = ['eeg'] * len(ch_names) + sfreq = self.info["sfreq"] + ch_names = self.info["ch_names"] + ch_types = ["eeg"] * len(ch_names) old_info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) old_info.set_montage(self.info.get_montage()) # Compute mapping from current montage to target montage - mapping = _map_meg_or_eeg_channels(old_info, new_info, mode='accurate', origin='auto') + mapping = _map_meg_or_eeg_channels( + old_info, new_info, mode="accurate", origin="auto" + ) # Apply the interpolation mapping D_new = mapping.dot(data_orig) # Update bad channels - new_bads = [ch for ch in self.info['bads'] if ch in target_ch_names] - new_info['bads'] = new_bads + new_bads = [ch for ch in self.info["bads"] if ch in target_ch_names] + new_info["bads"] = new_bads # Update the instance's info and data self.info = new_info - if hasattr(self, '_data'): + if hasattr(self, "_data"): if self._data.ndim == 2: # Raw-like: directly assign the new data self._data = D_new else: # Epochs-like n_epochs, _, n_times = self._data.shape - new_data = np.zeros((n_epochs, len(target_ch_names), n_times), dtype=self._data.dtype) + new_data = np.zeros( + (n_epochs, len(target_ch_names), n_times), dtype=self._data.dtype + ) for e in range(n_epochs): epoch_data_orig = self._data[e, picks_from, :] new_data[e, :, :] = mapping.dot(epoch_data_orig) self._data = new_data else: # Evoked-like data - if hasattr(self, 'data'): + if hasattr(self, "data"): self.data = D_new else: raise NotImplementedError("This method requires preloaded data.") From 3a0383a2eca9cee8a2aa392fcedde788ce1e0bc1 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Mon, 30 Dec 2024 21:49:31 +0100 Subject: [PATCH 04/40] minor changes --- doc/references.bib | 8 ++++++++ examples/preprocessing/interpolate_to.py | 12 ++++++------ mne/channels/channels.py | 2 ++ 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/doc/references.bib b/doc/references.bib index a129d2f46a2..f82a9f6ac68 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -2504,3 +2504,11 @@ @article{OyamaEtAl2015 year = {2015}, pages = {24--36}, } + +@inproceedings{MellotEtAl2024, + title = {Physics-informed and Unsupervised Riemannian Domain Adaptation for Machine Learning on Heterogeneous EEG Datasets}, + author = {Mellot, Apolline and Collas, Antoine and Chevallier, Sylvain and Engemann, Denis and Gramfort, Alexandre}, + booktitle = {Proceedings of the 32nd European Signal Processing Conference (EUSIPCO)}, + year = {2024}, + address = {Lyon, France} +} diff --git a/examples/preprocessing/interpolate_to.py b/examples/preprocessing/interpolate_to.py index 23d545e0596..fc1069cdb4e 100644 --- a/examples/preprocessing/interpolate_to.py +++ b/examples/preprocessing/interpolate_to.py @@ -2,18 +2,18 @@ .. _ex-interpolate-to-any-montage: ====================================================== -Interpolate EEG data to a any montage +Interpolate EEG data to any montage ====================================================== This example demonstrates how to interpolate EEG channels to match a given montage using the :func:`interpolate_to` method. This can be useful for standardizing -EEG channel layouts across different datasets. +EEG channel layouts across different datasets (see :footcite:`MellotEtAl2024`). -- Using the MNE method for interpolation. -- The target montage will be the standard "standard_1020" montage. +- Using the field interpolation for EEG data. +- Using the target montage "biosemi16". In this example, the data from the original EEG channels will be -interpolated onto the positions defined by the "standard_1020" montage. +interpolated onto the positions defined by the "biosemi16" montage. """ # Authors: Antoine Collas @@ -41,7 +41,7 @@ # %% # Define the target montage -standard_montage = make_standard_montage("standard_1020") +standard_montage = make_standard_montage("biosemi16") # %% # Use interpolate_to to project EEG data to the standard montage diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 6d8db1d7bb9..266129d24a2 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -966,6 +966,8 @@ def interpolate_to(self, montage, method="MNE", reg=0.0): ------- inst : instance of Raw, Epochs, or Evoked The instance with updated channel locations and data. + + .. versionadded:: 1.10.0 """ import numpy as np From 3ea4cad114d6143ecb584e4a234766102197056b Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Mon, 30 Dec 2024 22:05:06 +0100 Subject: [PATCH 05/40] fix nested imports --- mne/channels/channels.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 266129d24a2..33d640fcfd4 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -969,11 +969,7 @@ def interpolate_to(self, montage, method="MNE", reg=0.0): .. versionadded:: 1.10.0 """ - import numpy as np - - import mne - from mne import pick_types - from mne.forward._field_interpolation import _map_meg_or_eeg_channels + from ..forward._field_interpolation import _map_meg_or_eeg_channels # Ensure data is loaded _check_preload(self, "interpolation") @@ -1000,7 +996,7 @@ def interpolate_to(self, montage, method="MNE", reg=0.0): # Create a new info structure using MNE public API sfreq = self.info["sfreq"] ch_types = ["eeg"] * len(target_ch_names) - new_info = mne.create_info( + new_info = create_info( ch_names=target_ch_names, sfreq=sfreq, ch_types=ch_types ) new_info.set_montage(montage) @@ -1009,7 +1005,7 @@ def interpolate_to(self, montage, method="MNE", reg=0.0): sfreq = self.info["sfreq"] ch_names = self.info["ch_names"] ch_types = ["eeg"] * len(ch_names) - old_info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) + old_info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) old_info.set_montage(self.info.get_montage()) # Compute mapping from current montage to target montage From 29f7ce4d1a383d6b55bb7575a35b4b3b3e187bc3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 30 Dec 2024 21:06:08 +0000 Subject: [PATCH 06/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/channels/channels.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 33d640fcfd4..f64eef22cd1 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -996,9 +996,7 @@ def interpolate_to(self, montage, method="MNE", reg=0.0): # Create a new info structure using MNE public API sfreq = self.info["sfreq"] ch_types = ["eeg"] * len(target_ch_names) - new_info = create_info( - ch_names=target_ch_names, sfreq=sfreq, ch_types=ch_types - ) + new_info = create_info(ch_names=target_ch_names, sfreq=sfreq, ch_types=ch_types) new_info.set_montage(montage) # Create a simple old_info From 8a05713955a0525147ef344305317dfaa1184f5a Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Mon, 30 Dec 2024 22:10:06 +0100 Subject: [PATCH 07/40] fix comment --- mne/channels/channels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index f64eef22cd1..df9ae8f0db4 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -993,7 +993,7 @@ def interpolate_to(self, montage, method="MNE", reg=0.0): "The provided montage does not contain any channel positions." ) - # Create a new info structure using MNE public API + # Create a new info structure sfreq = self.info["sfreq"] ch_types = ["eeg"] * len(target_ch_names) new_info = create_info(ch_names=target_ch_names, sfreq=sfreq, ch_types=ch_types) From de0ddc490728ebe4a985bc884f5947c7f38ff9e9 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Mon, 30 Dec 2024 22:29:47 +0100 Subject: [PATCH 08/40] add test --- mne/channels/tests/test_interpolation.py | 37 +++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index a881a41edcc..43fee4266d5 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -12,7 +12,7 @@ from mne import Epochs, pick_channels, pick_types, read_events from mne._fiff.constants import FIFF from mne._fiff.proj import _has_eeg_average_ref_proj -from mne.channels import make_dig_montage +from mne.channels import make_dig_montage, make_standard_montage from mne.channels.interpolation import _make_interpolation_matrix from mne.datasets import testing from mne.io import RawArray, read_raw_ctf, read_raw_fif, read_raw_nirx @@ -439,3 +439,38 @@ def test_method_str(): raw.interpolate_bads(method="spline") raw.pick("eeg", exclude=()) raw.interpolate_bads(method="spline") + + +@pytest.mark.parametrize("montage_name", ["biosemi16", "standard_1020"]) +def test_interpolate_to_eeg(montage_name): + """Test the interpolate_to method for EEG.""" + # Load EEG data + raw, _ = _load_data("eeg") + + # Select only EEG channels + raw.pick("eeg") + + # Load data + raw.load_data() + + # Create a target montage + montage = make_standard_montage(montage_name) + + # Copy the raw object and apply interpolation + raw_interpolated = raw.copy().interpolate_to(montage) + + # Check if channel names match the target montage + assert set(raw_interpolated.info["ch_names"]) == set(montage.ch_names) + + # Check if the data was interpolated correctly + assert raw_interpolated.get_data().shape[0] == len(montage.ch_names) + + # Ensure original data is not altered + assert raw.info["ch_names"] != raw_interpolated.info["ch_names"] + + # Validate that bad channels are carried over + raw.info["bads"] = [raw.info["ch_names"][0]] + raw_interpolated = raw.copy().interpolate_to(montage) + assert raw_interpolated.info["bads"] == [ + ch for ch in raw.info["bads"] if ch in montage.ch_names + ] From 254fdac48fb7566c89e0b7dc1abbe0aa57526138 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Mon, 30 Dec 2024 22:53:04 +0100 Subject: [PATCH 09/40] fix docstring --- examples/preprocessing/interpolate_to.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/preprocessing/interpolate_to.py b/examples/preprocessing/interpolate_to.py index fc1069cdb4e..4d86310e1db 100644 --- a/examples/preprocessing/interpolate_to.py +++ b/examples/preprocessing/interpolate_to.py @@ -5,8 +5,8 @@ Interpolate EEG data to any montage ====================================================== -This example demonstrates how to interpolate EEG channels to match a given -montage using the :func:`interpolate_to` method. This can be useful for standardizing +This example demonstrates how to interpolate EEG channels to match a given montage. +This can be useful for standardizing EEG channel layouts across different datasets (see :footcite:`MellotEtAl2024`). - Using the field interpolation for EEG data. From 4bf3f1d18148cd8c54a74f510bb9f72c8ccdfa5e Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Mon, 30 Dec 2024 23:09:10 +0100 Subject: [PATCH 10/40] fix docstring --- mne/channels/channels.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index df9ae8f0db4..ce290a33363 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -967,6 +967,10 @@ def interpolate_to(self, montage, method="MNE", reg=0.0): inst : instance of Raw, Epochs, or Evoked The instance with updated channel locations and data. + Notes + ----- + This method is useful for standardizing EEG layouts across datasets. + .. versionadded:: 1.10.0 """ from ..forward._field_interpolation import _map_meg_or_eeg_channels From 2f27599f313c355e8f414105fb8c6f43c50a8971 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Mon, 30 Dec 2024 23:09:52 +0100 Subject: [PATCH 11/40] rm plt.tight_layout --- examples/preprocessing/interpolate_to.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/preprocessing/interpolate_to.py b/examples/preprocessing/interpolate_to.py index 4d86310e1db..ed25adcd3ac 100644 --- a/examples/preprocessing/interpolate_to.py +++ b/examples/preprocessing/interpolate_to.py @@ -57,7 +57,6 @@ axs[0].set_title("Original EEG Layout") evoked_interpolated.plot(exclude=[], picks="eeg", axes=axs[1], show=False) axs[1].set_title("Interpolated to Standard 1020 Montage") -plt.tight_layout() # %% # References From 872ba1df96ea712a6957afe77838c5e58b0f6842 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Mon, 30 Dec 2024 23:26:43 +0100 Subject: [PATCH 12/40] taller figure --- examples/preprocessing/interpolate_to.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/preprocessing/interpolate_to.py b/examples/preprocessing/interpolate_to.py index ed25adcd3ac..5830ba5d666 100644 --- a/examples/preprocessing/interpolate_to.py +++ b/examples/preprocessing/interpolate_to.py @@ -52,7 +52,7 @@ # %% # Comparing before and after interpolation -fig, axs = plt.subplots(2, 1, figsize=(8, 6)) +fig, axs = plt.subplots(2, 1, figsize=(8, 8)) evoked.plot(exclude=[], picks="eeg", axes=axs[0], show=False) axs[0].set_title("Original EEG Layout") evoked_interpolated.plot(exclude=[], picks="eeg", axes=axs[1], show=False) From 045496c971f86fb38307a4edfe537be2f8ba2e4a Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Wed, 1 Jan 2025 11:35:35 +0000 Subject: [PATCH 13/40] [autofix.ci] apply automated fixes --- examples/preprocessing/interpolate_to.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/preprocessing/interpolate_to.py b/examples/preprocessing/interpolate_to.py index 5830ba5d666..4413f85a453 100644 --- a/examples/preprocessing/interpolate_to.py +++ b/examples/preprocessing/interpolate_to.py @@ -18,6 +18,7 @@ # Authors: Antoine Collas # License: BSD-3-Clause +# Copyright the MNE-Python contributors. import matplotlib.pyplot as plt From 6fb3ae5cccd466b2ef59b3f600d0f7cdb1dc8729 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Fri, 3 Jan 2025 10:08:25 +0100 Subject: [PATCH 14/40] fix figure layout --- examples/preprocessing/interpolate_to.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/preprocessing/interpolate_to.py b/examples/preprocessing/interpolate_to.py index 4413f85a453..f3127d122a0 100644 --- a/examples/preprocessing/interpolate_to.py +++ b/examples/preprocessing/interpolate_to.py @@ -53,7 +53,7 @@ # %% # Comparing before and after interpolation -fig, axs = plt.subplots(2, 1, figsize=(8, 8)) +fig, axs = plt.subplots(2, 1, figsize=(8, 6), constrained_layout=True) evoked.plot(exclude=[], picks="eeg", axes=axs[0], show=False) axs[0].set_title("Original EEG Layout") evoked_interpolated.plot(exclude=[], picks="eeg", axes=axs[1], show=False) From dd093a88c6d68fda3376c6937b92611b63dee68e Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Fri, 3 Jan 2025 10:13:55 +0100 Subject: [PATCH 15/40] improve test --- mne/channels/tests/test_interpolation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index 43fee4266d5..f9b8177f9a6 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -463,10 +463,11 @@ def test_interpolate_to_eeg(montage_name): assert set(raw_interpolated.info["ch_names"]) == set(montage.ch_names) # Check if the data was interpolated correctly - assert raw_interpolated.get_data().shape[0] == len(montage.ch_names) + assert raw_interpolated.get_data().shape == (len(montage.ch_names), raw.n_times) # Ensure original data is not altered assert raw.info["ch_names"] != raw_interpolated.info["ch_names"] + assert raw.get_data().shape == (len(raw.info["ch_names"]), raw.n_times) # Validate that bad channels are carried over raw.info["bads"] = [raw.info["ch_names"][0]] From 95510166cc3712b1ad66b861323f8fd4dcdfc728 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Fri, 3 Jan 2025 10:22:14 +0100 Subject: [PATCH 16/40] simplify getting original data --- mne/channels/channels.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index ce290a33363..cac64b3e633 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -983,11 +983,8 @@ def interpolate_to(self, montage, method="MNE", reg=0.0): if len(picks_from) == 0: raise ValueError("No EEG channels available for interpolation.") - if hasattr(self, "_data"): - data_orig = self._data[picks_from] - else: - # If epochs-like data, for simplicity take the mean across epochs - data_orig = self.get_data()[:, picks_from, :].mean(axis=0) + # Get original data + data_orig = self.get_data(picks=picks_from) # Get target positions from the montage ch_pos = montage.get_positions()["ch_pos"] From fa326ad0fee1f785e49a069e7a4c61ee044ee84e Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Fri, 3 Jan 2025 10:29:24 +0100 Subject: [PATCH 17/40] simplify setting interpolated data --- mne/channels/channels.py | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index cac64b3e633..555d611a450 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -1013,7 +1013,7 @@ def interpolate_to(self, montage, method="MNE", reg=0.0): ) # Apply the interpolation mapping - D_new = mapping.dot(data_orig) + data_interp = mapping.dot(data_orig) # Update bad channels new_bads = [ch for ch in self.info["bads"] if ch in target_ch_names] @@ -1021,26 +1021,7 @@ def interpolate_to(self, montage, method="MNE", reg=0.0): # Update the instance's info and data self.info = new_info - if hasattr(self, "_data"): - if self._data.ndim == 2: - # Raw-like: directly assign the new data - self._data = D_new - else: - # Epochs-like - n_epochs, _, n_times = self._data.shape - new_data = np.zeros( - (n_epochs, len(target_ch_names), n_times), dtype=self._data.dtype - ) - for e in range(n_epochs): - epoch_data_orig = self._data[e, picks_from, :] - new_data[e, :, :] = mapping.dot(epoch_data_orig) - self._data = new_data - else: - # Evoked-like data - if hasattr(self, "data"): - self.data = D_new - else: - raise NotImplementedError("This method requires preloaded data.") + self._data = data_interp return self From 1a5ca892d8a2f74c318aba489f164de7e26eacc9 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Fri, 3 Jan 2025 10:37:23 +0100 Subject: [PATCH 18/40] merge two lines --- mne/channels/channels.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 555d611a450..e3e34f46326 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -1016,8 +1016,7 @@ def interpolate_to(self, montage, method="MNE", reg=0.0): data_interp = mapping.dot(data_orig) # Update bad channels - new_bads = [ch for ch in self.info["bads"] if ch in target_ch_names] - new_info["bads"] = new_bads + new_info["bads"] = [ch for ch in self.info["bads"] if ch in target_ch_names] # Update the instance's info and data self.info = new_info From 4333f0898fc4609543ff85bc4a0812114cb71324 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Mon, 6 Jan 2025 14:49:31 +0100 Subject: [PATCH 19/40] add spline method --- mne/channels/channels.py | 49 ++++++++++++++++-------- mne/channels/tests/test_interpolation.py | 7 ++-- 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index e3e34f46326..96a94e29000 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -950,15 +950,21 @@ def interpolate_bads( return self - def interpolate_to(self, montage, method="MNE", reg=0.0): - """Interpolate data onto a new montage. + def interpolate_to(self, montage, method="spline", reg=0.0): + """Interpolate EEG data onto a new montage. Parameters ---------- montage : DigMontage The target montage containing channel positions to interpolate onto. method : str - The interpolation method to use. 'MNE' by default. + Method to use for EEG channels. + Supported methods are 'spline' (default) and 'MNE'. + + .. warning:: + Be careful, only EEG channels are interpolated. Other channel types are + not interpolated. + reg : float The regularization parameter for the interpolation method (if applicable). @@ -974,17 +980,7 @@ def interpolate_to(self, montage, method="MNE", reg=0.0): .. versionadded:: 1.10.0 """ from ..forward._field_interpolation import _map_meg_or_eeg_channels - - # Ensure data is loaded - _check_preload(self, "interpolation") - - # Extract positions and data for EEG channels - picks_from = pick_types(self.info, meg=False, eeg=True, exclude=[]) - if len(picks_from) == 0: - raise ValueError("No EEG channels available for interpolation.") - - # Get original data - data_orig = self.get_data(picks=picks_from) + from .interpolation import _make_interpolation_matrix # Get target positions from the montage ch_pos = montage.get_positions()["ch_pos"] @@ -994,6 +990,17 @@ def interpolate_to(self, montage, method="MNE", reg=0.0): "The provided montage does not contain any channel positions." ) + # Check the method is valid + _check_option("method", method, ["spline", "MNE"]) + + # Ensure data is loaded + _check_preload(self, "interpolation") + + # Extract positions and data for EEG channels + picks_from = pick_types(self.info, meg=False, eeg=True, exclude=[]) + if len(picks_from) == 0: + raise ValueError("No EEG channels available for interpolation.") + # Create a new info structure sfreq = self.info["sfreq"] ch_types = ["eeg"] * len(target_ch_names) @@ -1008,11 +1015,19 @@ def interpolate_to(self, montage, method="MNE", reg=0.0): old_info.set_montage(self.info.get_montage()) # Compute mapping from current montage to target montage - mapping = _map_meg_or_eeg_channels( - old_info, new_info, mode="accurate", origin="auto" - ) + if method == "spline": + pos_from = np.array( + [self.info["chs"][idx]["loc"][:3] for idx in picks_from] + ) + pos_to = np.stack(list(ch_pos.values()), axis=0) + mapping = _make_interpolation_matrix(pos_from, pos_to, alpha=reg) + elif method == "MNE": + mapping = _map_meg_or_eeg_channels( + old_info, new_info, mode="accurate", origin="auto" + ) # Apply the interpolation mapping + data_orig = self.get_data(picks=picks_from) data_interp = mapping.dot(data_orig) # Update bad channels diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index f9b8177f9a6..14f813977e1 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -442,7 +442,8 @@ def test_method_str(): @pytest.mark.parametrize("montage_name", ["biosemi16", "standard_1020"]) -def test_interpolate_to_eeg(montage_name): +@pytest.mark.parametrize("method", ["spline", "MNE"]) +def test_interpolate_to_eeg(montage_name, method): """Test the interpolate_to method for EEG.""" # Load EEG data raw, _ = _load_data("eeg") @@ -457,7 +458,7 @@ def test_interpolate_to_eeg(montage_name): montage = make_standard_montage(montage_name) # Copy the raw object and apply interpolation - raw_interpolated = raw.copy().interpolate_to(montage) + raw_interpolated = raw.copy().interpolate_to(montage, method=method) # Check if channel names match the target montage assert set(raw_interpolated.info["ch_names"]) == set(montage.ch_names) @@ -471,7 +472,7 @@ def test_interpolate_to_eeg(montage_name): # Validate that bad channels are carried over raw.info["bads"] = [raw.info["ch_names"][0]] - raw_interpolated = raw.copy().interpolate_to(montage) + raw_interpolated = raw.copy().interpolate_to(montage, method=method) assert raw_interpolated.info["bads"] == [ ch for ch in raw.info["bads"] if ch in montage.ch_names ] From 4e256164f4381205e76f9299f806fdb62d33c148 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Mon, 6 Jan 2025 14:59:28 +0100 Subject: [PATCH 20/40] add splive vs mne to doc --- examples/preprocessing/interpolate_to.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/examples/preprocessing/interpolate_to.py b/examples/preprocessing/interpolate_to.py index f3127d122a0..50cd31ca0af 100644 --- a/examples/preprocessing/interpolate_to.py +++ b/examples/preprocessing/interpolate_to.py @@ -46,18 +46,29 @@ # %% # Use interpolate_to to project EEG data to the standard montage -evoked_interpolated = evoked.copy().interpolate_to(standard_montage) +evoked_interpolated_spline = evoked.copy().interpolate_to( + standard_montage, method="spline" +) # Plot the interpolated EEG layout -evoked_interpolated.plot(exclude=[], picks="eeg") +evoked_interpolated_spline.plot(exclude=[], picks="eeg") + +# %% +# Use interpolate_to to project EEG data to the standard montage +evoked_interpolated_mne = evoked.copy().interpolate_to(standard_montage, method="MNE") + +# Plot the interpolated EEG layout +evoked_interpolated_mne.plot(exclude=[], picks="eeg") # %% # Comparing before and after interpolation -fig, axs = plt.subplots(2, 1, figsize=(8, 6), constrained_layout=True) +fig, axs = plt.subplots(3, 1, figsize=(8, 6), constrained_layout=True) evoked.plot(exclude=[], picks="eeg", axes=axs[0], show=False) axs[0].set_title("Original EEG Layout") -evoked_interpolated.plot(exclude=[], picks="eeg", axes=axs[1], show=False) -axs[1].set_title("Interpolated to Standard 1020 Montage") +evoked_interpolated_spline.plot(exclude=[], picks="eeg", axes=axs[1], show=False) +axs[1].set_title("Interpolated to Standard 1020 Montage using spline interpolation") +evoked_interpolated_mne.plot(exclude=[], picks="eeg", axes=axs[2], show=False) +axs[2].set_title("Interpolated to Standard 1020 Montage using MNE interpolation") # %% # References From a249321596b79bead0cfad78f33a32e118956fcf Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Mon, 6 Jan 2025 16:59:07 +0100 Subject: [PATCH 21/40] keep all modalities in test --- mne/channels/tests/test_interpolation.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index 14f813977e1..a2e5c2dc786 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -446,10 +446,9 @@ def test_method_str(): def test_interpolate_to_eeg(montage_name, method): """Test the interpolate_to method for EEG.""" # Load EEG data - raw, _ = _load_data("eeg") - - # Select only EEG channels - raw.pick("eeg") + raw, epochs_eeg = _load_data("eeg") + epochs_eeg = epochs_eeg.copy() + assert not _has_eeg_average_ref_proj(epochs_eeg.info) # Load data raw.load_data() From fb8804f1742ca0fad63fb3f12f05685f8e84eda2 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Mon, 6 Jan 2025 16:59:27 +0100 Subject: [PATCH 22/40] use self.info instead of old_info --- mne/channels/channels.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 96a94e29000..be6bc772795 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -1007,13 +1007,6 @@ def interpolate_to(self, montage, method="spline", reg=0.0): new_info = create_info(ch_names=target_ch_names, sfreq=sfreq, ch_types=ch_types) new_info.set_montage(montage) - # Create a simple old_info - sfreq = self.info["sfreq"] - ch_names = self.info["ch_names"] - ch_types = ["eeg"] * len(ch_names) - old_info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) - old_info.set_montage(self.info.get_montage()) - # Compute mapping from current montage to target montage if method == "spline": pos_from = np.array( @@ -1023,7 +1016,7 @@ def interpolate_to(self, montage, method="spline", reg=0.0): mapping = _make_interpolation_matrix(pos_from, pos_to, alpha=reg) elif method == "MNE": mapping = _map_meg_or_eeg_channels( - old_info, new_info, mode="accurate", origin="auto" + self.info, new_info, mode="accurate", origin="auto" ) # Apply the interpolation mapping From 7ff7f519838ea00dab56c68dce6154ccebaa4ebe Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Mon, 6 Jan 2025 17:02:44 +0100 Subject: [PATCH 23/40] fix info when MNE interpolation --- mne/channels/channels.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index be6bc772795..202f335d448 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -1015,8 +1015,9 @@ def interpolate_to(self, montage, method="spline", reg=0.0): pos_to = np.stack(list(ch_pos.values()), axis=0) mapping = _make_interpolation_matrix(pos_from, pos_to, alpha=reg) elif method == "MNE": + info_eeg = pick_info(self.info, picks_from) mapping = _map_meg_or_eeg_channels( - self.info, new_info, mode="accurate", origin="auto" + info_eeg, new_info, mode="accurate", origin="auto" ) # Apply the interpolation mapping From 1334fb3c23bafeb51f45d980b2535bb3febe2cb6 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Mon, 6 Jan 2025 17:16:42 +0100 Subject: [PATCH 24/40] fix origin in spline method --- mne/channels/channels.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 202f335d448..4f404bf32d3 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -950,13 +950,17 @@ def interpolate_bads( return self - def interpolate_to(self, montage, method="spline", reg=0.0): + def interpolate_to(self, montage, origin="auto", method="spline", reg=0.0): """Interpolate EEG data onto a new montage. Parameters ---------- montage : DigMontage The target montage containing channel positions to interpolate onto. + origin : array-like, shape (3,) | str + Origin of the sphere in the head coordinate frame and in meters. + Can be ``'auto'`` (default), which means a head-digitization-based + origin fit. method : str Method to use for EEG channels. Supported methods are 'spline' (default) and 'MNE'. @@ -1009,11 +1013,29 @@ def interpolate_to(self, montage, method="spline", reg=0.0): # Compute mapping from current montage to target montage if method == "spline": - pos_from = np.array( - [self.info["chs"][idx]["loc"][:3] for idx in picks_from] - ) + # pos_from = np.array( + # [self.info["chs"][idx]["loc"][:3] for idx in picks_from] + # ) + + origin = _check_origin(origin, self.info) + pos_from = self.info._get_channel_positions(picks_from) + pos_from = pos_from - origin pos_to = np.stack(list(ch_pos.values()), axis=0) + + def _check_pos_sphere(pos): + distance = np.linalg.norm(pos, axis=-1) + distance = np.mean(distance / np.mean(distance)) + if np.abs(1.0 - distance) > 0.1: + warn( + "Your spherical fit is poor, interpolation results are " + "likely to be inaccurate." + ) + + _check_pos_sphere(pos_from) + _check_pos_sphere(pos_to) + mapping = _make_interpolation_matrix(pos_from, pos_to, alpha=reg) + elif method == "MNE": info_eeg = pick_info(self.info, picks_from) mapping = _map_meg_or_eeg_channels( From 45b8e286bbee62c7dd0a06255ac93849ee9b34f3 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Thu, 6 Feb 2025 12:59:03 +0100 Subject: [PATCH 25/40] keep all original channels (expect eeg) --- mne/channels/channels.py | 152 +++++++++++++++-------- mne/channels/tests/test_interpolation.py | 15 ++- 2 files changed, 109 insertions(+), 58 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 96e7b1f83b6..03467ed902b 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -960,12 +960,16 @@ def interpolate_bads( return self - def interpolate_to(self, montage, origin="auto", method="spline", reg=0.0): + def interpolate_to(self, sensors, origin="auto", method="spline", reg=0.0): """Interpolate EEG data onto a new montage. + .. warning:: + Be careful, only EEG channels are interpolated. Other channel types are + not interpolated. + Parameters ---------- - montage : DigMontage + sensors : DigMontage The target montage containing channel positions to interpolate onto. origin : array-like, shape (3,) | str Origin of the sphere in the head coordinate frame and in meters. @@ -975,10 +979,6 @@ def interpolate_to(self, montage, origin="auto", method="spline", reg=0.0): Method to use for EEG channels. Supported methods are 'spline' (default) and 'MNE'. - .. warning:: - Be careful, only EEG channels are interpolated. Other channel types are - not interpolated. - reg : float The regularization parameter for the interpolation method (if applicable). @@ -993,77 +993,125 @@ def interpolate_to(self, montage, origin="auto", method="spline", reg=0.0): .. versionadded:: 1.10.0 """ + from .._fiff.proj import _has_eeg_average_ref_proj + from ..epochs import EpochsArray + from ..evoked import EvokedArray from ..forward._field_interpolation import _map_meg_or_eeg_channels + from ..io import RawArray from .interpolation import _make_interpolation_matrix + # Check that the method option is valid. + _check_option("method", method, ["spline", "MNE"]) + if sensors is None: + raise ValueError("A sensors configuration must be provided.") + # Get target positions from the montage - ch_pos = montage.get_positions()["ch_pos"] + ch_pos = sensors.get_positions().get("ch_pos", {}) target_ch_names = list(ch_pos.keys()) - if len(target_ch_names) == 0: + if not target_ch_names: raise ValueError( - "The provided montage does not contain any channel positions." + "The provided sensors configuration has no channel positions." ) - # Check the method is valid - _check_option("method", method, ["spline", "MNE"]) + # Get original channel order + orig_names = self.info["ch_names"] - # Ensure data is loaded - _check_preload(self, "interpolation") + # Identify EEG channel + picks_good_eeg = pick_types(self.info, meg=False, eeg=True, exclude="bads") + if len(picks_good_eeg) == 0: + raise ValueError("No good EEG channels available for interpolation.") + # Also get the full list of EEG channel indices (including bad channels) + picks_remove_eeg = pick_types(self.info, meg=False, eeg=True, exclude=[]) + eeg_names_orig = [orig_names[i] for i in picks_remove_eeg] - # Extract positions and data for EEG channels - picks_from = pick_types(self.info, meg=False, eeg=True, exclude=[]) - if len(picks_from) == 0: - raise ValueError("No EEG channels available for interpolation.") + # Identify non-EEG channels in original order + non_eeg_names_ordered = [ch for ch in orig_names if ch not in eeg_names_orig] - # Create a new info structure + # Create destination info for new EEG channels sfreq = self.info["sfreq"] - ch_types = ["eeg"] * len(target_ch_names) - new_info = create_info(ch_names=target_ch_names, sfreq=sfreq, ch_types=ch_types) - new_info.set_montage(montage) + info_interp = create_info( + ch_names=target_ch_names, + sfreq=sfreq, + ch_types=["eeg"] * len(target_ch_names), + ) + info_interp.set_montage(sensors) + info_interp["bads"] = [ch for ch in self.info["bads"] if ch in target_ch_names] + # Do not assign "projs" directly. - # Compute mapping from current montage to target montage + # Compute the interpolation mapping if method == "spline": - # pos_from = np.array( - # [self.info["chs"][idx]["loc"][:3] for idx in picks_from] - # ) - - origin = _check_origin(origin, self.info) - pos_from = self.info._get_channel_positions(picks_from) - pos_from = pos_from - origin + origin_val = _check_origin(origin, self.info) + pos_from = self.info._get_channel_positions(picks_good_eeg) - origin_val pos_to = np.stack(list(ch_pos.values()), axis=0) def _check_pos_sphere(pos): - distance = np.linalg.norm(pos, axis=-1) - distance = np.mean(distance / np.mean(distance)) - if np.abs(1.0 - distance) > 0.1: - warn( - "Your spherical fit is poor, interpolation results are " - "likely to be inaccurate." - ) + d = np.linalg.norm(pos, axis=-1) + d_norm = np.mean(d / np.mean(d)) + if np.abs(1.0 - d_norm) > 0.1: + warn("Your spherical fit is poor; interpolation may be inaccurate.") _check_pos_sphere(pos_from) _check_pos_sphere(pos_to) - mapping = _make_interpolation_matrix(pos_from, pos_to, alpha=reg) elif method == "MNE": - info_eeg = pick_info(self.info, picks_from) + info_eeg = pick_info(self.info, picks_good_eeg) + # If the original info has an average EEG reference projector but + # the destination info does not, + # update info_interp via a temporary RawArray. + if _has_eeg_average_ref_proj(self.info) and not _has_eeg_average_ref_proj( + info_interp + ): + # Create dummy data: shape (n_channels, 1) + temp_data = np.zeros((len(info_interp["ch_names"]), 1)) + temp_raw = RawArray(temp_data, info_interp, first_samp=0) + # Using the public API, add an average reference projector. + temp_raw.set_eeg_reference( + ref_channels="average", projection=True, verbose=False + ) + # Extract the updated info. + info_interp = temp_raw.info mapping = _map_meg_or_eeg_channels( - info_eeg, new_info, mode="accurate", origin="auto" + info_eeg, info_interp, mode="accurate", origin=origin ) - - # Apply the interpolation mapping - data_orig = self.get_data(picks=picks_from) - data_interp = mapping.dot(data_orig) - - # Update bad channels - new_info["bads"] = [ch for ch in self.info["bads"] if ch in target_ch_names] - - # Update the instance's info and data - self.info = new_info - self._data = data_interp - - return self + else: + raise ValueError("Unsupported interpolation method.") + + # Interpolate EEG data + data_good = self.get_data(picks=picks_good_eeg) + data_interp = mapping.dot(data_good) + + # Create a new instance for the interpolated EEG channels + if hasattr(self, "first_samp"): # assume Raw if first_samp exists. + inst_interp = RawArray(data_interp, info_interp, first_samp=self.first_samp) + elif hasattr( + self, "drop_bad_epochs" + ): # assume Epochs if drop_bad_epochs exists. + inst_interp = EpochsArray(data_interp, info_interp) + else: + inst_interp = EvokedArray(data_interp, info_interp) + + # Merge only if non-EEG channels exist + if not non_eeg_names_ordered: + return inst_interp + + inst_non_eeg = self.copy().pick(non_eeg_names_ordered).load_data() + inst_out = inst_non_eeg.add_channels([inst_interp], force_update_info=True) + + # Reorder channels + # Insert the entire new EEG block at the position of the first EEG channel. + new_order = [] + inserted = False + for ch in orig_names: + if ch in eeg_names_orig: + if not inserted: + new_order.extend(info_interp["ch_names"]) + inserted = True + # Skip original EEG channels. + else: + new_order.append(ch) + inst_out.reorder_channels(new_order) + return inst_out @verbose diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index a2e5c2dc786..5f57eb76d45 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -453,6 +453,9 @@ def test_interpolate_to_eeg(montage_name, method): # Load data raw.load_data() + # Get EEG channels + original_eeg_chan = [raw.ch_names[i] for i in pick_types(raw.info, eeg=True)] + # Create a target montage montage = make_standard_montage(montage_name) @@ -460,18 +463,18 @@ def test_interpolate_to_eeg(montage_name, method): raw_interpolated = raw.copy().interpolate_to(montage, method=method) # Check if channel names match the target montage - assert set(raw_interpolated.info["ch_names"]) == set(montage.ch_names) + assert set(montage.ch_names).issubset(set(raw_interpolated.info["ch_names"])) # Check if the data was interpolated correctly - assert raw_interpolated.get_data().shape == (len(montage.ch_names), raw.n_times) + n_ch = len(raw.info["ch_names"]) - len(original_eeg_chan) + len(montage.ch_names) + assert raw_interpolated.get_data().shape == (n_ch, raw.n_times) # Ensure original data is not altered assert raw.info["ch_names"] != raw_interpolated.info["ch_names"] assert raw.get_data().shape == (len(raw.info["ch_names"]), raw.n_times) # Validate that bad channels are carried over - raw.info["bads"] = [raw.info["ch_names"][0]] + bads = [raw.info["ch_names"][0]] + raw.info["bads"] = bads raw_interpolated = raw.copy().interpolate_to(montage, method=method) - assert raw_interpolated.info["bads"] == [ - ch for ch in raw.info["bads"] if ch in montage.ch_names - ] + assert raw_interpolated.info["bads"] == bads From 5cf75694a09a28a8b297761dee5916ad049c21d6 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Thu, 6 Feb 2025 13:02:19 +0100 Subject: [PATCH 26/40] doc --- mne/channels/channels.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 03467ed902b..8ef49f67197 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -978,9 +978,9 @@ def interpolate_to(self, sensors, origin="auto", method="spline", reg=0.0): method : str Method to use for EEG channels. Supported methods are 'spline' (default) and 'MNE'. - reg : float - The regularization parameter for the interpolation method (if applicable). + The regularization parameter for the interpolation method + (only used when the method is 'spline'). Returns ------- From 9bc0e5921c87ba9388e13eadcd8c409145d6e6ab Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Thu, 6 Feb 2025 13:44:02 +0100 Subject: [PATCH 27/40] test epochs and evoked --- mne/channels/channels.py | 13 +++--- mne/channels/tests/test_interpolation.py | 59 +++++++++++++++--------- 2 files changed, 42 insertions(+), 30 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 8ef49f67197..c2df699605c 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -994,10 +994,11 @@ def interpolate_to(self, sensors, origin="auto", method="spline", reg=0.0): .. versionadded:: 1.10.0 """ from .._fiff.proj import _has_eeg_average_ref_proj - from ..epochs import EpochsArray - from ..evoked import EvokedArray + from ..epochs import BaseEpochs, EpochsArray + from ..evoked import Evoked, EvokedArray from ..forward._field_interpolation import _map_meg_or_eeg_channels from ..io import RawArray + from ..io.base import BaseRaw from .interpolation import _make_interpolation_matrix # Check that the method option is valid. @@ -1082,13 +1083,11 @@ def _check_pos_sphere(pos): data_interp = mapping.dot(data_good) # Create a new instance for the interpolated EEG channels - if hasattr(self, "first_samp"): # assume Raw if first_samp exists. + if isinstance(self, BaseRaw): inst_interp = RawArray(data_interp, info_interp, first_samp=self.first_samp) - elif hasattr( - self, "drop_bad_epochs" - ): # assume Epochs if drop_bad_epochs exists. + elif isinstance(self, BaseEpochs): inst_interp = EpochsArray(data_interp, info_interp) - else: + elif isinstance(self, Evoked): inst_interp = EvokedArray(data_interp, info_interp) # Merge only if non-EEG channels exist diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index 5f57eb76d45..2bb23002918 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -443,38 +443,51 @@ def test_method_str(): @pytest.mark.parametrize("montage_name", ["biosemi16", "standard_1020"]) @pytest.mark.parametrize("method", ["spline", "MNE"]) -def test_interpolate_to_eeg(montage_name, method): - """Test the interpolate_to method for EEG.""" +@pytest.mark.parametrize("data_type", ["raw", "epochs", "evoked"]) +def test_interpolate_to_eeg(montage_name, method, data_type): + """Test the interpolate_to method for EEG for raw, epochs, and evoked.""" # Load EEG data raw, epochs_eeg = _load_data("eeg") epochs_eeg = epochs_eeg.copy() assert not _has_eeg_average_ref_proj(epochs_eeg.info) - # Load data + # Load data for raw raw.load_data() - # Get EEG channels + # Get the original EEG channel names from raw (used later for shape checking) original_eeg_chan = [raw.ch_names[i] for i in pick_types(raw.info, eeg=True)] # Create a target montage montage = make_standard_montage(montage_name) - # Copy the raw object and apply interpolation - raw_interpolated = raw.copy().interpolate_to(montage, method=method) - - # Check if channel names match the target montage - assert set(montage.ch_names).issubset(set(raw_interpolated.info["ch_names"])) - - # Check if the data was interpolated correctly - n_ch = len(raw.info["ch_names"]) - len(original_eeg_chan) + len(montage.ch_names) - assert raw_interpolated.get_data().shape == (n_ch, raw.n_times) - - # Ensure original data is not altered - assert raw.info["ch_names"] != raw_interpolated.info["ch_names"] - assert raw.get_data().shape == (len(raw.info["ch_names"]), raw.n_times) - - # Validate that bad channels are carried over - bads = [raw.info["ch_names"][0]] - raw.info["bads"] = bads - raw_interpolated = raw.copy().interpolate_to(montage, method=method) - assert raw_interpolated.info["bads"] == bads + # Prepare data to interpolate to + if data_type == "raw": + inst = raw.copy() + elif data_type == "epochs": + inst = epochs_eeg.copy() + elif data_type == "evoked": + inst = epochs_eeg.average() + shape = list(inst._data.shape) + + # Interpolate + inst_interp = inst.copy().interpolate_to(montage, method=method) + + # Check that the new channel names include the montage channels. + assert set(montage.ch_names).issubset(set(inst_interp.info["ch_names"])) + + # Check that the original container's channel ordering is changed. + assert inst.info["ch_names"] != inst_interp.info["ch_names"] + + # Check that the data shape is as expected. + orig_total = len(inst.info["ch_names"]) + n_eeg_orig = len(original_eeg_chan) + new_nchan_expected = orig_total - n_eeg_orig + len(montage.ch_names) + expected_shape = tuple([new_nchan_expected] + shape[1:]) + assert inst_interp._data.shape == expected_shape + + # Validate that bad channels are carried over. + # Mark the first channel as bad. + bads = [inst.info["ch_names"][0]] + inst.info["bads"] = bads + inst_interp = inst.copy().interpolate_to(montage, method=method) + assert inst_interp.info["bads"] == bads From 2db82cc5d87ffd08c04065363a2b57d43c2fdecb Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Thu, 6 Feb 2025 16:44:42 +0100 Subject: [PATCH 28/40] fix epochs --- mne/channels/channels.py | 1 + mne/channels/tests/test_interpolation.py | 33 +++++++++++++++--------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index c2df699605c..fdfeaaa7d8c 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -1086,6 +1086,7 @@ def _check_pos_sphere(pos): if isinstance(self, BaseRaw): inst_interp = RawArray(data_interp, info_interp, first_samp=self.first_samp) elif isinstance(self, BaseEpochs): + data_interp = np.transpose(data_interp, (1, 0, 2)) inst_interp = EpochsArray(data_interp, info_interp) elif isinstance(self, Evoked): inst_interp = EvokedArray(data_interp, info_interp) diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index 2bb23002918..10418df56ea 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -449,14 +449,10 @@ def test_interpolate_to_eeg(montage_name, method, data_type): # Load EEG data raw, epochs_eeg = _load_data("eeg") epochs_eeg = epochs_eeg.copy() - assert not _has_eeg_average_ref_proj(epochs_eeg.info) # Load data for raw raw.load_data() - # Get the original EEG channel names from raw (used later for shape checking) - original_eeg_chan = [raw.ch_names[i] for i in pick_types(raw.info, eeg=True)] - # Create a target montage montage = make_standard_montage(montage_name) @@ -468,6 +464,8 @@ def test_interpolate_to_eeg(montage_name, method, data_type): elif data_type == "evoked": inst = epochs_eeg.average() shape = list(inst._data.shape) + orig_total = len(inst.info["ch_names"]) + n_eeg_orig = len(pick_types(inst.info, eeg=True)) # Interpolate inst_interp = inst.copy().interpolate_to(montage, method=method) @@ -479,15 +477,26 @@ def test_interpolate_to_eeg(montage_name, method, data_type): assert inst.info["ch_names"] != inst_interp.info["ch_names"] # Check that the data shape is as expected. - orig_total = len(inst.info["ch_names"]) - n_eeg_orig = len(original_eeg_chan) new_nchan_expected = orig_total - n_eeg_orig + len(montage.ch_names) - expected_shape = tuple([new_nchan_expected] + shape[1:]) + if len(shape) == 2: + expected_shape = tuple([new_nchan_expected, shape[1]]) + elif len(shape) == 3: + expected_shape = tuple([shape[0], new_nchan_expected, shape[2]]) assert inst_interp._data.shape == expected_shape # Validate that bad channels are carried over. - # Mark the first channel as bad. - bads = [inst.info["ch_names"][0]] - inst.info["bads"] = bads - inst_interp = inst.copy().interpolate_to(montage, method=method) - assert inst_interp.info["bads"] == bads + # Mark the first non eeg channel as bad + bads = None + i = 0 + all_ch = inst_interp.info["ch_names"] + eeg_ch = pick_types(inst_interp.info, eeg=True) + eeg_ch = [all_ch[i] for i in eeg_ch] + while (i < len(all_ch)) and (bads is None): + ch = all_ch[i] + if ch not in eeg_ch: + bads = [ch] + i += 1 + if bads is not None: + inst.info["bads"] = bads + inst_interp = inst.copy().interpolate_to(montage, method=method) + assert inst_interp.info["bads"] == bads From 3b52d4d6ddd8f5026f8f06979b1d1d242b0ad59e Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Thu, 6 Feb 2025 17:16:24 +0100 Subject: [PATCH 29/40] ylim --- examples/preprocessing/interpolate_to.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/preprocessing/interpolate_to.py b/examples/preprocessing/interpolate_to.py index 50cd31ca0af..f1de9187fb3 100644 --- a/examples/preprocessing/interpolate_to.py +++ b/examples/preprocessing/interpolate_to.py @@ -27,6 +27,7 @@ from mne.datasets import sample print(__doc__) +ylim = (-10, 10) # %% # Load EEG data @@ -38,7 +39,7 @@ evoked.pick("eeg") # Plot the original EEG layout -evoked.plot(exclude=[], picks="eeg") +evoked.plot(exclude=[], picks="eeg", ylim=dict(eeg=ylim)) # %% # Define the target montage @@ -51,23 +52,28 @@ ) # Plot the interpolated EEG layout -evoked_interpolated_spline.plot(exclude=[], picks="eeg") +evoked_interpolated_spline.plot(exclude=[], picks="eeg", ylim=dict(eeg=ylim)) # %% # Use interpolate_to to project EEG data to the standard montage evoked_interpolated_mne = evoked.copy().interpolate_to(standard_montage, method="MNE") # Plot the interpolated EEG layout -evoked_interpolated_mne.plot(exclude=[], picks="eeg") +evoked_interpolated_mne.plot(exclude=[], picks="eeg", ylim=dict(eeg=ylim)) # %% # Comparing before and after interpolation fig, axs = plt.subplots(3, 1, figsize=(8, 6), constrained_layout=True) -evoked.plot(exclude=[], picks="eeg", axes=axs[0], show=False) +evoked.plot(exclude=[], picks="eeg", axes=axs[0], show=False, ylim=dict(eeg=ylim)) axs[0].set_title("Original EEG Layout") -evoked_interpolated_spline.plot(exclude=[], picks="eeg", axes=axs[1], show=False) +evoked_interpolated_spline.plot( + exclude=[], picks="eeg", axes=axs[1], show=False, ylim=dict(eeg=ylim) +) axs[1].set_title("Interpolated to Standard 1020 Montage using spline interpolation") evoked_interpolated_mne.plot(exclude=[], picks="eeg", axes=axs[2], show=False) +evoked_interpolated_mne.plot( + exclude=[], picks="eeg", axes=axs[2], show=False, ylim=dict(eeg=ylim) +) axs[2].set_title("Interpolated to Standard 1020 Montage using MNE interpolation") # %% From a670566496243f2fbaddc0bc40c8321810a6d1aa Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Thu, 6 Feb 2025 17:16:37 +0100 Subject: [PATCH 30/40] Update interpolate_to.py --- examples/preprocessing/interpolate_to.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/preprocessing/interpolate_to.py b/examples/preprocessing/interpolate_to.py index f1de9187fb3..b97a7251cbb 100644 --- a/examples/preprocessing/interpolate_to.py +++ b/examples/preprocessing/interpolate_to.py @@ -70,7 +70,6 @@ exclude=[], picks="eeg", axes=axs[1], show=False, ylim=dict(eeg=ylim) ) axs[1].set_title("Interpolated to Standard 1020 Montage using spline interpolation") -evoked_interpolated_mne.plot(exclude=[], picks="eeg", axes=axs[2], show=False) evoked_interpolated_mne.plot( exclude=[], picks="eeg", axes=axs[2], show=False, ylim=dict(eeg=ylim) ) From e7112184a3173f492b45829d84e7a605e8968c3f Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Sun, 9 Feb 2025 17:41:09 +0100 Subject: [PATCH 31/40] review --- mne/channels/channels.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index fdfeaaa7d8c..5121c15ae43 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -999,12 +999,12 @@ def interpolate_to(self, sensors, origin="auto", method="spline", reg=0.0): from ..forward._field_interpolation import _map_meg_or_eeg_channels from ..io import RawArray from ..io.base import BaseRaw + from . import DigMontage from .interpolation import _make_interpolation_matrix # Check that the method option is valid. _check_option("method", method, ["spline", "MNE"]) - if sensors is None: - raise ValueError("A sensors configuration must be provided.") + _validate_type(sensors, DigMontage, "sensors") # Get target positions from the montage ch_pos = sensors.get_positions().get("ch_pos", {}) @@ -1055,7 +1055,8 @@ def _check_pos_sphere(pos): _check_pos_sphere(pos_to) mapping = _make_interpolation_matrix(pos_from, pos_to, alpha=reg) - elif method == "MNE": + else: + assert method == "MNE" info_eeg = pick_info(self.info, picks_good_eeg) # If the original info has an average EEG reference projector but # the destination info does not, @@ -1075,20 +1076,18 @@ def _check_pos_sphere(pos): mapping = _map_meg_or_eeg_channels( info_eeg, info_interp, mode="accurate", origin=origin ) - else: - raise ValueError("Unsupported interpolation method.") # Interpolate EEG data data_good = self.get_data(picks=picks_good_eeg) - data_interp = mapping.dot(data_good) + data_interp = mapping @ data_good # Create a new instance for the interpolated EEG channels if isinstance(self, BaseRaw): inst_interp = RawArray(data_interp, info_interp, first_samp=self.first_samp) elif isinstance(self, BaseEpochs): - data_interp = np.transpose(data_interp, (1, 0, 2)) inst_interp = EpochsArray(data_interp, info_interp) - elif isinstance(self, Evoked): + else: + assert isinstance(self, Evoked) inst_interp = EvokedArray(data_interp, info_interp) # Merge only if non-EEG channels exist From 056de4953202e2ec1127a1c65a95e611ec5fa289 Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Sun, 9 Feb 2025 18:18:17 +0100 Subject: [PATCH 32/40] vectorize --- mne/channels/channels.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 5121c15ae43..3e5bd736d3d 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -1099,16 +1099,17 @@ def _check_pos_sphere(pos): # Reorder channels # Insert the entire new EEG block at the position of the first EEG channel. - new_order = [] - inserted = False - for ch in orig_names: - if ch in eeg_names_orig: - if not inserted: - new_order.extend(info_interp["ch_names"]) - inserted = True - # Skip original EEG channels. - else: - new_order.append(ch) + orig_names_arr = np.array(orig_names) + mask_eeg = np.isin(orig_names_arr, eeg_names_orig) + if mask_eeg.any(): + first_eeg_index = np.where(mask_eeg)[0][0] + pre = orig_names_arr[:first_eeg_index] + new_eeg = np.array(info_interp["ch_names"]) + post = orig_names_arr[first_eeg_index:] + post = post[~np.isin(orig_names_arr[first_eeg_index:], eeg_names_orig)] + new_order = np.concatenate((pre, new_eeg, post)).tolist() + else: + new_order = orig_names inst_out.reorder_channels(new_order) return inst_out From 24979fa4ab5f69ea89d3d2f1f0eae1a75dc7ec9b Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Sun, 9 Feb 2025 18:49:44 +0100 Subject: [PATCH 33/40] comments about loss of infos --- mne/channels/channels.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 3e5bd736d3d..ecc9e1ba8ef 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -990,6 +990,7 @@ def interpolate_to(self, sensors, origin="auto", method="spline", reg=0.0): Notes ----- This method is useful for standardizing EEG layouts across datasets. + However, some attributes may be lost after interpolation. .. versionadded:: 1.10.0 """ @@ -1082,6 +1083,9 @@ def _check_pos_sphere(pos): data_interp = mapping @ data_good # Create a new instance for the interpolated EEG channels + # TODO: Creating a new instance leads to a loss of information. + # We should consider updating the existing instance in the future + # by 1) drop channels, 2) add channels, 3) re-order channels. if isinstance(self, BaseRaw): inst_interp = RawArray(data_interp, info_interp, first_samp=self.first_samp) elif isinstance(self, BaseEpochs): From e93fa5e1fd5bebe960fbef262eb39bc58ba00c6c Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Sun, 9 Feb 2025 19:35:29 +0100 Subject: [PATCH 34/40] add test on odering --- mne/channels/tests/test_interpolation.py | 56 +++++++++++++++++++++--- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index 10418df56ea..3de6c37f7c8 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -467,23 +467,69 @@ def test_interpolate_to_eeg(montage_name, method, data_type): orig_total = len(inst.info["ch_names"]) n_eeg_orig = len(pick_types(inst.info, eeg=True)) - # Interpolate + # Assert first and last channels are not EEG + if data_type == "raw": + ch_types = inst.get_channel_types() + assert ch_types[0] != "eeg" + assert ch_types[-1] != "eeg" + + # Record the names and data of the first and last channels. + if data_type == "raw": + first_name = inst.info["ch_names"][0] + last_name = inst.info["ch_names"][-1] + if inst._data.ndim == 2: + data_first = inst._data[0].copy() + data_last = inst._data[-1].copy() + elif inst._data.ndim == 3: + data_first = inst._data[:, 0, :].copy() + data_last = inst._data[:, -1, :].copy() + + # Interpolate the EEG channels. inst_interp = inst.copy().interpolate_to(montage, method=method) # Check that the new channel names include the montage channels. assert set(montage.ch_names).issubset(set(inst_interp.info["ch_names"])) - - # Check that the original container's channel ordering is changed. + # Check that the overall channel order is changed. assert inst.info["ch_names"] != inst_interp.info["ch_names"] # Check that the data shape is as expected. new_nchan_expected = orig_total - n_eeg_orig + len(montage.ch_names) if len(shape) == 2: - expected_shape = tuple([new_nchan_expected, shape[1]]) + expected_shape = (new_nchan_expected, shape[1]) elif len(shape) == 3: - expected_shape = tuple([shape[0], new_nchan_expected, shape[2]]) + expected_shape = (shape[0], new_nchan_expected, shape[2]) assert inst_interp._data.shape == expected_shape + # Verify that the first and last channels retain their positions. + if data_type == "raw": + assert inst_interp.info["ch_names"][0] == first_name + assert inst_interp.info["ch_names"][-1] == last_name + + # Verify that the data for the first and last channels is unchanged. + if data_type == "raw": + if inst_interp._data.ndim == 2: + np.testing.assert_allclose( + inst_interp._data[0], + data_first, + err_msg="Data for the first non-EEG channel has changed.", + ) + np.testing.assert_allclose( + inst_interp._data[-1], + data_last, + err_msg="Data for the last non-EEG channel has changed.", + ) + elif inst_interp._data.ndim == 3: + np.testing.assert_allclose( + inst_interp._data[:, 0, :], + data_first, + err_msg="Data for the first non-EEG channel has changed.", + ) + np.testing.assert_allclose( + inst_interp._data[:, -1, :], + data_last, + err_msg="Data for the last non-EEG channel has changed.", + ) + # Validate that bad channels are carried over. # Mark the first non eeg channel as bad bads = None From b399feb033902cd756d2e3f771c719594371ee7f Mon Sep 17 00:00:00 2001 From: Antoine Collas Date: Mon, 10 Feb 2025 17:52:59 +0100 Subject: [PATCH 35/40] import --- mne/channels/channels.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index ecc9e1ba8ef..d0e57eecb5f 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -41,7 +41,7 @@ pick_info, pick_types, ) -from .._fiff.proj import setup_proj +from .._fiff.proj import _has_eeg_average_ref_proj, setup_proj from .._fiff.reference import add_reference_channels, set_eeg_reference from .._fiff.tag import _rename_list from ..bem import _check_origin @@ -994,14 +994,13 @@ def interpolate_to(self, sensors, origin="auto", method="spline", reg=0.0): .. versionadded:: 1.10.0 """ - from .._fiff.proj import _has_eeg_average_ref_proj from ..epochs import BaseEpochs, EpochsArray from ..evoked import Evoked, EvokedArray from ..forward._field_interpolation import _map_meg_or_eeg_channels from ..io import RawArray from ..io.base import BaseRaw - from . import DigMontage from .interpolation import _make_interpolation_matrix + from .montage import DigMontage # Check that the method option is valid. _check_option("method", method, ["spline", "MNE"]) From 3aaa870ac8003822bda2bd18961951df5d96cfe0 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 10 Feb 2025 12:17:42 -0500 Subject: [PATCH 36/40] FIX: Changelog --- doc/changes/devel/13044.newfeature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 doc/changes/devel/13044.newfeature.rst diff --git a/doc/changes/devel/13044.newfeature.rst b/doc/changes/devel/13044.newfeature.rst new file mode 100644 index 00000000000..9633aba66b9 --- /dev/null +++ b/doc/changes/devel/13044.newfeature.rst @@ -0,0 +1 @@ +Add :meth:`mne.Evoked.interpolate_to` to allow interpolating EEG data to other montages, by :newcontrib:`Antoine Collas`. From 68a05926ae77f0fb3e82873f54560aa3d82c3b18 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 10 Feb 2025 12:19:12 -0500 Subject: [PATCH 37/40] Apply suggestions from code review --- mne/channels/tests/test_interpolation.py | 61 ++++++++---------------- 1 file changed, 20 insertions(+), 41 deletions(-) diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index 3de6c37f7c8..769015c993f 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -477,12 +477,8 @@ def test_interpolate_to_eeg(montage_name, method, data_type): if data_type == "raw": first_name = inst.info["ch_names"][0] last_name = inst.info["ch_names"][-1] - if inst._data.ndim == 2: - data_first = inst._data[0].copy() - data_last = inst._data[-1].copy() - elif inst._data.ndim == 3: - data_first = inst._data[:, 0, :].copy() - data_last = inst._data[:, -1, :].copy() + data_first = inst._data[..., 0, :].copy() + data_last = inst._data[..., -1, :].copy() # Interpolate the EEG channels. inst_interp = inst.copy().interpolate_to(montage, method=method) @@ -494,10 +490,9 @@ def test_interpolate_to_eeg(montage_name, method, data_type): # Check that the data shape is as expected. new_nchan_expected = orig_total - n_eeg_orig + len(montage.ch_names) - if len(shape) == 2: - expected_shape = (new_nchan_expected, shape[1]) - elif len(shape) == 3: - expected_shape = (shape[0], new_nchan_expected, shape[2]) + expected_shape = (new_nchan_expected, shape[1]) + if len(shape) == 3: + expected_shape = (shape[0],) + expected_shape assert inst_interp._data.shape == expected_shape # Verify that the first and last channels retain their positions. @@ -507,28 +502,16 @@ def test_interpolate_to_eeg(montage_name, method, data_type): # Verify that the data for the first and last channels is unchanged. if data_type == "raw": - if inst_interp._data.ndim == 2: - np.testing.assert_allclose( - inst_interp._data[0], - data_first, - err_msg="Data for the first non-EEG channel has changed.", - ) - np.testing.assert_allclose( - inst_interp._data[-1], - data_last, - err_msg="Data for the last non-EEG channel has changed.", - ) - elif inst_interp._data.ndim == 3: - np.testing.assert_allclose( - inst_interp._data[:, 0, :], - data_first, - err_msg="Data for the first non-EEG channel has changed.", - ) - np.testing.assert_allclose( - inst_interp._data[:, -1, :], - data_last, - err_msg="Data for the last non-EEG channel has changed.", - ) + np.testing.assert_allclose( + inst_interp._data[..., 0, :], + data_first, + err_msg="Data for the first non-EEG channel has changed.", + ) + np.testing.assert_allclose( + inst_interp._data[..., -1, :], + data_last, + err_msg="Data for the last non-EEG channel has changed.", + ) # Validate that bad channels are carried over. # Mark the first non eeg channel as bad @@ -537,12 +520,8 @@ def test_interpolate_to_eeg(montage_name, method, data_type): all_ch = inst_interp.info["ch_names"] eeg_ch = pick_types(inst_interp.info, eeg=True) eeg_ch = [all_ch[i] for i in eeg_ch] - while (i < len(all_ch)) and (bads is None): - ch = all_ch[i] - if ch not in eeg_ch: - bads = [ch] - i += 1 - if bads is not None: - inst.info["bads"] = bads - inst_interp = inst.copy().interpolate_to(montage, method=method) - assert inst_interp.info["bads"] == bads + # just the first non-EEG channel (if available) + bads = [ch for ch in all_ch if ch not in eeg_ch][:1] + inst.info["bads"] = bads + inst_interp = inst.copy().interpolate_to(montage, method=method) + assert inst_interp.info["bads"] == bads From 16c648f8202f8105b690d49ea2abcc6e8f1680cf Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 10 Feb 2025 12:20:08 -0500 Subject: [PATCH 38/40] FIX: Flake --- mne/channels/tests/test_interpolation.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index 769015c993f..d22a1ffab2e 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -515,12 +515,8 @@ def test_interpolate_to_eeg(montage_name, method, data_type): # Validate that bad channels are carried over. # Mark the first non eeg channel as bad - bads = None - i = 0 all_ch = inst_interp.info["ch_names"] - eeg_ch = pick_types(inst_interp.info, eeg=True) - eeg_ch = [all_ch[i] for i in eeg_ch] - # just the first non-EEG channel (if available) + eeg_ch = [all_ch[i] for i in pick_types(inst_interp.info, eeg=True)] bads = [ch for ch in all_ch if ch not in eeg_ch][:1] inst.info["bads"] = bads inst_interp = inst.copy().interpolate_to(montage, method=method) From ee32071cf56ddb2cd260e11f4e658986196e7c57 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 10 Feb 2025 12:27:25 -0500 Subject: [PATCH 39/40] FIX: Name --- doc/changes/names.inc | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 18753a0c872..282fa8341a0 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -24,6 +24,7 @@ .. _Anna Padee: https://github.com/apadee/ .. _Annalisa Pascarella: https://www.iac.cnr.it/personale/annalisa-pascarella .. _Anne-Sophie Dubarry: https://github.com/annesodub +.. _Antoine Collas: https://www.antoinecollas.fr .. _Antoine Gauthier: https://github.com/Okamille .. _Antti Rantala: https://github.com/Odingod .. _Apoorva Karekal: https://github.com/apoorva6262 From 5460dbbdfaaab409c936c443fe7891b26421c322 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 10 Feb 2025 12:51:11 -0500 Subject: [PATCH 40/40] FIX: Dim --- mne/channels/tests/test_interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index d22a1ffab2e..62c7d79e3eb 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -490,7 +490,7 @@ def test_interpolate_to_eeg(montage_name, method, data_type): # Check that the data shape is as expected. new_nchan_expected = orig_total - n_eeg_orig + len(montage.ch_names) - expected_shape = (new_nchan_expected, shape[1]) + expected_shape = (new_nchan_expected, shape[-1]) if len(shape) == 3: expected_shape = (shape[0],) + expected_shape assert inst_interp._data.shape == expected_shape