# Authors: SheepTAO <sheeptao@outlook.com>
# License: MIT
# Copyright the dpeeg contributors.
from abc import abstractmethod
from mne.utils import verbose, logger
from ..datasets.base import _BaseData, EEGData, SplitEEGData
from .base import Transforms
from ..utils import DPEEG_SEED, get_init_args
from .functions import segmentation_and_reconstruction_time, slide_win, crop
class Augmentation(Transforms):
"""Data augmentation base class.
Augment the data, with default augmentation applied to the `edata` and
`label`. Please verify the validity of the data.
Parameters
----------
only_train : bool
If True, data augmentation is performed only on the training set.
strict : bool
If False, allow the input data to be unsplit. In this case, data
augmentation will be applied to all data.
"""
def __init__(
self, repr: str | None = None, only_train: bool = True, strict: bool = True
) -> None:
super().__init__(repr)
self.only_train = only_train
self.strict = strict
def _apply(self, eegdata: _BaseData) -> _BaseData:
if not isinstance(eegdata, SplitEEGData):
if self.strict:
raise TypeError(
"The input must have been split, or `strict` is set to False."
)
else:
for egd, _ in eegdata._datas():
self._apply_aug(egd, "None")
else:
for egd, mode in eegdata._datas():
if (mode != "train") and self.only_train:
continue
self._apply_aug(egd, mode)
return eegdata
@abstractmethod
def _apply_aug(self, egd: EEGData, mode: str):
pass
[docs]
class SegRecTime(Augmentation):
"""Segmentation and reorganization in the time domain.
The S&R process involves segmenting the original eeg signals based on class
along the temporal dimension, followed by randomly splicing them back [1]_.
By default, augmentation is performed on `edata` and `label`. Ensure the
availability of the data.
Parameters
----------
samples : int
The number of consecutive samples to segment the data. eg, 125 for
250Hz data is segmented by 0.5s.
multiply : float
Data expansion multiple of relative metadata, 1 means doubled.
only_train : bool
If True, data augmentation is performed only on the training set.
strict : bool
If False, allow the input data to be unsplit. In this case, data
augmentation will be applied to all data.
shuffle : bool
Whether or not to shuffle the data after picking.
seed : int
Controls the shuffling applied to the data after picking.
Returns
-------
data : eegdata or dataset
Transformed eegdata.
References
----------
.. [1] F. Lotte, “Signal processing approaches to minimize or suppress
calibration time in oscillatory activity-based brain–computer
interfaces,” Proceedings of the IEEE, vol. 103, no. 6,
pp. 871–890, 2015.
Notes
-----
Data augmentation is only applied to the `edata` and `label` within the
eegdata, with other values remaining unchanged. If there are derived values
based on the `edata`, attention should be paid to the order of
transformations.
Examples
--------
>>> eegdata = dpeeg.EEGData(edata=np.random.randn(16, 3, 10),
... label=np.random.randint(0, 3, 16))
>>> split_eegdata = dpeeg.SplitEEGData(eegdata.copy(), eegdata.copy())
>>> transforms.SegRecTime(2, 3)(split_eegdata, verbose=False)
Train: [edata=(64, 3, 10), label=(64,)]
Test : [edata=(16, 3, 10), label=(16,)]
"""
def __init__(
self,
samples: int,
multiply: float = 1.0,
only_train: bool = True,
strict: bool = True,
shuffle: bool = True,
seed: int = DPEEG_SEED,
) -> None:
super().__init__(get_init_args(self, locals(), format="rp"), only_train, strict)
self.samples = samples
self.multiply = multiply
self.shuffle = shuffle
self.seed = seed
def _apply_aug(self, egd: EEGData, mode: str):
egd["edata"], egd["label"] = segmentation_and_reconstruction_time(
data=egd["edata"],
label=egd["label"],
samples=self.samples,
multiply=self.multiply,
shuffle=self.shuffle,
seed=self.seed,
)
[docs]
class SlideWinAug(Augmentation):
"""Sliding window data augmentation.
Data augmentation based on sliding windows will apply sliding windows to
the training set and crop the corresponding time windows in the test set.
By default, augmentation is performed on `edata` and `label`. Ensure the
availability of the data.
Parameters
----------
win : int
The size of the sliding window.
overlap : int
The amount of overlap between adjacent sliding windows.
tmin : int
Start time of selection in sampling points.
tmax : int, optional
End time of selection in sampling points. The default is to use the
window length from the start time.
Returns
-------
data : eegdata or dataset
Transformed eegdata.
Notes
-----
Data augmentation is only applied to the `edata` and `label` within the
eegdata, with other values remaining unchanged. If there are derived values
based on the `edata`, attention should be paid to the order of
transformations.
Examples
--------
>>> eegdata = dpeeg.EEGData(edata=np.random.randn(16, 3, 10),
... label=np.random.randint(0, 3, 16))
>>> split_eegdata = dpeeg.SplitEEGData(eegdata.copy(), eegdata.copy())
>>> transforms.SlideWinAug(2)(split_eegdata, verbose=False)
Train: [edata=(80, 3, 2), label=(80,)]
Test : [edata=(16, 3, 2), label=(16,)]
"""
def __init__(
self,
win: int,
overlap: int = 0,
tmin: int = 0,
tmax: int | None = None,
) -> None:
super().__init__(get_init_args(self, locals(), format="rp"), False, True)
self.win = win
self.overlap = overlap
self.tmin = tmin
self.tmax = tmin + win if tmax is None else tmax
def _apply_aug(self, egd: EEGData, mode: str):
if mode == "train":
egd["edata"], egd["label"] = slide_win(
data=egd["edata"],
win=self.win,
overlap=self.overlap,
label=egd["label"],
)
else:
egd["edata"] = crop(
data=egd["edata"],
tmin=self.tmin,
tmax=self.tmax,
include_tmax=False,
)