# Authors: SheepTAO <sheeptao@outlook.com>
# License: MIT
# Copyright the dpeeg contributors.
from abc import abstractmethod
from mne.utils import verbose, logger
from ..datasets.base import _BaseData, EEGData, SplitEEGData
from .base import Transforms
from ..utils import DPEEG_SEED, get_init_args
from ..tools.docs import fill_doc
from .functions import (
segmentation_and_reconstruction_time,
slide_win,
crop,
gaussian_noise_time,
)
class Augmentation(Transforms):
"""Data augmentation base class.
Augment the data, with default augmentation applied to the ``edata`` and
``label``. Please verify the validity of the data.
"""
def __init__(
self, repr: str | None = None, only_train: bool = True, strict: bool = True
) -> None:
super().__init__(repr)
self.only_train = only_train
self.strict = strict
def _apply(self, eegdata: _BaseData) -> _BaseData:
if not isinstance(eegdata, SplitEEGData):
if self.strict:
raise TypeError(
"The input must have been split, or `strict` is set to False."
)
else:
for egd, _ in eegdata._datas():
self._apply_aug(egd, "None")
else:
for egd, mode in eegdata._datas():
if (mode != "train") and self.only_train:
continue
self._apply_aug(egd, mode)
return eegdata
@abstractmethod
def _apply_aug(self, egd: EEGData, mode: str):
pass
[docs]
@fill_doc
class SegRecTime(Augmentation):
"""Segmentation and reorganization in the time domain.
The S&R process involves segmenting the original eeg signals based on class
along the temporal dimension, followed by randomly splicing them back [1]_.
By default, augmentation is performed on ``edata`` and ``label``. Ensure
the availability of the data.
Parameters
----------
samples : int
The number of consecutive samples to segment the data. eg, 125 for
250Hz data is segmented by 0.5s.
multiply : float
Data expansion multiple of relative metadata, 1 means doubled.
%(aug_only_train)s
%(aug_strict)s
seed : int
Seed to be used to instantiate numpy random number generator instance.
Returns
-------
data : eegdata or dataset
Transformed eegdata.
References
----------
.. [1] F. Lotte, “Signal processing approaches to minimize or suppress
calibration time in oscillatory activity-based brain-computer
interfaces,” Proceedings of the IEEE, vol. 103, no. 6,
pp. 871-890, 2015.
Notes
-----
Data augmentation is only applied to the `edata` and `label` within the
eegdata, with other values remaining unchanged. If there are derived values
based on the `edata`, attention should be paid to the order of
transformations.
Examples
--------
>>> eegdata = dpeeg.EEGData(edata=np.random.randn(16, 3, 10),
... label=np.random.randint(0, 3, 16))
>>> split_eegdata = dpeeg.SplitEEGData(eegdata.copy(), eegdata.copy())
>>> transforms.SegRecTime(2, 3)(split_eegdata, verbose=False)
Train: [edata=(64, 3, 10), label=(64,)]
Test : [edata=(16, 3, 10), label=(16,)]
"""
def __init__(
self,
samples: int,
multiply: float = 1.0,
only_train: bool = True,
strict: bool = True,
seed: int = DPEEG_SEED,
) -> None:
super().__init__(
get_init_args(self, locals(), format="rp"),
only_train=only_train,
strict=strict,
)
self.samples = samples
self.multiply = multiply
self.seed = seed
def _apply_aug(self, egd: EEGData, mode: str):
egd["edata"], egd["label"] = segmentation_and_reconstruction_time(
data=egd["edata"],
label=egd["label"],
samples=self.samples,
multiply=self.multiply,
seed=self.seed,
)
[docs]
class SlideWinAug(Augmentation):
"""Sliding window data augmentation.
Data augmentation based on sliding windows will apply sliding windows to
the training set and crop the corresponding time windows in the test set.
By default, augmentation is performed on ``edata`` and ``label``. Ensure
the availability of the data.
Parameters
----------
win : int
The size of the sliding window.
overlap : int
The amount of overlap between adjacent sliding windows.
tmin : int
Start time of selection in sampling points.
tmax : int, optional
End time of selection in sampling points. The default is to use the
window length from the start time.
Returns
-------
data : eegdata or dataset
Transformed eegdata.
Notes
-----
Data augmentation is only applied to the `edata` and `label` within the
eegdata, with other values remaining unchanged. If there are derived values
based on the `edata`, attention should be paid to the order of
transformations.
Examples
--------
>>> eegdata = dpeeg.EEGData(edata=np.random.randn(16, 3, 10),
... label=np.random.randint(0, 3, 16))
>>> split_eegdata = dpeeg.SplitEEGData(eegdata.copy(), eegdata.copy())
>>> transforms.SlideWinAug(2)(split_eegdata, verbose=False)
Train: [edata=(80, 3, 2), label=(80,)]
Test : [edata=(16, 3, 2), label=(16,)]
"""
def __init__(
self,
win: int,
overlap: int = 0,
tmin: int = 0,
tmax: int | None = None,
) -> None:
super().__init__(
get_init_args(self, locals(), format="rp"),
only_train=False,
strict=True,
)
self.win = win
self.overlap = overlap
self.tmin = tmin
self.tmax = tmin + win if tmax is None else tmax
def _apply_aug(self, egd: EEGData, mode: str):
if mode == "train":
egd["edata"], egd["label"] = slide_win(
data=egd["edata"],
win=self.win,
overlap=self.overlap,
label=egd["label"],
)
else:
egd["edata"] = crop(
data=egd["edata"],
tmin=self.tmin,
tmax=self.tmax,
include_tmax=False,
)
[docs]
@fill_doc
class GaussTime(Augmentation):
"""Randomly add white noise to all channels.
Gaussian white noise with a mean of 0 is directly added to the raw EEG
signal as the generated new data [1]_. By default, augmentation is
performed on ``edata`` and ``label``. Ensure the availability of the data.
Parameters
----------
std : float
Standard deviation to use for the additive noise.
%(aug_only_train)s
%(aug_strict)s
seed : int
Seed to be used to instantiate numpy random number generator instance.
References
----------
.. [1] Wang, F., Zhong, S. H., Peng, J., Jiang, J., & Liu, Y. (2018). Data
augmentation for eeg-based emotion recognition with deep convolutional
neural networks. In International Conference on Multimedia Modeling
(pp. 82-93).
"""
def __init__(
self,
std: float,
only_train: bool = True,
strict: bool = True,
seed: int = DPEEG_SEED,
) -> None:
super().__init__(
get_init_args(self, locals(), format="rp"),
only_train=only_train,
strict=strict,
)
self.std = std
self.seed = seed
def _apply_aug(self, egd: EEGData, mode: str):
egd["edata"], egd["label"] = gaussian_noise_time(
data=egd["edata"],
label=egd["label"],
mean=0,
std=self.std,
seed=self.seed,
)