Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Within Session splitter #664

Open
wants to merge 60 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 57 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
bacedc5
Creating new splitters and base evaluation
brunaafl Jun 6, 2024
419b2ca
Adding metasplitters
brunaafl Jun 7, 2024
d6e795d
Fixing LazyEvaluation
brunaafl Jun 10, 2024
140670c
Merge branch 'NeuroTechX:develop' into eval_splitters
brunaafl Jun 10, 2024
d724674
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2024
a278026
More optimized version of TimeSeriesSplit
brunaafl Jun 10, 2024
300a6b9
More optimized version of TimeSeriesSplit
brunaafl Jun 10, 2024
7cb79f6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2024
55db70f
Addressing some comments: documentation, types, inconsistencies
brunaafl Jun 10, 2024
2851a15
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2024
c73dd1a
Addressing some comments: optimizing code, adjusts
brunaafl Jun 12, 2024
2b0e735
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 12, 2024
cf4b709
Adding examples
brunaafl Jun 26, 2024
177bf65
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 26, 2024
a6b5772
Adding: Pytests for evaluation splitters, and examples for meta split…
brunaafl Aug 15, 2024
26b13d5
Changing: name of TimeSeriesSplit to PseudoOnlineSplit
brunaafl Sep 30, 2024
e6661c4
Merge branch 'develop' into eval_splitters
brunaafl Sep 30, 2024
430e3a8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2024
698e539
Fixing pre-commit
brunaafl Sep 30, 2024
0fff053
Merge remote-tracking branch 'origin/eval_splitters' into eval_splitters
brunaafl Sep 30, 2024
98d12ac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2024
558d27b
Adding some tests for metasplitters
brunaafl Oct 1, 2024
34ea645
Merge remote-tracking branch 'origin/eval_splitters' into eval_splitters
brunaafl Oct 1, 2024
b435bf8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2024
d8f26a3
Fixing pre-commit
brunaafl Oct 1, 2024
eaf0fb9
Merge remote-tracking branch 'origin/eval_splitters' into eval_splitters
brunaafl Oct 1, 2024
e5159f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2024
516a5e8
Fixing pre-commit
brunaafl Oct 1, 2024
b29ecd2
Merge remote-tracking branch 'origin/eval_splitters' into eval_splitters
brunaafl Oct 1, 2024
37cff03
Fix example SamplerSplit
brunaafl Oct 17, 2024
88ee910
Add shuffle and random_state parameters to WithinSession
brunaafl Oct 18, 2024
ea9cc59
Change nomenclature of variables
brunaafl Oct 18, 2024
819c4ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 18, 2024
65c305e
Merge branch 'develop' into within_session
bruAristimunha Oct 18, 2024
f1ad587
FIX: fixing the whats_new.rst file
bruAristimunha Oct 18, 2024
485e7a5
EHN: playing a little
bruAristimunha Oct 18, 2024
3f3742f
FIX: fixing the import and docs/docstring
bruAristimunha Oct 18, 2024
c181c59
FIX: fixing the import and docs/docstring
bruAristimunha Oct 18, 2024
8f034c8
FIX: fixing the import and docs/docstring
bruAristimunha Oct 18, 2024
fbef726
FIX: removing cross-session and cross-subject
bruAristimunha Oct 18, 2024
837c061
FIX: focus only in the within-session
bruAristimunha Oct 18, 2024
34822e9
Merge branch 'develop' into within_session
bruAristimunha Oct 19, 2024
39e92e5
Fix test
brunaafl Oct 19, 2024
612c6a6
Merge remote-tracking branch 'origin/within_session' into within_session
brunaafl Oct 19, 2024
590edb1
[FIX] I think it is fixed.
bruAristimunha Oct 23, 2024
b151d61
[FIX] shuffle everything
bruAristimunha Oct 23, 2024
602ccd5
Merge remote-tracking branch 'origin/within_session' into within_session
brunaafl Oct 25, 2024
74cf246
Changing WithinSession image
brunaafl Oct 25, 2024
c85928d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2024
f3020ee
Update moabb/evaluations/splitters.py
brunaafl Nov 25, 2024
c628fd7
Update moabb/evaluations/splitters.py
brunaafl Nov 25, 2024
83e425f
Update moabb/evaluations/splitters.py
brunaafl Nov 25, 2024
a785ef7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 25, 2024
686572e
Merge branch 'develop' into within_session
bruAristimunha Nov 26, 2024
8f8ada9
Adding possibility of passing a specific cv to do inner cv
brunaafl Nov 29, 2024
b87cf25
Merge remote-tracking branch 'origin/within_session' into within_session
brunaafl Nov 29, 2024
42e6b7b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 29, 2024
f66878a
Changing metasplitter behaviour to have the same behavior as other cr…
brunaafl Dec 8, 2024
71cebb1
Merge remote-tracking branch 'origin/within_session' into within_session
brunaafl Dec 8, 2024
a948175
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/evaluations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ Evaluations
CrossSubjectEvaluation


