-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Within Session splitter #664
base: develop
Are you sure you want to change the base?
Changes from 57 commits
bacedc5
419b2ca
d6e795d
140670c
d724674
a278026
300a6b9
7cb79f6
55db70f
2851a15
c73dd1a
2b0e735
cf4b709
177bf65
a6b5772
26b13d5
e6661c4
430e3a8
698e539
0fff053
98d12ac
558d27b
34ea645
b435bf8
d8f26a3
eaf0fb9
e5159f2
516a5e8
b29ecd2
37cff03
88ee910
ea9cc59
819c4ff
65c305e
f1ad587
485e7a5
3f3742f
c181c59
8f034c8
fbef726
837c061
34822e9
39e92e5
612c6a6
590edb1
b151d61
602ccd5
74cf246
c85928d
f3020ee
c628fd7
83e425f
a785ef7
686572e
8f8ada9
b87cf25
42e6b7b
f66878a
71cebb1
a948175
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not using only the other version? |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from sklearn.model_selection import BaseCrossValidator | ||
|
||
from moabb.evaluations.utils import sort_group | ||
|
||
|
||
class PseudoOnlineSplit(BaseCrossValidator): | ||
"""Pseudo-online split for evaluation test data. | ||
|
||
It takes into account the time sequence for obtaining the test data, and uses first run, | ||
or first #calib_size trials as calibration data, and the rest as evaluation data. | ||
Calibration data is important in the context where data alignment or filtering is used on | ||
training data. | ||
|
||
OBS: Be careful! Since this inference split is based on time disposition of obtained data, | ||
if your data is not organized by time, but by other parameter, such as class, you may want to | ||
be extra careful when using this split. | ||
|
||
Parameters | ||
---------- | ||
calib_size: int | ||
Size of calibration set, used if there is just one run. | ||
|
||
Examples | ||
-------- | ||
>>> import numpy as np | ||
>>> import pandas as pd | ||
>>> from moabb.evaluations.splitters import WithinSessionSplitter | ||
>>> from moabb.evaluations.metasplitters import PseudoOnlineSplit | ||
>>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [8, 9], [5, 4], [2, 5], [1, 7]]) | ||
>>> y = np.array([1, 2, 1, 2, 1, 2, 1, 2]) | ||
>>> subjects = np.array([1, 1, 1, 1, 1, 1, 1, 1]) | ||
>>> sessions = np.array([0, 0, 0, 0, 1, 1, 1, 1]) | ||
>>> runs = np.array(['0', '0', '1', '1', '0', '0', '1', '1']) | ||
>>> metadata = pd.DataFrame(data={'subject': subjects, 'session': sessions, 'run':runs}) | ||
>>> posplit = PseudoOnlineSplit | ||
>>> csubj = WithinSessionSplitter(cv=posplit, calib_size=1, custom_cv=True) | ||
>>> posplit.get_n_splits(metadata) | ||
2 | ||
>>> for i, (train_index, test_index) in enumerate(csubj.split(y, metadata)): | ||
>>> print(f"Fold {i}:") | ||
>>> print(f" Calibration: index={train_index}, group={subjects[train_index]}, sessions={sessions[train_index]}, runs={runs[train_index]}") | ||
>>> print(f" Test: index={test_index}, group={subjects[test_index]}, sessions={sessions[test_index]}, runs={runs[test_index]}") | ||
|
||
Fold 0: | ||
Calibration: index=[6, 7], group=[1 1], sessions=[1 1], runs=['1' '1'] | ||
Test: index=[4, 5], group=[1 1], sessions=[1 1], runs=['0' '0'] | ||
Fold 1: | ||
Calibration: index=[2, 3], group=[1 1], sessions=[0 0], runs=['1' '1'] | ||
Test: index=[0, 1], group=[1 1], sessions=[0 0], runs=['0' '0'] | ||
""" | ||
|
||
def __init__(self, calib_size: int = None): | ||
self.calib_size = calib_size | ||
|
||
def get_n_splits(self, metadata): | ||
return len(metadata.groupby(["subject", "session"])) | ||
|
||
def split(self, indices, y, metadata=None): | ||
|
||
if metadata is not None: | ||
for _, group in metadata.groupby(["subject", "session"]): | ||
runs = group.run.unique() | ||
if len(runs) > 1: | ||
brunaafl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# To guarantee that the runs are on the right order | ||
runs = sort_group(runs) | ||
for run in runs: | ||
test_ix = group[group["run"] != run].index | ||
calib_ix = group[group["run"] == run].index | ||
yield list(test_ix), list(calib_ix) | ||
break # Take the fist run as calibration | ||
else: | ||
if self.calib_size is None: | ||
raise ValueError( | ||
"Data contains just one run. Need to provide calibration size." | ||
) | ||
# Take first #calib_size samples as calibration | ||
calib_size = self.calib_size | ||
calib_ix = group[:calib_size].index | ||
test_ix = group[calib_size:].index | ||
brunaafl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
yield list(calib_ix), list(test_ix) | ||
|
||
else: | ||
brunaafl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if self.calib_size is None: | ||
raise ValueError("Need to provide calibration size.") | ||
calib_size = self.calib_size | ||
yield list(indices[:calib_size]), list(indices[calib_size:]) |
tomMoral marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
|
||
from sklearn.model_selection import BaseCrossValidator, StratifiedKFold | ||
from sklearn.utils import check_random_state | ||
|
||
from moabb.evaluations.metasplitters import PseudoOnlineSplit | ||
|
||
|
||
class WithinSessionSplitter(BaseCrossValidator): | ||
"""Data splitter for within session evaluation. | ||
|
||
Within-session evaluation uses k-fold cross_validation to determine train | ||
and test sets for each subject in each session. This splitter | ||
assumes that all data from all subjects is already known and loaded. | ||
|
||
.. image:: images/withinsess.png | ||
:alt: The schematic diagram of the WithinSession split | ||
:align: center | ||
|
||
|
||
Parameters | ||
---------- | ||
n_folds : int | ||
Number of folds. Must be at least 2. If | ||
random_state: int, RandomState instance or None, default=None | ||
Controls the randomness of splits. Only used when `shuffle` is True. | ||
Pass an int for reproducible output across multiple function calls. | ||
shuffle : bool, default=True | ||
Whether to shuffle each class's samples before splitting into batches. | ||
Note that the samples within each split will not be shuffled. | ||
custom_cv: bool, default=False | ||
Indicates if you are using PseudoOnlineSplit as cv strategy | ||
calib_size: int, default=None | ||
Size of calibration set if custom_cv==True | ||
cv: cros-validation object, default=StratifiedKFold | ||
Inner cross-validation strategy for splitting the sessions. Be careful, if | ||
PseudoOnlineSplit is used, it will return calibration and test indexes. | ||
|
||
|
||
Examples | ||
----------- | ||
|
||
>>> import pandas as pd | ||
>>> import numpy as np | ||
>>> from moabb.evaluations.splitters import WithinSessionSplitter | ||
>>> X = np.array([[1, 2], [3, 4], [5, 6], [1,4], [7, 4], [5, 8], [0,3], [2,4]]) | ||
>>> y = np.array([1, 2, 1, 2, 1, 2, 1, 2]) | ||
>>> subjects = np.array([1, 1, 1, 1, 1, 1, 1, 1]) | ||
>>> sessions = np.array(['T', 'T', 'T', 'T', 'E', 'E', 'E', 'E']) | ||
>>> metadata = pd.DataFrame(data={'subject': subjects, 'session': sessions}) | ||
>>> csess = WithinSessionSplitter(n_folds=2) | ||
>>> csess.get_n_splits(metadata) | ||
4 | ||
>>> for i, (train_index, test_index) in enumerate(csess.split(y, metadata)): | ||
... print(f"Fold {i}:") | ||
... print(f" Train: index={train_index}, group={subjects[train_index]}, session={sessions[train_index]}") | ||
... print(f" Test: index={test_index}, group={subjects[test_index]}, sessions={sessions[test_index]}") | ||
Fold 0: | ||
Train: index=[4 7], group=[1 1], session=['E' 'E'] | ||
Test: index=[5 6], group=[1 1], sessions=['E' 'E'] | ||
Fold 1: | ||
Train: index=[5 6], group=[1 1], session=['E' 'E'] | ||
Test: index=[4 7], group=[1 1], sessions=['E' 'E'] | ||
Fold 2: | ||
Train: index=[2 3], group=[1 1], session=['T' 'T'] | ||
Test: index=[0 1], group=[1 1], sessions=['T' 'T'] | ||
Fold 3: | ||
Train: index=[0 1], group=[1 1], session=['T' 'T'] | ||
Test: index=[2 3], group=[1 1], sessions=['T' 'T'] | ||
""" | ||
|
||
def __init__( | ||
self, | ||
cv=StratifiedKFold, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the scikit-learn framework, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand the concern, but I'm a bit unsure on how to implement it in the case shuffle=True, since I'm defining a different seed for each (subject, session). The suggestion is to instantiate cv in the init method in case split is not needed, and keep how it is being done otherwise? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed it's not easy. But you could for example make a wrapper around Also I just noticed that at the moment, the seeds for each subject/session are chosen at random. We will not be able to have reproducible results this way. Instead, you could add a parameter |
||
custom_cv=False, | ||
n_folds: int = 5, | ||
random_state: int = 42, | ||
shuffle: bool = True, | ||
calib_size: int = None, | ||
): | ||
self.n_folds = n_folds | ||
self.shuffle = shuffle | ||
self.random_state = check_random_state(random_state) if shuffle else None | ||
self.cv = cv | ||
self.calib_size = calib_size | ||
self.custom_cv = custom_cv | ||
|
||
def get_n_splits(self, metadata): | ||
num_sessions_subjects = metadata.groupby(["subject", "session"]).ngroups | ||
return ( | ||
self.cv.get_n_splits(metadata) | ||
if self.custom_cv | ||
else self.n_folds * num_sessions_subjects | ||
) | ||
|
||
def split(self, y, metadata, **kwargs): | ||
all_index = metadata.index.values | ||
subjects = metadata["subject"].unique() | ||
|
||
# Shuffle subjects if required | ||
if self.shuffle: | ||
self.random_state.shuffle(subjects) | ||
|
||
for i, subject in enumerate(subjects): | ||
subject_mask = metadata.subject == subject | ||
subject_indices = all_index[subject_mask] | ||
subject_metadata = metadata[subject_mask] | ||
sessions = subject_metadata.session.unique() | ||
|
||
# Shuffle sessions if required | ||
if self.shuffle: | ||
self.random_state.shuffle(sessions) | ||
|
||
for j, session in enumerate(sessions): | ||
session_mask = subject_metadata.session == session | ||
indices = subject_indices[session_mask] | ||
group_y = y[indices] | ||
|
||
# Handle custom splitter | ||
if isinstance(self.cv(), PseudoOnlineSplit): | ||
splitter = self.cv(calib_size=self.calib_size) | ||
for calib_ix, test_ix in splitter.split( | ||
indices, group_y, subject_metadata[session_mask] | ||
): | ||
yield calib_ix, test_ix | ||
brunaafl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
# Handle standard CV like StratifiedKFold | ||
splitter = self.cv( | ||
n_splits=self.n_folds, | ||
shuffle=self.shuffle, | ||
random_state=self.random_state.randint(0, 2**10), | ||
) | ||
for train_ix, test_ix in splitter.split(indices, group_y): | ||
yield indices[train_ix], indices[test_ix] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import numpy as np | ||
import pytest | ||
from sklearn.model_selection import StratifiedKFold | ||
from sklearn.utils import check_random_state | ||
|
||
from moabb.datasets.fake import FakeDataset | ||
from moabb.evaluations.splitters import WithinSessionSplitter | ||
from moabb.paradigms.motor_imagery import FakeImageryParadigm | ||
|
||
|
||
dataset = FakeDataset(["left_hand", "right_hand"], n_subjects=3, seed=12) | ||
paradigm = FakeImageryParadigm() | ||
|
||
|
||
# Split done for the Within Session evaluation | ||
def eval_split_within_session(shuffle, random_state): | ||
random_state = check_random_state(random_state) if shuffle else None | ||
for subject in dataset.subject_list: | ||
X, y, metadata = paradigm.get_data(dataset=dataset, subjects=[subject]) | ||
sessions = metadata.session | ||
for session in np.unique(sessions): | ||
ix = sessions == session | ||
cv = StratifiedKFold(n_splits=5, shuffle=shuffle, random_state=random_state) | ||
X_, metadata_, y_ = X[ix], y[ix], metadata[ix] | ||
for train, test in cv.split(y_, metadata_): | ||
yield X_[train], X_[test] | ||
|
||
|
||
@pytest.mark.parametrize("shuffle", [True, False]) | ||
@pytest.mark.parametrize("random_state", [0, 42]) | ||
def test_within_session(shuffle, random_state): | ||
X, y, metadata = paradigm.get_data(dataset=dataset) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is important to check if the split is the same when we load the data of one/a few subject(s) only, |
||
|
||
split = WithinSessionSplitter(n_folds=5, shuffle=shuffle, random_state=random_state) | ||
|
||
for (X_train_t, X_test_t), (train, test) in zip( | ||
eval_split_within_session(shuffle=shuffle, random_state=random_state), | ||
split.split(y, metadata), | ||
): | ||
X_train, X_test = X[train], X[test] | ||
|
||
# Check if the output is the same as the input | ||
assert np.array_equal(X_train, X_train_t) | ||
assert np.array_equal(X_test, X_test_t) | ||
|
||
|
||
def test_is_shuffling(): | ||
X, y, metadata = paradigm.get_data(dataset=dataset) | ||
|
||
split = WithinSessionSplitter(n_folds=5, shuffle=False) | ||
split_shuffle = WithinSessionSplitter(n_folds=5, shuffle=True, random_state=3) | ||
|
||
for (train, test), (train_shuffle, test_shuffle) in zip( | ||
split.split(y, metadata), split_shuffle.split(y, metadata) | ||
): | ||
# Check if the output is the same as the input | ||
assert np.array_equal(train, train_shuffle) == False | ||
assert np.array_equal(test, test_shuffle) == False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better represent only one session @brunaafl