LightConvNet#

class dpeeg.models.LightConvNet.LightConvNet(nCh: int, nTime: int, nCls: int, bands: int = 9, embed_dim: int = 64, win_len: int = 250, heads: int = 8, weight_softmax: bool = True, bias: bool = False)[source]#

A Temporal Dependency Learning CNN With Attention Mechanism for MI-EEG Decoding (LightConvNet).

LightConvNet [1] first implements the spatial convolution to learn spatial and spectral information from multi-view EEG data, which is preprocessed with a filter bank. Then, LightConvNet employs a series of non-overlapped time windows to segment the output time series. The discriminative feature from each time window is further extracted using a temporal variance layer to capture MI-related patterns in different stages during MI tasks. Moreover, LightConvNet designs a novel temporal attention module to further learn temporal dependencies among discriminative features from different time windows. The temporal attention module assigns different weights to features in various time windows according to their contribution to the final decoding performance, and fuses them into more discriminative features. Finally, the fused features are used for classification.

Parameters:
  • nCh (int) – Number of electrode channels.

  • nTime (int) – Number of data sampling points. For example, a 4-second data input with a sampling rate of 250 Hz is 1000.

  • nCls (int) – Number of classification categories.

  • bands (int) – The filter dimension of the input multi-view data.

  • embed_dim (int) – Number of spatial filters.

  • win_len (int) – The length of the time window.

  • heads (int) – Number of multi-head attention.

  • weight_softmax (bool) – Normalize the weight with softmax before the convolution.

  • bias (bool) – The learnable bias.

References

forward(x)[source]#

Forward pass function that processes the input EEG data and produces the decoded results.

Parameters:

x (Tensor) – Input EEG data, shape (batch_size, bands, nCh, nTime).

Returns:

cls_prob – Predicted class probability, shape (batch_size, nCls).

Return type:

Tensor