Source code for dpeeg.models.utils

# Authors: SheepTAO <sheeptao@outlook.com>

# License: MIT
# Copyright the dpeeg contributors.

import torch
import torch.nn as nn


[docs] class Conv2dWithNorm(nn.Conv2d): def __init__(self, *args, do_weight_norm=True, max_norm=1.0, p=2, **kwargs): super().__init__(*args, **kwargs) self.p = p self.max_norm = max_norm self.do_weight_norm = do_weight_norm
[docs] def forward(self, input: torch.Tensor) -> torch.Tensor: if self.do_weight_norm: self.weight.data = torch.renorm(self.weight.data, self.p, 0, self.max_norm) return super().forward(input)
def __repr__(self): repr = super().__repr__() if self.do_weight_norm: last_bracket_index = repr.rfind(")") self_repr = f", max_norm={self.max_norm}, p={self.p}" repr = repr[:last_bracket_index] + self_repr + ")" return repr
[docs] class LinearWithNorm(nn.Linear): def __init__(self, *args, do_weight_norm=True, max_norm=1.0, p=2, **kwargs): super().__init__(*args, **kwargs) self.p = p self.max_norm = max_norm self.do_weight_norm = do_weight_norm
[docs] def forward(self, input: torch.Tensor) -> torch.Tensor: if self.do_weight_norm: self.weight.data = torch.renorm(self.weight.data, self.p, 0, self.max_norm) return super().forward(input)
def __repr__(self): repr = super().__repr__() if self.do_weight_norm: last_bracket_index = repr.rfind(")") self_repr = f", max_norm={self.max_norm}, p={self.p}" repr = repr[:last_bracket_index] + self_repr + ")" return repr