.. autosummary::
:toctree: generated/
:template: class.rst

WithinSessionSplitter


------------
Base & Utils
------------
Expand Down
Binary file added docs/source/images/withinsess.pdf
Binary file not shown.
Binary file added docs/source/images/withinsess.png
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better represent only one session @brunaafl

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Develop branch

Enhancements
~~~~~~~~~~~~
- Adding :class:`moabb.evaluations.splitters.WithinSessionSplitter` (:gh:`664` by `Bruna Lopes_`)

- Update version of pyRiemann to 0.7 (:gh:`671` by `Gregoire Cattan`_)
- Add columns definitions in the datasets doc (:gh:`672` by `Pierre Guetschel`_)
Expand Down
Binary file added images/withinsess.png
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not using only the other version?

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions moabb/evaluations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
CrossSubjectEvaluation,
WithinSessionEvaluation,
)
from .splitters import WithinSessionSplitter
from .utils import create_save_path, save_model_cv, save_model_list
87 changes: 87 additions & 0 deletions moabb/evaluations/metasplitters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from sklearn.model_selection import BaseCrossValidator

from moabb.evaluations.utils import sort_group


class PseudoOnlineSplit(BaseCrossValidator):
"""Pseudo-online split for evaluation test data.

It takes into account the time sequence for obtaining the test data, and uses first run,
or first #calib_size trials as calibration data, and the rest as evaluation data.
Calibration data is important in the context where data alignment or filtering is used on
training data.

OBS: Be careful! Since this inference split is based on time disposition of obtained data,
if your data is not organized by time, but by other parameter, such as class, you may want to
be extra careful when using this split.

Parameters
----------
calib_size: int
Size of calibration set, used if there is just one run.

Examples
--------
>>> import numpy as np
>>> import pandas as pd
>>> from moabb.evaluations.splitters import WithinSessionSplitter
>>> from moabb.evaluations.metasplitters import PseudoOnlineSplit
>>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [8, 9], [5, 4], [2, 5], [1, 7]])
>>> y = np.array([1, 2, 1, 2, 1, 2, 1, 2])
>>> subjects = np.array([1, 1, 1, 1, 1, 1, 1, 1])
>>> sessions = np.array([0, 0, 0, 0, 1, 1, 1, 1])
>>> runs = np.array(['0', '0', '1', '1', '0', '0', '1', '1'])
>>> metadata = pd.DataFrame(data={'subject': subjects, 'session': sessions, 'run':runs})
>>> posplit = PseudoOnlineSplit
>>> csubj = WithinSessionSplitter(cv=posplit, calib_size=1, custom_cv=True)
>>> posplit.get_n_splits(metadata)
2
>>> for i, (train_index, test_index) in enumerate(csubj.split(y, metadata)):
>>> print(f"Fold {i}:")
>>> print(f" Calibration: index={train_index}, group={subjects[train_index]}, sessions={sessions[train_index]}, runs={runs[train_index]}")
>>> print(f" Test: index={test_index}, group={subjects[test_index]}, sessions={sessions[test_index]}, runs={runs[test_index]}")

Fold 0:
Calibration: index=[6, 7], group=[1 1], sessions=[1 1], runs=['1' '1']
Test: index=[4, 5], group=[1 1], sessions=[1 1], runs=['0' '0']
Fold 1:
Calibration: index=[2, 3], group=[1 1], sessions=[0 0], runs=['1' '1']
Test: index=[0, 1], group=[1 1], sessions=[0 0], runs=['0' '0']
"""

def __init__(self, calib_size: int = None):
self.calib_size = calib_size

def get_n_splits(self, metadata):
return len(metadata.groupby(["subject", "session"]))

def split(self, indices, y, metadata=None):

if metadata is not None:
for _, group in metadata.groupby(["subject", "session"]):
runs = group.run.unique()
if len(runs) > 1:
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
# To guarantee that the runs are on the right order
runs = sort_group(runs)
for run in runs:
test_ix = group[group["run"] != run].index
calib_ix = group[group["run"] == run].index
yield list(test_ix), list(calib_ix)
break # Take the fist run as calibration
else:
if self.calib_size is None:
raise ValueError(
"Data contains just one run. Need to provide calibration size."
)
# Take first #calib_size samples as calibration
calib_size = self.calib_size
calib_ix = group[:calib_size].index
test_ix = group[calib_size:].index
brunaafl marked this conversation as resolved.
Show resolved Hide resolved

