SplitTrainTest#
- class dpeeg.transforms.SplitTrainTest(test_size: float = 0.25, cross: bool = False, train_sessions: list[str] | None = None, test_sessions: list[str] | None = None, keys: list[str] | None = None, seed: int = 42)[source]#
Split the data into training and testing sets.
Split different types of input data. For inputs that are already
SplitEEGData, no processing is done. For inputs that areMultiSessEEGData, data can be split by session or merged and then split. ForEEGDatainputs, data is split proportionally.- Parameters:
test_size (float) – The proportion of the test set. Default use stratified fashion.
cross (bool) –
Trueindicates that data from multiple sessions will be split into training and test sets, working in conjunction withtrain_sessionsandtest_sessions.Falseindicates that data from multiple sessions will be merged and then split into training and test sets, working withtrain_sessions. These parameter are only effective when the input data type isMultiSessEEGData; they are ignored for other types.train_sessions (list of str, optional) – Session data to be used as the training set. If
cross=False,train_sessionsrepresents the sessions to be mixed and split (IfNone, all session data will be used.). Ifcross=True,train_sessionsmust be specified and represents the sessions to be used as the training set.test_sessions (list of str, optional) – Session data to be used as the test set. If
cross=False, this parameter is ignored. Ifcross=True,test_sessionsrepresents the sessions to be used as the test set (IfNone, the complement of train_sessions will be used as the test set.).keys (list of str, optional) – The key of the eegdata value to be split. If
None, all data will be split, and it is necessary to ensure the uniformity of the data samples. Ignored whencross=True.seed (int) – Random seed when splitting.
- Returns:
split_eegdata – Split eegdata or dataset.
- Return type:
Examples
Split the eegdata:
>>> eegdata = dpeeg.EEGData(edata=np.random.randn(16, 3, 10), ... label=np.random.randint(0, 3, 16)) >>> transforms.SplitTrainTest(0.2)(eegdata, verbose=False) Train: [edata=(12, 3, 10), label=(12,)] Test : [edata=(4, 3, 10), label=(4,)]
Split eegdata across multiple sessions:
>>> multi_sess_eegdata = dpeeg.MultiSessEEGData( ... [eegdata.copy() for _ in range(4)]) >>> transforms.SplitTrainTest( ... cross=True, ... train_sessions=[f"session_{i + 1}" for i in range(2)], ... test_sessions=[f"session_4"] ... )(multi_sess_eegdata, verbose=False) Train: [edata=(32, 3, 10), label=(32,)] Test : [edata=(16, 3, 10), label=(16,)]
or split the merged multiple sessions eegdata:
>>> transforms.SplitTrainTest(0.5)(multi_sess_eegdata, verbose=False) Train: [edata=(32, 3, 10), label=(32,)] Test : [edata=(32, 3, 10), label=(32,)]