MSVTNet#
- class dpeeg.models.MSVTNet.MSVTNet(nCh: int, nTime: int, nCls: int, F: list[int] = [9, 9, 9, 9], C1: list[int] = [15, 31, 63, 125], C2: int = 15, D: int = 2, P1: int = 8, P2: int = 7, Pc: float = 0.3, nhead: int = 8, ff_ratio: int = 1, Pt: float = 0.5, layers: int = 2, b_preds: bool = True)[source]#
MSVTNet: Multi-Scale Vision Transformer Neural Network for EEG-Based Motor Imagery Decoding (MSVTNet).
MSVTNet [1] effectively integrates the strengths of convolutional neural networks (CNNs) in extracting local features with the global feature extraction capabilities of Transformers. Specifically, to optimize classification features, a multi-branch CNN with different scales is designed to capture local spatiotemporal features, along with a Transformer to jointly model global and local spatiotemporal correlations features. Additionally, auxiliary branch loss (ABL) is leveraged for intermediate supervision, ensuring effective integration of CNNs and Transformers.
- Parameters:
nCh (int) – Number of electrode channels.
nTime (int) – Number of data sampling points. For example, a 4-second data input with a sampling rate of 250 Hz is 1000.
nCls (int) – Number of classification categories.
C1 (list of int) – The convolution kernel size of each branch temporal filter.
C2 (int) – Depthwise convolution kernel size.
D (int) – Depth of depthwise convolution.
P1 (float) – The first pooling kernel size.
P2 (float) – The second pooling kernel size.
Pc (float) – Dropout rate of multi-branch convolutional module.
nhead (int) – Number of multi-head attention.
ff_ratio (int) – The expansion factor of the fully connected feed-forward layer.
Pt (float) – Dropout rate of transformer encoder.
layers (int) – Number of transformer encoder layers.
b_preds (bool) – If
True, return the prediction for each branch.
References
- forward(x)[source]#
Forward pass function that processes the input EEG data and produces the decoded results.
- Parameters:
x (Tensor) – Input EEG data, shape (batch_size, bands, nCh, nTime).
- Returns:
cls_prob (Tensor) – Predicted class probability, shape (batch_size, nCls).
branch_cls_prob (list of Tensor) – If
b_preds=True, return the class prediction probability for each branch.
- class dpeeg.models.MSVTNet.JointCrossEntoryLoss(lamd: float = 0.6)[source]#
Auxiliary branch loss.
The parameters of MSVTNet are learned under the supervision of the auxiliary branch loss and model prediction loss:
\[ \begin{align}\begin{aligned}\mathcal{L}=\lambda\mathcal{L}_c+(1-\lambda)\sum_{b=1}^{B}\mathcal{L}_b\\\mathcal{L}_{c/b}=\mathrm{Cross Entropy Loss}(\hat{y})\end{aligned}\end{align} \]where \(\lambda\in(0, 1]\) is the ratio factor for intermediate supervision of the model.
- Parameters:
lamd (float) – Ratio factor of ABL.