Source code for dpeeg.models.DeepConvNet

import torch
import torch.nn as nn

from .utils import Conv2dWithNorm, LinearWithNorm
from ..tools.docs import fill_doc


__all__ = ["DeepConvNet"]


[docs] @fill_doc class DeepConvNet(nn.Module): """Deep Learning With Convolutional Neural Networks for EEG Decoding and Visualization (Deep ConvNet). Deep ConvNet [1]_ had four convolution-max-pooling blocks, with a special first block designed to handle EEG input, followed by three standard convolutionmax-pooling blocks and a dense softmax classification layer. The first convolutional block was split into two layers in order to better handle the large number of input channelsone input channel per electrode compared to three input channels (one per color) in rgb-images. In the first layer, each filter performs a convolution over time, and in the second layer, each filter performs a spatial filtering with weights for all possible pairs of electrodes with filters of the preceding temporal convolution. Note that as there is no activation function in between the two layers, they could in principle be combined into one layer. Using two layers however implicitly regularizes the overall convolution by forcing a separation of the linear transformation into a combination of a temporal convolution and a spatial filter. Parameters ---------- %(nCh)s %(nTime)s %(nCls)s dropout : float Dropout rate. References ---------- .. [1] R. T. Schirrmeister et al., “Deep learning with convolutional neural networks for EEG decoding and visualization,” Human Brain Mapping, vol. 38, no. 11, pp.5391-5420, 2017, doi: 10.1002/hbm.23730. """ def __init__(self, nCh: int, nTime: int, nCls: int, dropout: float = 0.25) -> None: super().__init__() self.nCh = nCh self.nTime = nTime kernel_size = [1, 10] filter_layer = [25, 50, 100, 200] first_layer = nn.Sequential( Conv2dWithNorm(1, 25, kernel_size, max_norm=2), Conv2dWithNorm(25, 25, (nCh, 1), bias=False, max_norm=2), nn.BatchNorm2d(25), nn.ELU(), nn.MaxPool2d((1, 3)), ) middle_layer = nn.Sequential( *[ nn.Sequential( nn.Dropout(dropout), Conv2dWithNorm(in_f, out_f, kernel_size), nn.BatchNorm2d(out_f), nn.ELU(), nn.MaxPool2d((1, 3)), ) for in_f, out_f in zip(filter_layer, filter_layer[1:]) ] ) self.conv_layer = nn.Sequential(first_layer, middle_layer) linear_in = self._forward_flatten().shape[1] self.head = nn.Sequential( nn.Flatten(), LinearWithNorm(linear_in, nCls, max_norm=0.5), nn.LogSoftmax(dim=1), ) def _forward_flatten(self): x = torch.rand(1, 1, self.nCh, self.nTime) x = self.conv_layer(x) x = torch.flatten(x, start_dim=1, end_dim=-1) return x
[docs] def forward(self, x): """Forward pass function that processes the input EEG data and produces the decoded results. Parameters ---------- x : Tensor Input EEG data, shape `(batch_size, 1, nCh, nTime)`. Returns ------- cls_prob : Tensor Predicted class probability, shape `(batch_size, cls)`. """ x = self.conv_layer(x) x = self.head(x) return x