SplitTrainTest#

class dpeeg.transforms.SplitTrainTest(test_size: float = 0.25, cross: bool = False, train_sessions: list[str] | None = None, test_sessions: list[str] | None = None, keys: list[str] | None = None, seed: int = 42)[source]#

Split the data into training and testing sets.

Split different types of input data. For inputs that are already SplitEEGData, no processing is done. For inputs that are MultiSessEEGData, data can be split by session or merged and then split. For EEGData inputs, data is split proportionally.

Parameters:
  • test_size (float) – The proportion of the test set. Default use stratified fashion.

  • cross (bool) – True indicates that data from multiple sessions will be split into training and test sets, working in conjunction with train_sessions and test_sessions. False indicates that data from multiple sessions will be merged and then split into training and test sets, working with train_sessions. These parameter are only effective when the input data type is MultiSessEEGData; they are ignored for other types.

  • train_sessions (list of str, optional) – Session data to be used as the training set. If cross=False, train_sessions represents the sessions to be mixed and split (If None, all session data will be used.). If cross=True, train_sessions must be specified and represents the sessions to be used as the training set.

  • test_sessions (list of str, optional) – Session data to be used as the test set. If cross=False, this parameter is ignored. If cross=True, test_sessions represents the sessions to be used as the test set (If None, the complement of train_sessions will be used as the test set.).

  • keys (list of str, optional) – The key of the eegdata value to be split. If None, all data will be split, and it is necessary to ensure the uniformity of the data samples. Ignored when cross=True.

  • seed (int) – Random seed when splitting.

Returns:

split_eegdata – Split eegdata or dataset.

Return type:

eegdata or dataset (split)

Examples

Split the eegdata:

>>> eegdata = dpeeg.EEGData(edata=np.random.randn(16, 3, 10),
...                         label=np.random.randint(0, 3, 16))
>>> transforms.SplitTrainTest(0.2)(eegdata, verbose=False)
Train: [edata=(12, 3, 10), label=(12,)]
Test : [edata=(4, 3, 10), label=(4,)]

Split eegdata across multiple sessions:

>>> multi_sess_eegdata = dpeeg.MultiSessEEGData(
...     [eegdata.copy() for _ in range(4)])
>>> transforms.SplitTrainTest(
...     cross=True,
...     train_sessions=[f"session_{i + 1}" for i in range(2)],
...     test_sessions=[f"session_4"]
... )(multi_sess_eegdata, verbose=False)
Train: [edata=(32, 3, 10), label=(32,)]
Test : [edata=(16, 3, 10), label=(16,)]

or split the merged multiple sessions eegdata:

>>> transforms.SplitTrainTest(0.5)(multi_sess_eegdata, verbose=False)
Train: [edata=(32, 3, 10), label=(32,)]
Test : [edata=(32, 3, 10), label=(32,)]