import torch
import torch.nn as nn
from .utils import Conv2dWithNorm, LinearWithNorm
from ..tools.docs import fill_doc
__all__ = ["ShallowConvNet"]
class Lambda(nn.Module):
def __init__(self, func) -> None:
super().__init__()
self.func = func
def forward(self, x):
return self.func(x)
[docs]
@fill_doc
class ShallowConvNet(nn.Module):
"""Deep Learning With Convolutional Neural Networks for EEG Decoding and
Visualization (ShallowConvNet).
Shallow ConvNet [1]_, inspired by the FBCSP pipeline, is specifically
tailored to decode band power features. The transformations performed by
the shallow ConvNet are similar to the transformations of FBCSP.
Concretely, the first two layers of the shallow ConvNet perform a temporal
convolution and a spatial filter, as in the deep ConvNet. These steps are
analogous to the bandpass and CSP spatial filter steps in FBCSP. In
contrast to the deep ConvNet, the temporal convolution of the shallow
ConvNet had a larger kernel size, allowing a larger range of
transformations in this layer (smaller kernel sizes for the shallow ConvNet
led to lower accuracies in preliminary experiments on the training set).
After the temporal convolution and the spatial filter of the shallow
ConvNet, a squaring nonlinearity, a mean pooling layer and a logarithmic
activation function followed; together these steps are analogous to the
trial log-variance computation in FBCSP. In contrast to FBCSP, the shallow
ConvNet embeds all the computational steps in a single network, and thus
all steps can be optimized jointly. Also, due to having several pooling
regions within one trial, the shallow ConvNet can learn a temporal
structure of the band power changes within the trial.
Parameters
----------
%(nCh)s
%(nTime)s
%(nCls)s
F : int
The number of convolution channels.
C : int
Temporal convolution kernel size.
P : int
Pooling kernel size.
S : int
Pooling layer stride size.
dropout : float
Dropout rate.
References
----------
.. [1] R. T. Schirrmeister et al., “Deep learning with convolutional neural
networks for EEG decoding and visualization,” Human Brain Mapping,
vol. 38, no. 11, pp.5391-5420, 2017, doi: 10.1002/hbm.23730.
"""
def __init__(
self,
nCh: int,
nTime: int,
nCls: int,
F: int = 40,
C: int = 14,
P: int = 35,
S: int = 7,
dropout: float = 0.5,
) -> None:
super().__init__()
self.nCh = nCh
self.nTime = nTime
self.conv = nn.Sequential(
Conv2dWithNorm(1, F, (1, C), max_norm=2, bias=False),
Conv2dWithNorm(F, F, (nCh, 1), max_norm=2, bias=False, groups=F),
nn.BatchNorm2d(F),
Lambda(torch.square),
nn.AvgPool2d((1, P), stride=(1, S)),
Lambda(torch.log),
)
linear_in = self.forward_flatten().shape[1]
self.head = nn.Sequential(
nn.Flatten(),
nn.Dropout(dropout),
LinearWithNorm(linear_in, nCls, max_norm=0.5),
nn.LogSoftmax(dim=1),
)
def forward_flatten(self):
x = torch.rand(1, 1, self.nCh, self.nTime)
x = self.conv(x)
x = torch.flatten(x, 1, -1)
return x
[docs]
def forward(self, x):
"""Forward pass function that processes the input EEG data and produces
the decoded results.
Parameters
----------
x : Tensor
Input EEG data, shape `(batch_size, bands, nCh, nTime)`.
Returns
-------
cls_prob : Tensor
Predicted class probability, shape `(batch_size, nCls)`.
"""
x = self.conv(x)
x = self.head(x)
return x