Source code for dpeeg.models.FBCNet

import torch
from torch import nn

from .utils import Conv2dWithNorm, LinearWithNorm
from ..tools.docs import fill_doc


__all__ = ["FBCNet"]


class swish(nn.Module):
    def __init__(self):
        super(swish, self).__init__()

    def forward(self, x):
        return x * torch.sigmoid(x)


class LogVarLayer(nn.Module):
    def __init__(self, dim):
        super(LogVarLayer, self).__init__()
        self.dim = dim

    def forward(self, x):
        x = torch.clamp(x.var(dim=self.dim, keepdim=True), 1e-6, 1e6)
        return torch.log(x)


[docs] @fill_doc class FBCNet(nn.Module): """FBCNet: A Multi-view Convolutional Neural Network for Brain-Computer Interface (FBCNet). FBCNet [1]_ employs a multi-view data representation followed by spatial filtering to extract spectro-spatially discriminative features. In FBCNet, a novel Variance layer is proposed to effectively aggregate the EEG time- domain information. Parameters ---------- %(nCh)s %(nTime)s %(nCls)s bands : int The filter dimension of the input multi-view data. m : int Number of sptatial filters. stride : int Number of time windows. Must evenly divide `nTime`. References ---------- .. [1] R. Mane et al., “FBCNet: A Multi-view Convolutional Neural Network for Brain-Computer Interface,” arXiv.org. Accessed: Mar. 09, 2023. [Online]. vailable: https://arxiv.org/abs/2104.01233v1 """ def __init__( self, nCh: int, nTime: int, nCls: int, bands: int = 9, m: int = 32, stride: int = 4, ) -> None: super().__init__() assert not (nTime % stride), "nTime must be divisible by stride." self.stride = stride self.scb = nn.Sequential( Conv2dWithNorm( bands, m * bands, (nCh, 1), groups=bands, do_weight_norm=True, max_norm=2, padding=0, ), nn.BatchNorm2d(m * bands), swish(), ) self.temporal_layer = LogVarLayer(dim=3) self.head = nn.Sequential( nn.Flatten(), LinearWithNorm(m * bands * stride, nCls, do_weight_norm=True, max_norm=0.5), nn.LogSoftmax(dim=1), )
[docs] def forward(self, x): """Forward pass function that processes the input EEG data and produces the decoded results. Parameters ---------- x : Tensor Input EEG data, shape `(batch_size, bands, nCh, nTime)`. Returns ------- cls_prob : Tensor Predicted class probability, shape `(batch_size, nCls)`. """ x = self.scb(x) x = x.reshape([*x.shape[:2], self.stride, x.shape[3] // self.stride]) x = self.temporal_layer(x) x = self.head(x) return x