Source code for dpeeg.models.EEGNet

import torch
import torch.nn as nn

from .utils import Conv2dWithNorm, LinearWithNorm
from ..tools.docs import fill_doc


__all__ = ["EEGNet"]


[docs] @fill_doc class EEGNet(nn.Module): """EEGNet: A Compact Convolutional Neural Network for EEG-based Brain-Computer Interfaces (EEGNet). EEGNet [1]_ is a compact convolutional neural network for EEG-based BCIs. EEGNet starts with a temporal convolution to learn frequency filters, then uses a depthwise convolution, connected to each feature map individually, to learn frequency-specific spatial filters. The separable convolution is a combination of a depthwise convolution, which learns a temporal summary for each feature map individually, followed by a pointwise convolution, which learns how to optimally mix the feature maps together. Parameters ---------- %(nCh)s %(nTime)s %(nCls)s F1 : int Number of temporal filters. C1 : int Temporal convolution kernel size. D : int Depth of depthwise convolution. F2 : int Number of separable convolutions. C2 : int Separable convolution kernel size. P1 : int The first pooling kernel size. P2 : int The second pooling kernel size. dropout : float Dropout rate. References ---------- .. [1] V. J. Lawhern, A. J. Solon, N. R. Waytowich, S. M. Gordon, C. P. Hung, and B. J. Lance, “EEGNet: a compact convolutional neural network for EEG-based brain–computer interfaces,” J. Neural Eng., vol. 15, no. 5, p. 056013, Jul. 2018, doi: 10.1088/1741-2552/aace8c. """ def __init__( self, nCh: int, nTime: int, nCls: int, F1: int = 8, C1: int = 63, D: int = 2, F2: int = 16, C2: int = 15, P1: int = 8, P2: int = 16, dropout: float = 0.5, ) -> None: super().__init__() self.filter = nn.Sequential( nn.Conv2d(1, F1, (1, C1), padding=(0, C1 // 2), bias=False), nn.BatchNorm2d(F1), ) self.depthwise_conv = nn.Sequential( Conv2dWithNorm(F1, D * F1, (nCh, 1), groups=F1, bias=False, max_norm=1), nn.BatchNorm2d(D * F1), nn.ELU(), nn.AvgPool2d((1, P1)), nn.Dropout(dropout), ) self.separable_conv = nn.Sequential( nn.Conv2d( D * F1, D * F1, (1, C2), padding=(0, C2 // 2), groups=D * F1, bias=False ), nn.Conv2d(D * F1, F2, 1, bias=False), nn.BatchNorm2d(F2), nn.ELU(), nn.AvgPool2d((1, P2)), nn.Dropout(dropout), ) self.flatten = nn.Flatten() self.fc = nn.Sequential( # Experimental results show that using linearwithnorm will lead to # performance degradation. # LinearWithNorm(self.get_size(nCh, nTime), nCls, bias=True, max_norm=0.25) nn.Linear(self.get_size(nCh, nTime), nCls, bias=True), nn.LogSoftmax(dim=1), ) def get_size(self, nCh, nTime): x = torch.randn(1, 1, nCh, nTime) out = self.filter(x) out = self.depthwise_conv(out) out = self.separable_conv(out) return self.flatten(out).size(1)
[docs] def forward(self, x): """Forward pass function that processes the input EEG data and produces the decoded results. Parameters ---------- x : Tensor Input EEG data, shape `(batch_size, 1, nCh, nTime)`. Returns ------- cls_prob : Tensor Predicted class probability, shape `(batch_size, nCls)`. """ out = self.filter(x) out = self.depthwise_conv(out) out = self.separable_conv(out) out = self.flatten(out) return self.fc(out)