Source code for dpeeg.transforms.trans

# Authors: SheepTAO <sheeptao@outlook.com>

# License: MIT
# Copyright the dpeeg contributors.

from collections.abc import Callable

import numpy as np
from numpy import ndarray
from mne.utils import verbose, logger

from .base import Transforms, TransformsEGD
from ..datasets.base import EEGData
from ..utils import DPEEG_SEED, get_init_args
from .functions import (
    crop,
    slide_win,
    cheby2_filter,
    label_mapping,
    pick_label,
)


[docs] class Identity(Transforms): """Placeholder identity operator.""" def __init__(self) -> None: super().__init__("Identity()") def _apply(self, eegdata): return eegdata
[docs] class Crop(TransformsEGD): """Crop a time interval. Crop the eeg signal in terms of time. Default is `edata`. Parameters ---------- tmin : int Start time of selection in sampling points. tmax : int End time of selection in sampling points. None means use the full time. include_tmax : bool If `False`, exclude tmax. Returns ------- data : eegdata or dataset Transformed eegdata. Examples -------- >>> eegdata = dpeeg.EEGData(edata=np.random.randn(16, 3, 10), ... label=np.random.randint(0, 3, 16)) >>> transforms.Crop(4, 9)(eegdata, verbose=False) [edata=(16, 3, 5), label=(16,)] """ def __init__( self, tmin: int = 0, tmax: int | None = None, include_tmax: bool = False, ) -> None: super().__init__(get_init_args(self, locals(), format="rp")) self.tmin = tmin self.tmax = tmax self.include_tmax = include_tmax def _apply_egd(self, egd: EEGData, key: str | None): egd["edata"] = crop( data=egd["edata"], tmin=self.tmin, tmax=self.tmax, include_tmax=self.include_tmax, )
[docs] class SlideWin(TransformsEGD): """Apply a sliding window to the dataset. This transform is only splits the time series (dim = -1) through the sliding window operation on the original dataset. If the time axis is not divisible by the sliding window, the last remaining time data will be discarded. Applied to `edata` and `label` by default. Parameters ---------- win : int The size of the sliding window. overlap : int The amount of overlap between adjacent sliding windows. Returns ------- data : eegdata or dataset Transformed eegdata. Examples -------- >>> eegdata = dpeeg.EEGData(edata=np.random.randn(16, 3, 10), ... label=np.random.randint(0, 3, 16)) >>> transforms.SlideWin(3, 1)(eegdata, verbose=False) [edata=(64, 3, 3), label=(64,)] """ def __init__(self, win: int, overlap: int = 0) -> None: super().__init__(get_init_args(self, locals(), format="rp")) self.win = win self.overlap = overlap def _apply_egd(self, egd: EEGData, key: str | None): egd["edata"], egd["label"] = slide_win( data=egd["edata"], win=self.win, overlap=self.overlap, label=egd["label"], )
[docs] class Unsqueeze(TransformsEGD): """Insert a dimension on the data. This transform is usually used to insert a empty dimension on signals. Parameters ---------- key : str The key of the eegdata value to be transformed. dim : int Position in the expanded dim where the new dim is placed. Returns ------- data : eegdata or dataset Transformed eegdata. Examples -------- >>> eegdata = dpeeg.EEGData(edata=np.random.randn(16, 3, 10), ... label=np.random.randint(0, 3, 16)) >>> transforms.Unsqueeze(dim=2)(eegdata, verbose=False) [edata=(16, 3, 1, 10), label=(16,)] """ def __init__(self, key: str = "edata", dim: int = 1) -> None: super().__init__(get_init_args(self, locals(), format="rp")) self.key = key self.dim = dim def _apply_egd(self, egd: EEGData, key: str | None): egd[self.key] = np.expand_dims(egd[self.key], self.dim)
[docs] class Squeeze(TransformsEGD): """Remove a dimension on the data. Parameters ---------- key : str The key of the eegdata value to be transformed. dim : int Selects a subset of the entries of length one in the shape. If a dim is selected with shape entry greater than one, an error is raised. Returns ------- data : eegdata or dataset Transformed eegdata. Examples -------- >>> eegdata = dpeeg.EEGData(edata=np.random.randn(16, 1, 3, 10), ... label=np.random.randint(0, 3, 16)) >>> transforms.Squeeze()(eegdata, verbose=False) [edata=(16, 3, 10), label=(16,)] """ def __init__(self, key: str = "edata", dim: int = 1) -> None: super().__init__(get_init_args(self, locals(), format="rp")) self.key = key self.dim = dim def _apply_egd(self, egd: EEGData, key: str | None): egd[self.key] = np.squeeze(egd[self.key], self.dim)
[docs] class Transpose(TransformsEGD): """Data dims transposed. By default, the EEG data (``edata``) of eegdata are transposed. Parameters ---------- dims : tuple or list of int, optinal A tuple or list contains a permutation of [0,1,...,N-1] where N is the number of dims of the key values. The ``i``^th dim of the value will correspond to the axis numbered ``dims[i]`` of the input. If not specified, reverse the dims order by default. key : str The key of the eegdata value to be transformed. Examples -------- >>> eegdata = dpeeg.EEGData(edata=np.random.randn(16, 1, 3, 10), ... label=np.random.randint(0, 3, 16)) >>> transforms.Transpose()(eegdata, verbose=False) [edata=(10, 3, 16), label=(16,)] """ def __init__(self, dims: list[int] | None = None, key: str = "edata"): super().__init__(get_init_args(self, locals(), "rp")) self.dims = dims self.key = key def _apply_egd(self, egd: EEGData, key: str | None): egd[self.key] = np.transpose(egd[self.key], self.dims)
[docs] class FilterBank(TransformsEGD): """Filter Bank. EEG data will be filtered according to different filtering frequencies and finally concatenated together. eg.`(Batch, ...) -> (Batch, F, ...)` if the number of filter banks exceeds 1, `(Batch, ...) -> (Batch, ...)` if the filter has only one. By default, filtering is performed on `edata`, please ensure the availability of the data. Related references include [1]_ and [2]_. Parameters ---------- freq : float EEG data sampling frequency. filter_bank : multiple 2 float of list The low-pass and high-pass cutoff frequencies for each filter set. transition_bandwidth : float The bandwidth (in hertz) of the transition region of the frequency response from the passband to the stopband. gstop : float The minimum attenuation in the stopband (dB). gpass : float The maximum loss in the passband (dB). Returns ------- data : eegdata or dataset Transformed eegdata. References ---------- .. [1] R. Mane, E. Chew, K. Chua, K. K. Ang, N. Robinson, A. P. Vinod, S.-W. Lee, and C. Guan, “FBCNet: A multi-view convolutional neural network for brain-computer interface,” arXiv preprint arXiv:2104.01233, 2021. .. [2] X. Ma, W. Chen, Z. Pei, J. Liu, B. Huang, and J. Chen, “A temporal dependency learning CNN with attention mechanism for MI-EEG decoding,” IEEE Transactions on Neural Systems and Rehabilitation Engineering, 2023. Examples -------- >>> eegdata = dpeeg.EEGData(edata=np.random.randn(16, 3, 10), ... label=np.random.randint(0, 3, 16)) >>> transforms.FilterBank(250)(eegdata, verbose=False) [edata=(16, 9, 3, 10), label=(16,)] """ def __init__( self, freq: float, filter_bank: list = [ [4, 8], [8, 12], [12, 16], [16, 20], [20, 24], [24, 28], [28, 32], [32, 36], [36, 40], ], transition_bandwidth: float = 2.0, gstop: float = 30, gpass: float = 3, ) -> None: super().__init__(get_init_args(self, locals(), format="rp")) self.freq = freq self.filter_bank = self._check_filter_bank(filter_bank) self.transition_bandwidth = transition_bandwidth self.gpass = gpass self.gstop = gstop self.bank_len = len(self.filter_bank) def _check_filter_bank(self, fb): if not isinstance(fb, list): raise TypeError(f"filter_bank must be a list, not {type(fb)}.") for f in fb: if len(f) != 2: raise ValueError( "The filter should be of two variables low pass and high " "pass cutoff frequency." ) return fb def _apply_egd(self, egd: EEGData, key: str | None): trials = egd.trials() data = np.empty((trials, self.bank_len, *egd["edata"].shape[1:])) for i, cutoff in enumerate(self.filter_bank): filter_data = cheby2_filter( data=egd["edata"], freq=self.freq, l_freq=cutoff[0], h_freq=cutoff[1], transition_bandwidth=self.transition_bandwidth, gpass=self.gpass, gstop=self.gstop, ) data[:, i] = filter_data if self.bank_len == 1: data = np.squeeze(data, 1) egd["edata"] = data
[docs] class ApplyFunc(TransformsEGD): """Apply a custom function to data. Parameters ---------- func : Callable Transformation data callback function. The first parameter of the function must be `EEGData`. keys : list of str, optional The key of the eegdata to be transformed, if required. Applies to all eegdata by default. **kwargs : dict, optional Additional arguments for callback function, if required. Returns ------- data : eegdata or dataset Transformed eegdata. Examples -------- If you want to pass a function with parameters, such as you want to use `np.expand_dims()` with `axis` parameter, you can do as follows: >>> eegdata = dpeeg.EEGData(edata=np.random.randn(16, 3, 10), ... label=np.random.randint(0, 3, 16)) >>> def expand_dim(data, dim=1): ... data["edata"] = np.expand_dims(data["edata"], dim) >>> transforms.ApplyFunc(expand_dim, dim=0)(eegdata, verbose=False) [edata=(1, 16, 3, 10), label=(16,)] >>> split_eegdata = dpeeg.SplitEEGData(eegdata, eegdata.copy()) >>> transforms.ApplyFunc(expand_dim, ["train"])(split_eegdata, verbose=False) Train: [edata=(1, 1, 16, 3, 10), label=(16,)] Test : [edata=(1, 16, 3, 10), label=(16,)] """ def __init__( self, func: Callable, keys: list[str] | None = None, **kwargs, ) -> None: super().__init__(get_init_args(self, locals(), format="rp")) self.func = func self.keys = keys self.kwargs = kwargs def _apply_egd(self, egd: EEGData, key: str | None): if (self.keys is None) or (key in self.keys): self.func(egd, **self.kwargs)
[docs] class LabelMapping(TransformsEGD): """Rearrange the original label according to mapping rules. Parameters ---------- mapping : ndarray (2, label_num), optional Label mapping relationship. The first row is the original label, and the second row is the mapped label. If ``None``, the label will be reordered in ascending order starting from zero. order : bool Force the new labels to start incrementing from zero. Returns ------- data : eegdata or dataset Transformed eegdata. Examples -------- >>> eegdata = dpeeg.EEGData(edata=np.random.randn(16, 3, 10), ... label=np.random.randint(0, 3, 16)) >>> eegdata['label'] array([3, 2, 2, 2, 3, 2, 4, 3, 4, 3, 3, 2, 4, 4, 2, 3]) Merge labels as needed: >>> transforms.LabelMapping( ... np.array([[2, 3, 4], [0, 0, 1]]) ... )(eegdata, verbose=False) >>> eegdata["label"] array([0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0]) """ def __init__(self, mapping: ndarray | None = None, order: bool = True) -> None: super().__init__(get_init_args(self, locals(), format="rp")) self.mapping = mapping self.order = order def _apply_egd(self, egd: EEGData, key: str | None): egd["label"] = label_mapping( label=egd["label"], mapping=self.mapping, order=self.order, )
[docs] class PickLabel(TransformsEGD): """Pick a subset of data. Pick the required labels and data from the dataset and re-label them. Parameters ---------- pick : ndarray Label to include. keys : list of str, optional The key of the eegdata value to be transformed, if required. Applies to all eegdata by default. order : bool If `True`, relabel the selected labels. shuffle : bool Whether or not to shuffle the data after picking. seed : int Controls the shuffling applied to the data after picking. Returns ------- data : eegdata or dataset Transformed eegdata. Examples -------- >>> eegdata = dpeeg.EEGData(edata=np.random.randn(16, 3, 10), ... label=np.random.randint(0, 3, 16)) array([1, 2, 0, 2, 1, 2, 0, 1, 0, 0, 0, 1, 2, 1, 0, 0]) >>> transforms.PickLabel(np.array([1, 2]))(eegdata, verbose=False) array([1, 0, 1, 0, 1, 0, 0, 0, 1]) If some values do not need to be transformed, they can be excluded by the `keys` parameter: >>> eegdata = dpeeg.EEGData( ... edata=np.random.randn(16, 3, 10), ... label=np.random.randint(0, 3, 16), ... adj=np.random.randn(16, 3, 3), ... pcc=np.random.randn(16, 3, 3), ... ) >>> transforms.PickLabel( ... np.array([0, 1]), keys=["edata", "adj"] ... )(eegdata, verbose=False) [edata=(12, 3, 10), label=(12,), adj=(12, 3, 3), pcc=(16, 3, 3)] """ def __init__( self, pick: ndarray, keys: list[str] | None = None, order: bool = True, shuffle: bool = True, seed: int = DPEEG_SEED, ) -> None: super().__init__(get_init_args(self, locals(), format="rp")) self.pick = pick self.keys = keys self.order = order self.shuffle = shuffle self.seed = seed def _apply_egd(self, egd: EEGData, key: str | None): label = egd["label"] keys, values = [], [] for key, value in egd.items(): if (key != "label") and ((self.keys is None) or (key in self.keys)): keys.append(key) values.append(value) data, label = pick_label( *values, label=label, pick=self.pick, order=self.order, shuffle=self.shuffle, seed=self.seed, ) egd["label"] = label for i, key in enumerate(keys): egd[key] = data[i]