Skip to content

Commit

Permalink
Modified cvep_pomdp.py and other files to work with the public releas…
Browse files Browse the repository at this point in the history
…e of the cvep dataset
  • Loading branch information
Hororohoruru committed Jun 2, 2023
1 parent dd17040 commit e755925
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
6 changes: 3 additions & 3 deletions pomdp_bci/config/cvep.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"n_subs": 10,
"sub_list": ["juan", "ludo", "simon", "noemie", "emilie", "marcel", "max", "quentin", "rebai", "felix"],
"n_subs": 11,
"excluded_subs": [2],
"ch_slice": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", "23", "24", "25", "26", "27", "28", "29", "30", "31", "32"],
"sfreq": 500,
"max_augmentation": 3
}
}
11 changes: 6 additions & 5 deletions pomdp_bci/cvep_pomdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
get_code_prediction, make_preds_accumul_aggresive


def fit_clf(win_data, win_labels):
def fit_clf(win_data, win_labels, sfreq, algo='EEGnet_patchembeddingdilation'):
"""
Return a fit classifier using the selected data and architecture
Expand Down Expand Up @@ -58,7 +58,7 @@ def fit_clf(win_data, win_labels):
win_labels = np.squeeze(win_labels[index])

# Initialize NN
win_samples = int(code_win_len * params['sfreq'])
win_samples = int(code_win_len * sfreq)
n_channels = win_data.shape[1] # Number of channels in the data (for channels last)

if algo == 'EEGnet_patchembeddingdilation':
Expand Down Expand Up @@ -125,12 +125,13 @@ def fit_clf(win_data, win_labels):

for dataset, algo in itertools.product(datasets, algos):
# Load dataset-specific parameters
with open(f'{config/dataset}.json', 'r') as dataset_params:
with open(f'config/{dataset}.json', 'r') as dataset_params:
params = json.loads(dataset_params.read())

downsample = int(params['sfreq'] / 250) # After downsampling, sfreq should be 250Hz

sub_list = params['sub_list']
excluded_subs = params['excluded_subs']
sub_list = [sub_n for sub_n in range(params['n_subs']) if sub_n not in excluded_subs]
score_dict = {sub: {} for sub in sub_list}
metadata_dict = {}

Expand Down Expand Up @@ -191,7 +192,7 @@ def fit_clf(win_data, win_labels):
data_cv /= cal_std + 1e-8

# Fit
pomdp_clf = fit_clf(data_cal, labels_cal, algorithm=algo)
pomdp_clf = fit_clf(data_cal, labels_cal, sfreq=params['sfreq'], algo=algo)
print('Data was fit')

# Predict the codes of validation data (split in 10 to avoid OOM from the GPU)
Expand Down
4 changes: 1 addition & 3 deletions pomdp_bci/utils/utils_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def load_data(subject, dataset, eeg_path=None, ch_keep=[]):
event_id = nakanishi.event_id

elif dataset == 'cvep':
filename = f"{subject}_mseqwhite.set"
filename = f"P{subject}_whitemseq.set"

file_path = os.path.join(eeg_path, filename)
raw = mne.io.read_raw_eeglab(file_path, preload=True, verbose=False)
Expand Down Expand Up @@ -114,8 +114,6 @@ def load_data(subject, dataset, eeg_path=None, ch_keep=[]):

# CVEP needs the montage manually set
if dataset == 'cvep':
montage = get_liveamp_montage(eeg_path)
raw.set_montage(montage)
raw = raw.drop_channels(['21', '10'])

mne.set_eeg_reference(raw, 'average', copy=False, verbose=False)
Expand Down

0 comments on commit e755925

Please sign in to comment.