yield list(calib_ix), list(test_ix)

else:
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
if self.calib_size is None:
raise ValueError("Need to provide calibration size.")
calib_size = self.calib_size
yield list(indices[:calib_size]), list(indices[calib_size:])
133 changes: 133 additions & 0 deletions moabb/evaluations/splitters.py
tomMoral marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@

from sklearn.model_selection import BaseCrossValidator, StratifiedKFold
from sklearn.utils import check_random_state

from moabb.evaluations.metasplitters import PseudoOnlineSplit


class WithinSessionSplitter(BaseCrossValidator):
"""Data splitter for within session evaluation.

Within-session evaluation uses k-fold cross_validation to determine train
and test sets for each subject in each session. This splitter
assumes that all data from all subjects is already known and loaded.

.. image:: images/withinsess.png
:alt: The schematic diagram of the WithinSession split
:align: center


Parameters
----------
n_folds : int
Number of folds. Must be at least 2. If
random_state: int, RandomState instance or None, default=None
Controls the randomness of splits. Only used when `shuffle` is True.
Pass an int for reproducible output across multiple function calls.
shuffle : bool, default=True
Whether to shuffle each class's samples before splitting into batches.
Note that the samples within each split will not be shuffled.
custom_cv: bool, default=False
Indicates if you are using PseudoOnlineSplit as cv strategy
calib_size: int, default=None
Size of calibration set if custom_cv==True
cv: cros-validation object, default=StratifiedKFold
Inner cross-validation strategy for splitting the sessions. Be careful, if
PseudoOnlineSplit is used, it will return calibration and test indexes.


Examples
-----------

>>> import pandas as pd
>>> import numpy as np
>>> from moabb.evaluations.splitters import WithinSessionSplitter
>>> X = np.array([[1, 2], [3, 4], [5, 6], [1,4], [7, 4], [5, 8], [0,3], [2,4]])
>>> y = np.array([1, 2, 1, 2, 1, 2, 1, 2])
>>> subjects = np.array([1, 1, 1, 1, 1, 1, 1, 1])
>>> sessions = np.array(['T', 'T', 'T', 'T', 'E', 'E', 'E', 'E'])
>>> metadata = pd.DataFrame(data={'subject': subjects, 'session': sessions})
>>> csess = WithinSessionSplitter(n_folds=2)
>>> csess.get_n_splits(metadata)
4
>>> for i, (train_index, test_index) in enumerate(csess.split(y, metadata)):
... print(f"Fold {i}:")
... print(f" Train: index={train_index}, group={subjects[train_index]}, session={sessions[train_index]}")
... print(f" Test: index={test_index}, group={subjects[test_index]}, sessions={sessions[test_index]}")
Fold 0:
Train: index=[4 7], group=[1 1], session=['E' 'E']
Test: index=[5 6], group=[1 1], sessions=['E' 'E']
Fold 1:
Train: index=[5 6], group=[1 1], session=['E' 'E']
Test: index=[4 7], group=[1 1], sessions=['E' 'E']
Fold 2:
Train: index=[2 3], group=[1 1], session=['T' 'T']
Test: index=[0 1], group=[1 1], sessions=['T' 'T']
Fold 3:
Train: index=[0 1], group=[1 1], session=['T' 'T']
Test: index=[2 3], group=[1 1], sessions=['T' 'T']
"""

