Classifier#

class dpeeg.trainer.Classifier(model: Module, loss_fn: str | Type[Module] = 'NLLLoss', loss_fn_args: dict | None = None, optimizer: str | Type[Optimizer] = 'Adam', optimizer_args: dict | None = None, lr: float = 0.001, lr_sch: str | Type[LRScheduler] | None = None, lr_sch_args: dict | None = None, grad_acc: int = 1, batch_size: int = 32, max_epochs: int = 1000, no_increase_epochs: int = 50, var_check: Literal['train_loss', 'train_inacc'] = 'train_loss', load_best_state: bool = True, nGPU: int = 0, seed: int = 42, keep_data_gpu: bool = True, depth: int | None = None, data_size: tuple | list | None = None, verbose: int | str = 'INFO')[source]#

Classifier training.

During different training model processes, early stopping mechanisms can be executed using the training set (validation set not required) to select the model.

Parameters:
  • model (Module) – Inherit Module and should define the forward method. The first parameter returned by model forward propagation is the prediction.

  • loss_fn (str, Type[Module]) – Name of the loss function from torch.nn which will be used for training. If Module, means using a custom loss function. Note: custom loss_fn is a class (not an instance), and its initialization list is (**loss_fn_args).

  • loss_fn_args (dict, optional) – Additional arguments to be passed to the loss function.

  • optimizer (str, Type[Optimizer]) – Name of the optimization function from torch.optim which will be used for training. If Optimizer, means using a custom optimizer. Note: custom optimizer is a class (not an instance), and its initialization list is (model, lr=lr, **optimizer_args).

  • optimizer_args (dict, optional) – Additional arguments to be passed to the optimization function.

  • lr (float) – Learning rate.

  • lr_sch (str, Type[LRScheduler], optional) – Name of the learning scheduler from torch.optim.lr_scheduler which will be used for training. If LRScheduler, means using a custom learning scheduler. Note: custom learning scheduler is a class (not an instance), and its initialization list is (optimizer, **lr_sch_args).

  • lr_sch_args (dict, optional) – Additional arguments to be passed to the lr_scheduler function.

  • grad_acc (int) – Aradient accumulation.

  • batch_size (int) – Mini-batch size.

  • max_epochs (int) – Maximum number of epochs in training.

  • no_increase_epochs (int) – Maximum number of consecutive epochs when the accuracy or loss of the training set has no relative improvement.

  • var_check (str) – The best value (train_inacc/train_loss) to check while determining the best model which will be used to evaluate its performance on the test.

  • load_best_state (bool) – If True, the best model parameters will be used for evaluation.

  • nGPU (int) – Select the gpu id to train. If the GPU is not available then the CPU is used.

  • seed (int) – Select random seed for review.

  • keep_data_gpu (bool) – Keep the dataset on the GPU to avoid the time consumption of data migration. Please adjust according to the personal GPU memory.

  • data_size (tuple, list, optional) – Output the structure of the network model according to the input dimension if the data_size is given.

  • depth (int, optional) – Depth of nested layers to display.

  • verbose (int, str) – The log level of console. Default is INFO. Mainly used for debug.

fit(trainset: EEGData, testset: EEGData, log_dir: str) dict[str, dict[str, Tensor]][source]#

Fit the model.

Parameters:
  • trainset (EEGData) – Dataset used for training.

  • testset (EEGData) – Dataset used to evaluate the model.

  • log_dir (str) – The path to save the training log.

Returns:

result – Returns the training set and test set results (including true labels, predicted labels and accuracy).

Return type:

dict