Source code for dpeeg.models.EEGConformer

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch import Tensor, fill
from einops import rearrange
from einops.layers.torch import Rearrange

from ..tools.docs import fill_doc


__all__ = ["EEGConformer"]


class PatchEmbedding(nn.Module):
    def __init__(self, nCh, emb_size=40):
        super().__init__()

        self.shallownet = nn.Sequential(
            nn.Conv2d(1, 40, (1, 25), (1, 1)),
            nn.Conv2d(40, 40, (nCh, 1), (1, 1)),
            nn.BatchNorm2d(40),
            nn.ELU(),
            nn.AvgPool2d((1, 75), (1, 15)),
            nn.Dropout(0.5),
        )

        self.projection = nn.Sequential(
            nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)),
            Rearrange("b e (h) (w) -> b (h w) e"),
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.shallownet(x)
        x = self.projection(x)
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: Tensor) -> Tensor:
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        energy = torch.einsum("bhqd, bhkd -> bhqk", queries, keys)

        scaling = self.emb_size ** (1 / 2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        out = torch.einsum("bhal, bhlv -> bhav ", att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out


class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x


class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion, drop_p):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )


class TransformerEncoderBlock(nn.Sequential):
    def __init__(
        self,
        emb_size,
        num_heads=10,
        drop_p=0.5,
        forward_expansion=4,
        forward_drop_p=0.5,
    ):
        super().__init__(
            ResidualAdd(
                nn.Sequential(
                    nn.LayerNorm(emb_size),
                    MultiHeadAttention(emb_size, num_heads, drop_p),
                    nn.Dropout(drop_p),
                )
            ),
            ResidualAdd(
                nn.Sequential(
                    nn.LayerNorm(emb_size),
                    FeedForwardBlock(
                        emb_size, expansion=forward_expansion, drop_p=forward_drop_p
                    ),
                    nn.Dropout(drop_p),
                )
            ),
        )


class TransformerEncoder(nn.Sequential):
    def __init__(
        self, depth, emb_size, num_heads, drop_p, forward_expansion, forward_drop_p
    ):
        super().__init__(
            *[
                TransformerEncoderBlock(
                    emb_size, num_heads, drop_p, forward_expansion, forward_drop_p
                )
                for _ in range(depth)
            ]
        )


class ClassificationHead(nn.Sequential):
    def __init__(self, in_features, nCls):
        super().__init__(
            nn.Flatten(),
            nn.Linear(in_features, 256),
            nn.ELU(),
            nn.Dropout(0.5),
            nn.Linear(256, 32),
            nn.ELU(),
            nn.Dropout(0.3),
            nn.Linear(32, nCls),
            nn.LogSoftmax(dim=1),
        )


[docs] @fill_doc class EEGConformer(nn.Module): """EEG Conformer: Convolutional Transformer for EEG Decoding and Visualization (EEG Conformer). EEG Conformer [1]_ is proposed to encapsulate local and global features in a unified EEG classification framework. The architecture comprises three components: a convolution module, a self-attention module, and a fully- connected classifier. In the convolution module, taking the raw two- dimensional EEG trials as the input, temporal and spatial convolutional layers are applied along the time dimension and electrode channel dimensions, respectively. Then, an average pooling layer is utilized to suppress noise interference while improving generalization. Secondly, the spatial-temporal representation obtained by the convolution module is fed into the selfattention module. The self-attention module further extracts the long-term temporal features by measuring the global correlations between different time positions in the feature maps. Finally, a compact classifier consisting of several fullyconnected layers is adopted to output the decoding results. Parameters ---------- %(nCh)s %(nTime)s %(nCls)s emb_size : int Embedding layer size. depth : int Depth of transformer encoder. num_heads : int Number of multi-head attention. drop_p : float Dropout rate of transformer encoder. forward_expansion : int The expansion factor of the fully connected feed-forward layer. forward_drop_p : float Dropout rate of fully connected feed-forward layer. References ---------- .. [1] Y. Song, Q. Zheng, B. Liu, and X. Gao, “EEG conformer: Convolutional transformer for EEG decoding and visualization,” IEEE Transactions on Neural Systems and Rehabilitation Engineering, vol. 31, pp. 710–719, 2022. """ def __init__( self, nCh: int, nTime: int, nCls: int, emb_size: int = 40, depth: int = 6, num_heads: int = 10, drop_p: float = 0.5, forward_expansion: int = 4, forward_drop_p: float = 0.5, ) -> None: super().__init__() self.nCh = nCh self.nTime = nTime self.patch_embedding = PatchEmbedding(nCh, emb_size) self.transformer_encoder = TransformerEncoder( depth, emb_size, num_heads, drop_p, forward_expansion, forward_drop_p ) in_freatures = self._forward_transformer().size(1) self.classification_head = ClassificationHead(in_freatures, nCls) def _forward_transformer(self) -> torch.Tensor: x = torch.randn(1, 1, self.nCh, self.nTime) x = self.patch_embedding(x) x = self.transformer_encoder(x) return x.flatten(start_dim=1, end_dim=-1)
[docs] def forward(self, x): """Forward pass function that processes the input EEG data and produces the decoded results. Parameters ---------- x : Tensor Input EEG data, shape `(batch_size, 1, nCh, nTime)`. Returns ------- cls_prob : Tensor Predicted class probability, shape `(batch_size, nCls)`. """ out = self.patch_embedding(x) out = self.transformer_encoder(out) return self.classification_head(out)