def __init__(
self,
cv=StratifiedKFold,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the scikit-learn framework, cv is a cross-validator object, not a class. I think it would be best to stick to it. This would avoid to instantiate it during the split call. You can have split=None by default and instantiate cv=StratifiedKFold() class in the __init__ method.
Also, you can check the cf argument with sklearn’s check_cv

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the concern, but I'm a bit unsure on how to implement it in the case shuffle=True, since I'm defining a different seed for each (subject, session). The suggestion is to instantiate cv in the init method in case split is not needed, and keep how it is being done otherwise?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed it's not easy. But you could for example make a wrapper around StratifiedKFolds which would instantiate a different cv with a different seed for each subject/session.

Also I just noticed that at the moment, the seeds for each subject/session are chosen at random. We will not be able to have reproducible results this way. Instead, you could add a parameter global_seed to your wrapper and use, for each cv, random_state = global_seed + 10000*subject_number + session_number (it's safe to say we will never have 10000 sessions) if global_seed is an integer and none otherwise

custom_cv=False,
n_folds: int = 5,
random_state: int = 42,
shuffle: bool = True,
calib_size: int = None,
):
self.n_folds = n_folds
self.shuffle = shuffle
self.random_state = check_random_state(random_state) if shuffle else None
self.cv = cv
self.calib_size = calib_size
self.custom_cv = custom_cv

def get_n_splits(self, metadata):
num_sessions_subjects = metadata.groupby(["subject", "session"]).ngroups
return (
self.cv.get_n_splits(metadata)
if self.custom_cv
else self.n_folds * num_sessions_subjects
)

def split(self, y, metadata, **kwargs):
all_index = metadata.index.values
subjects = metadata["subject"].unique()

# Shuffle subjects if required
if self.shuffle:
self.random_state.shuffle(subjects)

for i, subject in enumerate(subjects):
subject_mask = metadata.subject == subject
subject_indices = all_index[subject_mask]
subject_metadata = metadata[subject_mask]
sessions = subject_metadata.session.unique()

# Shuffle sessions if required
if self.shuffle:
self.random_state.shuffle(sessions)

for j, session in enumerate(sessions):
session_mask = subject_metadata.session == session
indices = subject_indices[session_mask]
group_y = y[indices]

# Handle custom splitter
if isinstance(self.cv(), PseudoOnlineSplit):
splitter = self.cv(calib_size=self.calib_size)
for calib_ix, test_ix in splitter.split(
indices, group_y, subject_metadata[session_mask]
):
yield calib_ix, test_ix
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
else:
# Handle standard CV like StratifiedKFold
splitter = self.cv(
n_splits=self.n_folds,
shuffle=self.shuffle,
random_state=self.random_state.randint(0, 2**10),
)
for train_ix, test_ix in splitter.split(indices, group_y):
yield indices[train_ix], indices[test_ix]
13 changes: 13 additions & 0 deletions moabb/evaluations/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import re
from pathlib import Path
from pickle import HIGHEST_PROTOCOL, dump
from typing import Sequence

import numpy as np
from numpy import argmax
from sklearn.pipeline import Pipeline

Expand Down Expand Up @@ -222,6 +224,17 @@ def create_save_path(
print("No hdf5_path provided, models will not be saved.")


def sort_group(groups):
runs_sort = []
pattern = r"([0-9]+)(|[a-zA-Z]+[a-zA-Z0-9]*)"
for i, group in enumerate(groups):
index, description = re.fullmatch(pattern, group).groups()
index = int(index)
runs_sort.append(index)
sorted_ix = np.argsort(runs_sort)
return groups[sorted_ix]


def _convert_sklearn_params_to_optuna(param_grid: dict) -> dict:
"""
Function to convert the parameter in Optuna format. This function will
Expand Down
58 changes: 58 additions & 0 deletions moabb/tests/splits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
import pytest
from sklearn.model_selection import StratifiedKFold
from sklearn.utils import check_random_state

from moabb.datasets.fake import FakeDataset
from moabb.evaluations.splitters import WithinSessionSplitter
from moabb.paradigms.motor_imagery import FakeImageryParadigm


dataset = FakeDataset(["left_hand", "right_hand"], n_subjects=3, seed=12)
paradigm = FakeImageryParadigm()


# Split done for the Within Session evaluation
def eval_split_within_session(shuffle, random_state):
random_state = check_random_state(random_state) if shuffle else None
for subject in dataset.subject_list:
X, y, metadata = paradigm.get_data(dataset=dataset, subjects=[subject])
sessions = metadata.session
for session in np.unique(sessions):
ix = sessions == session
cv = StratifiedKFold(n_splits=5, shuffle=shuffle, random_state=random_state)
X_, metadata_, y_ = X[ix], y[ix], metadata[ix]
for train, test in cv.split(y_, metadata_):
yield X_[train], X_[test]


@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("random_state", [0, 42])
def test_within_session(shuffle, random_state):
X, y, metadata = paradigm.get_data(dataset=dataset)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is important to check if the split is the same when we load the data of one/a few subject(s) only, paradigm.get_data(dataset=dataset, subjects=[m, n...])


split = WithinSessionSplitter(n_folds=5, shuffle=shuffle, random_state=random_state)

for (X_train_t, X_test_t), (train, test) in zip(
eval_split_within_session(shuffle=shuffle, random_state=random_state),
split.split(y, metadata),
):
X_train, X_test = X[train], X[test]

# Check if the output is the same as the input
assert np.array_equal(X_train, X_train_t)
assert np.array_equal(X_test, X_test_t)


def test_is_shuffling():
X, y, metadata = paradigm.get_data(dataset=dataset)

split = WithinSessionSplitter(n_folds=5, shuffle=False)
split_shuffle = WithinSessionSplitter(n_folds=5, shuffle=True, random_state=3)

for (train, test), (train_shuffle, test_shuffle) in zip(
split.split(y, metadata), split_shuffle.split(y, metadata)
):
# Check if the output is the same as the input
assert np.array_equal(train, train_shuffle) == False
assert np.array_equal(test, test_shuffle) == False
Loading