MSVTNet#

class dpeeg.models.MSVTNet.MSVTNet(nCh: int, nTime: int, nCls: int, F: list[int] = [9, 9, 9, 9], C1: list[int] = [15, 31, 63, 125], C2: int = 15, D: int = 2, P1: int = 8, P2: int = 7, Pc: float = 0.3, nhead: int = 8, ff_ratio: int = 1, Pt: float = 0.5, layers: int = 2, b_preds: bool = True)[source]#

MSVTNet: Multi-Scale Vision Transformer Neural Network for EEG-Based Motor Imagery Decoding (MSVTNet).

MSVTNet [1] effectively integrates the strengths of convolutional neural networks (CNNs) in extracting local features with the global feature extraction capabilities of Transformers. Specifically, to optimize classification features, a multi-branch CNN with different scales is designed to capture local spatiotemporal features, along with a Transformer to jointly model global and local spatiotemporal correlations features. Additionally, auxiliary branch loss (ABL) is leveraged for intermediate supervision, ensuring effective integration of CNNs and Transformers.

Parameters:
  • nCh (int) – Number of electrode channels.

  • nTime (int) – Number of data sampling points. For example, a 4-second data input with a sampling rate of 250 Hz is 1000.

  • nCls (int) – Number of classification categories.

  • F (list of int) – Number of temporal filters per branch.

  • C1 (list of int) – The convolution kernel size of each branch temporal filter.

  • C2 (int) – Depthwise convolution kernel size.

  • D (int) – Depth of depthwise convolution.

  • P1 (float) – The first pooling kernel size.

  • P2 (float) – The second pooling kernel size.

  • Pc (float) – Dropout rate of multi-branch convolutional module.

  • nhead (int) – Number of multi-head attention.

  • ff_ratio (int) – The expansion factor of the fully connected feed-forward layer.

  • Pt (float) – Dropout rate of transformer encoder.

  • layers (int) – Number of transformer encoder layers.

  • b_preds (bool) – If True, return the prediction for each branch.

References

forward(x)[source]#

Forward pass function that processes the input EEG data and produces the decoded results.

Parameters:

x (Tensor) – Input EEG data, shape (batch_size, bands, nCh, nTime).

Returns:

  • cls_prob (Tensor) – Predicted class probability, shape (batch_size, nCls).

  • branch_cls_prob (list of Tensor) – If b_preds=True, return the class prediction probability for each branch.

class dpeeg.models.MSVTNet.JointCrossEntoryLoss(lamd: float = 0.6)[source]#

Auxiliary branch loss.

The parameters of MSVTNet are learned under the supervision of the auxiliary branch loss and model prediction loss:

\[ \begin{align}\begin{aligned}\mathcal{L}=\lambda\mathcal{L}_c+(1-\lambda)\sum_{b=1}^{B}\mathcal{L}_b\\\mathcal{L}_{c/b}=\mathrm{Cross Entropy Loss}(\hat{y})\end{aligned}\end{align} \]

where \(\lambda\in(0, 1]\) is the ratio factor for intermediate supervision of the model.

Parameters:

lamd (float) – Ratio factor of ABL.

forward(out, label)[source]#

Forward pass function that processes the model and branch prediction probabilities.

Parameters:
  • out (tuple of Tensor) – Models and branch prediction probabilities.

  • label (Tensor) – True label.

Returns:

loss – Loss with gradient.

Return type:

Tensor