split_subjects#

class dpeeg.transforms.split_subjects(eegdata: BaseDataset, test_subjects: list[int] | None = None, ret_eegdata: bool = True, verbose=None)[source]#

Split the dataset by subjects.

Splitting the dataset at the subject level is different from the SplitTrainTest transformation. The former splits the data of all subjects in the entire dataset (similar to cross-subject), while the latter splits the data of each subject. The eegdata of different subjects are converted through ToEEGData when ret_eegdata=True.

Parameters:
  • eegdata (EEG Dataset) – Input eeg dataset.

  • test_subjects (list of int, optional) – The list of subjects in the test set. If None, the subject data of the entire dataset will be merged.

  • ret_eegdata (bool) – Transform the merged inter-subject data into EEGData type. If False, return labeled MultiSessEEGData. For specific usage, please refer to the example.

Attention

Since this transformation will change the structure of the entire dataset, it cannot be used with Sequential. It is often used at the begining or end of preprocessing the dataset.

Examples

Since the input is a dataset type, first define a dataset with 2 subjects. Here, merge the data of all subjects:

>>> from dpeeg.datasets import EEGDataset
>>>
>>> eegdata = dpeeg.EEGData(edata=np.random.randn(16, 3, 10),
...                         label=np.random.randint(0, 3, 16))
>>> multi_sess_eegdata = dpeeg.MultiSessEEGData([eegdata, eegdata])
>>> eegdataset = EEGDataset([eegdata, eegdata, multi_sess_eegdata])
>>> transforms.split_subjects(eegdataset, verbose=False)
[edata=(64, 3, 10), label=(64,)]

set train_subjects parameters to split the dataset:

>>> transforms.split_subjects(eegdataset, [1, 2], verbose=False)
Train: [edata=(32, 3, 10), label=(32,)]
Test : [edata=(32, 3, 10), label=(32,)]

set ret_eegdata to return the labeled MultiSessEEGData:

>>> transforms.split_subjects(eegdataset, [3], False, verbose=False)
{'train_1': [edata=(16, 3, 10), label=(16,)],
 'train_2': [edata=(16, 3, 10), label=(16,)],
 'test_3_session_1': [edata=(16, 3, 10), label=(16,)],
 'test_3_session_2': [edata=(16, 3, 10), label=(16,)]}