Source code for dpeeg.models.MSVTNet

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange

from ..tools.docs import fill_doc


__all__ = ["MSVTNet", "JointCrossEntoryLoss"]


class TSConv(nn.Sequential):
    def __init__(self, nCh, F, C1, C2, D, P1, P2, Pc) -> None:
        super().__init__(
            nn.Conv2d(1, F, (1, C1), padding="same", bias=False),
            nn.BatchNorm2d(F),
            nn.Conv2d(F, F * D, (nCh, 1), groups=F, bias=False),
            nn.BatchNorm2d(F * D),
            nn.ELU(),
            nn.AvgPool2d((1, P1)),
            nn.Dropout(Pc),
            nn.Conv2d(F * D, F * D, (1, C2), padding="same", groups=F * D, bias=False),
            nn.BatchNorm2d(F * D),
            nn.ELU(),
            nn.AvgPool2d((1, P2)),
            nn.Dropout(Pc),
        )


class PositionalEncoding(nn.Module):
    def __init__(self, seq_len, d_model) -> None:
        super().__init__()
        self.seq_len = seq_len
        self.d_model = d_model
        self.pe = nn.Parameter(torch.zeros(1, seq_len, d_model))

    def forward(self, x):
        x += self.pe
        return x


class Transformer(nn.Module):
    def __init__(
        self,
        seq_len,
        d_model,
        nhead,
        ff_ratio,
        Pt=0.5,
        num_layers=4,
    ) -> None:
        super().__init__()
        self.cls_embedding = nn.Parameter(torch.zeros(1, 1, d_model))
        self.pos_embedding = PositionalEncoding(seq_len + 1, d_model)

        dim_ff = d_model * ff_ratio
        self.dropout = nn.Dropout(Pt)
        self.trans = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model, nhead, dim_ff, Pt, batch_first=True, norm_first=True
            ),
            num_layers,
            norm=nn.LayerNorm(d_model),
        )

    def forward(self, x):
        b = x.shape[0]
        x = torch.cat((self.cls_embedding.expand(b, -1, -1), x), dim=1)
        x = self.pos_embedding(x)
        x = self.dropout(x)
        return self.trans(x)[:, 0]


class ClsHead(nn.Sequential):
    def __init__(self, linear_in, nCls):
        super().__init__(nn.Flatten(), nn.Linear(linear_in, nCls), nn.LogSoftmax(dim=1))


[docs] @fill_doc class MSVTNet(nn.Module): """MSVTNet: Multi-Scale Vision Transformer Neural Network for EEG-Based Motor Imagery Decoding (MSVTNet). MSVTNet [1]_ effectively integrates the strengths of convolutional neural networks (CNNs) in extracting local features with the global feature extraction capabilities of Transformers. Specifically, to optimize classification features, a multi-branch CNN with different scales is designed to capture local spatiotemporal features, along with a Transformer to jointly model global and local spatiotemporal correlations features. Additionally, auxiliary branch loss (ABL) is leveraged for intermediate supervision, ensuring effective integration of CNNs and Transformers. Parameters ---------- %(nCh)s %(nTime)s %(nCls)s F : list of int Number of temporal filters per branch. C1 : list of int The convolution kernel size of each branch temporal filter. C2 : int Depthwise convolution kernel size. D : int Depth of depthwise convolution. P1 : float The first pooling kernel size. P2 : float The second pooling kernel size. Pc : float Dropout rate of multi-branch convolutional module. nhead : int Number of multi-head attention. ff_ratio : int The expansion factor of the fully connected feed-forward layer. Pt : float Dropout rate of transformer encoder. layers : int Number of transformer encoder layers. b_preds : bool If ``True``, return the prediction for each branch. References ---------- .. [1] K. Liu et al., "MSVTNet: Multi-Scale Vision Transformer Neural Network for EEG-Based Motor Imagery Decoding," in IEEE Journal of Biomedical and Health Informatics, doi: 10.1109/JBHI.2024.3450753. """ def __init__( self, nCh: int, nTime: int, nCls: int, F: list[int] = [9, 9, 9, 9], C1: list[int] = [15, 31, 63, 125], C2: int = 15, D: int = 2, P1: int = 8, P2: int = 7, Pc: float = 0.3, nhead: int = 8, ff_ratio: int = 1, Pt: float = 0.5, layers: int = 2, b_preds: bool = True, ) -> None: super().__init__() self.nCh = nCh self.nTime = nTime self.b_preds = b_preds assert len(F) == len(C1), "The length of F and C1 should be equal." self.mstsconv = nn.ModuleList( [ nn.Sequential( TSConv(nCh, F[b], C1[b], C2, D, P1, P2, Pc), Rearrange("b d 1 t -> b t d"), ) for b in range(len(F)) ] ) branch_linear_in = self._forward_flatten(cat=False) self.branch_head = nn.ModuleList( [ClsHead(branch_linear_in[b].shape[1], nCls) for b in range(len(F))] ) seq_len, d_model = self._forward_mstsconv().shape[1:3] # type: ignore self.transformer = Transformer(seq_len, d_model, nhead, ff_ratio, Pt, layers) linear_in = self._forward_flatten().shape[1] # type: ignore self.last_head = ClsHead(linear_in, nCls) def _forward_mstsconv(self, cat=True): x = torch.randn(1, 1, self.nCh, self.nTime) x = [tsconv(x) for tsconv in self.mstsconv] if cat: x = torch.cat(x, dim=2) return x def _forward_flatten(self, cat=True): x = self._forward_mstsconv(cat) if cat: x = self.transformer(x) x = x.flatten(start_dim=1, end_dim=-1) else: x = [_.flatten(start_dim=1, end_dim=-1) for _ in x] return x
[docs] def forward(self, x): """Forward pass function that processes the input EEG data and produces the decoded results. Parameters ---------- x : Tensor Input EEG data, shape `(batch_size, bands, nCh, nTime)`. Returns ------- cls_prob : Tensor Predicted class probability, shape `(batch_size, nCls)`. branch_cls_prob : list of Tensor If ``b_preds=True``, return the class prediction probability for each branch. """ x = [tsconv(x) for tsconv in self.mstsconv] bx = [branch(x[idx]) for idx, branch in enumerate(self.branch_head)] x = torch.cat(x, dim=2) x = self.transformer(x) x = self.last_head(x) if self.b_preds: return x, bx else: return x
[docs] class JointCrossEntoryLoss(nn.Module): r"""Auxiliary branch loss. The parameters of MSVTNet are learned under the supervision of the auxiliary branch loss and model prediction loss: .. math:: \mathcal{L}=\lambda\mathcal{L}_c+(1-\lambda)\sum_{b=1}^{B}\mathcal{L}_b \mathcal{L}_{c/b}=\mathrm{Cross Entropy Loss}(\hat{y}) where :math:`\lambda\in(0, 1]` is the ratio factor for intermediate supervision of the model. Parameters ---------- lamd : float Ratio factor of ABL. """ def __init__(self, lamd: float = 0.6) -> None: super().__init__() self.lamd = lamd
[docs] def forward(self, out, label): """Forward pass function that processes the model and branch prediction probabilities. Parameters ---------- out : tuple of Tensor Models and branch prediction probabilities. label : Tensor True label. Returns ------- loss : Tensor Loss with gradient. """ end_out = out[0] branch_out = out[1] end_loss = F.nll_loss(end_out, label) branch_loss = [F.nll_loss(out, label).unsqueeze(0) for out in branch_out] branch_loss = torch.cat(branch_loss) loss = self.lamd * end_loss + (1 - self.lamd) * torch.sum(branch_loss) return loss