From 8eaa521270c17096586e45211f4df8a9f6bb57be Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Mon, 10 Feb 2025 15:12:39 +0000 Subject: [PATCH 1/2] Add support for n-dimensional arrays in `_tfr_from_mt` (#13104) Co-authored-by: Eric Larson --- mne/time_frequency/tfr.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index f4a01e87895..0c8bb0f4fb0 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -4291,19 +4291,20 @@ def _tfr_from_mt(x_mt, weights): Parameters ---------- - x_mt : array, shape (n_channels, n_tapers, n_freqs, n_times) + x_mt : array, shape (..., n_tapers, n_freqs, n_times) The complex-valued multitaper coefficients. weights : array, shape (n_tapers, n_freqs) The weights to use to combine the tapered estimates. Returns ------- - tfr : array, shape (n_channels, n_freqs, n_times) + tfr : array, shape (..., n_freqs, n_times) The time-frequency power estimates. """ - weights = weights[np.newaxis, :, :, np.newaxis] # add singleton channel & time dims + # add singleton dim for time and any dims preceding the tapers + weights = weights[..., np.newaxis] tfr = weights * x_mt tfr *= tfr.conj() - tfr = tfr.real.sum(axis=1) - tfr *= 2 / (weights * weights.conj()).real.sum(axis=1) + tfr = tfr.real.sum(axis=-3) + tfr *= 2 / (weights * weights.conj()).real.sum(axis=-3) return tfr From 64ed25561841e6d91413564222e498c44244ad94 Mon Sep 17 00:00:00 2001 From: Stefan Appelhoff Date: Mon, 10 Feb 2025 17:38:04 +0100 Subject: [PATCH 2/2] add overwrite and verbose params to info.save (#13107) Co-authored-by: Eric Larson --- doc/changes/devel/13107.newfeature.rst | 1 + mne/_fiff/meas_info.py | 13 +++++++++++-- mne/_fiff/tests/test_meas_info.py | 2 +- 3 files changed, 13 insertions(+), 3 deletions(-) create mode 100644 doc/changes/devel/13107.newfeature.rst diff --git a/doc/changes/devel/13107.newfeature.rst b/doc/changes/devel/13107.newfeature.rst new file mode 100644 index 00000000000..a19381fbdb2 --- /dev/null +++ b/doc/changes/devel/13107.newfeature.rst @@ -0,0 +1 @@ +The :meth:`mne.Info.save` method now has an ``overwrite`` and a ``verbose`` parameter, by `Stefan Appelhoff`_. diff --git a/mne/_fiff/meas_info.py b/mne/_fiff/meas_info.py index 51612824a6a..28f0629c323 100644 --- a/mne/_fiff/meas_info.py +++ b/mne/_fiff/meas_info.py @@ -1935,15 +1935,24 @@ def _repr_html_(self): info_template = _get_html_template("repr", "info.html.jinja") return info_template.render(info=self) - def save(self, fname): + @verbose + def save(self, fname, *, overwrite=False, verbose=None): """Write measurement info in fif file. Parameters ---------- fname : path-like The name of the file. Should end by ``'-info.fif'``. + %(overwrite)s + + .. versionadded:: 1.10 + %(verbose)s + + See Also + -------- + mne.io.write_info """ - write_info(fname, self) + write_info(fname, self, overwrite=overwrite) def _simplify_info(info, *, keep=()): diff --git a/mne/_fiff/tests/test_meas_info.py b/mne/_fiff/tests/test_meas_info.py index a38ecaade50..9c0639a830c 100644 --- a/mne/_fiff/tests/test_meas_info.py +++ b/mne/_fiff/tests/test_meas_info.py @@ -975,7 +975,7 @@ def test_field_round_trip(tmp_path): meas_date=_stamp_to_dt((1, 2)), ) fname = tmp_path / "temp-info.fif" - write_info(fname, info) + info.save(fname) info_read = read_info(fname) assert_object_equal(info, info_read) with pytest.raises(TypeError, match="datetime"):