import torch
import torch.nn as nn
from torch.optim.adamw import AdamW
from torch.nn.init import trunc_normal_, constant_
from ..tools.docs import fill_doc
__all__ = ["IFNet", "IFNetAdamW"]
class InterFre(nn.Module):
def forward(self, x):
out = sum(x)
out = nn.GELU()(out)
return out
[docs]
@fill_doc
class IFNet(nn.Module):
"""IFNet: An Interactive Frequency Convolutional Neural Network for
Enhancing Motor Imagery Decoding From EEG (IFNet).
Inspired by the concept of cross-frequency coupling and its correlation
with different behavioral tasks, IFNet [1]_ explores cross-frequency
interactions for enhancing representation of MI characteristics. IFNet
first extracts spectro-spatial features in low and high-frequency bands,
respectively. Then the interplay between the two bands is learned using an
element-wise addition operation followed by temporal average pooling.
Combined with repeated trial augmentation as a regularizer, IFNet yields
spectro-spatiotemporally robust features for the final MI classification.
Parameters
----------
%(nCh)s
%(nTime)s
%(nCls)s
F : int
Number of spectro-spatial filters.
C : int
Spectro-spatial filter kernel size.
radix : int
Number of cross-frequency domains.
P : int
Pooling kernel size.
dropout : float
Dropout rate.
References
----------
.. [1] J. Wang, L. Yao and Y. Wang, "IFNet: An Interactive Frequency
Convolutional Neural Network for Enhancing Motor Imagery Decoding from
EEG," in IEEE Transactions on Neural Systems and Rehabilitation
Engineering, doi: 10.1109/TNSRE.2023.3257319.
"""
def __init__(
self,
nCh: int,
nTime: int,
nCls: int,
F: int = 64,
C: int = 63,
radix: int = 2,
P: int = 125,
dropout: float = 0.5,
) -> None:
super().__init__()
self.F = F
self.mF = F * radix
self.sConv = nn.Sequential(
nn.Conv1d(nCh * radix, self.mF, 1, bias=False, groups=radix),
nn.BatchNorm1d(self.mF),
)
self.tConv = nn.ModuleList()
for _ in range(radix):
self.tConv.append(
nn.Sequential(
nn.Conv1d(F, F, C, 1, padding=C // 2, groups=F, bias=False),
nn.BatchNorm1d(F),
)
)
C //= 2
self.interFre = InterFre()
self.downSamp = nn.Sequential(nn.AvgPool1d(P), nn.Dropout(dropout))
self.fc = nn.Sequential(
nn.Flatten(), nn.Linear(int(F * (nTime // P)), nCls), nn.LogSoftmax(dim=1)
)
self.apply(self.initParms)
def initParms(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.01)
if m.bias is not None:
constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d)):
if m.weight is not None:
constant_(m.weight, 1.0)
if m.bias is not None:
constant_(m.bias, 0)
elif isinstance(m, (nn.Conv1d, nn.Conv2d)):
trunc_normal_(m.weight, std=0.01)
if m.bias is not None:
constant_(m.bias, 0)
[docs]
def forward(self, x):
"""Forward pass function that processes the input EEG data and produces
the decoded results.
Parameters
----------
x : Tensor
Input EEG data, shape `(batch_size, nCh * radix, nTime)`.
Returns
-------
cls_prob : Tensor
Predicted class probability, shape `(batch_size, nCls)`.
"""
out = self.sConv(x)
out = torch.split(out, self.F, dim=1)
out = [m(x) for m, x in zip(self.tConv, out)]
out = self.interFre(out)
out = self.downSamp(out)
return self.fc(out)
[docs]
class IFNetAdamW(AdamW):
"""Customized AdamW Optimizer for IFNet.
IFNetAdamW optimizer allows bias and weights based on certain parameters to
not decay.
Parameters
----------
net : IFNet
IFNet model instance.
"""
def __init__(self, net: nn.Module, **kwargs) -> None:
has_decay = []
no_decay = []
for name, param in net.named_parameters():
if not param.requires_grad:
continue
if len(param.shape) == 1 or name.endswith(".bias"):
no_decay.append(param)
else:
has_decay.append(param)
params = [{"params": has_decay}, {"params": no_decay, "weight_decay": 0}]
super().__init__(params, **kwargs)