Source code for dpeeg.models.LightConvNet

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..tools.docs import fill_doc


__all__ = ["LightConvNet"]


class LogVarLayer(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        out = torch.clamp(x.var(dim=self.dim), 1e-6, 1e6)
        return torch.log(out)


class LightweightConv1d(nn.Module):
    """
    Args:
        input_size: # of channels of the input and output
        kernel_size: convolution channels
        padding: padding
        num_heads: number of heads used. The weight is of shape
            `(num_heads, 1, kernel_size)`
        weight_softmax: normalize the weight with softmax before the convolution
    Shape:
        Input: BxCxT, i.e. (batch_size, input_size, timesteps)
        Output: BxCxT, i.e. (batch_size, input_size, timesteps)
    Attributes:
        weight: the learnable weights of the module of shape
            `(num_heads, 1, kernel_size)`
        bias: the learnable bias of the module of shape `(input_size)`
    """

    def __init__(
        self,
        input_size,
        kernel_size=1,
        padding=0,
        heads=1,
        weight_softmax=False,
        bias=False,
    ):
        super().__init__()
        self.input_size = input_size
        self.kernel_size = kernel_size
        self.heads = heads
        self.padding = padding
        self.weight_softmax = weight_softmax
        self.weight = nn.Parameter(torch.Tensor(heads, 1, kernel_size))

        if bias:
            self.bias = nn.Parameter(torch.Tensor(input_size))
        else:
            self.bias = None

        self.init_parameters()

    def init_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            nn.init.constant_(self.bias, 0.0)

    def forward(self, input):
        B, C, T = input.size()
        H = self.heads

        weight = self.weight
        if self.weight_softmax:
            weight = F.softmax(weight, dim=-1)

        input = input.view(-1, H, T)
        output = F.conv1d(input, weight, padding=self.padding, groups=self.heads)
        output = output.view(B, C, -1)
        if self.bias is not None:
            output = output + self.bias.view(1, -1, 1)

        return output


[docs] @fill_doc class LightConvNet(nn.Module): """A Temporal Dependency Learning CNN With Attention Mechanism for MI-EEG Decoding (LightConvNet). LightConvNet [1]_ first implements the spatial convolution to learn spatial and spectral information from multi-view EEG data, which is preprocessed with a filter bank. Then, LightConvNet employs a series of non-overlapped time windows to segment the output time series. The discriminative feature from each time window is further extracted using a temporal variance layer to capture MI-related patterns in different stages during MI tasks. Moreover, LightConvNet designs a novel temporal attention module to further learn temporal dependencies among discriminative features from different time windows. The temporal attention module assigns different weights to features in various time windows according to their contribution to the final decoding performance, and fuses them into more discriminative features. Finally, the fused features are used for classification. Parameters ---------- %(nCh)s %(nTime)s %(nCls)s bands : int The filter dimension of the input multi-view data. embed_dim : int Number of spatial filters. win_len : int The length of the time window. heads : int Number of multi-head attention. weight_softmax : bool Normalize the weight with softmax before the convolution. bias : bool The learnable bias. References ---------- .. [1] “A Temporal Dependency Learning CNN With Attention Mechanism for MI-EEG Decoding | IEEE Journals & Magazine | IEEE Xplore.” Accessed: Oct. 20, 2023. [Online]. Available: https://ieeexplore.ieee.org/document/10196350 """ def __init__( self, nCh: int, nTime: int, nCls: int, bands: int = 9, embed_dim: int = 64, win_len: int = 250, heads: int = 8, weight_softmax: bool = True, bias: bool = False, ): super().__init__() self.win_len = win_len self.spacial_block = nn.Sequential( nn.Conv2d(bands, embed_dim, (nCh, 1)), nn.BatchNorm2d(embed_dim), nn.ELU() ) self.temporal_block = LogVarLayer(dim=3) self.conv = LightweightConv1d( embed_dim, (nTime // win_len), heads=heads, weight_softmax=weight_softmax, bias=bias, ) self.classify = nn.Sequential(nn.Linear(embed_dim, nCls), nn.LogSoftmax(dim=1))
[docs] def forward(self, x): """Forward pass function that processes the input EEG data and produces the decoded results. Parameters ---------- x : Tensor Input EEG data, shape `(batch_size, bands, nCh, nTime)`. Returns ------- cls_prob : Tensor Predicted class probability, shape `(batch_size, nCls)`. """ out = self.spacial_block(x) out = out.reshape([*out.shape[0:2], -1, self.win_len]) out = self.temporal_block(out) out = self.conv(out) out = out.view(out.size(0), -1) out = self.classify(out) return out