split_subjects#
- class dpeeg.transforms.split_subjects(eegdata: BaseDataset, test_subjects: list[int] | None = None, ret_eegdata: bool = True, verbose=None)[source]#
Split the dataset by subjects.
Splitting the dataset at the subject level is different from the
SplitTrainTesttransformation. The former splits the data of all subjects in the entire dataset (similar to cross-subject), while the latter splits the data of each subject. The eegdata of different subjects are converted throughToEEGDatawhenret_eegdata=True.- Parameters:
eegdata (EEG Dataset) – Input eeg dataset.
test_subjects (list of int, optional) – The list of subjects in the test set. If
None, the subject data of the entire dataset will be merged.ret_eegdata (bool) – Transform the merged inter-subject data into
EEGDatatype. IfFalse, return labeled MultiSessEEGData. For specific usage, please refer to the example.
Attention
Since this transformation will change the structure of the entire dataset, it cannot be used with
Sequential. It is often used at the begining or end of preprocessing the dataset.Examples
Since the input is a dataset type, first define a dataset with 2 subjects. Here, merge the data of all subjects:
>>> from dpeeg.datasets import EEGDataset >>> >>> eegdata = dpeeg.EEGData(edata=np.random.randn(16, 3, 10), ... label=np.random.randint(0, 3, 16)) >>> multi_sess_eegdata = dpeeg.MultiSessEEGData([eegdata, eegdata]) >>> eegdataset = EEGDataset([eegdata, eegdata, multi_sess_eegdata]) >>> transforms.split_subjects(eegdataset, verbose=False) [edata=(64, 3, 10), label=(64,)]
set
train_subjectsparameters to split the dataset:>>> transforms.split_subjects(eegdataset, [1, 2], verbose=False) Train: [edata=(32, 3, 10), label=(32,)] Test : [edata=(32, 3, 10), label=(32,)]
set
ret_eegdatato return the labeled MultiSessEEGData:>>> transforms.split_subjects(eegdataset, [3], False, verbose=False) {'train_1': [edata=(16, 3, 10), label=(16,)], 'train_2': [edata=(16, 3, 10), label=(16,)], 'test_3_session_1': [edata=(16, 3, 10), label=(16,)], 'test_3_session_2': [edata=(16, 3, 10), label=(16,)]}